dont tg.spawn() for each message, allow wait_for to read

This commit is contained in:
jesopo 2020-04-03 13:04:02 +01:00
parent 45ac3be550
commit 688418df04
4 changed files with 125 additions and 94 deletions

View file

@ -35,30 +35,29 @@ class Bot(object):
async def _run_server(self, server: Server):
async with anyio.create_task_group() as tg:
async def _read_query():
while not tg.cancel_scope.cancel_called:
await server._read_lines()
await tg.cancel_scope.cancel()
async def _read():
while not tg.cancel_scope.cancel_called:
lines = await server._read_lines()
for line, emits in lines:
for emit in emits:
await tg.spawn(server._on_read_emit, line, emit)
await tg.spawn(server._on_read_line, line)
await tg.spawn(self.line_read, server, line)
line = await server.next_line()
await self.line_read(server, line)
await tg.cancel_scope.cancel()
async def _write():
try:
while not tg.cancel_scope.cancel_called:
lines = await server._write_lines()
for line in lines:
await self.line_send(server, line)
except Exception as e:
print(e)
while not tg.cancel_scope.cancel_called:
lines = await server._write_lines()
for line in lines:
await self.line_send(server, line)
await tg.cancel_scope.cancel()
await tg.spawn(server.handshake)
await tg.spawn(_read)
await tg.spawn(_write)
await tg.spawn(_read)
await server.handshake()
await tg.spawn(_read_query)
del self.servers[server.name]
await self.disconnected(server)

View file

@ -52,7 +52,12 @@ class IServer(Server):
async def queue_capability(self, cap: ICapability):
pass
async def line_written(self, line: Line):
async def line_read(self, line: Line):
pass
async def line_send(self, line: Line):
pass
async def next_line(self) -> Line:
pass
def cap_agreed(self, capability: ICapability) -> bool:

View file

