refactor TCP logic in to ITCPTransport (we can mock this for unittests)

This commit is contained in:
jesopo 2020-04-06 13:22:17 +01:00
parent 750fbd8acc
commit 179a2ca93a
6 changed files with 87 additions and 25 deletions

View file

@ -2,9 +2,9 @@ import asyncio
import anyio
from typing import Dict
from irctokens import Line
from .server import ConnectionParams, Server
from .server import ConnectionParams, Server
from .transport import TCPTransport
RECONNECT_DELAY = 10.0 # ten seconds reconnect
@ -23,7 +23,7 @@ class Bot(object):
async def add_server(self, name: str, params: ConnectionParams) -> Server:
server = self.create_server(name)
self.servers[name] = server
await server.connect(params)
await server.connect(TCPTransport() ,params)
await self._server_queue.put(server)
return server

View file

@ -7,6 +7,25 @@ from irctokens import Line
from .params import ConnectionParams, SASLParams
class ITCPReader(object):
async def read(self, byte_count: int):
pass
class ITCPWriter(object):
def write(self, data: bytes):
pass
async def drain(self):
pass
class ITCPTransport(object):
async def connect(self,
hostname: str,
port: int,
tls: bool,
tls_verify: bool=True,
bindhost: Optional[str]=None
) -> Tuple[ITCPReader, ITCPWriter]:
pass
class SendPriority(IntEnum):
HIGH = 0
MEDIUM = 10
@ -53,7 +72,9 @@ class IServer(Server):
def set_throttle(self, rate: int, time: float):
pass
async def connect(self, params: ConnectionParams):
async def connect(self,
transport: ITCPTransport,
params: ConnectionParams):
pass
async def line_read(self, line: Line):

View file

@ -1,9 +1,8 @@
from typing import Dict, Iterable, List, Optional
from irctokens import build
from .contexts import ServerContext
from .matching import (Response, Numerics, ResponseOr, ParamAny, ParamNot,
ParamLiteral)
from .contexts import ServerContext
from .matching import Response, ResponseOr, ParamAny, ParamLiteral
from .interface import ICapability
class Capability(ICapability):

View file

@ -1,6 +1,6 @@
import ssl
def ssl_context(verify: bool=True) -> ssl.SSLContext:
def tls_context(verify: bool=True) -> ssl.SSLContext:
context = ssl.SSLContext(ssl.PROTOCOL_TLS)
context.options |= ssl.OP_NO_SSLv2
context.options |= ssl.OP_NO_SSLv3

View file

@ -1,25 +1,26 @@
import asyncio
from ssl import SSLContext
from asyncio import Future, PriorityQueue, Queue
from typing import Deque, Dict, List, Optional, Set, Tuple
from asyncio import Future, PriorityQueue
from typing import Awaitable, Deque, Dict, List, Optional, Set, Tuple
from collections import deque
from asyncio_throttle import Throttler
from ircstates import Emit
from ircstates import Emit, Channel
from irctokens import build, Line, tokenise
from .ircv3 import CAPContext, CAP_SASL
from .sasl import SASLContext, SASLResult
from .matching import Numeric, ParamAny, ParamFolded
from .asyncs import MaybeAwait
from .interface import (ConnectionParams, ICapability, IServer, SentLine,
SendPriority, SASLParams, IMatchResponse)
from .sasl import SASLContext, SASLResult
from .security import ssl_context
from .interface import ITCPTransport, ITCPReader, ITCPWriter
THROTTLE_RATE = 4 # lines
THROTTLE_TIME = 2 # seconds
class Server(IServer):
_reader: asyncio.StreamReader
_writer: asyncio.StreamWriter
_reader: ITCPReader
_writer: ITCPWriter
params: ConnectionParams
def __init__(self, name: str):
@ -49,16 +50,15 @@ class Server(IServer):
self.throttle.rate_limit = rate
self.throttle.period = time
async def connect(self, params: ConnectionParams):
cur_ssl: Optional[SSLContext] = None
if params.tls:
cur_ssl = ssl_context(params.tls_verify)
reader, writer = await asyncio.open_connection(
async def connect(self,
transport: ITCPTransport,
params: ConnectionParams):
reader, writer = await transport.connect(
params.host,
params.port,
ssl=cur_ssl,
local_addr=(params.bindhost, 0))
tls =params.tls,
tls_verify=params.tls_verify,
bindhost =params.bindhost)
self._reader = reader
self._writer = writer

42
ircrobots/transport.py Normal file
View file

@ -0,0 +1,42 @@
from ssl import SSLContext
from typing import Optional, Tuple
from asyncio import open_connection, StreamReader, StreamWriter
from .interface import ITCPTransport, ITCPReader, ITCPWriter
from .security import tls_context
class TCPReader(ITCPReader):
def __init__(self, reader: StreamReader):
self._reader = reader
async def read(self, byte_count: int) -> bytes:
return await self._reader.read(byte_count)
class TCPWriter(ITCPWriter):
def __init__(self, writer: StreamWriter):
self._writer = writer
def write(self, data: bytes):
self._writer.write(data)
async def drain(self):
await self._writer.drain()
class TCPTransport(ITCPTransport):
async def connect(self,
hostname: str,
port: int,
tls: bool,
tls_verify: bool=True,
bindhost: Optional[str]=None
) -> Tuple[ITCPReader, ITCPWriter]:
cur_ssl: Optional[SSLContext] = None
if tls:
cur_ssl = tls_context(tls_verify)
reader, writer = await open_connection(
hostname,
port,
ssl=cur_ssl,
local_addr=(bindhost, 0))
return (TCPReader(reader), TCPWriter(writer))