Compare commits

..

1 commit

Author SHA1 Message Date
jesopo
2cab5b3002 simplify wait_for 2021-02-18 14:50:01 +00:00
19 changed files with 105 additions and 222 deletions

View file

@ -3,7 +3,7 @@ cache: pip
python: python:
- "3.7" - "3.7"
- "3.8" - "3.8"
- "3.9" - "3.8-dev"
install: install:
- pip3 install mypy -r requirements.txt - pip3 install mypy -r requirements.txt
script: script:

View file

@ -11,4 +11,4 @@ see [examples/](examples/) for some usage demonstration.
## contact ## contact
Come say hi at `#irctokens` on irc.libera.chat Come say hi at [##irctokens on freenode](https://webchat.freenode.net/?channels=%23%23irctokens)

View file

@ -1 +1 @@
0.7.0 0.3.7

View file

@ -154,8 +154,8 @@ async def main(hostname: str, channel: str, nickname: str):
params = ConnectionParams( params = ConnectionParams(
nickname, nickname,
hostname, hostname,
6697 6697,
) tls=True)
await bot.add_server("freenode", params) await bot.add_server("freenode", params)
await bot.run() await bot.run()

View file

@ -23,6 +23,7 @@ async def main():
"MyNickname", "MyNickname",
host = "chat.freenode.invalid", host = "chat.freenode.invalid",
port = 6697, port = 6697,
tls = True,
sasl = sasl_params) sasl = sasl_params)
await bot.add_server("freenode", params) await bot.add_server("freenode", params)

View file

@ -25,7 +25,7 @@ class Bot(BaseBot):
async def main(): async def main():
bot = Bot() bot = Bot()
for name, host in SERVERS: for name, host in SERVERS:
params = ConnectionParams("BitBotNewTest", host, 6697) params = ConnectionParams("BitBotNewTest", host, 6697, True)
await bot.add_server(name, params) await bot.add_server(name, params)
await bot.run() await bot.run()

View file

@ -3,4 +3,3 @@ from .server import Server
from .params import (ConnectionParams, SASLUserPass, SASLExternal, SASLSCRAM, from .params import (ConnectionParams, SASLUserPass, SASLExternal, SASLSCRAM,
STSPolicy, ResumePolicy) STSPolicy, ResumePolicy)
from .ircv3 import Capability from .ircv3 import Capability
from .security import TLS

View file

@ -19,17 +19,9 @@ class MaybeAwait(Generic[TEvent]):
class WaitFor(object): class WaitFor(object):
def __init__(self, def __init__(self,
response: IMatchResponse, response: IMatchResponse,
deadline: float): label: Optional[str]=None):
self.response = response self.response = response
self.deadline = deadline self._label = label
self._label: Optional[str] = None
self._our_fut: "Future[Line]" = Future()
def __await__(self) -> Generator[Any, None, Line]:
return self._our_fut.__await__()
def with_label(self, label: str):
self._label = label
def match(self, server: IServer, line: Line): def match(self, server: IServer, line: Line):
if (self._label is not None and if (self._label is not None and
@ -39,6 +31,3 @@ class WaitFor(object):
label == self._label): label == self._label):
return True return True
return self.response.match(server, line) return self.response.match(server, line)
def resolve(self, line: Line):
self._our_fut.set_result(line)

View file

@ -1,4 +1,4 @@
import asyncio, traceback import asyncio
import anyio import anyio
from typing import Dict from typing import Dict
@ -6,53 +6,40 @@ from ircstates.server import ServerDisconnectedException
from .server import ConnectionParams, Server from .server import ConnectionParams, Server
from .transport import TCPTransport from .transport import TCPTransport
from .interface import IBot, IServer, ITCPTransport from .interface import IBot, IServer
class Bot(IBot): class Bot(IBot):
def __init__(self): def __init__(self):
self.servers: Dict[str, Server] = {} self.servers: Dict[str, Server] = {}
self._server_queue: asyncio.Queue[Server] = asyncio.Queue() self._server_queue: asyncio.Queue[Server] = asyncio.Queue()
# methods designed to be overridden
def create_server(self, name: str): def create_server(self, name: str):
return Server(self, name) return Server(self, name)
async def disconnected(self, server: IServer): async def disconnected(self, server: IServer):
if (server.name in self.servers and if (server.name in self.servers and
server.params is not None and server.params is not None and
server.disconnected): server.disconnected):
await asyncio.sleep(server.params.reconnect)
reconnect = server.params.reconnect await self.add_server(server.name, server.params)
# /methods designed to be overridden
while True:
await asyncio.sleep(reconnect)
try:
await self.add_server(server.name, server.params)
except Exception as e:
traceback.print_exc()
# let's try again, exponential backoff up to 5 mins
reconnect = min(reconnect*2, 300)
else:
break
async def disconnect(self, server: IServer): async def disconnect(self, server: IServer):
del self.servers[server.name]
await server.disconnect() await server.disconnect()
del self.servers[server.name]
async def add_server(self, async def add_server(self, name: str, params: ConnectionParams) -> Server:
name: str,
params: ConnectionParams,
transport: ITCPTransport = TCPTransport()) -> Server:
server = self.create_server(name) server = self.create_server(name)
self.servers[name] = server self.servers[name] = server
await server.connect(transport, params) await server.connect(TCPTransport(), params)
await self._server_queue.put(server) await self._server_queue.put(server)
return server return server
async def _run_server(self, server: Server): async def _run_server(self, server: Server):
try: try:
async with anyio.create_task_group() as tg: async with anyio.create_task_group() as tg:
tg.start_soon(server._read_lines) await tg.spawn(server._read_lines)
tg.start_soon(server._send_lines) await tg.spawn(server._send_lines)
except ServerDisconnectedException: except ServerDisconnectedException:
server.disconnected = True server.disconnected = True
@ -62,4 +49,4 @@ class Bot(IBot):
async with anyio.create_task_group() as tg: async with anyio.create_task_group() as tg:
while not tg.cancel_scope.cancel_called: while not tg.cancel_scope.cancel_called:
server = await self._server_queue.get() server = await self._server_queue.get()
tg.start_soon(self._run_server, server) await tg.spawn(self._run_server, server)

View file

@ -6,7 +6,6 @@ from ircstates import Server, Emit
from irctokens import Line, Hostmask from irctokens import Line, Hostmask
from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy
from .security import TLS
class ITCPReader(object): class ITCPReader(object):
async def read(self, byte_count: int): async def read(self, byte_count: int):
@ -25,10 +24,11 @@ class ITCPWriter(object):
class ITCPTransport(object): class ITCPTransport(object):
async def connect(self, async def connect(self,
hostname: str, hostname: str,
port: int, port: int,
tls: Optional[TLS], tls: bool,
bindhost: Optional[str]=None tls_verify: bool=True,
bindhost: Optional[str]=None
) -> Tuple[ITCPReader, ITCPWriter]: ) -> Tuple[ITCPReader, ITCPWriter]:
pass pass

View file

