diff --git a/.travis.yml b/.travis.yml index 4311574..8b924d4 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,7 +3,7 @@ cache: pip python: - "3.7" - "3.8" - - "3.8-dev" + - "3.9" install: - pip3 install mypy -r requirements.txt script: diff --git a/README.md b/README.md index c8e88cd..409a2f5 100644 --- a/README.md +++ b/README.md @@ -11,4 +11,4 @@ see [examples/](examples/) for some usage demonstration. ## contact -Come say hi at [##irctokens on freenode](https://webchat.freenode.net/?channels=%23%23irctokens) +Come say hi at `#irctokens` on irc.libera.chat diff --git a/VERSION b/VERSION index 0f82685..faef31a 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.3.7 +0.7.0 diff --git a/examples/factoids.py b/examples/factoids.py index bea0b4c..336dc2b 100644 --- a/examples/factoids.py +++ b/examples/factoids.py @@ -154,8 +154,8 @@ async def main(hostname: str, channel: str, nickname: str): params = ConnectionParams( nickname, hostname, - 6697, - tls=True) + 6697 + ) await bot.add_server("freenode", params) await bot.run() diff --git a/examples/sasl.py b/examples/sasl.py index e8e9818..97c81fa 100644 --- a/examples/sasl.py +++ b/examples/sasl.py @@ -23,7 +23,6 @@ async def main(): "MyNickname", host = "chat.freenode.invalid", port = 6697, - tls = True, sasl = sasl_params) await bot.add_server("freenode", params) diff --git a/examples/simple.py b/examples/simple.py index 54892b4..e47fc1b 100644 --- a/examples/simple.py +++ b/examples/simple.py @@ -25,7 +25,7 @@ class Bot(BaseBot): async def main(): bot = Bot() for name, host in SERVERS: - params = ConnectionParams("BitBotNewTest", host, 6697, True) + params = ConnectionParams("BitBotNewTest", host, 6697) await bot.add_server(name, params) await bot.run() diff --git a/ircrobots/__init__.py b/ircrobots/__init__.py index 5b798ed..c033b8f 100644 --- a/ircrobots/__init__.py +++ b/ircrobots/__init__.py @@ -3,3 +3,4 @@ from .server import Server from .params import (ConnectionParams, SASLUserPass, SASLExternal, SASLSCRAM, STSPolicy, ResumePolicy) from .ircv3 import Capability +from .security import TLS diff --git a/ircrobots/asyncs.py b/ircrobots/asyncs.py index 7b2f4e2..54d0b3b 100644 --- a/ircrobots/asyncs.py +++ b/ircrobots/asyncs.py @@ -19,9 +19,17 @@ class MaybeAwait(Generic[TEvent]): class WaitFor(object): def __init__(self, response: IMatchResponse, - label: Optional[str]=None): + deadline: float): self.response = response - self._label = label + self.deadline = deadline + self._label: Optional[str] = None + self._our_fut: "Future[Line]" = Future() + + def __await__(self) -> Generator[Any, None, Line]: + return self._our_fut.__await__() + + def with_label(self, label: str): + self._label = label def match(self, server: IServer, line: Line): if (self._label is not None and @@ -31,3 +39,6 @@ class WaitFor(object): label == self._label): return True return self.response.match(server, line) + + def resolve(self, line: Line): + self._our_fut.set_result(line) diff --git a/ircrobots/bot.py b/ircrobots/bot.py index cdfe3ab..809fb5e 100644 --- a/ircrobots/bot.py +++ b/ircrobots/bot.py @@ -1,4 +1,4 @@ -import asyncio +import asyncio, traceback import anyio from typing import Dict @@ -6,40 +6,53 @@ from ircstates.server import ServerDisconnectedException from .server import ConnectionParams, Server from .transport import TCPTransport -from .interface import IBot, IServer +from .interface import IBot, IServer, ITCPTransport class Bot(IBot): def __init__(self): self.servers: Dict[str, Server] = {} self._server_queue: asyncio.Queue[Server] = asyncio.Queue() - # methods designed to be overridden def create_server(self, name: str): return Server(self, name) + async def disconnected(self, server: IServer): if (server.name in self.servers and server.params is not None and server.disconnected): - await asyncio.sleep(server.params.reconnect) - await self.add_server(server.name, server.params) - # /methods designed to be overridden + + reconnect = server.params.reconnect + + while True: + await asyncio.sleep(reconnect) + try: + await self.add_server(server.name, server.params) + except Exception as e: + traceback.print_exc() + # let's try again, exponential backoff up to 5 mins + reconnect = min(reconnect*2, 300) + else: + break async def disconnect(self, server: IServer): - await server.disconnect() del self.servers[server.name] + await server.disconnect() - async def add_server(self, name: str, params: ConnectionParams) -> Server: + async def add_server(self, + name: str, + params: ConnectionParams, + transport: ITCPTransport = TCPTransport()) -> Server: server = self.create_server(name) self.servers[name] = server - await server.connect(TCPTransport(), params) + await server.connect(transport, params) await self._server_queue.put(server) return server async def _run_server(self, server: Server): try: async with anyio.create_task_group() as tg: - await tg.spawn(server._read_lines) - await tg.spawn(server._send_lines) + tg.start_soon(server._read_lines) + tg.start_soon(server._send_lines) except ServerDisconnectedException: server.disconnected = True @@ -49,4 +62,4 @@ class Bot(IBot): async with anyio.create_task_group() as tg: while not tg.cancel_scope.cancel_called: server = await self._server_queue.get() - await tg.spawn(self._run_server, server) + tg.start_soon(self._run_server, server) diff --git a/ircrobots/interface.py b/ircrobots/interface.py index f680f5f..db66353 100644 --- a/ircrobots/interface.py +++ b/ircrobots/interface.py @@ -6,6 +6,7 @@ from ircstates import Server, Emit from irctokens import Line, Hostmask from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy +from .security import TLS class ITCPReader(object): async def read(self, byte_count: int): @@ -24,11 +25,10 @@ class ITCPWriter(object): class ITCPTransport(object): async def connect(self, - hostname: str, - port: int, - tls: bool, - tls_verify: bool=True, - bindhost: Optional[str]=None + hostname: str, + port: int, + tls: Optional[TLS], + bindhost: Optional[str]=None ) -> Tuple[ITCPReader, ITCPWriter]: pass diff --git a/ircrobots/ircv3.py b/ircrobots/ircv3.py index b26ab8f..3795237 100644 --- a/ircrobots/ircv3.py +++ b/ircrobots/ircv3.py @@ -8,6 +8,7 @@ from .contexts import ServerContext from .matching import Response, ANY from .interface import ICapability from .params import ConnectionParams, STSPolicy, ResumePolicy +from .security import TLSVerifyChain class Capability(ICapability): def __init__(self, @@ -101,12 +102,12 @@ def _cap_dict(s: str) -> Dict[str, str]: return d async def sts_transmute(params: ConnectionParams): - if not params.sts is None and not params.tls: + if not params.sts is None and params.tls is None: now = time() since = (now-params.sts.created) if since <= params.sts.duration: params.port = params.sts.port - params.tls = True + params.tls = TLSVerifyChain() async def resume_transmute(params: ConnectionParams): if params.resume is not None: params.host = params.resume.address @@ -182,7 +183,7 @@ class CAPContext(ServerContext): if not params.tls: if "port" in sts_dict: params.port = int(sts_dict["port"]) - params.tls = True + params.tls = TLSVerifyChain() await self.server.bot.disconnect(self.server) await self.server.bot.add_server(self.server.name, params) diff --git a/ircrobots/matching/params.py b/ircrobots/matching/params.py index c038db4..ec2e8ff 100644 --- a/ircrobots/matching/params.py +++ b/ircrobots/matching/params.py @@ -73,8 +73,7 @@ class Formatless(IMatchResponseParam): def __init__(self, value: TYPE_MAYBELIT_VALUE): self._value = _assure_lit(value) def __repr__(self) -> str: - brepr = super().__repr__() - return f"Formatless({brepr})" + return f"Formatless({self._value!r})" def match(self, server: IServer, arg: str) -> bool: strip = formatting.strip(arg) return self._value.match(server, strip) diff --git a/ircrobots/params.py b/ircrobots/params.py index fcbdbc2..e2699fe 100644 --- a/ircrobots/params.py +++ b/ircrobots/params.py @@ -1,6 +1,9 @@ +from re import compile as re_compile from typing import List, Optional from dataclasses import dataclass, field +from .security import TLS, TLSNoVerify, TLSVerifyChain + class SASLParams(object): mechanism: str @@ -28,19 +31,24 @@ class ResumePolicy(object): address: str token: str +RE_IPV6HOST = re_compile("\[([a-fA-F0-9:]+)\]") + +_TLS_TYPES = { + "+": TLSVerifyChain, + "~": TLSNoVerify, +} @dataclass class ConnectionParams(object): nickname: str host: str port: int - tls: bool + tls: Optional[TLS] = field(default_factory=TLSVerifyChain) username: Optional[str] = None realname: Optional[str] = None bindhost: Optional[str] = None password: Optional[str] = None - tls_verify: bool = True sasl: Optional[SASLParams] = None sts: Optional[STSPolicy] = None @@ -50,3 +58,26 @@ class ConnectionParams(object): alt_nicknames: List[str] = field(default_factory=list) autojoin: List[str] = field(default_factory=list) + + @staticmethod + def from_hoststring( + nickname: str, + hoststring: str + ) -> "ConnectionParams": + + ipv6host = RE_IPV6HOST.search(hoststring) + if ipv6host is not None and ipv6host.start() == 0: + host = ipv6host.group(1) + port_s = hoststring[ipv6host.end()+1:] + else: + host, _, port_s = hoststring.strip().partition(":") + + tls_type: Optional[TLS] = None + if not port_s: + port_s = "6667" + else: + tls_type = _TLS_TYPES.get(port_s[0], lambda: None)() + if tls_type is not None: + port_s = port_s[1:] or "6697" + + return ConnectionParams(nickname, host, int(port_s), tls_type) diff --git a/ircrobots/sasl.py b/ircrobots/sasl.py index 887484c..8f3e21c 100644 --- a/ircrobots/sasl.py +++ b/ircrobots/sasl.py @@ -32,7 +32,9 @@ AUTH_BYTE_MAX = 400 AUTHENTICATE_ANY = Response("AUTHENTICATE", [ANY]) NUMERICS_FAIL = Response(ERR_SASLFAIL) -NUMERICS_INITIAL = Responses([ERR_SASLFAIL, ERR_SASLALREADY, RPL_SASLMECHS]) +NUMERICS_INITIAL = Responses([ + ERR_SASLFAIL, ERR_SASLALREADY, RPL_SASLMECHS, ERR_SASLABORTED +]) NUMERICS_LAST = Responses([RPL_SASLSUCCESS, ERR_SASLFAIL]) def _b64e(s: str): @@ -150,6 +152,8 @@ class SASLContext(ServerContext): return SASLResult.SUCCESS elif line.command == "904": match.pop(0) + else: + break return SASLResult.FAILURE diff --git a/ircrobots/security.py b/ircrobots/security.py index 17d1b78..96c9f5c 100644 --- a/ircrobots/security.py +++ b/ircrobots/security.py @@ -1,13 +1,29 @@ import ssl +from dataclasses import dataclass +from typing import Optional, Tuple + +@dataclass +class TLS: + client_keypair: Optional[Tuple[str, str]] = None + +# tls without verification +class TLSNoVerify(TLS): + pass + +# verify via CAs +class TLSVerifyChain(TLS): + pass + +# verify by a pinned hash +class TLSVerifyHash(TLSNoVerify): + def __init__(self, sum: str): + self.sum = sum.lower() +class TLSVerifySHA512(TLSVerifyHash): + pass def tls_context(verify: bool=True) -> ssl.SSLContext: - context = ssl.SSLContext(ssl.PROTOCOL_TLS) - context.options |= ssl.OP_NO_SSLv2 - context.options |= ssl.OP_NO_SSLv3 - context.options |= ssl.OP_NO_TLSv1 - context.load_default_certs() - - if verify: - context.verify_mode = ssl.CERT_REQUIRED - - return context + ctx = ssl.create_default_context() + if not verify: + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + return ctx diff --git a/ircrobots/server.py b/ircrobots/server.py index a70f355..64ff5e9 100644 --- a/ircrobots/server.py +++ b/ircrobots/server.py @@ -6,6 +6,7 @@ from collections import deque from time import monotonic import anyio +from asyncio_rlock import RLock from asyncio_throttle import Throttler from async_timeout import timeout as timeout_ from ircstates import Emit, Channel, ChannelUser @@ -28,7 +29,7 @@ from .interface import ITCPTransport, ITCPReader, ITCPWriter THROTTLE_RATE = 4 # lines THROTTLE_TIME = 2 # seconds -PING_INTERVAL = 60 # seconds +PING_TIMEOUT = 60 # seconds WAIT_TIMEOUT = 20 # seconds JOIN_ERR_FIRST = [ @@ -54,8 +55,7 @@ class Server(IServer): self.disconnected = False - self.throttle = Throttler( - rate_limit=100, period=THROTTLE_TIME) + self.throttle = Throttler(rate_limit=100, period=1) self.sasl_state = SASLResult.NONE self.last_read = monotonic() @@ -64,12 +64,14 @@ class Server(IServer): self._send_queue: PriorityQueue[SentLine] = PriorityQueue() self.desired_caps: Set[ICapability] = set([]) - self.read_lock = asyncio.Lock() self._read_queue: Deque[Line] = deque() - self._process_queue: Deque[Line] = deque() + self._process_queue: Deque[Tuple[Line, Optional[Emit]]] = deque() - self._wait_for: Optional[Tuple[Awaitable, WaitFor]] = None - self._wait_for_fut: Optional[Future[WaitFor]] = None + self._ping_sent = False + self._read_lguard = RLock() + self.read_lock = self._read_lguard + self._read_lwork = asyncio.Lock() + self._wait_for = asyncio.Event() self._pending_who: Deque[str] = deque() self._alt_nicks: List[str] = [] @@ -122,9 +124,8 @@ class Server(IServer): reader, writer = await transport.connect( params.host, params.port, - tls =params.tls, - tls_verify=params.tls_verify, - bindhost =params.bindhost) + tls =params.tls, + bindhost =params.bindhost) self._reader = reader self._writer = writer @@ -179,9 +180,9 @@ class Server(IServer): self._pending_who[0] == chan): self._pending_who.popleft() await self._next_who() - - elif (line.command == ERR_NICKNAMEINUSE and - not self.registered): + elif (line.command in { + ERR_NICKNAMEINUSE, ERR_ERRONEUSNICKNAME, ERR_UNAVAILRESOURCE + } and not self.registered): if self._alt_nicks: nick = self._alt_nicks.pop(0) await self.send(build("NICK", [nick])) @@ -262,44 +263,61 @@ class Server(IServer): else: await self.send(build("WHO", [chan])) - async def _read_line(self) -> Line: + async def _read_line(self, timeout: float) -> Optional[Line]: while True: - async with self.read_lock: - if self._read_queue: - return self._read_queue.popleft() + if self._read_queue: + return self._read_queue.popleft() - data = await self._reader.read(1024) - lines = self.recv(data) - # last_read under self.recv() as recv might throw Disconnected - self.last_read = monotonic() - for line in lines: - self._read_queue.append(line) + try: + async with timeout_(timeout): + data = await self._reader.read(1024) + except asyncio.TimeoutError: + return None + + self.last_read = monotonic() + lines = self.recv(data) + for line in lines: + self.line_preread(line) + self._read_queue.append(line) async def _read_lines(self): - sent_ping = False while True: - if not self._process_queue: - try: - async with timeout_(PING_INTERVAL): - line = await self._read_line() - except asyncio.TimeoutError: - if not sent_ping: - sent_ping = True - await self.send(build("PING", ["hello"])) - continue - else: - raise ServerDisconnectedException() - else: - sent_ping = False - self._process_queue.append(line) + async with self._read_lguard: + pass - line = self._process_queue.popleft() - emit = self.parse_tokens(line) - await self._on_read(line, emit) + if not self._process_queue: + async with self._read_lwork: + read_aw = asyncio.create_task(self._read_line(PING_TIMEOUT)) + wait_aw = asyncio.create_task(self._wait_for.wait()) + dones, notdones = await asyncio.wait( + [read_aw, wait_aw], + return_when=asyncio.FIRST_COMPLETED + ) + self._wait_for.clear() + + for done in dones: + if isinstance(done.result(), Line): + self._ping_sent = False + line = done.result() + emit = self.parse_tokens(line) + self._process_queue.append((line, emit)) + elif done.result() is None: + if not self._ping_sent: + await self.send(build("PING", ["hello"])) + self._ping_sent = True + else: + await self.disconnect() + raise ServerDisconnectedException() + for notdone in notdones: + notdone.cancel() + + else: + line, emit = self._process_queue.popleft() + await self._on_read(line, emit) async def wait_for(self, response: Union[IMatchResponse, Set[IMatchResponse]], - label: Optional[str]=None, + sent_aw: Optional[Awaitable[SentLine]]=None, timeout: float=WAIT_TIMEOUT ) -> Line: @@ -309,13 +327,18 @@ class Server(IServer): else: response_obj = response - wait_for = WaitFor(response_obj, label) - async with timeout_(timeout): - while True: - line = await self._read_line() - self._process_queue.append(line) - if wait_for.match(self, line): - return line + async with self._read_lguard: + self._wait_for.set() + async with self._read_lwork: + async with timeout_(timeout): + while True: + line = await self._read_line(timeout) + if line: + self._ping_sent = False + emit = self.parse_tokens(line) + self._process_queue.append((line, emit)) + if response_obj.match(self, line): + return line async def _on_send_line(self, line: Line): if (line.command in ["PRIVMSG", "NOTICE", "TAGMSG"] and @@ -531,7 +554,7 @@ class Server(IServer): for symbol in symbols: mode = self.isupport.prefix.from_prefix(symbol) if mode is not None: - channel_user.modes.append(mode) + channel_user.modes.add(mode) obj.channels.append(channel_user) elif line.command == RPL_ENDOFWHOIS: diff --git a/ircrobots/transport.py b/ircrobots/transport.py index 291409c..3a43cb3 100644 --- a/ircrobots/transport.py +++ b/ircrobots/transport.py @@ -1,10 +1,12 @@ +from hashlib import sha512 from ssl import SSLContext from typing import Optional, Tuple from asyncio import StreamReader, StreamWriter from async_stagger import open_connection from .interface import ITCPTransport, ITCPReader, ITCPWriter -from .security import tls_context +from .security import (tls_context, TLS, TLSNoVerify, TLSVerifyHash, + TLSVerifySHA512) class TCPReader(ITCPReader): def __init__(self, reader: StreamReader): @@ -32,16 +34,18 @@ class TCPWriter(ITCPWriter): class TCPTransport(ITCPTransport): async def connect(self, - hostname: str, - port: int, - tls: bool, - tls_verify: bool=True, - bindhost: Optional[str]=None + hostname: str, + port: int, + tls: Optional[TLS], + bindhost: Optional[str]=None ) -> Tuple[ITCPReader, ITCPWriter]: cur_ssl: Optional[SSLContext] = None - if tls: - cur_ssl = tls_context(tls_verify) + if tls is not None: + cur_ssl = tls_context(not isinstance(tls, TLSNoVerify)) + if tls.client_keypair is not None: + (client_cert, client_key) = tls.client_keypair + cur_ssl.load_cert_chain(client_cert, keyfile=client_key) local_addr: Optional[Tuple[str, int]] = None if not bindhost is None: @@ -55,5 +59,20 @@ class TCPTransport(ITCPTransport): server_hostname=server_hostname, ssl =cur_ssl, local_addr =local_addr) + + if isinstance(tls, TLSVerifyHash): + cert: bytes = writer.transport.get_extra_info( + "ssl_object" + ).getpeercert(True) + if isinstance(tls, TLSVerifySHA512): + sum = sha512(cert).hexdigest() + else: + raise ValueError(f"unknown hash pinning {type(tls)}") + + if not sum == tls.sum: + raise ValueError( + f"pinned hash for {hostname} does not match ({sum})" + ) + return (TCPReader(reader), TCPWriter(writer)) diff --git a/requirements.txt b/requirements.txt index e9918bb..4193582 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ anyio ~=2.0.2 +asyncio-rlock ~=0.1.0 asyncio-throttle ~=1.0.1 -dataclasses ~=0.6; python_version<"3.7" -ircstates ~=0.11.7 +ircstates ~=0.13.0 async_stagger ~=0.3.0 -async_timeout ~=3.0.1 +async_timeout ~=4.0.2 diff --git a/setup.py b/setup.py index 7bc3e5c..b4b6b17 100644 --- a/setup.py +++ b/setup.py @@ -26,6 +26,6 @@ setup( "Operating System :: Microsoft :: Windows", "Topic :: Communications :: Chat :: Internet Relay Chat" ], - python_requires='>=3.6', + python_requires='>=3.7', install_requires=install_requires )