From 179a2ca93aa8999b3fa8cf49f0a807ab552546ef Mon Sep 17 00:00:00 2001 From: jesopo Date: Mon, 6 Apr 2020 13:22:17 +0100 Subject: [PATCH] refactor TCP logic in to ITCPTransport (we can mock this for unittests) --- ircrobots/bot.py | 6 +++--- ircrobots/interface.py | 23 ++++++++++++++++++++++- ircrobots/ircv3.py | 5 ++--- ircrobots/security.py | 2 +- ircrobots/server.py | 34 +++++++++++++++++----------------- ircrobots/transport.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 87 insertions(+), 25 deletions(-) create mode 100644 ircrobots/transport.py diff --git a/ircrobots/bot.py b/ircrobots/bot.py index 3bbb75a..a19599e 100644 --- a/ircrobots/bot.py +++ b/ircrobots/bot.py @@ -2,9 +2,9 @@ import asyncio import anyio from typing import Dict -from irctokens import Line -from .server import ConnectionParams, Server +from .server import ConnectionParams, Server +from .transport import TCPTransport RECONNECT_DELAY = 10.0 # ten seconds reconnect @@ -23,7 +23,7 @@ class Bot(object): async def add_server(self, name: str, params: ConnectionParams) -> Server: server = self.create_server(name) self.servers[name] = server - await server.connect(params) + await server.connect(TCPTransport() ,params) await self._server_queue.put(server) return server diff --git a/ircrobots/interface.py b/ircrobots/interface.py index 755f0b0..2039ec0 100644 --- a/ircrobots/interface.py +++ b/ircrobots/interface.py @@ -7,6 +7,25 @@ from irctokens import Line from .params import ConnectionParams, SASLParams +class ITCPReader(object): + async def read(self, byte_count: int): + pass +class ITCPWriter(object): + def write(self, data: bytes): + pass + async def drain(self): + pass + +class ITCPTransport(object): + async def connect(self, + hostname: str, + port: int, + tls: bool, + tls_verify: bool=True, + bindhost: Optional[str]=None + ) -> Tuple[ITCPReader, ITCPWriter]: + pass + class SendPriority(IntEnum): HIGH = 0 MEDIUM = 10 @@ -53,7 +72,9 @@ class IServer(Server): def set_throttle(self, rate: int, time: float): pass - async def connect(self, params: ConnectionParams): + async def connect(self, + transport: ITCPTransport, + params: ConnectionParams): pass async def line_read(self, line: Line): diff --git a/ircrobots/ircv3.py b/ircrobots/ircv3.py index c4aa495..27cb72b 100644 --- a/ircrobots/ircv3.py +++ b/ircrobots/ircv3.py @@ -1,9 +1,8 @@ from typing import Dict, Iterable, List, Optional from irctokens import build -from .contexts import ServerContext -from .matching import (Response, Numerics, ResponseOr, ParamAny, ParamNot, - ParamLiteral) +from .contexts import ServerContext +from .matching import Response, ResponseOr, ParamAny, ParamLiteral from .interface import ICapability class Capability(ICapability): diff --git a/ircrobots/security.py b/ircrobots/security.py index 0282866..17d1b78 100644 --- a/ircrobots/security.py +++ b/ircrobots/security.py @@ -1,6 +1,6 @@ import ssl -def ssl_context(verify: bool=True) -> ssl.SSLContext: +def tls_context(verify: bool=True) -> ssl.SSLContext: context = ssl.SSLContext(ssl.PROTOCOL_TLS) context.options |= ssl.OP_NO_SSLv2 context.options |= ssl.OP_NO_SSLv3 diff --git a/ircrobots/server.py b/ircrobots/server.py index df09646..68eb39d 100644 --- a/ircrobots/server.py +++ b/ircrobots/server.py @@ -1,25 +1,26 @@ -import asyncio -from ssl import SSLContext -from asyncio import Future, PriorityQueue, Queue -from typing import Deque, Dict, List, Optional, Set, Tuple +from asyncio import Future, PriorityQueue +from typing import Awaitable, Deque, Dict, List, Optional, Set, Tuple from collections import deque from asyncio_throttle import Throttler -from ircstates import Emit +from ircstates import Emit, Channel from irctokens import build, Line, tokenise from .ircv3 import CAPContext, CAP_SASL +from .sasl import SASLContext, SASLResult +from .matching import Numeric, ParamAny, ParamFolded +from .asyncs import MaybeAwait + from .interface import (ConnectionParams, ICapability, IServer, SentLine, SendPriority, SASLParams, IMatchResponse) -from .sasl import SASLContext, SASLResult -from .security import ssl_context +from .interface import ITCPTransport, ITCPReader, ITCPWriter THROTTLE_RATE = 4 # lines THROTTLE_TIME = 2 # seconds class Server(IServer): - _reader: asyncio.StreamReader - _writer: asyncio.StreamWriter + _reader: ITCPReader + _writer: ITCPWriter params: ConnectionParams def __init__(self, name: str): @@ -49,16 +50,15 @@ class Server(IServer): self.throttle.rate_limit = rate self.throttle.period = time - async def connect(self, params: ConnectionParams): - cur_ssl: Optional[SSLContext] = None - if params.tls: - cur_ssl = ssl_context(params.tls_verify) - - reader, writer = await asyncio.open_connection( + async def connect(self, + transport: ITCPTransport, + params: ConnectionParams): + reader, writer = await transport.connect( params.host, params.port, - ssl=cur_ssl, - local_addr=(params.bindhost, 0)) + tls =params.tls, + tls_verify=params.tls_verify, + bindhost =params.bindhost) self._reader = reader self._writer = writer diff --git a/ircrobots/transport.py b/ircrobots/transport.py new file mode 100644 index 0000000..9e20e7d --- /dev/null +++ b/ircrobots/transport.py @@ -0,0 +1,42 @@ +from ssl import SSLContext +from typing import Optional, Tuple +from asyncio import open_connection, StreamReader, StreamWriter + +from .interface import ITCPTransport, ITCPReader, ITCPWriter +from .security import tls_context + +class TCPReader(ITCPReader): + def __init__(self, reader: StreamReader): + self._reader = reader + + async def read(self, byte_count: int) -> bytes: + return await self._reader.read(byte_count) +class TCPWriter(ITCPWriter): + def __init__(self, writer: StreamWriter): + self._writer = writer + + def write(self, data: bytes): + self._writer.write(data) + + async def drain(self): + await self._writer.drain() + +class TCPTransport(ITCPTransport): + async def connect(self, + hostname: str, + port: int, + tls: bool, + tls_verify: bool=True, + bindhost: Optional[str]=None + ) -> Tuple[ITCPReader, ITCPWriter]: + + cur_ssl: Optional[SSLContext] = None + if tls: + cur_ssl = tls_context(tls_verify) + + reader, writer = await open_connection( + hostname, + port, + ssl=cur_ssl, + local_addr=(bindhost, 0)) + return (TCPReader(reader), TCPWriter(writer))