diff options
| -rw-r--r-- | src/mailman/testing/mta.py | 245 |
1 files changed, 118 insertions, 127 deletions
diff --git a/src/mailman/testing/mta.py b/src/mailman/testing/mta.py index b1e62648c..a4ae01763 100644 --- a/src/mailman/testing/mta.py +++ b/src/mailman/testing/mta.py @@ -17,11 +17,13 @@ """Fake MTA for testing purposes.""" +import asyncio import logging +import smtplib from aiosmtpd.controller import Controller -from lazr.smtptest.controller import QueueController -from lazr.smtptest.server import Channel, QueueServer +from aiosmtpd.handlers import Message as MessageHandler +from aiosmtpd.smtp import SMTP from mailman import public from mailman.interfaces.mta import IMailTransportAgentLifecycle from queue import Empty, Queue @@ -46,113 +48,65 @@ class FakeMTA: pass -class StatisticsChannel(Channel): - """A channel that can answers to the fake STAT command.""" +class ConnectionCountingHandler(MessageHandler): + def __init__(self, msg_queue): + super().__init__() + self._msg_queue = msg_queue - def __init__(self, server, connection, address): - super().__init__(server, connection, address) - self._auth_response = None - self._waiting_for_auth_response = False - - def smtp_EHLO(self, arg): - if not arg: - self.push('501 Syntax: HELO hostname') - return - if self._SMTPChannel__greeting: - self.push('503 Duplicate HELO/EHLO') - else: - self._SMTPChannel__greeting = arg - self.push('250-%s' % self._SMTPChannel__fqdn) - self.push('250 AUTH PLAIN') + def handle_message(self, message): + self._msg_queue.put(message) - def smtp_STAT(self, arg): - """Cause the server to send statistics to its controller.""" - self._server.send_statistics() - self.push('250 Ok') - def _check_auth(self, response): - # Base 64 for "testuser:testpass" - if response == 'AHRlc3R1c2VyAHRlc3RwYXNz': - self.push('235 Ok') - self._server.send_auth(response) - else: - self.push('571 Bad authentication') +class ConnectionCountingSMTP(SMTP): + def __init__(self, handler, oob_queue, err_queue, *args, **kws): + super().__init__(handler, *args, **kws) + self._auth_response = None + self._waiting_for_auth_response = False + self._connection_count = 0 + self._oob_queue = oob_queue + self._err_queue = err_queue + self._last_error = None + @asyncio.coroutine def smtp_AUTH(self, arg): """Record that the AUTH occurred.""" args = arg.split() if args[0].lower() == 'plain': if len(args) == 2: + response = args[1] # The second argument is the AUTH PLAIN <initial-response> # which must be equal to the base 64 equivalent of the # expected login string "testuser:testpass". - self._check_auth(args[1]) + if response == 'AHRlc3R1c2VyAHRlc3RwYXNz': + yield from self.push('235 Ok') + self._oob_queue.put(response) + else: + yield from self.push('571 Bad authentication') else: assert len(args) == 1, args # Send a challenge and set us up to wait for the response. - self.push('334 ') + yield from self.push('334 ') self._waiting_for_auth_response = True else: - self.push('571 Bad authentication') - - def smtp_RCPT(self, arg): - """For testing, sometimes cause a non-25x response.""" - code = self._server.next_error('rcpt') - if code is None: - # Everything's cool. - Channel.smtp_RCPT(self, arg) - else: - # The test suite wants this to fail. The message corresponds to - # the exception we expect smtplib.SMTP to raise. - self.push('%d Error: SMTPRecipientsRefused' % code) - - def smtp_MAIL(self, arg): - """For testing, sometimes cause a non-25x response.""" - code = self._server.next_error('mail') - if code is None: - # Everything's cool. - Channel.smtp_MAIL(self, arg) - else: - # The test suite wants this to fail. The message corresponds to - # the exception we expect smtplib.SMTP to raise. - self.push('%d Error: SMTPResponseException' % code) + yield from self.push('571 Bad authentication') - def found_terminator(self): - # Are we're waiting for the AUTH challenge response? - if self._waiting_for_auth_response: - line = self._emptystring.join(self.received_lines) - self._auth_response = line - self._waiting_for_auth_response = False - self.received_lines = [] - # Now check to see if they authenticated correctly. - self._check_auth(line) - else: - super().found_terminator() - - -class ConnectionCountingServer(QueueServer): - """Count the number of SMTP connections opened.""" - - def __init__(self, host, port, queue, oob_queue, err_queue): - """See `lazr.smtptest.server.QueueServer`. + @asyncio.coroutine + def smtp_EHLO(self, arg): + yield from super().smtp_EHLO(arg) + # If the upcall succeeded, this flag will be set. In that case, also + # push an AUTH PLAIN response, which the superclass doesn't do. + ## if self.extended_smtp: + ## yield from self.push('250 AUTH PLAIN') - :param oob_queue: A queue for communicating information back to the - controller, e.g. statistics. - :type oob_queue: `Queue.Queue` - :param err_queue: A queue for allowing the controller to request SMTP - errors from the server. - :type err_queue: `Queue.Queue` - """ - QueueServer.__init__(self, host, port, queue) - self._connection_count = 0 - self.last_auth = None - # The out-of-band queue is where the server sends statistics to the - # controller upon request. - self._oob_queue = oob_queue - self._err_queue = err_queue - self._last_error = None + @asyncio.coroutine + def smtp_STAT(self, arg): + """Cause the server to send statistics to its controller.""" + # Do not count the connection caused by the STAT connect. + self._connection_count -= 1 + self._oob_queue.put(self._connection_count) + yield from self.push('250 Ok') - def next_error(self, command): + def _next_error(self, command): """Return the next error for the SMTP command, if there is one. :param command: The SMTP command for which an error might be @@ -177,70 +131,107 @@ class ConnectionCountingServer(QueueServer): return code return None - def handle_accept(self): - """See `lazr.smtp.server.Server`.""" - connection, address = self.accept() - self._connection_count += 1 - log.info('[ConnectionCountingServer] accepted: %s', address) - StatisticsChannel(self, connection, address) + @asyncio.coroutine + def smtp_RCPT(self, arg): + """For testing, sometimes cause a non-25x response.""" + code = self._next_error('rcpt') + if code is None: + # Everything's cool. + yield from super().smtp_RCPT(arg) + else: + # The test suite wants this to fail. The message corresponds to + # the exception we expect smtplib.SMTP to raise. + yield from self.push('%d Error: SMTPRecipientsRefused' % code) - def process_message(self, peer, mailfrom, rcpttos, data): - # Provide a guaranteed order to recpttos. - QueueServer.process_message( - self, peer, mailfrom, sorted(rcpttos), data) + @asyncio.coroutine + def smtp_MAIL(self, arg): + """For testing, sometimes cause a non-25x response.""" + code = self._next_error('mail') + if code is None: + # Everything's cool. + yield from super().smtp_MAIL(arg) + else: + # The test suite wants this to fail. The message corresponds to + # the exception we expect smtplib.SMTP to raise. + yield from self.push('%d Error: SMTPResponseException' % code) - def reset(self): - """See `lazr.smtp.server.Server`.""" - QueueServer.reset(self) - self._connection_count = 0 + def found_terminator(self): + # Are we're waiting for the AUTH challenge response? + if self._waiting_for_auth_response: + line = self._emptystring.join(self.received_lines) + self._auth_response = line + self._waiting_for_auth_response = False + self.received_lines = [] + # Now check to see if they authenticated correctly. + self._check_auth(line) + else: + super().found_terminator() - def send_statistics(self): - """Send the current connection statistics to the controller.""" - # Do not count the connection caused by the STAT connect. - self._connection_count -= 1 - self._oob_queue.put(self._connection_count) - def send_auth(self, arg): - """Echo back the authentication data.""" - self._oob_queue.put(arg) +import socket +import asyncio -class ConnectionCountingController(QueueController): +class ConnectionCountingController(Controller): """Count the number of SMTP connections opened.""" def __init__(self, host, port): - """See `lazr.smtptest.controller.QueueController`.""" - self.oob_queue = Queue() + self._msg_queue = Queue() + self._oob_queue = Queue() self.err_queue = Queue() - QueueController.__init__(self, host, port) + handler = ConnectionCountingHandler(self._msg_queue) + super().__init__(handler, hostname=host, port=port) - def _make_server(self, host, port): - """See `lazr.smtptest.controller.QueueController`.""" - self.server = ConnectionCountingServer( - host, port, self.queue, self.oob_queue, self.err_queue) + def factory(self): + return ConnectionCountingSMTP( + self.handler, self._oob_queue, self.err_queue) + + def _run(self, ready_event): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True) + sock.bind((self.hostname, self.port)) + asyncio.set_event_loop(self.loop) + server = self.loop.run_until_complete( + self.loop.create_server(self.factory, sock=sock)) + self.loop.call_soon(ready_event.set) + self.loop.run_forever() + server.close() + self.loop.run_until_complete(server.wait_closed()) + self.loop.close() def start(self): - """See `lazr.smtptest.controller.QueueController`.""" - QueueController.start(self) + super().start() # Reset the connection statistics, since the base class's start() # method causes a connection to occur. self.reset() + def _connect(self): + client = smtplib.SMTP() + client.connect(self.hostname, self.port) + return client + def get_connection_count(self): """Retrieve the number of connections. :return: The number of connections to the server that have been made. :rtype: integer """ - smtpd = self._connect() - smtpd.docmd('STAT') + client = self._connect() + client.docmd('STAT') # An Empty exception will occur if the data isn't available in 10 # seconds. Let that propagate. - return self.oob_queue.get(block=True, timeout=10) + return self._oob_queue.get(block=True, timeout=10) def get_authentication_credentials(self): """Retrieve the last authentication credentials.""" - return self.oob_queue.get(block=True, timeout=10) + return self._oob_queue.get(block=True, timeout=10) + + def __iter__(self): + while True: + try: + yield self._msg_queue.get_nowait() + except Empty: + raise StopIteration @property def messages(self): @@ -252,5 +243,5 @@ class ConnectionCountingController(QueueController): list(self) def reset(self): - smtpd = self._connect() - smtpd.docmd('RSET') + client = self._connect() + client.docmd('RSET') |