@ -26,12 +26,22 @@ class SASLError(Exception):
class SASLUnknownMechanismError(SASLError):
pass
AUTHENTICATE_ANY = Response("AUTHENTICATE", [ParamAny()])
NUMERICS_FAIL = Numerics(["ERR_SASLFAIL"])
NUMERICS_INITIAL = Numerics(
["ERR_SASLFAIL", "ERR_SASLALREADY", "RPL_SASLMECHS"])
NUMERICS_LAST = Numerics(["RPL_SASLSUCCESS", "ERR_SASLFAIL"])
def _b64e(s: str):
return b64encode(s.encode("utf8")).decode("ascii")
def _b64eb(s: bytes) -> str:
# encode-from-bytes
return b64encode(s).decode("ascii")
def _b64db(s: str) -> bytes:
# decode-to-bytes
return b64decode(s)
class SASLContext(ServerContext):
async def from_params(self, params: SASLParams) -> SASLResult:
if params.mechanism == "USERPASS":
@ -79,12 +89,14 @@ class SASLContext(ServerContext):
password: str,
mechanisms: List[str]=SASL_USERPASS_MECHANISMS
) -> SASLResult:
# this will, in the future, offer SCRAM support
def _common(server_mechs) -> str:
def _common(server_mechs) -> List[str]:
mechs: List[str] = []
for our_mech in mechanisms:
if our_mech in server_mechs:
return our_mech
mechs.append(our_mech)
if mechs:
return mechs
else:
raise SASLUnknownMechanismError(
"No matching SASL mechanims. "
@ -98,66 +110,62 @@ class SASLContext(ServerContext):
else:
# CAP v3.1 does not. pick the pick and wait for 907 to inform us of
# what mechanisms are supported
match = mechanisms[0]
match = mechanisms
await self.server.send(build("AUTHENTICATE", [match]))
line = await self.server.wait_for(ResponseOr(
AUTHENTICATE_ANY,
NUMERICS_INITIAL
))
while match:
await self.server.send(build("AUTHENTICATE", [match[0]]))
line = await self.server.wait_for(ResponseOr(
AUTHENTICATE_ANY,
NUMERICS_INITIAL
))
if line.command == "907":
# we've done SASL already. cleanly abort
return SASLResult.ALREADY
elif line.command == "908":
# prior to CAP v3.2 - ERR telling us which mechs are supported
available = line.params[1].split(",")
match = _common(available)
if line.command == "907":
# we've done SASL already. cleanly abort
return SASLResult.ALREADY
elif line.command == "908":
# prior to CAP v3.2 - ERR telling us which mechs are supported
available = line.params[1].split(",")
match = _common(available)
elif line.command == "AUTHENTICATE" and line.params[0] == "+":
auth_text = ""
await self.server.send(build("AUTHENTICATE", [match]))
line = await self.server.wait_for(AUTHENTICATE_ANY)
if match == "PLAIN":
auth_text = f"{username}\0{username}\0{password}"
elif match[0].startswith("SCRAM-SHA-"):
auth_text = await self._scram(
match.pop(0), username, password)
if line.command == "AUTHENTICATE" and line.params[0] == "+":
def _b64e(s: str):
return b64encode(s.encode("utf8")).decode("ascii")
if not auth_text == "+":
auth_text = _b64e(auth_text)
def _b64eb(s: bytes) -> str:
# encode-from-bytes
return b64encode(s).decode("ascii")
def _b64db(s: str) -> bytes:
# decode-to-bytes
return b64decode(s)
if auth_text:
await self.server.send(build("AUTHENTICATE", [auth_text]))
auth_text = ""
if match == "PLAIN":
auth_text = f"{username}\0{username}\0{password}"
elif match.startswith("SCRAM-SHA-"):
algo = match.replace("SCRAM-", "", 1)
scram = SCRAMContext(algo, username, password)
client_first = _b64eb(scram.client_first())
await self.server.send(build("AUTHENTICATE", [client_first]))
line = await self.server.wait_for(AUTHENTICATE_ANY)
server_first = _b64db(line.params[0])
client_final = _b64eb(scram.server_first(server_first))
if not client_final == "":
await self.server.send(build("AUTHENTICATE", [client_final]))
line = await self.server.wait_for(AUTHENTICATE_ANY)
server_final = _b64db(line.params[0])
verified = scram.server_final(server_final)
#TODO PANIC if verified is false!
auth_text = "+"
else:
auth_text = ""
if not auth_text == "+":
auth_text = _b64e(auth_text)
if auth_text:
await self.server.send(build("AUTHENTICATE", [auth_text]))
line = await self.server.wait_for(NUMERICS_LAST)
if line.command == "903":
return SASLResult.SUCCESS
return SASLResult.FAILURE
async def _scram(self, algo: str, username: str, password: str) -> str:
algo = algo.replace("SCRAM-", "", 1)
scram = SCRAMContext(algo, username, password)
client_first = _b64eb(scram.client_first())
await self.server.send(build("AUTHENTICATE", [client_first]))
line = await self.server.wait_for(AUTHENTICATE_ANY)
server_first = _b64db(line.params[0])
client_final = _b64eb(scram.server_first(server_first))
if not client_final == "":
await self.server.send(build("AUTHENTICATE", [client_final]))
line = await self.server.wait_for(AUTHENTICATE_ANY)
server_final = _b64db(line.params[0])
verified = scram.server_final(server_final)
#TODO PANIC if verified is false!
return "+"
else:
return ""

View file

@ -1,6 +1,6 @@
import asyncio
from ssl import SSLContext
from asyncio import Future, PriorityQueue
from asyncio import Future, PriorityQueue, Queue
from typing import Awaitable, List, Optional, Set, Tuple
from asyncio_throttle import Throttler
@ -26,17 +26,19 @@ class Server(IServer):
super().__init__(name)
self.throttle = Throttler(
rate_limit=THROTTLE_RATE, period=THROTTLE_TIME)
rate_limit=100, period=THROTTLE_TIME)
self.sasl_state = SASLResult.NONE
self._write_queue: PriorityQueue[SentLine] = PriorityQueue()
self._cap_queue: Set[ICapability] = set([])
self._wait_for_cache: List[Tuple[Line, List[Emit]]] = []
self._write_queue: PriorityQueue[SentLine] = PriorityQueue()
self._read_queue: Queue[Tuple[Line, List[Emit]]] = Queue()
self._cap_queue: Set[ICapability] = set([])
self._wait_for: List[Tuple[BaseResponse, Future]] = []
async def send_raw(self, line: str, priority=SendPriority.DEFAULT):
async def send_raw(self, line: str, priority=SendPriority.DEFAULT
) -> SentLine:
await self.send(tokenise(line), priority)
async def send(self, line: Line, priority=SendPriority.DEFAULT):
async def send(self, line: Line, priority=SendPriority.DEFAULT) -> SentLine:
prio_line = SentLine(priority, line)
await self._write_queue.put(prio_line)
await prio_line.future
@ -72,7 +74,9 @@ class Server(IServer):
await CAPContext(self).handshake()
async def _on_read_emit(self, line: Line, emit: Emit):
if emit.command == "CAP":
if emit.command == "001":
self.set_throttle(THROTTLE_RATE, THROTTLE_TIME)
elif emit.command == "CAP":
if emit.subcommand == "NEW":
await self._cap_new(emit)
elif emit.command == "JOIN":
@ -80,26 +84,41 @@ class Server(IServer):
await self.send(build("MODE", [emit.channel.name]))
async def _on_read_line(self, line: Line):
for i, (response, future) in enumerate(self._wait_for):
if response.match(line):
self._wait_for.pop(i)
future.set_result(line)
break
if line.command == "PING":
await self.send(build("PONG", line.params))
async def line_read(self, line: Line):
pass
async def _read_lines(self) -> List[Tuple[Line, List[Emit]]]:
data = await self._reader.read(1024)
lines = self.recv(data)
for line, emits in lines:
for emit in emits:
await self._on_read_emit(line, emit)
await self._on_read_line(line)
await self.line_read(line)
await self._read_queue.put((line, emits))
return lines
async def next_line(self) -> Line:
line, emits = await self._read_queue.get()
return line
def wait_for(self, response: BaseResponse) -> Awaitable[Line]:
future: "Future[Line]" = asyncio.Future()
self._wait_for.append((response, future))
return future
async def wait_for(self, response: BaseResponse) -> Line:
while True:
lines = self._wait_for_cache.copy()
self._wait_for_cache.clear()
async def line_written(self, line: Line):
if not lines:
lines += await self._read_lines()
for i, (line, emits) in enumerate(lines):
if response.match(line):
self._wait_for_cache = lines[i+1:]
return line
async def line_send(self, line: Line):
pass
async def _write_lines(self) -> List[Line]:
lines: List[SentLine] = []