refactor TCP logic in to ITCPTransport (we can mock this for unittests)
This commit is contained in:
parent
750fbd8acc
commit
179a2ca93a
6 changed files with 87 additions and 25 deletions
|
@ -2,9 +2,9 @@ import asyncio
|
|||
import anyio
|
||||
|
||||
from typing import Dict
|
||||
from irctokens import Line
|
||||
|
||||
from .server import ConnectionParams, Server
|
||||
from .server import ConnectionParams, Server
|
||||
from .transport import TCPTransport
|
||||
|
||||
RECONNECT_DELAY = 10.0 # ten seconds reconnect
|
||||
|
||||
|
@ -23,7 +23,7 @@ class Bot(object):
|
|||
async def add_server(self, name: str, params: ConnectionParams) -> Server:
|
||||
server = self.create_server(name)
|
||||
self.servers[name] = server
|
||||
await server.connect(params)
|
||||
await server.connect(TCPTransport() ,params)
|
||||
await self._server_queue.put(server)
|
||||
return server
|
||||
|
||||
|
|
|
@ -7,6 +7,25 @@ from irctokens import Line
|
|||
|
||||
from .params import ConnectionParams, SASLParams
|
||||
|
||||
class ITCPReader(object):
|
||||
async def read(self, byte_count: int):
|
||||
pass
|
||||
class ITCPWriter(object):
|
||||
def write(self, data: bytes):
|
||||
pass
|
||||
async def drain(self):
|
||||
pass
|
||||
|
||||
class ITCPTransport(object):
|
||||
async def connect(self,
|
||||
hostname: str,
|
||||
port: int,
|
||||
tls: bool,
|
||||
tls_verify: bool=True,
|
||||
bindhost: Optional[str]=None
|
||||
) -> Tuple[ITCPReader, ITCPWriter]:
|
||||
pass
|
||||
|
||||
class SendPriority(IntEnum):
|
||||
HIGH = 0
|
||||
MEDIUM = 10
|
||||
|
@ -53,7 +72,9 @@ class IServer(Server):
|
|||
def set_throttle(self, rate: int, time: float):
|
||||
pass
|
||||
|
||||
async def connect(self, params: ConnectionParams):
|
||||
async def connect(self,
|
||||
transport: ITCPTransport,
|
||||
params: ConnectionParams):
|
||||
pass
|
||||
|
||||
async def line_read(self, line: Line):
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
from typing import Dict, Iterable, List, Optional
|
||||
from irctokens import build
|
||||
|
||||
from .contexts import ServerContext
|
||||
from .matching import (Response, Numerics, ResponseOr, ParamAny, ParamNot,
|
||||
ParamLiteral)
|
||||
from .contexts import ServerContext
|
||||
from .matching import Response, ResponseOr, ParamAny, ParamLiteral
|
||||
from .interface import ICapability
|
||||
|
||||
class Capability(ICapability):
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import ssl
|
||||
|
||||
def ssl_context(verify: bool=True) -> ssl.SSLContext:
|
||||
def tls_context(verify: bool=True) -> ssl.SSLContext:
|
||||
context = ssl.SSLContext(ssl.PROTOCOL_TLS)
|
||||
context.options |= ssl.OP_NO_SSLv2
|
||||
context.options |= ssl.OP_NO_SSLv3
|
||||
|
|
|
@ -1,25 +1,26 @@
|
|||
import asyncio
|
||||
from ssl import SSLContext
|
||||
from asyncio import Future, PriorityQueue, Queue
|
||||
from typing import Deque, Dict, List, Optional, Set, Tuple
|
||||
from asyncio import Future, PriorityQueue
|
||||
from typing import Awaitable, Deque, Dict, List, Optional, Set, Tuple
|
||||
from collections import deque
|
||||
|
||||
from asyncio_throttle import Throttler
|
||||
from ircstates import Emit
|
||||
from ircstates import Emit, Channel
|
||||
from irctokens import build, Line, tokenise
|
||||
|
||||
from .ircv3 import CAPContext, CAP_SASL
|
||||
from .sasl import SASLContext, SASLResult
|
||||
from .matching import Numeric, ParamAny, ParamFolded
|
||||
from .asyncs import MaybeAwait
|
||||
|
||||
from .interface import (ConnectionParams, ICapability, IServer, SentLine,
|
||||
SendPriority, SASLParams, IMatchResponse)
|
||||
from .sasl import SASLContext, SASLResult
|
||||
from .security import ssl_context
|
||||
from .interface import ITCPTransport, ITCPReader, ITCPWriter
|
||||
|
||||
THROTTLE_RATE = 4 # lines
|
||||
THROTTLE_TIME = 2 # seconds
|
||||
|
||||
class Server(IServer):
|
||||
_reader: asyncio.StreamReader
|
||||
_writer: asyncio.StreamWriter
|
||||
_reader: ITCPReader
|
||||
_writer: ITCPWriter
|
||||
params: ConnectionParams
|
||||
|
||||
def __init__(self, name: str):
|
||||
|
@ -49,16 +50,15 @@ class Server(IServer):
|
|||
self.throttle.rate_limit = rate
|
||||
self.throttle.period = time
|
||||
|
||||
async def connect(self, params: ConnectionParams):
|
||||
cur_ssl: Optional[SSLContext] = None
|
||||
if params.tls:
|
||||
cur_ssl = ssl_context(params.tls_verify)
|
||||
|
||||
reader, writer = await asyncio.open_connection(
|
||||
async def connect(self,
|
||||
transport: ITCPTransport,
|
||||
params: ConnectionParams):
|
||||
reader, writer = await transport.connect(
|
||||
params.host,
|
||||
params.port,
|
||||
ssl=cur_ssl,
|
||||
local_addr=(params.bindhost, 0))
|
||||
tls =params.tls,
|
||||
tls_verify=params.tls_verify,
|
||||
bindhost =params.bindhost)
|
||||
|
||||
self._reader = reader
|
||||
self._writer = writer
|
||||
|
|
42
ircrobots/transport.py
Normal file
42
ircrobots/transport.py
Normal file
|
@ -0,0 +1,42 @@
|
|||
from ssl import SSLContext
|
||||
from typing import Optional, Tuple
|
||||
from asyncio import open_connection, StreamReader, StreamWriter
|
||||
|
||||
from .interface import ITCPTransport, ITCPReader, ITCPWriter
|
||||
from .security import tls_context
|
||||
|
||||
class TCPReader(ITCPReader):
|
||||
def __init__(self, reader: StreamReader):
|
||||
self._reader = reader
|
||||
|
||||
async def read(self, byte_count: int) -> bytes:
|
||||
return await self._reader.read(byte_count)
|
||||
class TCPWriter(ITCPWriter):
|
||||
def __init__(self, writer: StreamWriter):
|
||||
self._writer = writer
|
||||
|
||||
def write(self, data: bytes):
|
||||
self._writer.write(data)
|
||||
|
||||
async def drain(self):
|
||||
await self._writer.drain()
|
||||
|
||||
class TCPTransport(ITCPTransport):
|
||||
async def connect(self,
|
||||
hostname: str,
|
||||
port: int,
|
||||
tls: bool,
|
||||
tls_verify: bool=True,
|
||||
bindhost: Optional[str]=None
|
||||
) -> Tuple[ITCPReader, ITCPWriter]:
|
||||
|
||||
cur_ssl: Optional[SSLContext] = None
|
||||
if tls:
|
||||
cur_ssl = tls_context(tls_verify)
|
||||
|
||||
reader, writer = await open_connection(
|
||||
hostname,
|
||||
port,
|
||||
ssl=cur_ssl,
|
||||
local_addr=(bindhost, 0))
|
||||
return (TCPReader(reader), TCPWriter(writer))
|
Loading…
Reference in a new issue