dont tg.spawn() for each message, allow wait_for to read
This commit is contained in:
parent
45ac3be550
commit
688418df04
4 changed files with 125 additions and 94 deletions
|
@ -35,30 +35,29 @@ class Bot(object):
|
|||
|
||||
async def _run_server(self, server: Server):
|
||||
async with anyio.create_task_group() as tg:
|
||||
async def _read_query():
|
||||
while not tg.cancel_scope.cancel_called:
|
||||
await server._read_lines()
|
||||
await tg.cancel_scope.cancel()
|
||||
|
||||
async def _read():
|
||||
while not tg.cancel_scope.cancel_called:
|
||||
lines = await server._read_lines()
|
||||
|
||||
for line, emits in lines:
|
||||
for emit in emits:
|
||||
await tg.spawn(server._on_read_emit, line, emit)
|
||||
await tg.spawn(server._on_read_line, line)
|
||||
await tg.spawn(self.line_read, server, line)
|
||||
line = await server.next_line()
|
||||
await self.line_read(server, line)
|
||||
await tg.cancel_scope.cancel()
|
||||
|
||||
async def _write():
|
||||
try:
|
||||
while not tg.cancel_scope.cancel_called:
|
||||
lines = await server._write_lines()
|
||||
for line in lines:
|
||||
await self.line_send(server, line)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
while not tg.cancel_scope.cancel_called:
|
||||
lines = await server._write_lines()
|
||||
for line in lines:
|
||||
await self.line_send(server, line)
|
||||
await tg.cancel_scope.cancel()
|
||||
|
||||
await tg.spawn(server.handshake)
|
||||
await tg.spawn(_read)
|
||||
await tg.spawn(_write)
|
||||
await tg.spawn(_read)
|
||||
await server.handshake()
|
||||
await tg.spawn(_read_query)
|
||||
|
||||
del self.servers[server.name]
|
||||
await self.disconnected(server)
|
||||
|
||||
|
|
|
@ -52,7 +52,12 @@ class IServer(Server):
|
|||
async def queue_capability(self, cap: ICapability):
|
||||
pass
|
||||
|
||||
async def line_written(self, line: Line):
|
||||
async def line_read(self, line: Line):
|
||||
pass
|
||||
async def line_send(self, line: Line):
|
||||
pass
|
||||
|
||||
async def next_line(self) -> Line:
|
||||
pass
|
||||
|
||||
def cap_agreed(self, capability: ICapability) -> bool:
|
||||
|
|
|
@ -26,12 +26,22 @@ class SASLError(Exception):
|
|||
class SASLUnknownMechanismError(SASLError):
|
||||
pass
|
||||
|
||||
|
||||
AUTHENTICATE_ANY = Response("AUTHENTICATE", [ParamAny()])
|
||||
NUMERICS_FAIL = Numerics(["ERR_SASLFAIL"])
|
||||
NUMERICS_INITIAL = Numerics(
|
||||
["ERR_SASLFAIL", "ERR_SASLALREADY", "RPL_SASLMECHS"])
|
||||
NUMERICS_LAST = Numerics(["RPL_SASLSUCCESS", "ERR_SASLFAIL"])
|
||||
|
||||
def _b64e(s: str):
|
||||
return b64encode(s.encode("utf8")).decode("ascii")
|
||||
|
||||
def _b64eb(s: bytes) -> str:
|
||||
# encode-from-bytes
|
||||
return b64encode(s).decode("ascii")
|
||||
def _b64db(s: str) -> bytes:
|
||||
# decode-to-bytes
|
||||
return b64decode(s)
|
||||
|
||||
class SASLContext(ServerContext):
|
||||
async def from_params(self, params: SASLParams) -> SASLResult:
|
||||
if params.mechanism == "USERPASS":
|
||||
|
@ -79,12 +89,14 @@ class SASLContext(ServerContext):
|
|||
password: str,
|
||||
mechanisms: List[str]=SASL_USERPASS_MECHANISMS
|
||||
) -> SASLResult:
|
||||
# this will, in the future, offer SCRAM support
|
||||
|
||||
def _common(server_mechs) -> str:
|
||||
def _common(server_mechs) -> List[str]:
|
||||
mechs: List[str] = []
|
||||
for our_mech in mechanisms:
|
||||
if our_mech in server_mechs:
|
||||
return our_mech
|
||||
mechs.append(our_mech)
|
||||
|
||||
if mechs:
|
||||
return mechs
|
||||
else:
|
||||
raise SASLUnknownMechanismError(
|
||||
"No matching SASL mechanims. "
|
||||
|
@ -98,66 +110,62 @@ class SASLContext(ServerContext):
|
|||
else:
|
||||
# CAP v3.1 does not. pick the pick and wait for 907 to inform us of
|
||||
# what mechanisms are supported
|
||||
match = mechanisms[0]
|
||||
match = mechanisms
|
||||
|
||||
await self.server.send(build("AUTHENTICATE", [match]))
|
||||
line = await self.server.wait_for(ResponseOr(
|
||||
AUTHENTICATE_ANY,
|
||||
NUMERICS_INITIAL
|
||||
))
|
||||
while match:
|
||||
await self.server.send(build("AUTHENTICATE", [match[0]]))
|
||||
line = await self.server.wait_for(ResponseOr(
|
||||
AUTHENTICATE_ANY,
|
||||
NUMERICS_INITIAL
|
||||
))
|
||||
|
||||
if line.command == "907":
|
||||
# we've done SASL already. cleanly abort
|
||||
return SASLResult.ALREADY
|
||||
elif line.command == "908":
|
||||
# prior to CAP v3.2 - ERR telling us which mechs are supported
|
||||
available = line.params[1].split(",")
|
||||
match = _common(available)
|
||||
if line.command == "907":
|
||||
# we've done SASL already. cleanly abort
|
||||
return SASLResult.ALREADY
|
||||
elif line.command == "908":
|
||||
# prior to CAP v3.2 - ERR telling us which mechs are supported
|
||||
available = line.params[1].split(",")
|
||||
match = _common(available)
|
||||
elif line.command == "AUTHENTICATE" and line.params[0] == "+":
|
||||
auth_text = ""
|
||||
|
||||
await self.server.send(build("AUTHENTICATE", [match]))
|
||||
line = await self.server.wait_for(AUTHENTICATE_ANY)
|
||||
if match == "PLAIN":
|
||||
auth_text = f"{username}\0{username}\0{password}"
|
||||
elif match[0].startswith("SCRAM-SHA-"):
|
||||
auth_text = await self._scram(
|
||||
match.pop(0), username, password)
|
||||
|
||||
if line.command == "AUTHENTICATE" and line.params[0] == "+":
|
||||
def _b64e(s: str):
|
||||
return b64encode(s.encode("utf8")).decode("ascii")
|
||||
if not auth_text == "+":
|
||||
auth_text = _b64e(auth_text)
|
||||
|
||||
def _b64eb(s: bytes) -> str:
|
||||
# encode-from-bytes
|
||||
return b64encode(s).decode("ascii")
|
||||
def _b64db(s: str) -> bytes:
|
||||
# decode-to-bytes
|
||||
return b64decode(s)
|
||||
if auth_text:
|
||||
await self.server.send(build("AUTHENTICATE", [auth_text]))
|
||||
|
||||
auth_text = ""
|
||||
if match == "PLAIN":
|
||||
auth_text = f"{username}\0{username}\0{password}"
|
||||
elif match.startswith("SCRAM-SHA-"):
|
||||
algo = match.replace("SCRAM-", "", 1)
|
||||
scram = SCRAMContext(algo, username, password)
|
||||
|
||||
client_first = _b64eb(scram.client_first())
|
||||
await self.server.send(build("AUTHENTICATE", [client_first]))
|
||||
line = await self.server.wait_for(AUTHENTICATE_ANY)
|
||||
|
||||
server_first = _b64db(line.params[0])
|
||||
client_final = _b64eb(scram.server_first(server_first))
|
||||
if not client_final == "":
|
||||
await self.server.send(build("AUTHENTICATE", [client_final]))
|
||||
line = await self.server.wait_for(AUTHENTICATE_ANY)
|
||||
|
||||
server_final = _b64db(line.params[0])
|
||||
verified = scram.server_final(server_final)
|
||||
#TODO PANIC if verified is false!
|
||||
auth_text = "+"
|
||||
else:
|
||||
auth_text = ""
|
||||
|
||||
if not auth_text == "+":
|
||||
auth_text = _b64e(auth_text)
|
||||
if auth_text:
|
||||
await self.server.send(build("AUTHENTICATE", [auth_text]))
|
||||
line = await self.server.wait_for(NUMERICS_LAST)
|
||||
if line.command == "903":
|
||||
return SASLResult.SUCCESS
|
||||
|
||||
return SASLResult.FAILURE
|
||||
|
||||
async def _scram(self, algo: str, username: str, password: str) -> str:
|
||||
algo = algo.replace("SCRAM-", "", 1)
|
||||
scram = SCRAMContext(algo, username, password)
|
||||
|
||||
client_first = _b64eb(scram.client_first())
|
||||
await self.server.send(build("AUTHENTICATE", [client_first]))
|
||||
line = await self.server.wait_for(AUTHENTICATE_ANY)
|
||||
|
||||
server_first = _b64db(line.params[0])
|
||||
client_final = _b64eb(scram.server_first(server_first))
|
||||
if not client_final == "":
|
||||
await self.server.send(build("AUTHENTICATE", [client_final]))
|
||||
line = await self.server.wait_for(AUTHENTICATE_ANY)
|
||||
|
||||
server_final = _b64db(line.params[0])
|
||||
verified = scram.server_final(server_final)
|
||||
#TODO PANIC if verified is false!
|
||||
return "+"
|
||||
else:
|
||||
return ""
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import asyncio
|
||||
from ssl import SSLContext
|
||||
from asyncio import Future, PriorityQueue
|
||||
from asyncio import Future, PriorityQueue, Queue
|
||||
from typing import Awaitable, List, Optional, Set, Tuple
|
||||
|
||||
from asyncio_throttle import Throttler
|
||||
|
@ -26,17 +26,19 @@ class Server(IServer):
|
|||
super().__init__(name)
|
||||
|
||||
self.throttle = Throttler(
|
||||
rate_limit=THROTTLE_RATE, period=THROTTLE_TIME)
|
||||
rate_limit=100, period=THROTTLE_TIME)
|
||||
|
||||
self.sasl_state = SASLResult.NONE
|
||||
|
||||
self._write_queue: PriorityQueue[SentLine] = PriorityQueue()
|
||||
self._cap_queue: Set[ICapability] = set([])
|
||||
self._wait_for_cache: List[Tuple[Line, List[Emit]]] = []
|
||||
self._write_queue: PriorityQueue[SentLine] = PriorityQueue()
|
||||
self._read_queue: Queue[Tuple[Line, List[Emit]]] = Queue()
|
||||
self._cap_queue: Set[ICapability] = set([])
|
||||
|
||||
self._wait_for: List[Tuple[BaseResponse, Future]] = []
|
||||
|
||||
async def send_raw(self, line: str, priority=SendPriority.DEFAULT):
|
||||
async def send_raw(self, line: str, priority=SendPriority.DEFAULT
|
||||
) -> SentLine:
|
||||
await self.send(tokenise(line), priority)
|
||||
async def send(self, line: Line, priority=SendPriority.DEFAULT):
|
||||
async def send(self, line: Line, priority=SendPriority.DEFAULT) -> SentLine:
|
||||
prio_line = SentLine(priority, line)
|
||||
await self._write_queue.put(prio_line)
|
||||
await prio_line.future
|
||||
|
@ -72,7 +74,9 @@ class Server(IServer):
|
|||
await CAPContext(self).handshake()
|
||||
|
||||
async def _on_read_emit(self, line: Line, emit: Emit):
|
||||
if emit.command == "CAP":
|
||||
if emit.command == "001":
|
||||
self.set_throttle(THROTTLE_RATE, THROTTLE_TIME)
|
||||
elif emit.command == "CAP":
|
||||
if emit.subcommand == "NEW":
|
||||
await self._cap_new(emit)
|
||||
elif emit.command == "JOIN":
|
||||
|
@ -80,26 +84,41 @@ class Server(IServer):
|
|||
await self.send(build("MODE", [emit.channel.name]))
|
||||
|
||||
async def _on_read_line(self, line: Line):
|
||||
for i, (response, future) in enumerate(self._wait_for):
|
||||
if response.match(line):
|
||||
self._wait_for.pop(i)
|
||||
future.set_result(line)
|
||||
break
|
||||
|
||||
if line.command == "PING":
|
||||
await self.send(build("PONG", line.params))
|
||||
|
||||
async def line_read(self, line: Line):
|
||||
pass
|
||||
async def _read_lines(self) -> List[Tuple[Line, List[Emit]]]:
|
||||
data = await self._reader.read(1024)
|
||||
lines = self.recv(data)
|
||||
for line, emits in lines:
|
||||
for emit in emits:
|
||||
await self._on_read_emit(line, emit)
|
||||
|
||||
await self._on_read_line(line)
|
||||
await self.line_read(line)
|
||||
|
||||
await self._read_queue.put((line, emits))
|
||||
return lines
|
||||
async def next_line(self) -> Line:
|
||||
line, emits = await self._read_queue.get()
|
||||
return line
|
||||
|
||||
def wait_for(self, response: BaseResponse) -> Awaitable[Line]:
|
||||
future: "Future[Line]" = asyncio.Future()
|
||||
self._wait_for.append((response, future))
|
||||
return future
|
||||
async def wait_for(self, response: BaseResponse) -> Line:
|
||||
while True:
|
||||
lines = self._wait_for_cache.copy()
|
||||
self._wait_for_cache.clear()
|
||||
|
||||
async def line_written(self, line: Line):
|
||||
if not lines:
|
||||
lines += await self._read_lines()
|
||||
|
||||
for i, (line, emits) in enumerate(lines):
|
||||
if response.match(line):
|
||||
self._wait_for_cache = lines[i+1:]
|
||||
return line
|
||||
|
||||
async def line_send(self, line: Line):
|
||||
pass
|
||||
async def _write_lines(self) -> List[Line]:
|
||||
lines: List[SentLine] = []
|
||||
|
|
Loading…
Reference in a new issue