@ -8,7 +8,6 @@ from .contexts import ServerContext
from .matching import Response, ANY from .matching import Response, ANY
from .interface import ICapability from .interface import ICapability
from .params import ConnectionParams, STSPolicy, ResumePolicy from .params import ConnectionParams, STSPolicy, ResumePolicy
from .security import TLSVerifyChain
class Capability(ICapability): class Capability(ICapability):
def __init__(self, def __init__(self,
@ -102,12 +101,12 @@ def _cap_dict(s: str) -> Dict[str, str]:
return d return d
async def sts_transmute(params: ConnectionParams): async def sts_transmute(params: ConnectionParams):
if not params.sts is None and params.tls is None: if not params.sts is None and not params.tls:
now = time() now = time()
since = (now-params.sts.created) since = (now-params.sts.created)
if since <= params.sts.duration: if since <= params.sts.duration:
params.port = params.sts.port params.port = params.sts.port
params.tls = TLSVerifyChain() params.tls = True
async def resume_transmute(params: ConnectionParams): async def resume_transmute(params: ConnectionParams):
if params.resume is not None: if params.resume is not None:
params.host = params.resume.address params.host = params.resume.address
@ -183,7 +182,7 @@ class CAPContext(ServerContext):
if not params.tls: if not params.tls:
if "port" in sts_dict: if "port" in sts_dict:
params.port = int(sts_dict["port"]) params.port = int(sts_dict["port"])
params.tls = TLSVerifyChain() params.tls = True
await self.server.bot.disconnect(self.server) await self.server.bot.disconnect(self.server)
await self.server.bot.add_server(self.server.name, params) await self.server.bot.add_server(self.server.name, params)

View file

@ -73,7 +73,8 @@ class Formatless(IMatchResponseParam):
def __init__(self, value: TYPE_MAYBELIT_VALUE): def __init__(self, value: TYPE_MAYBELIT_VALUE):
self._value = _assure_lit(value) self._value = _assure_lit(value)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"Formatless({self._value!r})" brepr = super().__repr__()
return f"Formatless({brepr})"
def match(self, server: IServer, arg: str) -> bool: def match(self, server: IServer, arg: str) -> bool:
strip = formatting.strip(arg) strip = formatting.strip(arg)
return self._value.match(server, strip) return self._value.match(server, strip)

View file

@ -1,9 +1,6 @@
from re import compile as re_compile
from typing import List, Optional from typing import List, Optional
from dataclasses import dataclass, field from dataclasses import dataclass, field
from .security import TLS, TLSNoVerify, TLSVerifyChain
class SASLParams(object): class SASLParams(object):
mechanism: str mechanism: str
@ -31,24 +28,19 @@ class ResumePolicy(object):
address: str address: str
token: str token: str
RE_IPV6HOST = re_compile("\[([a-fA-F0-9:]+)\]")
_TLS_TYPES = {
"+": TLSVerifyChain,
"~": TLSNoVerify,
}
@dataclass @dataclass
class ConnectionParams(object): class ConnectionParams(object):
nickname: str nickname: str
host: str host: str
port: int port: int
tls: Optional[TLS] = field(default_factory=TLSVerifyChain) tls: bool
username: Optional[str] = None username: Optional[str] = None
realname: Optional[str] = None realname: Optional[str] = None
bindhost: Optional[str] = None bindhost: Optional[str] = None
password: Optional[str] = None password: Optional[str] = None
tls_verify: bool = True
sasl: Optional[SASLParams] = None sasl: Optional[SASLParams] = None
sts: Optional[STSPolicy] = None sts: Optional[STSPolicy] = None
@ -58,26 +50,3 @@ class ConnectionParams(object):
alt_nicknames: List[str] = field(default_factory=list) alt_nicknames: List[str] = field(default_factory=list)
autojoin: List[str] = field(default_factory=list) autojoin: List[str] = field(default_factory=list)
@staticmethod
def from_hoststring(
nickname: str,
hoststring: str
) -> "ConnectionParams":
ipv6host = RE_IPV6HOST.search(hoststring)
if ipv6host is not None and ipv6host.start() == 0:
host = ipv6host.group(1)
port_s = hoststring[ipv6host.end()+1:]
else:
host, _, port_s = hoststring.strip().partition(":")
tls_type: Optional[TLS] = None
if not port_s:
port_s = "6667"
else:
tls_type = _TLS_TYPES.get(port_s[0], lambda: None)()
if tls_type is not None:
port_s = port_s[1:] or "6697"
return ConnectionParams(nickname, host, int(port_s), tls_type)

View file

@ -32,9 +32,7 @@ AUTH_BYTE_MAX = 400
AUTHENTICATE_ANY = Response("AUTHENTICATE", [ANY]) AUTHENTICATE_ANY = Response("AUTHENTICATE", [ANY])
NUMERICS_FAIL = Response(ERR_SASLFAIL) NUMERICS_FAIL = Response(ERR_SASLFAIL)
NUMERICS_INITIAL = Responses([ NUMERICS_INITIAL = Responses([ERR_SASLFAIL, ERR_SASLALREADY, RPL_SASLMECHS])
ERR_SASLFAIL, ERR_SASLALREADY, RPL_SASLMECHS, ERR_SASLABORTED
])
NUMERICS_LAST = Responses([RPL_SASLSUCCESS, ERR_SASLFAIL]) NUMERICS_LAST = Responses([RPL_SASLSUCCESS, ERR_SASLFAIL])
def _b64e(s: str): def _b64e(s: str):
@ -152,8 +150,6 @@ class SASLContext(ServerContext):
return SASLResult.SUCCESS return SASLResult.SUCCESS
elif line.command == "904": elif line.command == "904":
match.pop(0) match.pop(0)
else:
break
return SASLResult.FAILURE return SASLResult.FAILURE

View file

@ -1,29 +1,13 @@
import ssl import ssl
from dataclasses import dataclass
from typing import Optional, Tuple
@dataclass
class TLS:
client_keypair: Optional[Tuple[str, str]] = None
# tls without verification
class TLSNoVerify(TLS):
pass
# verify via CAs
class TLSVerifyChain(TLS):
pass
# verify by a pinned hash
class TLSVerifyHash(TLSNoVerify):
def __init__(self, sum: str):
self.sum = sum.lower()
class TLSVerifySHA512(TLSVerifyHash):
pass
def tls_context(verify: bool=True) -> ssl.SSLContext: def tls_context(verify: bool=True) -> ssl.SSLContext:
ctx = ssl.create_default_context() context = ssl.SSLContext(ssl.PROTOCOL_TLS)
if not verify: context.options |= ssl.OP_NO_SSLv2
ctx.check_hostname = False context.options |= ssl.OP_NO_SSLv3
ctx.verify_mode = ssl.CERT_NONE context.options |= ssl.OP_NO_TLSv1
return ctx context.load_default_certs()
if verify:
context.verify_mode = ssl.CERT_REQUIRED
return context

View file

@ -6,7 +6,6 @@ from collections import deque
from time import monotonic from time import monotonic
import anyio import anyio
from asyncio_rlock import RLock
from asyncio_throttle import Throttler from asyncio_throttle import Throttler
from async_timeout import timeout as timeout_ from async_timeout import timeout as timeout_
from ircstates import Emit, Channel, ChannelUser from ircstates import Emit, Channel, ChannelUser
@ -29,7 +28,7 @@ from .interface import ITCPTransport, ITCPReader, ITCPWriter
THROTTLE_RATE = 4 # lines THROTTLE_RATE = 4 # lines
THROTTLE_TIME = 2 # seconds THROTTLE_TIME = 2 # seconds
PING_TIMEOUT = 60 # seconds PING_INTERVAL = 60 # seconds
WAIT_TIMEOUT = 20 # seconds WAIT_TIMEOUT = 20 # seconds
JOIN_ERR_FIRST = [ JOIN_ERR_FIRST = [
@ -55,7 +54,8 @@ class Server(IServer):
self.disconnected = False self.disconnected = False
self.throttle = Throttler(rate_limit=100, period=1) self.throttle = Throttler(
rate_limit=100, period=THROTTLE_TIME)
self.sasl_state = SASLResult.NONE self.sasl_state = SASLResult.NONE
self.last_read = monotonic() self.last_read = monotonic()
@ -64,14 +64,12 @@ class Server(IServer):
self._send_queue: PriorityQueue[SentLine] = PriorityQueue() self._send_queue: PriorityQueue[SentLine] = PriorityQueue()
self.desired_caps: Set[ICapability] = set([]) self.desired_caps: Set[ICapability] = set([])
self.read_lock = asyncio.Lock()
self._read_queue: Deque[Line] = deque() self._read_queue: Deque[Line] = deque()
self._process_queue: Deque[Tuple[Line, Optional[Emit]]] = deque() self._process_queue: Deque[Line] = deque()
self._ping_sent = False self._wait_for: Optional[Tuple[Awaitable, WaitFor]] = None
self._read_lguard = RLock() self._wait_for_fut: Optional[Future[WaitFor]] = None
self.read_lock = self._read_lguard
self._read_lwork = asyncio.Lock()
self._wait_for = asyncio.Event()
self._pending_who: Deque[str] = deque() self._pending_who: Deque[str] = deque()
self._alt_nicks: List[str] = [] self._alt_nicks: List[str] = []
@ -124,8 +122,9 @@ class Server(IServer):
reader, writer = await transport.connect( reader, writer = await transport.connect(
params.host, params.host,
params.port, params.port,
tls =params.tls, tls =params.tls,
bindhost =params.bindhost) tls_verify=params.tls_verify,
bindhost =params.bindhost)
self._reader = reader self._reader = reader
self._writer = writer self._writer = writer
@ -180,9 +179,9 @@ class Server(IServer):
self._pending_who[0] == chan): self._pending_who[0] == chan):
self._pending_who.popleft() self._pending_who.popleft()
await self._next_who() await self._next_who()
elif (line.command in {
ERR_NICKNAMEINUSE, ERR_ERRONEUSNICKNAME, ERR_UNAVAILRESOURCE elif (line.command == ERR_NICKNAMEINUSE and
} and not self.registered): not self.registered):
if self._alt_nicks: if self._alt_nicks:
nick = self._alt_nicks.pop(0) nick = self._alt_nicks.pop(0)
await self.send(build("NICK", [nick])) await self.send(build("NICK", [nick]))
@ -263,61 +262,44 @@ class Server(IServer):
else: else:
await self.send(build("WHO", [chan])) await self.send(build("WHO", [chan]))
async def _read_line(self, timeout: float) -> Optional[Line]: async def _read_line(self) -> Line:
while True: while True:
if self._read_queue: async with self.read_lock:
return self._read_queue.popleft() if self._read_queue:
return self._read_queue.popleft()
try: data = await self._reader.read(1024)
async with timeout_(timeout): lines = self.recv(data)
data = await self._reader.read(1024) # last_read under self.recv() as recv might throw Disconnected
except asyncio.TimeoutError: self.last_read = monotonic()
return None for line in lines:
self._read_queue.append(line)
self.last_read = monotonic()
lines = self.recv(data)
for line in lines:
self.line_preread(line)
self._read_queue.append(line)
async def _read_lines(self): async def _read_lines(self):
sent_ping = False
while True: while True:
async with self._read_lguard:
pass
if not self._process_queue: if not self._process_queue:
async with self._read_lwork: try:
read_aw = asyncio.create_task(self._read_line(PING_TIMEOUT)) async with timeout_(PING_INTERVAL):
wait_aw = asyncio.create_task(self._wait_for.wait()) line = await self._read_line()
dones, notdones = await asyncio.wait( except asyncio.TimeoutError:
[read_aw, wait_aw], if not sent_ping:
return_when=asyncio.FIRST_COMPLETED sent_ping = True
) await self.send(build("PING", ["hello"]))
self._wait_for.clear() continue
else:
raise ServerDisconnectedException()
else:
sent_ping = False
self._process_queue.append(line)
for done in dones: line = self._process_queue.popleft()
if isinstance(done.result(), Line): emit = self.parse_tokens(line)
self._ping_sent = False await self._on_read(line, emit)
line = done.result()
emit = self.parse_tokens(line)
self._process_queue.append((line, emit))
elif done.result() is None:
if not self._ping_sent:
await self.send(build("PING", ["hello"]))
self._ping_sent = True
else:
await self.disconnect()
raise ServerDisconnectedException()
for notdone in notdones:
notdone.cancel()
else:
line, emit = self._process_queue.popleft()
await self._on_read(line, emit)
async def wait_for(self, async def wait_for(self,
response: Union[IMatchResponse, Set[IMatchResponse]], response: Union[IMatchResponse, Set[IMatchResponse]],
sent_aw: Optional[Awaitable[SentLine]]=None, label: Optional[str]=None,
timeout: float=WAIT_TIMEOUT timeout: float=WAIT_TIMEOUT
) -> Line: ) -> Line:
@ -327,18 +309,13 @@ class Server(IServer):
else: else:
response_obj = response response_obj = response
async with self._read_lguard: wait_for = WaitFor(response_obj, label)
self._wait_for.set() async with timeout_(timeout):
async with self._read_lwork: while True:
async with timeout_(timeout): line = await self._read_line()
while True: self._process_queue.append(line)
line = await self._read_line(timeout) if wait_for.match(self, line):
if line: return line
self._ping_sent = False
emit = self.parse_tokens(line)
self._process_queue.append((line, emit))
if response_obj.match(self, line):
return line
async def _on_send_line(self, line: Line): async def _on_send_line(self, line: Line):
if (line.command in ["PRIVMSG", "NOTICE", "TAGMSG"] and if (line.command in ["PRIVMSG", "NOTICE", "TAGMSG"] and
@ -554,7 +531,7 @@ class Server(IServer):
for symbol in symbols: for symbol in symbols:
mode = self.isupport.prefix.from_prefix(symbol) mode = self.isupport.prefix.from_prefix(symbol)
if mode is not None: if mode is not None:
channel_user.modes.add(mode) channel_user.modes.append(mode)
obj.channels.append(channel_user) obj.channels.append(channel_user)
elif line.command == RPL_ENDOFWHOIS: elif line.command == RPL_ENDOFWHOIS:

View file

@ -1,12 +1,10 @@
from hashlib import sha512
from ssl import SSLContext from ssl import SSLContext
from typing import Optional, Tuple from typing import Optional, Tuple
from asyncio import StreamReader, StreamWriter from asyncio import StreamReader, StreamWriter
from async_stagger import open_connection from async_stagger import open_connection
from .interface import ITCPTransport, ITCPReader, ITCPWriter from .interface import ITCPTransport, ITCPReader, ITCPWriter
from .security import (tls_context, TLS, TLSNoVerify, TLSVerifyHash, from .security import tls_context
TLSVerifySHA512)
class TCPReader(ITCPReader): class TCPReader(ITCPReader):
def __init__(self, reader: StreamReader): def __init__(self, reader: StreamReader):
@ -34,18 +32,16 @@ class TCPWriter(ITCPWriter):
class TCPTransport(ITCPTransport): class TCPTransport(ITCPTransport):
async def connect(self, async def connect(self,
hostname: str, hostname: str,
port: int, port: int,
tls: Optional[TLS], tls: bool,
bindhost: Optional[str]=None tls_verify: bool=True,
bindhost: Optional[str]=None
) -> Tuple[ITCPReader, ITCPWriter]: ) -> Tuple[ITCPReader, ITCPWriter]:
cur_ssl: Optional[SSLContext] = None cur_ssl: Optional[SSLContext] = None
if tls is not None: if tls:
cur_ssl = tls_context(not isinstance(tls, TLSNoVerify)) cur_ssl = tls_context(tls_verify)
if tls.client_keypair is not None:
(client_cert, client_key) = tls.client_keypair
cur_ssl.load_cert_chain(client_cert, keyfile=client_key)
local_addr: Optional[Tuple[str, int]] = None local_addr: Optional[Tuple[str, int]] = None
if not bindhost is None: if not bindhost is None:
@ -59,20 +55,5 @@ class TCPTransport(ITCPTransport):
server_hostname=server_hostname, server_hostname=server_hostname,
ssl =cur_ssl, ssl =cur_ssl,
local_addr =local_addr) local_addr =local_addr)
if isinstance(tls, TLSVerifyHash):
cert: bytes = writer.transport.get_extra_info(
"ssl_object"
).getpeercert(True)
if isinstance(tls, TLSVerifySHA512):
sum = sha512(cert).hexdigest()
else:
raise ValueError(f"unknown hash pinning {type(tls)}")
if not sum == tls.sum:
raise ValueError(
f"pinned hash for {hostname} does not match ({sum})"
)
return (TCPReader(reader), TCPWriter(writer)) return (TCPReader(reader), TCPWriter(writer))

View file

@ -1,6 +1,6 @@
anyio ~=2.0.2 anyio ~=2.0.2
asyncio-rlock ~=0.1.0
asyncio-throttle ~=1.0.1 asyncio-throttle ~=1.0.1
ircstates ~=0.13.0 dataclasses ~=0.6; python_version<"3.7"
ircstates ~=0.11.7
async_stagger ~=0.3.0 async_stagger ~=0.3.0
async_timeout ~=4.0.2 async_timeout ~=3.0.1

View file

@ -26,6 +26,6 @@ setup(
"Operating System :: Microsoft :: Windows", "Operating System :: Microsoft :: Windows",
"Topic :: Communications :: Chat :: Internet Relay Chat" "Topic :: Communications :: Chat :: Internet Relay Chat"
], ],
python_requires='>=3.7', python_requires='>=3.6',
install_requires=install_requires install_requires=install_requires
) )