Viewing file: test_protocol.py (18.09 KB) -rw-r--r-- Select action/file-type: (+) | (+) | (+) | Code (+) | Session (+) | (+) | SDB (+) | (+) | (+) | (+) | (+) | (+) |
# Copyright (c) Twisted Matrix Laboratories. # See LICENSE for details.
""" Tests for L{twisted.internet.protocol}. """
from io import BytesIO
from zope.interface import implementer from zope.interface.verify import verifyObject
from twisted.internet.defer import CancelledError from twisted.internet.interfaces import ( IConsumer, ILoggingContext, IProtocol, IProtocolFactory, ) from twisted.internet.protocol import ( ClientCreator, ConsumerToProtocolAdapter, Factory, FileWrapper, Protocol, ProtocolToConsumerAdapter, ) from twisted.logger import LogLevel, globalLogPublisher from twisted.python.failure import Failure from twisted.test.proto_helpers import MemoryReactorClock, StringTransport from twisted.trial.unittest import TestCase
class ClientCreatorTests(TestCase): """ Tests for L{twisted.internet.protocol.ClientCreator}. """
def _basicConnectTest(self, check): """ Helper for implementing a test to verify that one of the I{connect} methods of L{ClientCreator} passes the right arguments to the right reactor method.
@param check: A function which will be invoked with a reactor and a L{ClientCreator} instance and which should call one of the L{ClientCreator}'s I{connect} methods and assert that all of its arguments except for the factory are passed on as expected to the reactor. The factory should be returned. """
class SomeProtocol(Protocol): pass
reactor = MemoryReactorClock() cc = ClientCreator(reactor, SomeProtocol) factory = check(reactor, cc) protocol = factory.buildProtocol(None) self.assertIsInstance(protocol, SomeProtocol)
def test_connectTCP(self): """ L{ClientCreator.connectTCP} calls C{reactor.connectTCP} with the host and port information passed to it, and with a factory which will construct the protocol passed to L{ClientCreator.__init__}. """
def check(reactor, cc): cc.connectTCP("example.com", 1234, 4321, ("1.2.3.4", 9876)) host, port, factory, timeout, bindAddress = reactor.tcpClients.pop() self.assertEqual(host, "example.com") self.assertEqual(port, 1234) self.assertEqual(timeout, 4321) self.assertEqual(bindAddress, ("1.2.3.4", 9876)) return factory
self._basicConnectTest(check)
def test_connectUNIX(self): """ L{ClientCreator.connectUNIX} calls C{reactor.connectUNIX} with the filename passed to it, and with a factory which will construct the protocol passed to L{ClientCreator.__init__}. """
def check(reactor, cc): cc.connectUNIX("/foo/bar", 123, True) address, factory, timeout, checkPID = reactor.unixClients.pop() self.assertEqual(address, "/foo/bar") self.assertEqual(timeout, 123) self.assertTrue(checkPID) return factory
self._basicConnectTest(check)
def test_connectSSL(self): """ L{ClientCreator.connectSSL} calls C{reactor.connectSSL} with the host, port, and context factory passed to it, and with a factory which will construct the protocol passed to L{ClientCreator.__init__}. """
def check(reactor, cc): expectedContextFactory = object() cc.connectSSL( "example.com", 1234, expectedContextFactory, 4321, ("4.3.2.1", 5678) ) ( host, port, factory, contextFactory, timeout, bindAddress, ) = reactor.sslClients.pop() self.assertEqual(host, "example.com") self.assertEqual(port, 1234) self.assertIs(contextFactory, expectedContextFactory) self.assertEqual(timeout, 4321) self.assertEqual(bindAddress, ("4.3.2.1", 5678)) return factory
self._basicConnectTest(check)
def _cancelConnectTest(self, connect): """ Helper for implementing a test to verify that cancellation of the L{Deferred} returned by one of L{ClientCreator}'s I{connect} methods is implemented to cancel the underlying connector.
@param connect: A function which will be invoked with a L{ClientCreator} instance as an argument and which should call one its I{connect} methods and return the result.
@return: A L{Deferred} which fires when the test is complete or fails if there is a problem. """ reactor = MemoryReactorClock() cc = ClientCreator(reactor, Protocol) d = connect(cc) connector = reactor.connectors.pop() self.assertFalse(connector._disconnected) d.cancel() self.assertTrue(connector._disconnected) return self.assertFailure(d, CancelledError)
def test_cancelConnectTCP(self): """ The L{Deferred} returned by L{ClientCreator.connectTCP} can be cancelled to abort the connection attempt before it completes. """
def connect(cc): return cc.connectTCP("example.com", 1234)
return self._cancelConnectTest(connect)
def test_cancelConnectUNIX(self): """ The L{Deferred} returned by L{ClientCreator.connectTCP} can be cancelled to abort the connection attempt before it completes. """
def connect(cc): return cc.connectUNIX("/foo/bar")
return self._cancelConnectTest(connect)
def test_cancelConnectSSL(self): """ The L{Deferred} returned by L{ClientCreator.connectTCP} can be cancelled to abort the connection attempt before it completes. """
def connect(cc): return cc.connectSSL("example.com", 1234, object())
return self._cancelConnectTest(connect)
def _cancelConnectTimeoutTest(self, connect): """ Like L{_cancelConnectTest}, but for the case where the L{Deferred} is cancelled after the connection is set up but before it is fired with the resulting protocol instance. """ reactor = MemoryReactorClock() cc = ClientCreator(reactor, Protocol) d = connect(reactor, cc) connector = reactor.connectors.pop() # Sanity check - there is an outstanding delayed call to fire the # Deferred. self.assertEqual(len(reactor.getDelayedCalls()), 1)
# Cancel the Deferred, disconnecting the transport just set up and # cancelling the delayed call. d.cancel()
self.assertEqual(reactor.getDelayedCalls(), [])
# A real connector implementation is responsible for disconnecting the # transport as well. For our purposes, just check that someone told the # connector to disconnect. self.assertTrue(connector._disconnected)
return self.assertFailure(d, CancelledError)
def test_cancelConnectTCPTimeout(self): """ L{ClientCreator.connectTCP} inserts a very short delayed call between the time the connection is established and the time the L{Deferred} returned from one of its connect methods actually fires. If the L{Deferred} is cancelled in this interval, the established connection is closed, the timeout is cancelled, and the L{Deferred} fails with L{CancelledError}. """
def connect(reactor, cc): d = cc.connectTCP("example.com", 1234) host, port, factory, timeout, bindAddress = reactor.tcpClients.pop() protocol = factory.buildProtocol(None) transport = StringTransport() protocol.makeConnection(transport) return d
return self._cancelConnectTimeoutTest(connect)
def test_cancelConnectUNIXTimeout(self): """ L{ClientCreator.connectUNIX} inserts a very short delayed call between the time the connection is established and the time the L{Deferred} returned from one of its connect methods actually fires. If the L{Deferred} is cancelled in this interval, the established connection is closed, the timeout is cancelled, and the L{Deferred} fails with L{CancelledError}. """
def connect(reactor, cc): d = cc.connectUNIX("/foo/bar") address, factory, timeout, bindAddress = reactor.unixClients.pop() protocol = factory.buildProtocol(None) transport = StringTransport() protocol.makeConnection(transport) return d
return self._cancelConnectTimeoutTest(connect)
def test_cancelConnectSSLTimeout(self): """ L{ClientCreator.connectSSL} inserts a very short delayed call between the time the connection is established and the time the L{Deferred} returned from one of its connect methods actually fires. If the L{Deferred} is cancelled in this interval, the established connection is closed, the timeout is cancelled, and the L{Deferred} fails with L{CancelledError}. """
def connect(reactor, cc): d = cc.connectSSL("example.com", 1234, object()) ( host, port, factory, contextFactory, timeout, bindADdress, ) = reactor.sslClients.pop() protocol = factory.buildProtocol(None) transport = StringTransport() protocol.makeConnection(transport) return d
return self._cancelConnectTimeoutTest(connect)
def _cancelConnectFailedTimeoutTest(self, connect): """ Like L{_cancelConnectTest}, but for the case where the L{Deferred} is cancelled after the connection attempt has failed but before it is fired with the resulting failure. """ reactor = MemoryReactorClock() cc = ClientCreator(reactor, Protocol) d, factory = connect(reactor, cc) connector = reactor.connectors.pop() factory.clientConnectionFailed( connector, Failure(Exception("Simulated failure")) )
# Sanity check - there is an outstanding delayed call to fire the # Deferred. self.assertEqual(len(reactor.getDelayedCalls()), 1)
# Cancel the Deferred, cancelling the delayed call. d.cancel()
self.assertEqual(reactor.getDelayedCalls(), [])
return self.assertFailure(d, CancelledError)
def test_cancelConnectTCPFailedTimeout(self): """ Similar to L{test_cancelConnectTCPTimeout}, but for the case where the connection attempt fails. """
def connect(reactor, cc): d = cc.connectTCP("example.com", 1234) host, port, factory, timeout, bindAddress = reactor.tcpClients.pop() return d, factory
return self._cancelConnectFailedTimeoutTest(connect)
def test_cancelConnectUNIXFailedTimeout(self): """ Similar to L{test_cancelConnectUNIXTimeout}, but for the case where the connection attempt fails. """
def connect(reactor, cc): d = cc.connectUNIX("/foo/bar") address, factory, timeout, bindAddress = reactor.unixClients.pop() return d, factory
return self._cancelConnectFailedTimeoutTest(connect)
def test_cancelConnectSSLFailedTimeout(self): """ Similar to L{test_cancelConnectSSLTimeout}, but for the case where the connection attempt fails. """
def connect(reactor, cc): d = cc.connectSSL("example.com", 1234, object()) ( host, port, factory, contextFactory, timeout, bindADdress, ) = reactor.sslClients.pop() return d, factory
return self._cancelConnectFailedTimeoutTest(connect)
class ProtocolTests(TestCase): """ Tests for L{twisted.internet.protocol.Protocol}. """
def test_interfaces(self): """ L{Protocol} instances provide L{IProtocol} and L{ILoggingContext}. """ proto = Protocol() self.assertTrue(verifyObject(IProtocol, proto)) self.assertTrue(verifyObject(ILoggingContext, proto))
def test_logPrefix(self): """ L{Protocol.logPrefix} returns the protocol class's name. """
class SomeThing(Protocol): pass
self.assertEqual("SomeThing", SomeThing().logPrefix())
def test_makeConnection(self): """ L{Protocol.makeConnection} sets the given transport on itself, and then calls C{connectionMade}. """ result = []
class SomeProtocol(Protocol): def connectionMade(self): result.append(self.transport)
transport = object() protocol = SomeProtocol() protocol.makeConnection(transport) self.assertEqual(result, [transport])
class FactoryTests(TestCase): """ Tests for L{protocol.Factory}. """
def test_interfaces(self): """ L{Factory} instances provide both L{IProtocolFactory} and L{ILoggingContext}. """ factory = Factory() self.assertTrue(verifyObject(IProtocolFactory, factory)) self.assertTrue(verifyObject(ILoggingContext, factory))
def test_logPrefix(self): """ L{Factory.logPrefix} returns the name of the factory class. """
class SomeKindOfFactory(Factory): pass
self.assertEqual("SomeKindOfFactory", SomeKindOfFactory().logPrefix())
def test_defaultBuildProtocol(self): """ L{Factory.buildProtocol} by default constructs a protocol by calling its C{protocol} attribute, and attaches the factory to the result. """
class SomeProtocol(Protocol): pass
f = Factory() f.protocol = SomeProtocol protocol = f.buildProtocol(None) self.assertIsInstance(protocol, SomeProtocol) self.assertIs(protocol.factory, f)
def test_forProtocol(self): """ L{Factory.forProtocol} constructs a Factory, passing along any additional arguments, and sets its C{protocol} attribute to the given Protocol subclass. """
class ArgTakingFactory(Factory): def __init__(self, *args, **kwargs): self.args, self.kwargs = args, kwargs
factory = ArgTakingFactory.forProtocol(Protocol, 1, 2, foo=12) self.assertEqual(factory.protocol, Protocol) self.assertEqual(factory.args, (1, 2)) self.assertEqual(factory.kwargs, {"foo": 12})
def test_doStartLoggingStatement(self): """ L{Factory.doStart} logs that it is starting a factory, followed by the L{repr} of the L{Factory} instance that is being started. """ events = [] globalLogPublisher.addObserver(events.append) self.addCleanup(lambda: globalLogPublisher.removeObserver(events.append))
f = Factory() f.doStart()
self.assertIs(events[0]["factory"], f) self.assertEqual(events[0]["log_level"], LogLevel.info) self.assertEqual(events[0]["log_format"], "Starting factory {factory!r}")
def test_doStopLoggingStatement(self): """ L{Factory.doStop} logs that it is stopping a factory, followed by the L{repr} of the L{Factory} instance that is being stopped. """ events = [] globalLogPublisher.addObserver(events.append) self.addCleanup(lambda: globalLogPublisher.removeObserver(events.append))
class MyFactory(Factory): numPorts = 1
f = MyFactory() f.doStop()
self.assertIs(events[0]["factory"], f) self.assertEqual(events[0]["log_level"], LogLevel.info) self.assertEqual(events[0]["log_format"], "Stopping factory {factory!r}")
class AdapterTests(TestCase): """ Tests for L{ProtocolToConsumerAdapter} and L{ConsumerToProtocolAdapter}. """
def test_protocolToConsumer(self): """ L{IProtocol} providers can be adapted to L{IConsumer} providers using L{ProtocolToConsumerAdapter}. """ result = [] p = Protocol() p.dataReceived = result.append consumer = IConsumer(p) consumer.write(b"hello") self.assertEqual(result, [b"hello"]) self.assertIsInstance(consumer, ProtocolToConsumerAdapter)
def test_consumerToProtocol(self): """ L{IConsumer} providers can be adapted to L{IProtocol} providers using L{ProtocolToConsumerAdapter}. """ result = []
@implementer(IConsumer) class Consumer: def write(self, d): result.append(d)
c = Consumer() protocol = IProtocol(c) protocol.dataReceived(b"hello") self.assertEqual(result, [b"hello"]) self.assertIsInstance(protocol, ConsumerToProtocolAdapter)
class FileWrapperTests(TestCase): """ L{twisted.internet.protocol.FileWrapper} """
def test_write(self): """ L{twisted.internet.protocol.FileWrapper.write} """ wrapper = FileWrapper(BytesIO()) wrapper.write(b"test1") self.assertEqual(wrapper.file.getvalue(), b"test1")
wrapper = FileWrapper(BytesIO()) # BytesIO() cannot accept unicode, so this will # cause an exception to be thrown which will be # handled by FileWrapper.handle_exception(). wrapper.write("stuff") self.assertNotEqual(wrapper.file.getvalue(), "stuff")
def test_writeSequence(self): """ L{twisted.internet.protocol.FileWrapper.writeSequence} """ wrapper = FileWrapper(BytesIO()) wrapper.writeSequence([b"test1", b"test2"]) self.assertEqual(wrapper.file.getvalue(), b"test1test2")
wrapper = FileWrapper(BytesIO()) # In Python 3, b"".join([u"a", u"b"]) will raise a TypeError self.assertRaises(TypeError, wrapper.writeSequence, ["test3", "test4"])
|