from __future__ import absolute_import, division, print_function
from tornado.concurrent import Future
from tornado import gen
from tornado import netutil
from tornado.iostream import IOStream, SSLIOStream, PipeIOStream, StreamClosedError, _StreamBuffer
from tornado.httputil import HTTPHeaders
from tornado.locks import Condition, Event
from tornado.log import gen_log, app_log
from tornado.netutil import ssl_wrap_socket
from tornado.stack_context import NullContext
from tornado.tcpserver import TCPServer
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog, gen_test # noqa: E501
from tornado.test.util import (unittest, skipIfNonUnix, refusing_port, skipPypy3V58,
ignore_deprecation)
from tornado.web import RequestHandler, Application
import errno
import hashlib
import os
import platform
import random
import socket
import ssl
import sys
try:
from unittest import mock # type: ignore
except ImportError:
try:
import mock # type: ignore
except ImportError:
mock = None
def _server_ssl_options():
return dict(
certfile=os.path.join(os.path.dirname(__file__), 'test.crt'),
keyfile=os.path.join(os.path.dirname(__file__), 'test.key'),
)
class HelloHandler(RequestHandler):
def get(self):
self.write("Hello")
class TestIOStreamWebMixin(object):
def _make_client_iostream(self):
raise NotImplementedError()
def get_app(self):
return Application([('/', HelloHandler)])
def test_connection_closed(self):
# When a server sends a response and then closes the connection,
# the client must be allowed to read the data before the IOStream
# closes itself. Epoll reports closed connections with a separate
# EPOLLRDHUP event delivered at the same time as the read event,
# while kqueue reports them as a second read/write event with an EOF
# flag.
response = self.fetch("/", headers={"Connection": "close"})
response.rethrow()
@gen_test
def test_read_until_close(self):
stream = self._make_client_iostream()
yield stream.connect(('127.0.0.1', self.get_http_port()))
stream.write(b"GET / HTTP/1.0\r\n\r\n")
data = yield stream.read_until_close()
self.assertTrue(data.startswith(b"HTTP/1.1 200"))
self.assertTrue(data.endswith(b"Hello"))
@gen_test
def test_read_zero_bytes(self):
self.stream = self._make_client_iostream()
yield self.stream.connect(("127.0.0.1", self.get_http_port()))
self.stream.write(b"GET / HTTP/1.0\r\n\r\n")
# normal read
data = yield self.stream.read_bytes(9)
self.assertEqual(data, b"HTTP/1.1 ")
# zero bytes
data = yield self.stream.read_bytes(0)
self.assertEqual(data, b"")
# another normal read
data = yield self.stream.read_bytes(3)
self.assertEqual(data, b"200")
self.stream.close()
@gen_test
def test_write_while_connecting(self):
stream = self._make_client_iostream()
connect_fut = stream.connect(("127.0.0.1", self.get_http_port()))
# unlike the previous tests, try to write before the connection
# is complete.
write_fut = stream.write(b"GET / HTTP/1.0\r\nConnection: close\r\n\r\n")
self.assertFalse(connect_fut.done())
# connect will always complete before write.
it = gen.WaitIterator(connect_fut, write_fut)
resolved_order = []
while not it.done():
yield it.next()
resolved_order.append(it.current_future)
self.assertEqual(resolved_order, [connect_fut, write_fut])
data = yield stream.read_until_close()
self.assertTrue(data.endswith(b"Hello"))
stream.close()
@gen_test
def test_future_interface(self):
"""Basic test of IOStream's ability to return Futures."""
stream = self._make_client_iostream()
connect_result = yield stream.connect(
("127.0.0.1", self.get_http_port()))
self.assertIs(connect_result, stream)
yield stream.write(b"GET / HTTP/1.0\r\n\r\n")
first_line = yield stream.read_until(b"\r\n")
self.assertEqual(first_line, b"HTTP/1.1 200 OK\r\n")
# callback=None is equivalent to no callback.
header_data = yield stream.read_until(b"\r\n\r\n", callback=None)
headers = HTTPHeaders.parse(header_data.decode('latin1'))
content_length = int(headers['Content-Length'])
body = yield stream.read_bytes(content_length)
self.assertEqual(body, b'Hello')
stream.close()
@gen_test
def test_future_close_while_reading(self):
stream = self._make_client_iostream()
yield stream.connect(("127.0.0.1", self.get_http_port()))
yield stream.write(b"GET / HTTP/1.0\r\n\r\n")
with self.assertRaises(StreamClosedError):
yield stream.read_bytes(1024 * 1024)
stream.close()
@gen_test
def test_future_read_until_close(self):
# Ensure that the data comes through before the StreamClosedError.
stream = self._make_client_iostream()
yield stream.connect(("127.0.0.1", self.get_http_port()))
yield stream.write(b"GET / HTTP/1.0\r\nConnection: close\r\n\r\n")
yield stream.read_until(b"\r\n\r\n")
body = yield stream.read_until_close()
self.assertEqual(body, b"Hello")
# Nothing else to read; the error comes immediately without waiting
# for yield.
with self.assertRaises(StreamClosedError):
stream.read_bytes(1)
class TestReadWriteMixin(object):
# Tests where one stream reads and the other writes.
# These should work for BaseIOStream implementations.
def make_iostream_pair(self, **kwargs):
raise NotImplementedError
@gen_test
def test_write_zero_bytes(self):
# Attempting to write zero bytes should run the callback without
# going into an infinite loop.
rs, ws = yield self.make_iostream_pair()
yield ws.write(b'')
ws.close()
rs.close()
@gen_test
def test_streaming_callback(self):
rs, ws = yield self.make_iostream_pair()
try:
chunks = []
cond = Condition()
def streaming_callback(data):
chunks.append(data)
cond.notify()
with ignore_deprecation():
fut = rs.read_bytes(6, streaming_callback=streaming_callback)
ws.write(b"1234")
while not chunks:
yield cond.wait()
ws.write(b"5678")
final_data = yield(fut)
self.assertFalse(final_data)
self.assertEqual(chunks, [b"1234", b"56"])
# the rest of the last chunk is still in the buffer
data = yield rs.read_bytes(2)
self.assertEqual(data, b"78")
finally:
rs.close()
ws.close()
@gen_test
def test_streaming_callback_with_final_callback(self):
rs, ws = yield self.make_iostream_pair()
try:
chunks = []
final_called = []
cond = Condition()
def streaming_callback(data):
chunks.append(data)
cond.notify()
def final_callback(data):
self.assertFalse(data)
final_called.append(True)
cond.notify()
with ignore_deprecation():
rs.read_bytes(6, callback=final_callback,
streaming_callback=streaming_callback)
ws.write(b"1234")
while not chunks:
yield cond.wait()
ws.write(b"5678")
while not final_called:
yield cond.wait()
self.assertEqual(chunks, [b"1234", b"56"])
# the rest of the last chunk is still in the buffer
data = yield rs.read_bytes(2)
self.assertEqual(data, b"78")
finally:
rs.close()
ws.close()
@gen_test
def test_streaming_callback_with_data_in_buffer(self):
rs, ws = yield self.make_iostream_pair()
ws.write(b"abcd\r\nefgh")
data = yield rs.read_until(b"\r\n")
self.assertEqual(data, b"abcd\r\n")
streaming_fut = Future()
with ignore_deprecation():
rs.read_until_close(streaming_callback=streaming_fut.set_result)
data = yield streaming_fut
self.assertEqual(data, b"efgh")
rs.close()
ws.close()
@gen_test
def test_streaming_until_close(self):
rs, ws = yield self.make_iostream_pair()
try:
chunks = []
closed = [False]
cond = Condition()
def streaming_callback(data):
chunks.append(data)
cond.notify()
def close_callback(data):
assert not data, data
closed[0] = True
cond.notify()
with ignore_deprecation():
rs.read_until_close(callback=close_callback,
streaming_callback=streaming_callback)
ws.write(b"1234")
while len(chunks) != 1:
yield cond.wait()
yield ws.write(b"5678")
ws.close()
while not closed[0]:
yield cond.wait()
self.assertEqual(chunks, [b"1234", b"5678"])
finally:
ws.close()
rs.close()
@gen_test
def test_streaming_until_close_future(self):
rs, ws = yield self.make_iostream_pair()
try:
chunks = []
@gen.coroutine
def rs_task():
with ignore_deprecation():
yield rs.read_until_close(streaming_callback=chunks.append)
@gen.coroutine
def ws_task():
yield ws.write(b"1234")
yield gen.sleep(0.01)
yield ws.write(b"5678")
ws.close()
yield [rs_task(), ws_task()]
self.assertEqual(chunks, [b"1234", b"5678"])
finally:
ws.close()
rs.close()
@gen_test
def test_delayed_close_callback(self):
# The scenario: Server closes the connection while there is a pending
# read that can be served out of buffered data. The client does not
# run the close_callback as soon as it detects the close, but rather
# defers it until after the buffered read has finished.
rs, ws = yield self.make_iostream_pair()
try:
event = Event()
rs.set_close_callback(event.set)
ws.write(b"12")
chunks = []
def callback1(data):
chunks.append(data)
with ignore_deprecation():
rs.read_bytes(1, callback2)
ws.close()
def callback2(data):
chunks.append(data)
with ignore_deprecation():
rs.read_bytes(1, callback1)
yield event.wait() # stopped by close_callback
self.assertEqual(chunks, [b"1", b"2"])
finally:
ws.close()
rs.close()
@gen_test
def test_future_delayed_close_callback(self):
# Same as test_delayed_close_callback, but with the future interface.
rs, ws = yield self.make_iostream_pair()
try:
ws.write(b"12")
chunks = []
chunks.append((yield rs.read_bytes(1)))
ws.close()
chunks.append((yield rs.read_bytes(1)))
self.assertEqual(chunks, [b"1", b"2"])
finally:
ws.close()
rs.close()
@gen_test
def test_close_buffered_data(self):
# Similar to the previous test, but with data stored in the OS's
# socket buffers instead of the IOStream's read buffer. Out-of-band
# close notifications must be delayed until all data has been
# drained into the IOStream buffer. (epoll used to use out-of-band
# close events with EPOLLRDHUP, but no longer)
#
# This depends on the read_chunk_size being smaller than the
# OS socket buffer, so make it small.
rs, ws = yield self.make_iostream_pair(read_chunk_size=256)
try:
ws.write(b"A" * 512)
data = yield rs.read_bytes(256)
self.assertEqual(b"A" * 256, data)
ws.close()
# Allow the close to propagate to the `rs` side of the
# connection. Using add_callback instead of add_timeout
# doesn't seem to work, even with multiple iterations
yield gen.sleep(0.01)
data = yield rs.read_bytes(256)
self.assertEqual(b"A" * 256, data)
finally:
ws.close()
rs.close()
@gen_test
def test_read_until_close_after_close(self):
# Similar to test_delayed_close_callback, but read_until_close takes
# a separate code path so test it separately.
rs, ws = yield self.make_iostream_pair()
try:
ws.write(b"1234")
ws.close()
# Read one byte to make sure the client has received the data.
# It won't run the close callback as long as there is more buffered
# data that could satisfy a later read.
data = yield rs.read_bytes(1)
self.assertEqual(data, b"1")
data = yield rs.read_until_close()
self.assertEqual(data, b"234")
finally:
ws.close()
rs.close()
@gen_test
def test_streaming_read_until_close_after_close(self):
# Same as the preceding test but with a streaming_callback.
# All data should go through the streaming callback,
# and the final read callback just gets an empty string.
rs, ws = yield self.make_iostream_pair()
try:
ws.write(b"1234")
ws.close()
data = yield rs.read_bytes(1)
self.assertEqual(data, b"1")
streaming_data = []
final_future = Future()
with ignore_deprecation():
rs.read_until_close(final_future.set_result,
streaming_callback=streaming_data.append)
final_data = yield final_future
self.assertEqual(b'', final_data)
self.assertEqual(b''.join(streaming_data), b"234")
finally:
ws.close()
rs.close()
@gen_test
def test_large_read_until(self):
# Performance test: read_until used to have a quadratic component
# so a read_until of 4MB would take 8 seconds; now it takes 0.25
# seconds.
rs, ws = yield self.make_iostream_pair()
try:
# This test fails on pypy with ssl. I think it's because
# pypy's gc defeats moves objects, breaking the
# "frozen write buffer" assumption.
if (isinstance(rs, SSLIOStream) and
platform.python_implementation() == 'PyPy'):
raise unittest.SkipTest(
"pypy gc causes problems with openssl")
NUM_KB = 4096
for i in range(NUM_KB):
ws.write(b"A" * 1024)
ws.write(b"\r\n")
data = yield rs.read_until(b"\r\n")
self.assertEqual(len(data), NUM_KB * 1024 + 2)
finally:
ws.close()
rs.close()
@gen_test
def test_close_callback_with_pending_read(self):
# Regression test for a bug that was introduced in 2.3
# where the IOStream._close_callback would never be called
# if there were pending reads.
OK = b"OK\r\n"
rs, ws = yield self.make_iostream_pair()
event = Event()
rs.set_close_callback(event.set)
try:
ws.write(OK)
res = yield rs.read_until(b"\r\n")
self.assertEqual(res, OK)
ws.close()
rs.read_until(b"\r\n")
# If _close_callback (self.stop) is not called,
# an AssertionError: Async operation timed out after 5 seconds
# will be raised.
yield event.wait()
finally:
ws.close()
rs.close()
@gen_test
def test_future_close_callback(self):
# Regression test for interaction between the Future read interfaces
# and IOStream._maybe_add_error_listener.
rs, ws = yield self.make_iostream_pair()
closed = [False]
cond = Condition()
def close_callback():
closed[0] = True
cond.notify()
rs.set_close_callback(close_callback)
try:
ws.write(b'a')
res = yield rs.read_bytes(1)
self.assertEqual(res, b'a')
self.assertFalse(closed[0])
ws.close()
yield cond.wait()
self.assertTrue(closed[0])
finally:
rs.close()
ws.close()
@gen_test
def test_write_memoryview(self):
rs, ws = yield self.make_iostream_pair()
try:
fut = rs.read_bytes(4)
ws.write(memoryview(b"hello"))
data = yield fut
self.assertEqual(data, b"hell")
finally:
ws.close()
rs.close()
@gen_test
def test_read_bytes_partial(self):
rs, ws = yield self.make_iostream_pair()
try:
# Ask for more than is available with partial=True
fut = rs.read_bytes(50, partial=True)
ws.write(b"hello")
data = yield fut
self.assertEqual(data, b"hello")
# Ask for less than what is available; num_bytes is still
# respected.
fut = rs.read_bytes(3, partial=True)
ws.write(b"world")
data = yield fut
self.assertEqual(data, b"wor")
# Partial reads won't return an empty string, but read_bytes(0)
# will.
data = yield rs.read_bytes(0, partial=True)
self.assertEqual(data, b'')
finally:
ws.close()
rs.close()
@gen_test
def test_read_until_max_bytes(self):
rs, ws = yield self.make_iostream_pair()
closed = Event()
rs.set_close_callback(closed.set)
try:
# Extra room under the limit
fut = rs.read_until(b"def", max_bytes=50)
ws.write(b"abcdef")
data = yield fut
self.assertEqual(data, b"abcdef")
# Just enough space
fut = rs.read_until(b"def", max_bytes=6)
ws.write(b"abcdef")
data = yield fut
self.assertEqual(data, b"abcdef")
# Not enough space, but we don't know it until all we can do is
# log a warning and close the connection.
with ExpectLog(gen_log, "Unsatisfiable read"):
fut = rs.read_until(b"def", max_bytes=5)
ws.write(b"123456")
yield closed.wait()
finally:
ws.close()
rs.close()
@gen_test
def test_read_until_max_bytes_inline_legacy(self):
rs, ws = yield self.make_iostream_pair()
closed = Event()
rs.set_close_callback(closed.set)
try:
# Similar to the error case in the previous test, but the
# ws writes first so rs reads are satisfied
# inline. For consistency with the out-of-line case, we
# do not raise the error synchronously.
ws.write(b"123456")
with ExpectLog(gen_log, "Unsatisfiable read"):
with ignore_deprecation():
rs.read_until(b"def", callback=lambda x: self.fail(), max_bytes=5)
yield closed.wait()
finally:
ws.close()
rs.close()
@gen_test
def test_read_until_max_bytes_inline(self):
rs, ws = yield self.make_iostream_pair()
closed = Event()
rs.set_close_callback(closed.set)
try:
# Similar to the error case in the previous test, but the
# ws writes first so rs reads are satisfied
# inline. For consistency with the out-of-line case, we
# do not raise the error synchronously.
ws.write(b"123456")
with ExpectLog(gen_log, "Unsatisfiable read"):
with self.assertRaises(StreamClosedError):
yield rs.read_until(b"def", max_bytes=5)
yield closed.wait()
finally:
ws.close()
rs.close()
@gen_test
def test_read_until_max_bytes_ignores_extra(self):
rs, ws = yield self.make_iostream_pair()
closed = Event()
rs.set_close_callback(closed.set)
try:
# Even though data that matches arrives the same packet that
# puts us over the limit, we fail the request because it was not
# found within the limit.
ws.write(b"abcdef")
with ExpectLog(gen_log, "Unsatisfiable read"):
rs.read_until(b"def", max_bytes=5)
yield closed.wait()
finally:
ws.close()
rs.close()
@gen_test
def test_read_until_regex_max_bytes(self):
rs, ws = yield self.make_iostream_pair()
closed = Event()
rs.set_close_callback(closed.set)
try:
# Extra room under the limit
fut = rs.read_until_regex(b"def", max_bytes=50)
ws.write(b"abcdef")
data = yield fut
self.assertEqual(data, b"abcdef")
# Just enough space
fut = rs.read_until_regex(b"def", max_bytes=6)
ws.write(b"abcdef")
data = yield fut
self.assertEqual(data, b"abcdef")
# Not enough space, but we don't know it until all we can do is
# log a warning and close the connection.
with ExpectLog(gen_log, "Unsatisfiable read"):
rs.read_until_regex(b"def", max_bytes=5)
ws.write(b"123456")
yield closed.wait()
finally:
ws.close()
rs.close()
@gen_test
def test_read_until_regex_max_bytes_inline(self):
rs, ws = yield self.make_iostream_pair()
closed = Event()
rs.set_close_callback(closed.set)
try:
# Similar to the error case in the previous test, but the
# ws writes first so rs reads are satisfied
# inline. For consistency with the out-of-line case, we
# do not raise the error synchronously.
ws.write(b"123456")
with ExpectLog(gen_log, "Unsatisfiable read"):
rs.read_until_regex(b"def", max_bytes=5)
yield closed.wait()
finally:
ws.close()
rs.close()
@gen_test
def test_read_until_regex_max_bytes_ignores_extra(self):
rs, ws = yield self.make_iostream_pair()
closed = Event()
rs.set_close_callback(closed.set)
try:
# Even though data that matches arrives the same packet that
# puts us over the limit, we fail the request because it was not
# found within the limit.
ws.write(b"abcdef")
with ExpectLog(gen_log, "Unsatisfiable read"):
rs.read_until_regex(b"def", max_bytes=5)
yield closed.wait()
finally:
ws.close()
rs.close()
@gen_test
def test_small_reads_from_large_buffer(self):
# 10KB buffer size, 100KB available to read.
# Read 1KB at a time and make sure that the buffer is not eagerly
# filled.
rs, ws = yield self.make_iostream_pair(max_buffer_size=10 * 1024)
try:
ws.write(b"a" * 1024 * 100)
for i in range(100):
data = yield rs.read_bytes(1024)
self.assertEqual(data, b"a" * 1024)
finally:
ws.close()
rs.close()
@gen_test
def test_small_read_untils_from_large_buffer(self):
# 10KB buffer size, 100KB available to read.
# Read 1KB at a time and make sure that the buffer is not eagerly
# filled.
rs, ws = yield self.make_iostream_pair(max_buffer_size=10 * 1024)
try:
ws.write((b"a" * 1023 + b"\n") * 100)
for i in range(100):
data = yield rs.read_until(b"\n", max_bytes=4096)
self.assertEqual(data, b"a" * 1023 + b"\n")
finally:
ws.close()
rs.close()
@gen_test
def test_flow_control(self):
MB = 1024 * 1024
rs, ws = yield self.make_iostream_pair(max_buffer_size=5 * MB)
try:
# Client writes more than the rs will accept.
ws.write(b"a" * 10 * MB)
# The rs pauses while reading.
yield rs.read_bytes(MB)
yield gen.sleep(0.1)
# The ws's writes have been blocked; the rs can
# continue to read gradually.
for i in range(9):
yield rs.read_bytes(MB)
finally:
rs.close()
ws.close()
@gen_test
def test_read_into(self):
rs, ws = yield self.make_iostream_pair()
def sleep_some():
self.io_loop.run_sync(lambda: gen.sleep(0.05))
try:
buf = bytearray(10)
fut = rs.read_into(buf)
ws.write(b"hello")
yield gen.sleep(0.05)
self.assertTrue(rs.reading())
ws.write(b"world!!")
data = yield fut
self.assertFalse(rs.reading())
self.assertEqual(data, 10)
self.assertEqual(bytes(buf), b"helloworld")
# Existing buffer is fed into user buffer
fut = rs.read_into(buf)
yield gen.sleep(0.05)
self.assertTrue(rs.reading())
ws.write(b"1234567890")
data = yield fut
self.assertFalse(rs.reading())
self.assertEqual(data, 10)
self.assertEqual(bytes(buf), b"!!12345678")
# Existing buffer can satisfy read immediately
buf = bytearray(4)
ws.write(b"abcdefghi")
data = yield rs.read_into(buf)
self.assertEqual(data, 4)
self.assertEqual(bytes(buf), b"90ab")
data = yield rs.read_bytes(7)
self.assertEqual(data, b"cdefghi")
finally:
ws.close()
rs.close()
@gen_test
def test_read_into_partial(self):
rs, ws = yield self.make_iostream_pair()
try:
# Partial read
buf = bytearray(10)
fut = rs.read_into(buf, partial=True)
ws.write(b"hello")
data = yield fut
self.assertFalse(rs.reading())
self.assertEqual(data, 5)
self.assertEqual(bytes(buf), b"hello\0\0\0\0\0")
# Full read despite partial=True
ws.write(b"world!1234567890")
data = yield rs.read_into(buf, partial=True)
self.assertEqual(data, 10)
self.assertEqual(bytes(buf), b"world!1234")
# Existing buffer can satisfy read immediately
data = yield rs.read_into(buf, partial=True)
self.assertEqual(data, 6)
self.assertEqual(bytes(buf), b"5678901234")
finally:
ws.close()
rs.close()
@gen_test
def test_read_into_zero_bytes(self):
rs, ws = yield self.make_iostream_pair()
try:
buf = bytearray()
fut = rs.read_into(buf)
self.assertEqual(fut.result(), 0)
finally:
ws.close()
rs.close()
@gen_test
def test_many_mixed_reads(self):
# Stress buffer handling when going back and forth between
# read_bytes() (using an internal buffer) and read_into()
# (using a user-allocated buffer).
r = random.Random(42)
nbytes = 1000000
rs, ws = yield self.make_iostream_pair()
produce_hash = hashlib.sha1()
consume_hash = hashlib.sha1()
@gen.coroutine
def produce():
remaining = nbytes
while remaining > 0:
size = r.randint(1, min(1000, remaining))
data = os.urandom(size)
produce_hash.update(data)
yield ws.write(data)
remaining -= size
assert remaining == 0
@gen.coroutine
def consume():
remaining = nbytes
while remaining > 0:
if r.random() > 0.5:
# read_bytes()
size = r.randint(1, min(1000, remaining))
data = yield rs.read_bytes(size)
consume_hash.update(data)
remaining -= size
else:
# read_into()
size = r.randint(1, min(1000, remaining))
buf = bytearray(size)
n = yield rs.read_into(buf)
assert n == size
consume_hash.update(buf)
remaining -= size
assert remaining == 0
try:
yield [produce(), consume()]
assert produce_hash.hexdigest() == consume_hash.hexdigest()
finally:
ws.close()
rs.close()
class TestIOStreamMixin(TestReadWriteMixin):
def _make_server_iostream(self, connection, **kwargs):
raise NotImplementedError()
def _make_client_iostream(self, connection, **kwargs):
raise NotImplementedError()
@gen.coroutine
def make_iostream_pair(self, **kwargs):
listener, port = bind_unused_port()
server_stream_fut = Future()
def accept_callback(connection, address):
server_stream_fut.set_result(self._make_server_iostream(connection, **kwargs))
netutil.add_accept_handler(listener, accept_callback)
client_stream = self._make_client_iostream(socket.socket(), **kwargs)
connect_fut = client_stream.connect(('127.0.0.1', port))
server_stream, client_stream = yield [server_stream_fut, connect_fut]
self.io_loop.remove_handler(listener.fileno())
listener.close()
raise gen.Return((server_stream, client_stream))
def test_connection_refused_legacy(self):
# When a connection is refused, the connect callback should not
# be run. (The kqueue IOLoop used to behave differently from the
# epoll IOLoop in this respect)
cleanup_func, port = refusing_port()
self.addCleanup(cleanup_func)
stream = IOStream(socket.socket())
self.connect_called = False
def connect_callback():
self.connect_called = True
self.stop()
stream.set_close_callback(self.stop)
# log messages vary by platform and ioloop implementation
with ExpectLog(gen_log, ".*", required=False):
with ignore_deprecation():
stream.connect(("127.0.0.1", port), connect_callback)
self.wait()
self.assertFalse(self.connect_called)
self.assertTrue(isinstance(stream.error, socket.error), stream.error)
if sys.platform != 'cygwin':
_ERRNO_CONNREFUSED = (errno.ECONNREFUSED,)
if hasattr(errno, "WSAECONNREFUSED"):
_ERRNO_CONNREFUSED += (errno.WSAECONNREFUSED,)
# cygwin's errnos don't match those used on native windows python
self.assertTrue(stream.error.args[0] in _ERRNO_CONNREFUSED)
@gen_test
def test_connection_refused(self):
# When a connection is refused, the connect callback should not
# be run. (The kqueue IOLoop used to behave differently from the
# epoll IOLoop in this respect)
cleanup_func, port = refusing_port()
self.addCleanup(cleanup_func)
stream = IOStream(socket.socket())
stream.set_close_callback(self.stop)
# log messages vary by platform and ioloop implementation
with ExpectLog(gen_log, ".*", required=False):
with self.assertRaises(StreamClosedError):
yield stream.connect(("127.0.0.1", port))
self.assertTrue(isinstance(stream.error, socket.error), stream.error)
if sys.platform != 'cygwin':
_ERRNO_CONNREFUSED = (errno.ECONNREFUSED,)
if hasattr(errno, "WSAECONNREFUSED"):
_ERRNO_CONNREFUSED += (errno.WSAECONNREFUSED,)
# cygwin's errnos don't match those used on native windows python
self.assertTrue(stream.error.args[0] in _ERRNO_CONNREFUSED)
@unittest.skipIf(mock is None, 'mock package not present')
@gen_test
def test_gaierror(self):
# Test that IOStream sets its exc_info on getaddrinfo error.
# It's difficult to reliably trigger a getaddrinfo error;
# some resolvers own't even return errors for malformed names,
# so we mock it instead. If IOStream changes to call a Resolver
# before sock.connect, the mock target will need to change too.
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
stream = IOStream(s)
stream.set_close_callback(self.stop)
with mock.patch('socket.socket.connect',
side_effect=socket.gaierror(errno.EIO, 'boom')):
with self.assertRaises(StreamClosedError):
yield stream.connect(('localhost', 80))
self.assertTrue(isinstance(stream.error, socket.gaierror))
@gen_test
def test_read_callback_error(self):
# Test that IOStream sets its exc_info when a read callback throws
server, client = yield self.make_iostream_pair()
try:
closed = Event()
server.set_close_callback(closed.set)
with ExpectLog(
app_log, "(Uncaught exception|Exception in callback)"
):
# Clear ExceptionStackContext so IOStream catches error
with NullContext():
with ignore_deprecation():
server.read_bytes(1, callback=lambda data: 1 / 0)
client.write(b"1")
yield closed.wait()
self.assertTrue(isinstance(server.error, ZeroDivisionError))
finally:
server.close()
client.close()
@unittest.skipIf(mock is None, 'mock package not present')
@gen_test
def test_read_until_close_with_error(self):
server, client = yield self.make_iostream_pair()
try:
with mock.patch('tornado.iostream.BaseIOStream._try_inline_read',
side_effect=IOError('boom')):
with self.assertRaisesRegexp(IOError, 'boom'):
client.read_until_close()
finally:
server.close()
client.close()
@skipIfNonUnix
@skipPypy3V58
@gen_test
def test_inline_read_error(self):
# An error on an inline read is raised without logging (on the
# assumption that it will eventually be noticed or logged further
# up the stack).
#
# This test is posix-only because windows os.close() doesn't work
# on socket FDs, but we can't close the socket object normally
# because we won't get the error we want if the socket knows
# it's closed.
server, client = yield self.make_iostream_pair()
try:
os.close(server.socket.fileno())
with self.assertRaises(socket.error):
server.read_bytes(1)
finally:
server.close()
client.close()
@skipPypy3V58
@gen_test
def test_async_read_error_logging(self):
# Socket errors on asynchronous reads should be logged (but only
# once).
server, client = yield self.make_iostream_pair()
closed = Event()
server.set_close_callback(closed.set)
try:
# Start a read that will be fulfilled asynchronously.
server.read_bytes(1)
client.write(b'a')
# Stub out read_from_fd to make it fail.
def fake_read_from_fd():
os.close(server.socket.fileno())
server.__class__.read_from_fd(server)
server.read_from_fd = fake_read_from_fd
# This log message is from _handle_read (not read_from_fd).
with ExpectLog(gen_log, "error on read"):
yield closed.wait()
finally:
server.close()
client.close()
@gen_test
def test_future_write(self):
"""
Test that write() Futures are never orphaned.
"""
# Run concurrent writers that will write enough bytes so as to
# clog the socket buffer and accumulate bytes in our write buffer.
m, n = 5000, 1000
nproducers = 10
total_bytes = m * n * nproducers
server, client = yield self.make_iostream_pair(max_buffer_size=total_bytes)
@gen.coroutine
def produce():
data = b'x' * m
for i in range(n):
yield server.write(data)
@gen.coroutine
def consume():
nread = 0
while nread < total_bytes:
res = yield client.read_bytes(m)
nread += len(res)
try:
yield [produce() for i in range(nproducers)] + [consume()]
finally:
server.close()
client.close()
class TestIOStreamWebHTTP(TestIOStreamWebMixin, AsyncHTTPTestCase):
def _make_client_iostream(self):
return IOStream(socket.socket())
class TestIOStreamWebHTTPS(TestIOStreamWebMixin, AsyncHTTPSTestCase):
def _make_client_iostream(self):
return SSLIOStream(socket.socket(),
ssl_options=dict(cert_reqs=ssl.CERT_NONE))
class TestIOStream(TestIOStreamMixin, AsyncTestCase):
def _make_server_iostream(self, connection, **kwargs):
return IOStream(connection, **kwargs)
def _make_client_iostream(self, connection, **kwargs):
return IOStream(connection, **kwargs)
class TestIOStreamSSL(TestIOStreamMixin, AsyncTestCase):
def _make_server_iostream(self, connection, **kwargs):
connection = ssl.wrap_socket(connection,
server_side=True,
do_handshake_on_connect=False,
**_server_ssl_options())
return SSLIOStream(connection, **kwargs)
def _make_client_iostream(self, connection, **kwargs):
return SSLIOStream(connection,
ssl_options=dict(cert_reqs=ssl.CERT_NONE),
**kwargs)
# This will run some tests that are basically redundant but it's the
# simplest way to make sure that it works to pass an SSLContext
# instead of an ssl_options dict to the SSLIOStream constructor.
class TestIOStreamSSLContext(TestIOStreamMixin, AsyncTestCase):
def _make_server_iostream(self, connection, **kwargs):
context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
context.load_cert_chain(
os.path.join(os.path.dirname(__file__), 'test.crt'),
os.path.join(os.path.dirname(__file__), 'test.key'))
connection = ssl_wrap_socket(connection, context,
server_side=True,
do_handshake_on_connect=False)
return SSLIOStream(connection, **kwargs)
def _make_client_iostream(self, connection, **kwargs):
context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
return SSLIOStream(connection, ssl_options=context, **kwargs)
class TestIOStreamStartTLS(AsyncTestCase):
def setUp(self):
try:
super(TestIOStreamStartTLS, self).setUp()
self.listener, self.port = bind_unused_port()
self.server_stream = None
self.server_accepted = Future()
netutil.add_accept_handler(self.listener, self.accept)
self.client_stream = IOStream(socket.socket())
self.io_loop.add_future(self.client_stream.connect(
('127.0.0.1', self.port)), self.stop)
self.wait()
self.io_loop.add_future(self.server_accepted, self.stop)
self.wait()
except Exception as e:
print(e)
raise
def tearDown(self):
if self.server_stream is not None:
self.server_stream.close()
if self.client_stream is not None:
self.client_stream.close()
self.listener.close()
super(TestIOStreamStartTLS, self).tearDown()
def accept(self, connection, address):
if self.server_stream is not None:
self.fail("should only get one connection")
self.server_stream = IOStream(connection)
self.server_accepted.set_result(None)
@gen.coroutine
def client_send_line(self, line):
self.client_stream.write(line)
recv_line = yield self.server_stream.read_until(b"\r\n")
self.assertEqual(line, recv_line)
@gen.coroutine
def server_send_line(self, line):
self.server_stream.write(line)
recv_line = yield self.client_stream.read_until(b"\r\n")
self.assertEqual(line, recv_line)
def client_start_tls(self, ssl_options=None, server_hostname=None):
client_stream = self.client_stream
self.client_stream = None
return client_stream.start_tls(False, ssl_options, server_hostname)
def server_start_tls(self, ssl_options=None):
server_stream = self.server_stream
self.server_stream = None
return server_stream.start_tls(True, ssl_options)
@gen_test
def test_start_tls_smtp(self):
# This flow is simplified from RFC 3207 section 5.
# We don't really need all of this, but it helps to make sure
# that after realistic back-and-forth traffic the buffers end up
# in a sane state.
yield self.server_send_line(b"220 mail.example.com ready\r\n")
yield self.client_send_line(b"EHLO mail.example.com\r\n")
yield self.server_send_line(b"250-mail.example.com welcome\r\n")
yield self.server_send_line(b"250 STARTTLS\r\n")
yield self.client_send_line(b"STARTTLS\r\n")
yield self.server_send_line(b"220 Go ahead\r\n")
client_future = self.client_start_tls(dict(cert_reqs=ssl.CERT_NONE))
server_future = self.server_start_tls(_server_ssl_options())
self.client_stream = yield client_future
self.server_stream = yield server_future
self.assertTrue(isinstance(self.client_stream, SSLIOStream))
self.assertTrue(isinstance(self.server_stream, SSLIOStream))
yield self.client_send_line(b"EHLO mail.example.com\r\n")
yield self.server_send_line(b"250 mail.example.com welcome\r\n")
@gen_test
def test_handshake_fail(self):
server_future = self.server_start_tls(_server_ssl_options())
# Certificates are verified with the default configuration.
client_future = self.client_start_tls(server_hostname="localhost")
with ExpectLog(gen_log, "SSL Error"):
with self.assertRaises(ssl.SSLError):
yield client_future
with self.assertRaises((ssl.SSLError, socket.error)):
yield server_future
@gen_test
def test_check_hostname(self):
# Test that server_hostname parameter to start_tls is being used.
# The check_hostname functionality is only available in python 2.7 and
# up and in python 3.4 and up.
server_future = self.server_start_tls(_server_ssl_options())
client_future = self.client_start_tls(
ssl.create_default_context(),
server_hostname='127.0.0.1')
with ExpectLog(gen_log, "SSL Error"):
with self.assertRaises(ssl.SSLError):
# The client fails to connect with an SSL error.
yield client_future
with self.assertRaises(Exception):
# The server fails to connect, but the exact error is unspecified.
yield server_future
class WaitForHandshakeTest(AsyncTestCase):
@gen.coroutine
def connect_to_server(self, server_cls):
server = client = None
try:
sock, port = bind_unused_port()
server = server_cls(ssl_options=_server_ssl_options())
server.add_socket(sock)
client = SSLIOStream(socket.socket(),
ssl_options=dict(cert_reqs=ssl.CERT_NONE))
yield client.connect(('127.0.0.1', port))
self.assertIsNotNone(client.socket.cipher())
finally:
if server is not None:
server.stop()
if client is not None:
client.close()
@gen_test
def test_wait_for_handshake_callback(self):
test = self
handshake_future = Future()
class TestServer(TCPServer):
def handle_stream(self, stream, address):
# The handshake has not yet completed.
test.assertIsNone(stream.socket.cipher())
self.stream = stream
with ignore_deprecation():
stream.wait_for_handshake(self.handshake_done)
def handshake_done(self):
# Now the handshake is done and ssl information is available.
test.assertIsNotNone(self.stream.socket.cipher())
handshake_future.set_result(None)
yield self.connect_to_server(TestServer)
yield handshake_future
@gen_test
def test_wait_for_handshake_future(self):
test = self
handshake_future = Future()
class TestServer(TCPServer):
def handle_stream(self, stream, address):
test.assertIsNone(stream.socket.cipher())
test.io_loop.spawn_callback(self.handle_connection, stream)
@gen.coroutine
def handle_connection(self, stream):
yield stream.wait_for_handshake()
handshake_future.set_result(None)
yield self.connect_to_server(TestServer)
yield handshake_future
@gen_test
def test_wait_for_handshake_already_waiting_error(self):
test = self
handshake_future = Future()
class TestServer(TCPServer):
@gen.coroutine
def handle_stream(self, stream, address):
fut = stream.wait_for_handshake()
test.assertRaises(RuntimeError, stream.wait_for_handshake)
yield fut
handshake_future.set_result(None)
yield self.connect_to_server(TestServer)
yield handshake_future
@gen_test
def test_wait_for_handshake_already_connected(self):
handshake_future = Future()
class TestServer(TCPServer):
@gen.coroutine
def handle_stream(self, stream, address):
yield stream.wait_for_handshake()
yield stream.wait_for_handshake()
handshake_future.set_result(None)
yield self.connect_to_server(TestServer)
yield handshake_future
@skipIfNonUnix
class TestPipeIOStream(TestReadWriteMixin, AsyncTestCase):
@gen.coroutine
def make_iostream_pair(self, **kwargs):
r, w = os.pipe()
return PipeIOStream(r, **kwargs), PipeIOStream(w, **kwargs)
@gen_test
def test_pipe_iostream(self):
rs, ws = yield self.make_iostream_pair()
ws.write(b"hel")
ws.write(b"lo world")
data = yield rs.read_until(b' ')
self.assertEqual(data, b"hello ")
data = yield rs.read_bytes(3)
self.assertEqual(data, b"wor")
ws.close()
data = yield rs.read_until_close()
self.assertEqual(data, b"ld")
rs.close()
@gen_test
def test_pipe_iostream_big_write(self):
rs, ws = yield self.make_iostream_pair()
NUM_BYTES = 1048576
# Write 1MB of data, which should fill the buffer
ws.write(b"1" * NUM_BYTES)
data = yield rs.read_bytes(NUM_BYTES)
self.assertEqual(data, b"1" * NUM_BYTES)
ws.close()
rs.close()
class TestStreamBuffer(unittest.TestCase):
"""
Unit tests for the private _StreamBuffer class.
"""
def setUp(self):
self.random = random.Random(42)
def to_bytes(self, b):
if isinstance(b, (bytes, bytearray)):
return bytes(b)
elif isinstance(b, memoryview):
return b.tobytes() # For py2
else:
raise TypeError(b)
def make_streambuffer(self, large_buf_threshold=10):
buf = _StreamBuffer()
assert buf._large_buf_threshold
buf._large_buf_threshold = large_buf_threshold
return buf
def check_peek(self, buf, expected):
size = 1
while size < 2 * len(expected):
got = self.to_bytes(buf.peek(size))
self.assertTrue(got) # Not empty
self.assertLessEqual(len(got), size)
self.assertTrue(expected.startswith(got), (expected, got))
size = (size * 3 + 1) // 2
def check_append_all_then_skip_all(self, buf, objs, input_type):
self.assertEqual(len(buf), 0)
expected = b''
for o in objs:
expected += o
buf.append(input_type(o))
self.assertEqual(len(buf), len(expected))
self.check_peek(buf, expected)
while expected:
n = self.random.randrange(1, len(expected) + 1)
expected = expected[n:]
buf.advance(n)
self.assertEqual(len(buf), len(expected))
self.check_peek(buf, expected)
self.assertEqual(len(buf), 0)
def test_small(self):
objs = [b'12', b'345', b'67', b'89a', b'bcde', b'fgh', b'ijklmn']
buf = self.make_streambuffer()
self.check_append_all_then_skip_all(buf, objs, bytes)
buf = self.make_streambuffer()
self.check_append_all_then_skip_all(buf, objs, bytearray)
buf = self.make_streambuffer()
self.check_append_all_then_skip_all(buf, objs, memoryview)
# Test internal algorithm
buf = self.make_streambuffer(10)
for i in range(9):
buf.append(b'x')
self.assertEqual(len(buf._buffers), 1)
for i in range(9):
buf.append(b'x')
self.assertEqual(len(buf._buffers), 2)
buf.advance(10)
self.assertEqual(len(buf._buffers), 1)
buf.advance(8)
self.assertEqual(len(buf._buffers), 0)
self.assertEqual(len(buf), 0)
def test_large(self):
objs = [b'12' * 5,
b'345' * 2,
b'67' * 20,
b'89a' * 12,
b'bcde' * 1,
b'fgh' * 7,
b'ijklmn' * 2]
buf = self.make_streambuffer()
self.check_append_all_then_skip_all(buf, objs, bytes)
buf = self.make_streambuffer()
self.check_append_all_then_skip_all(buf, objs, bytearray)
buf = self.make_streambuffer()
self.check_append_all_then_skip_all(buf, objs, memoryview)
# Test internal algorithm
buf = self.make_streambuffer(10)
for i in range(3):
buf.append(b'x' * 11)
self.assertEqual(len(buf._buffers), 3)
buf.append(b'y')
self.assertEqual(len(buf._buffers), 4)
buf.append(b'z')
self.assertEqual(len(buf._buffers), 4)
buf.advance(33)
self.assertEqual(len(buf._buffers), 1)
buf.advance(2)
self.assertEqual(len(buf._buffers), 0)
self.assertEqual(len(buf), 0)