Compare commits
1 commit
master
...
better-wai
Author | SHA1 | Date | |
---|---|---|---|
|
2cab5b3002 |
19 changed files with 105 additions and 222 deletions
|
@ -3,7 +3,7 @@ cache: pip
|
||||||
python:
|
python:
|
||||||
- "3.7"
|
- "3.7"
|
||||||
- "3.8"
|
- "3.8"
|
||||||
- "3.9"
|
- "3.8-dev"
|
||||||
install:
|
install:
|
||||||
- pip3 install mypy -r requirements.txt
|
- pip3 install mypy -r requirements.txt
|
||||||
script:
|
script:
|
||||||
|
|
|
@ -11,4 +11,4 @@ see [examples/](examples/) for some usage demonstration.
|
||||||
|
|
||||||
## contact
|
## contact
|
||||||
|
|
||||||
Come say hi at `#irctokens` on irc.libera.chat
|
Come say hi at [##irctokens on freenode](https://webchat.freenode.net/?channels=%23%23irctokens)
|
||||||
|
|
2
VERSION
2
VERSION
|
@ -1 +1 @@
|
||||||
0.7.0
|
0.3.7
|
||||||
|
|
|
@ -154,8 +154,8 @@ async def main(hostname: str, channel: str, nickname: str):
|
||||||
params = ConnectionParams(
|
params = ConnectionParams(
|
||||||
nickname,
|
nickname,
|
||||||
hostname,
|
hostname,
|
||||||
6697
|
6697,
|
||||||
)
|
tls=True)
|
||||||
await bot.add_server("freenode", params)
|
await bot.add_server("freenode", params)
|
||||||
await bot.run()
|
await bot.run()
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,7 @@ async def main():
|
||||||
"MyNickname",
|
"MyNickname",
|
||||||
host = "chat.freenode.invalid",
|
host = "chat.freenode.invalid",
|
||||||
port = 6697,
|
port = 6697,
|
||||||
|
tls = True,
|
||||||
sasl = sasl_params)
|
sasl = sasl_params)
|
||||||
|
|
||||||
await bot.add_server("freenode", params)
|
await bot.add_server("freenode", params)
|
||||||
|
|
|
@ -25,7 +25,7 @@ class Bot(BaseBot):
|
||||||
async def main():
|
async def main():
|
||||||
bot = Bot()
|
bot = Bot()
|
||||||
for name, host in SERVERS:
|
for name, host in SERVERS:
|
||||||
params = ConnectionParams("BitBotNewTest", host, 6697)
|
params = ConnectionParams("BitBotNewTest", host, 6697, True)
|
||||||
await bot.add_server(name, params)
|
await bot.add_server(name, params)
|
||||||
|
|
||||||
await bot.run()
|
await bot.run()
|
||||||
|
|
|
@ -3,4 +3,3 @@ from .server import Server
|
||||||
from .params import (ConnectionParams, SASLUserPass, SASLExternal, SASLSCRAM,
|
from .params import (ConnectionParams, SASLUserPass, SASLExternal, SASLSCRAM,
|
||||||
STSPolicy, ResumePolicy)
|
STSPolicy, ResumePolicy)
|
||||||
from .ircv3 import Capability
|
from .ircv3 import Capability
|
||||||
from .security import TLS
|
|
||||||
|
|
|
@ -19,17 +19,9 @@ class MaybeAwait(Generic[TEvent]):
|
||||||
class WaitFor(object):
|
class WaitFor(object):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
response: IMatchResponse,
|
response: IMatchResponse,
|
||||||
deadline: float):
|
label: Optional[str]=None):
|
||||||
self.response = response
|
self.response = response
|
||||||
self.deadline = deadline
|
self._label = label
|
||||||
self._label: Optional[str] = None
|
|
||||||
self._our_fut: "Future[Line]" = Future()
|
|
||||||
|
|
||||||
def __await__(self) -> Generator[Any, None, Line]:
|
|
||||||
return self._our_fut.__await__()
|
|
||||||
|
|
||||||
def with_label(self, label: str):
|
|
||||||
self._label = label
|
|
||||||
|
|
||||||
def match(self, server: IServer, line: Line):
|
def match(self, server: IServer, line: Line):
|
||||||
if (self._label is not None and
|
if (self._label is not None and
|
||||||
|
@ -39,6 +31,3 @@ class WaitFor(object):
|
||||||
label == self._label):
|
label == self._label):
|
||||||
return True
|
return True
|
||||||
return self.response.match(server, line)
|
return self.response.match(server, line)
|
||||||
|
|
||||||
def resolve(self, line: Line):
|
|
||||||
self._our_fut.set_result(line)
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
import asyncio, traceback
|
import asyncio
|
||||||
import anyio
|
import anyio
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
|
@ -6,53 +6,40 @@ from ircstates.server import ServerDisconnectedException
|
||||||
|
|
||||||
from .server import ConnectionParams, Server
|
from .server import ConnectionParams, Server
|
||||||
from .transport import TCPTransport
|
from .transport import TCPTransport
|
||||||
from .interface import IBot, IServer, ITCPTransport
|
from .interface import IBot, IServer
|
||||||
|
|
||||||
class Bot(IBot):
|
class Bot(IBot):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.servers: Dict[str, Server] = {}
|
self.servers: Dict[str, Server] = {}
|
||||||
self._server_queue: asyncio.Queue[Server] = asyncio.Queue()
|
self._server_queue: asyncio.Queue[Server] = asyncio.Queue()
|
||||||
|
|
||||||
|
# methods designed to be overridden
|
||||||
def create_server(self, name: str):
|
def create_server(self, name: str):
|
||||||
return Server(self, name)
|
return Server(self, name)
|
||||||
|
|
||||||
async def disconnected(self, server: IServer):
|
async def disconnected(self, server: IServer):
|
||||||
if (server.name in self.servers and
|
if (server.name in self.servers and
|
||||||
server.params is not None and
|
server.params is not None and
|
||||||
server.disconnected):
|
server.disconnected):
|
||||||
|
await asyncio.sleep(server.params.reconnect)
|
||||||
reconnect = server.params.reconnect
|
await self.add_server(server.name, server.params)
|
||||||
|
# /methods designed to be overridden
|
||||||
while True:
|
|
||||||
await asyncio.sleep(reconnect)
|
|
||||||
try:
|
|
||||||
await self.add_server(server.name, server.params)
|
|
||||||
except Exception as e:
|
|
||||||
traceback.print_exc()
|
|
||||||
# let's try again, exponential backoff up to 5 mins
|
|
||||||
reconnect = min(reconnect*2, 300)
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
async def disconnect(self, server: IServer):
|
async def disconnect(self, server: IServer):
|
||||||
del self.servers[server.name]
|
|
||||||
await server.disconnect()
|
await server.disconnect()
|
||||||
|
del self.servers[server.name]
|
||||||
|
|
||||||
async def add_server(self,
|
async def add_server(self, name: str, params: ConnectionParams) -> Server:
|
||||||
name: str,
|
|
||||||
params: ConnectionParams,
|
|
||||||
transport: ITCPTransport = TCPTransport()) -> Server:
|
|
||||||
server = self.create_server(name)
|
server = self.create_server(name)
|
||||||
self.servers[name] = server
|
self.servers[name] = server
|
||||||
await server.connect(transport, params)
|
await server.connect(TCPTransport(), params)
|
||||||
await self._server_queue.put(server)
|
await self._server_queue.put(server)
|
||||||
return server
|
return server
|
||||||
|
|
||||||
async def _run_server(self, server: Server):
|
async def _run_server(self, server: Server):
|
||||||
try:
|
try:
|
||||||
async with anyio.create_task_group() as tg:
|
async with anyio.create_task_group() as tg:
|
||||||
tg.start_soon(server._read_lines)
|
await tg.spawn(server._read_lines)
|
||||||
tg.start_soon(server._send_lines)
|
await tg.spawn(server._send_lines)
|
||||||
except ServerDisconnectedException:
|
except ServerDisconnectedException:
|
||||||
server.disconnected = True
|
server.disconnected = True
|
||||||
|
|
||||||
|
@ -62,4 +49,4 @@ class Bot(IBot):
|
||||||
async with anyio.create_task_group() as tg:
|
async with anyio.create_task_group() as tg:
|
||||||
while not tg.cancel_scope.cancel_called:
|
while not tg.cancel_scope.cancel_called:
|
||||||
server = await self._server_queue.get()
|
server = await self._server_queue.get()
|
||||||
tg.start_soon(self._run_server, server)
|
await tg.spawn(self._run_server, server)
|
||||||
|
|
|
@ -6,7 +6,6 @@ from ircstates import Server, Emit
|
||||||
from irctokens import Line, Hostmask
|
from irctokens import Line, Hostmask
|
||||||
|
|
||||||
from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy
|
from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy
|
||||||
from .security import TLS
|
|
||||||
|
|
||||||
class ITCPReader(object):
|
class ITCPReader(object):
|
||||||
async def read(self, byte_count: int):
|
async def read(self, byte_count: int):
|
||||||
|
@ -25,10 +24,11 @@ class ITCPWriter(object):
|
||||||
|
|
||||||
class ITCPTransport(object):
|
class ITCPTransport(object):
|
||||||
async def connect(self,
|
async def connect(self,
|
||||||
hostname: str,
|
hostname: str,
|
||||||
port: int,
|
port: int,
|
||||||
tls: Optional[TLS],
|
tls: bool,
|
||||||
bindhost: Optional[str]=None
|
tls_verify: bool=True,
|
||||||
|
bindhost: Optional[str]=None
|
||||||
) -> Tuple[ITCPReader, ITCPWriter]:
|
) -> Tuple[ITCPReader, ITCPWriter]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,6 @@ from .contexts import ServerContext
|
||||||
from .matching import Response, ANY
|
from .matching import Response, ANY
|
||||||
from .interface import ICapability
|
from .interface import ICapability
|
||||||
from .params import ConnectionParams, STSPolicy, ResumePolicy
|
from .params import ConnectionParams, STSPolicy, ResumePolicy
|
||||||
from .security import TLSVerifyChain
|
|
||||||
|
|
||||||
class Capability(ICapability):
|
class Capability(ICapability):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -102,12 +101,12 @@ def _cap_dict(s: str) -> Dict[str, str]:
|
||||||
return d
|
return d
|
||||||
|
|
||||||
async def sts_transmute(params: ConnectionParams):
|
async def sts_transmute(params: ConnectionParams):
|
||||||
if not params.sts is None and params.tls is None:
|
if not params.sts is None and not params.tls:
|
||||||
now = time()
|
now = time()
|
||||||
since = (now-params.sts.created)
|
since = (now-params.sts.created)
|
||||||
if since <= params.sts.duration:
|
if since <= params.sts.duration:
|
||||||
params.port = params.sts.port
|
params.port = params.sts.port
|
||||||
params.tls = TLSVerifyChain()
|
params.tls = True
|
||||||
async def resume_transmute(params: ConnectionParams):
|
async def resume_transmute(params: ConnectionParams):
|
||||||
if params.resume is not None:
|
if params.resume is not None:
|
||||||
params.host = params.resume.address
|
params.host = params.resume.address
|
||||||
|
@ -183,7 +182,7 @@ class CAPContext(ServerContext):
|
||||||
if not params.tls:
|
if not params.tls:
|
||||||
if "port" in sts_dict:
|
if "port" in sts_dict:
|
||||||
params.port = int(sts_dict["port"])
|
params.port = int(sts_dict["port"])
|
||||||
params.tls = TLSVerifyChain()
|
params.tls = True
|
||||||
|
|
||||||
await self.server.bot.disconnect(self.server)
|
await self.server.bot.disconnect(self.server)
|
||||||
await self.server.bot.add_server(self.server.name, params)
|
await self.server.bot.add_server(self.server.name, params)
|
||||||
|
|
|
@ -73,7 +73,8 @@ class Formatless(IMatchResponseParam):
|
||||||
def __init__(self, value: TYPE_MAYBELIT_VALUE):
|
def __init__(self, value: TYPE_MAYBELIT_VALUE):
|
||||||
self._value = _assure_lit(value)
|
self._value = _assure_lit(value)
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"Formatless({self._value!r})"
|
brepr = super().__repr__()
|
||||||
|
return f"Formatless({brepr})"
|
||||||
def match(self, server: IServer, arg: str) -> bool:
|
def match(self, server: IServer, arg: str) -> bool:
|
||||||
strip = formatting.strip(arg)
|
strip = formatting.strip(arg)
|
||||||
return self._value.match(server, strip)
|
return self._value.match(server, strip)
|
||||||
|
|
|
@ -1,9 +1,6 @@
|
||||||
from re import compile as re_compile
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
from .security import TLS, TLSNoVerify, TLSVerifyChain
|
|
||||||
|
|
||||||
class SASLParams(object):
|
class SASLParams(object):
|
||||||
mechanism: str
|
mechanism: str
|
||||||
|
|
||||||
|
@ -31,24 +28,19 @@ class ResumePolicy(object):
|
||||||
address: str
|
address: str
|
||||||
token: str
|
token: str
|
||||||
|
|
||||||
RE_IPV6HOST = re_compile("\[([a-fA-F0-9:]+)\]")
|
|
||||||
|
|
||||||
_TLS_TYPES = {
|
|
||||||
"+": TLSVerifyChain,
|
|
||||||
"~": TLSNoVerify,
|
|
||||||
}
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ConnectionParams(object):
|
class ConnectionParams(object):
|
||||||
nickname: str
|
nickname: str
|
||||||
host: str
|
host: str
|
||||||
port: int
|
port: int
|
||||||
tls: Optional[TLS] = field(default_factory=TLSVerifyChain)
|
tls: bool
|
||||||
|
|
||||||
username: Optional[str] = None
|
username: Optional[str] = None
|
||||||
realname: Optional[str] = None
|
realname: Optional[str] = None
|
||||||
bindhost: Optional[str] = None
|
bindhost: Optional[str] = None
|
||||||
|
|
||||||
password: Optional[str] = None
|
password: Optional[str] = None
|
||||||
|
tls_verify: bool = True
|
||||||
sasl: Optional[SASLParams] = None
|
sasl: Optional[SASLParams] = None
|
||||||
|
|
||||||
sts: Optional[STSPolicy] = None
|
sts: Optional[STSPolicy] = None
|
||||||
|
@ -58,26 +50,3 @@ class ConnectionParams(object):
|
||||||
alt_nicknames: List[str] = field(default_factory=list)
|
alt_nicknames: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
autojoin: List[str] = field(default_factory=list)
|
autojoin: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_hoststring(
|
|
||||||
nickname: str,
|
|
||||||
hoststring: str
|
|
||||||
) -> "ConnectionParams":
|
|
||||||
|
|
||||||
ipv6host = RE_IPV6HOST.search(hoststring)
|
|
||||||
if ipv6host is not None and ipv6host.start() == 0:
|
|
||||||
host = ipv6host.group(1)
|
|
||||||
port_s = hoststring[ipv6host.end()+1:]
|
|
||||||
else:
|
|
||||||
host, _, port_s = hoststring.strip().partition(":")
|
|
||||||
|
|
||||||
tls_type: Optional[TLS] = None
|
|
||||||
if not port_s:
|
|
||||||
port_s = "6667"
|
|
||||||
else:
|
|
||||||
tls_type = _TLS_TYPES.get(port_s[0], lambda: None)()
|
|
||||||
if tls_type is not None:
|
|
||||||
port_s = port_s[1:] or "6697"
|
|
||||||
|
|
||||||
return ConnectionParams(nickname, host, int(port_s), tls_type)
|
|
||||||
|
|
|
@ -32,9 +32,7 @@ AUTH_BYTE_MAX = 400
|
||||||
AUTHENTICATE_ANY = Response("AUTHENTICATE", [ANY])
|
AUTHENTICATE_ANY = Response("AUTHENTICATE", [ANY])
|
||||||
|
|
||||||
NUMERICS_FAIL = Response(ERR_SASLFAIL)
|
NUMERICS_FAIL = Response(ERR_SASLFAIL)
|
||||||
NUMERICS_INITIAL = Responses([
|
NUMERICS_INITIAL = Responses([ERR_SASLFAIL, ERR_SASLALREADY, RPL_SASLMECHS])
|
||||||
ERR_SASLFAIL, ERR_SASLALREADY, RPL_SASLMECHS, ERR_SASLABORTED
|
|
||||||
])
|
|
||||||
NUMERICS_LAST = Responses([RPL_SASLSUCCESS, ERR_SASLFAIL])
|
NUMERICS_LAST = Responses([RPL_SASLSUCCESS, ERR_SASLFAIL])
|
||||||
|
|
||||||
def _b64e(s: str):
|
def _b64e(s: str):
|
||||||
|
@ -152,8 +150,6 @@ class SASLContext(ServerContext):
|
||||||
return SASLResult.SUCCESS
|
return SASLResult.SUCCESS
|
||||||
elif line.command == "904":
|
elif line.command == "904":
|
||||||
match.pop(0)
|
match.pop(0)
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
return SASLResult.FAILURE
|
return SASLResult.FAILURE
|
||||||
|
|
||||||
|
|
|
@ -1,29 +1,13 @@
|
||||||
import ssl
|
import ssl
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TLS:
|
|
||||||
client_keypair: Optional[Tuple[str, str]] = None
|
|
||||||
|
|
||||||
# tls without verification
|
|
||||||
class TLSNoVerify(TLS):
|
|
||||||
pass
|
|
||||||
|
|
||||||
# verify via CAs
|
|
||||||
class TLSVerifyChain(TLS):
|
|
||||||
pass
|
|
||||||
|
|
||||||
# verify by a pinned hash
|
|
||||||
class TLSVerifyHash(TLSNoVerify):
|
|
||||||
def __init__(self, sum: str):
|
|
||||||
self.sum = sum.lower()
|
|
||||||
class TLSVerifySHA512(TLSVerifyHash):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def tls_context(verify: bool=True) -> ssl.SSLContext:
|
def tls_context(verify: bool=True) -> ssl.SSLContext:
|
||||||
ctx = ssl.create_default_context()
|
context = ssl.SSLContext(ssl.PROTOCOL_TLS)
|
||||||
if not verify:
|
context.options |= ssl.OP_NO_SSLv2
|
||||||
ctx.check_hostname = False
|
context.options |= ssl.OP_NO_SSLv3
|
||||||
ctx.verify_mode = ssl.CERT_NONE
|
context.options |= ssl.OP_NO_TLSv1
|
||||||
return ctx
|
context.load_default_certs()
|
||||||
|
|
||||||
|
if verify:
|
||||||
|
context.verify_mode = ssl.CERT_REQUIRED
|
||||||
|
|
||||||
|
return context
|
||||||
|
|
|
@ -6,7 +6,6 @@ from collections import deque
|
||||||
from time import monotonic
|
from time import monotonic
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
from asyncio_rlock import RLock
|
|
||||||
from asyncio_throttle import Throttler
|
from asyncio_throttle import Throttler
|
||||||
from async_timeout import timeout as timeout_
|
from async_timeout import timeout as timeout_
|
||||||
from ircstates import Emit, Channel, ChannelUser
|
from ircstates import Emit, Channel, ChannelUser
|
||||||
|
@ -29,7 +28,7 @@ from .interface import ITCPTransport, ITCPReader, ITCPWriter
|
||||||
|
|
||||||
THROTTLE_RATE = 4 # lines
|
THROTTLE_RATE = 4 # lines
|
||||||
THROTTLE_TIME = 2 # seconds
|
THROTTLE_TIME = 2 # seconds
|
||||||
PING_TIMEOUT = 60 # seconds
|
PING_INTERVAL = 60 # seconds
|
||||||
WAIT_TIMEOUT = 20 # seconds
|
WAIT_TIMEOUT = 20 # seconds
|
||||||
|
|
||||||
JOIN_ERR_FIRST = [
|
JOIN_ERR_FIRST = [
|
||||||
|
@ -55,7 +54,8 @@ class Server(IServer):
|
||||||
|
|
||||||
self.disconnected = False
|
self.disconnected = False
|
||||||
|
|
||||||
self.throttle = Throttler(rate_limit=100, period=1)
|
self.throttle = Throttler(
|
||||||
|
rate_limit=100, period=THROTTLE_TIME)
|
||||||
|
|
||||||
self.sasl_state = SASLResult.NONE
|
self.sasl_state = SASLResult.NONE
|
||||||
self.last_read = monotonic()
|
self.last_read = monotonic()
|
||||||
|
@ -64,14 +64,12 @@ class Server(IServer):
|
||||||
self._send_queue: PriorityQueue[SentLine] = PriorityQueue()
|
self._send_queue: PriorityQueue[SentLine] = PriorityQueue()
|
||||||
self.desired_caps: Set[ICapability] = set([])
|
self.desired_caps: Set[ICapability] = set([])
|
||||||
|
|
||||||
|
self.read_lock = asyncio.Lock()
|
||||||
self._read_queue: Deque[Line] = deque()
|
self._read_queue: Deque[Line] = deque()
|
||||||
self._process_queue: Deque[Tuple[Line, Optional[Emit]]] = deque()
|
self._process_queue: Deque[Line] = deque()
|
||||||
|
|
||||||
self._ping_sent = False
|
self._wait_for: Optional[Tuple[Awaitable, WaitFor]] = None
|
||||||
self._read_lguard = RLock()
|
self._wait_for_fut: Optional[Future[WaitFor]] = None
|
||||||
self.read_lock = self._read_lguard
|
|
||||||
self._read_lwork = asyncio.Lock()
|
|
||||||
self._wait_for = asyncio.Event()
|
|
||||||
|
|
||||||
self._pending_who: Deque[str] = deque()
|
self._pending_who: Deque[str] = deque()
|
||||||
self._alt_nicks: List[str] = []
|
self._alt_nicks: List[str] = []
|
||||||
|
@ -124,8 +122,9 @@ class Server(IServer):
|
||||||
reader, writer = await transport.connect(
|
reader, writer = await transport.connect(
|
||||||
params.host,
|
params.host,
|
||||||
params.port,
|
params.port,
|
||||||
tls =params.tls,
|
tls =params.tls,
|
||||||
bindhost =params.bindhost)
|
tls_verify=params.tls_verify,
|
||||||
|
bindhost =params.bindhost)
|
||||||
|
|
||||||
self._reader = reader
|
self._reader = reader
|
||||||
self._writer = writer
|
self._writer = writer
|
||||||
|
@ -180,9 +179,9 @@ class Server(IServer):
|
||||||
self._pending_who[0] == chan):
|
self._pending_who[0] == chan):
|
||||||
self._pending_who.popleft()
|
self._pending_who.popleft()
|
||||||
await self._next_who()
|
await self._next_who()
|
||||||
elif (line.command in {
|
|
||||||
ERR_NICKNAMEINUSE, ERR_ERRONEUSNICKNAME, ERR_UNAVAILRESOURCE
|
elif (line.command == ERR_NICKNAMEINUSE and
|
||||||
} and not self.registered):
|
not self.registered):
|
||||||
if self._alt_nicks:
|
if self._alt_nicks:
|
||||||
nick = self._alt_nicks.pop(0)
|
nick = self._alt_nicks.pop(0)
|
||||||
await self.send(build("NICK", [nick]))
|
await self.send(build("NICK", [nick]))
|
||||||
|
@ -263,61 +262,44 @@ class Server(IServer):
|
||||||
else:
|
else:
|
||||||
await self.send(build("WHO", [chan]))
|
await self.send(build("WHO", [chan]))
|
||||||
|
|
||||||
async def _read_line(self, timeout: float) -> Optional[Line]:
|
async def _read_line(self) -> Line:
|
||||||
while True:
|
while True:
|
||||||
if self._read_queue:
|
async with self.read_lock:
|
||||||
return self._read_queue.popleft()
|
if self._read_queue:
|
||||||
|
return self._read_queue.popleft()
|
||||||
|
|
||||||
try:
|
data = await self._reader.read(1024)
|
||||||
async with timeout_(timeout):
|
lines = self.recv(data)
|
||||||
data = await self._reader.read(1024)
|
# last_read under self.recv() as recv might throw Disconnected
|
||||||
except asyncio.TimeoutError:
|
self.last_read = monotonic()
|
||||||
return None
|
for line in lines:
|
||||||
|
self._read_queue.append(line)
|
||||||
self.last_read = monotonic()
|
|
||||||
lines = self.recv(data)
|
|
||||||
for line in lines:
|
|
||||||
self.line_preread(line)
|
|
||||||
self._read_queue.append(line)
|
|
||||||
|
|
||||||
async def _read_lines(self):
|
async def _read_lines(self):
|
||||||
|
sent_ping = False
|
||||||
while True:
|
while True:
|
||||||
async with self._read_lguard:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if not self._process_queue:
|
if not self._process_queue:
|
||||||
async with self._read_lwork:
|
try:
|
||||||
read_aw = asyncio.create_task(self._read_line(PING_TIMEOUT))
|
async with timeout_(PING_INTERVAL):
|
||||||
wait_aw = asyncio.create_task(self._wait_for.wait())
|
line = await self._read_line()
|
||||||
dones, notdones = await asyncio.wait(
|
except asyncio.TimeoutError:
|
||||||
[read_aw, wait_aw],
|
if not sent_ping:
|
||||||
return_when=asyncio.FIRST_COMPLETED
|
sent_ping = True
|
||||||
)
|
await self.send(build("PING", ["hello"]))
|
||||||
self._wait_for.clear()
|
continue
|
||||||
|
else:
|
||||||
|
raise ServerDisconnectedException()
|
||||||
|
else:
|
||||||
|
sent_ping = False
|
||||||
|
self._process_queue.append(line)
|
||||||
|
|
||||||
for done in dones:
|
line = self._process_queue.popleft()
|
||||||
if isinstance(done.result(), Line):
|
emit = self.parse_tokens(line)
|
||||||
self._ping_sent = False
|
await self._on_read(line, emit)
|
||||||
line = done.result()
|
|
||||||
emit = self.parse_tokens(line)
|
|
||||||
self._process_queue.append((line, emit))
|
|
||||||
elif done.result() is None:
|
|
||||||
if not self._ping_sent:
|
|
||||||
await self.send(build("PING", ["hello"]))
|
|
||||||
self._ping_sent = True
|
|
||||||
else:
|
|
||||||
await self.disconnect()
|
|
||||||
raise ServerDisconnectedException()
|
|
||||||
for notdone in notdones:
|
|
||||||
notdone.cancel()
|
|
||||||
|
|
||||||
else:
|
|
||||||
line, emit = self._process_queue.popleft()
|
|
||||||
await self._on_read(line, emit)
|
|
||||||
|
|
||||||
async def wait_for(self,
|
async def wait_for(self,
|
||||||
response: Union[IMatchResponse, Set[IMatchResponse]],
|
response: Union[IMatchResponse, Set[IMatchResponse]],
|
||||||
sent_aw: Optional[Awaitable[SentLine]]=None,
|
label: Optional[str]=None,
|
||||||
timeout: float=WAIT_TIMEOUT
|
timeout: float=WAIT_TIMEOUT
|
||||||
) -> Line:
|
) -> Line:
|
||||||
|
|
||||||
|
@ -327,18 +309,13 @@ class Server(IServer):
|
||||||
else:
|
else:
|
||||||
response_obj = response
|
response_obj = response
|
||||||
|
|
||||||
async with self._read_lguard:
|
wait_for = WaitFor(response_obj, label)
|
||||||
self._wait_for.set()
|
async with timeout_(timeout):
|
||||||
async with self._read_lwork:
|
while True:
|
||||||
async with timeout_(timeout):
|
line = await self._read_line()
|
||||||
while True:
|
self._process_queue.append(line)
|
||||||
line = await self._read_line(timeout)
|
if wait_for.match(self, line):
|
||||||
if line:
|
return line
|
||||||
self._ping_sent = False
|
|
||||||
emit = self.parse_tokens(line)
|
|
||||||
self._process_queue.append((line, emit))
|
|
||||||
if response_obj.match(self, line):
|
|
||||||
return line
|
|
||||||
|
|
||||||
async def _on_send_line(self, line: Line):
|
async def _on_send_line(self, line: Line):
|
||||||
if (line.command in ["PRIVMSG", "NOTICE", "TAGMSG"] and
|
if (line.command in ["PRIVMSG", "NOTICE", "TAGMSG"] and
|
||||||
|
@ -554,7 +531,7 @@ class Server(IServer):
|
||||||
for symbol in symbols:
|
for symbol in symbols:
|
||||||
mode = self.isupport.prefix.from_prefix(symbol)
|
mode = self.isupport.prefix.from_prefix(symbol)
|
||||||
if mode is not None:
|
if mode is not None:
|
||||||
channel_user.modes.add(mode)
|
channel_user.modes.append(mode)
|
||||||
|
|
||||||
obj.channels.append(channel_user)
|
obj.channels.append(channel_user)
|
||||||
elif line.command == RPL_ENDOFWHOIS:
|
elif line.command == RPL_ENDOFWHOIS:
|
||||||
|
|
|
@ -1,12 +1,10 @@
|
||||||
from hashlib import sha512
|
|
||||||
from ssl import SSLContext
|
from ssl import SSLContext
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
from asyncio import StreamReader, StreamWriter
|
from asyncio import StreamReader, StreamWriter
|
||||||
from async_stagger import open_connection
|
from async_stagger import open_connection
|
||||||
|
|
||||||
from .interface import ITCPTransport, ITCPReader, ITCPWriter
|
from .interface import ITCPTransport, ITCPReader, ITCPWriter
|
||||||
from .security import (tls_context, TLS, TLSNoVerify, TLSVerifyHash,
|
from .security import tls_context
|
||||||
TLSVerifySHA512)
|
|
||||||
|
|
||||||
class TCPReader(ITCPReader):
|
class TCPReader(ITCPReader):
|
||||||
def __init__(self, reader: StreamReader):
|
def __init__(self, reader: StreamReader):
|
||||||
|
@ -34,18 +32,16 @@ class TCPWriter(ITCPWriter):
|
||||||
|
|
||||||
class TCPTransport(ITCPTransport):
|
class TCPTransport(ITCPTransport):
|
||||||
async def connect(self,
|
async def connect(self,
|
||||||
hostname: str,
|
hostname: str,
|
||||||
port: int,
|
port: int,
|
||||||
tls: Optional[TLS],
|
tls: bool,
|
||||||
bindhost: Optional[str]=None
|
tls_verify: bool=True,
|
||||||
|
bindhost: Optional[str]=None
|
||||||
) -> Tuple[ITCPReader, ITCPWriter]:
|
) -> Tuple[ITCPReader, ITCPWriter]:
|
||||||
|
|
||||||
cur_ssl: Optional[SSLContext] = None
|
cur_ssl: Optional[SSLContext] = None
|
||||||
if tls is not None:
|
if tls:
|
||||||
cur_ssl = tls_context(not isinstance(tls, TLSNoVerify))
|
cur_ssl = tls_context(tls_verify)
|
||||||
if tls.client_keypair is not None:
|
|
||||||
(client_cert, client_key) = tls.client_keypair
|
|
||||||
cur_ssl.load_cert_chain(client_cert, keyfile=client_key)
|
|
||||||
|
|
||||||
local_addr: Optional[Tuple[str, int]] = None
|
local_addr: Optional[Tuple[str, int]] = None
|
||||||
if not bindhost is None:
|
if not bindhost is None:
|
||||||
|
@ -59,20 +55,5 @@ class TCPTransport(ITCPTransport):
|
||||||
server_hostname=server_hostname,
|
server_hostname=server_hostname,
|
||||||
ssl =cur_ssl,
|
ssl =cur_ssl,
|
||||||
local_addr =local_addr)
|
local_addr =local_addr)
|
||||||
|
|
||||||
if isinstance(tls, TLSVerifyHash):
|
|
||||||
cert: bytes = writer.transport.get_extra_info(
|
|
||||||
"ssl_object"
|
|
||||||
).getpeercert(True)
|
|
||||||
if isinstance(tls, TLSVerifySHA512):
|
|
||||||
sum = sha512(cert).hexdigest()
|
|
||||||
else:
|
|
||||||
raise ValueError(f"unknown hash pinning {type(tls)}")
|
|
||||||
|
|
||||||
if not sum == tls.sum:
|
|
||||||
raise ValueError(
|
|
||||||
f"pinned hash for {hostname} does not match ({sum})"
|
|
||||||
)
|
|
||||||
|
|
||||||
return (TCPReader(reader), TCPWriter(writer))
|
return (TCPReader(reader), TCPWriter(writer))
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
anyio ~=2.0.2
|
anyio ~=2.0.2
|
||||||
asyncio-rlock ~=0.1.0
|
|
||||||
asyncio-throttle ~=1.0.1
|
asyncio-throttle ~=1.0.1
|
||||||
ircstates ~=0.13.0
|
dataclasses ~=0.6; python_version<"3.7"
|
||||||
|
ircstates ~=0.11.7
|
||||||
async_stagger ~=0.3.0
|
async_stagger ~=0.3.0
|
||||||
async_timeout ~=4.0.2
|
async_timeout ~=3.0.1
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -26,6 +26,6 @@ setup(
|
||||||
"Operating System :: Microsoft :: Windows",
|
"Operating System :: Microsoft :: Windows",
|
||||||
"Topic :: Communications :: Chat :: Internet Relay Chat"
|
"Topic :: Communications :: Chat :: Internet Relay Chat"
|
||||||
],
|
],
|
||||||
python_requires='>=3.7',
|
python_requires='>=3.6',
|
||||||
install_requires=install_requires
|
install_requires=install_requires
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue