don't try SASL twice, try SASL on CAP NEW

This commit is contained in:
jesopo 2020-04-02 18:38:37 +01:00
parent c139879670
commit 61f1cdba9d
2 changed files with 21 additions and 3 deletions

View file

@ -15,6 +15,7 @@ SASL_USERPASS_MECHANISMS = [
] ]
class SASLResult(Enum): class SASLResult(Enum):
NONE = 0
SUCCESS = 1 SUCCESS = 1
FAILURE = 2 FAILURE = 2
ALREADY = 3 ALREADY = 3

View file

@ -9,7 +9,7 @@ from irctokens import build, Line, tokenise
from .ircv3 import Capability, CAPS, CAP_SASL from .ircv3 import Capability, CAPS, CAP_SASL
from .interface import ConnectionParams, IServer, PriorityLine, SendPriority from .interface import ConnectionParams, IServer, PriorityLine, SendPriority
from .matching import BaseResponse from .matching import BaseResponse
from .sasl import SASLContext from .sasl import SASLContext, SASLResult
sc = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) sc = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
@ -26,6 +26,7 @@ class Server(IServer):
self.throttle = Throttler( self.throttle = Throttler(
rate_limit=THROTTLE_RATE, period=THROTTLE_TIME) rate_limit=THROTTLE_RATE, period=THROTTLE_TIME)
self.sasl_state = SASLResult.NONE
self._write_queue: PriorityQueue[PriorityLine] = PriorityQueue() self._write_queue: PriorityQueue[PriorityLine] = PriorityQueue()
@ -66,6 +67,8 @@ class Server(IServer):
await self._cap_ls_done() await self._cap_ls_done()
elif emit.subcommand in ["ACK", "NAK"]: elif emit.subcommand in ["ACK", "NAK"]:
await self._cap_ack(emit) await self._cap_ack(emit)
elif emit.subcommand == "NEW":
await self._cap_new()
async def _on_read_line(self, line: Line): async def _on_read_line(self, line: Line):
for i, (response, future) in enumerate(self._wait_for): for i, (response, future) in enumerate(self._wait_for):
@ -124,13 +127,27 @@ class Server(IServer):
if matches: if matches:
self._requested_caps = matches self._requested_caps = matches
await self.send(build("CAP", ["REQ", " ".join(matches)])) await self.send(build("CAP", ["REQ", " ".join(matches)]))
async def _cap_ack(self, emit: Emit): async def _cap_ack(self, emit: Emit):
if not self.params.sasl is None and self.cap_agreed(CAP_SASL): await self._maybe_sasl()
await SASLContext(self).from_params(self.params.sasl)
for cap in (emit.tokens or []): for cap in (emit.tokens or []):
if cap in self._requested_caps: if cap in self._requested_caps:
self._requested_caps.remove(cap) self._requested_caps.remove(cap)
if not self._requested_caps: if not self._requested_caps:
await self.send(build("CAP", ["END"])) await self.send(build("CAP", ["END"]))
async def _cap_new(self):
await self._maybe_sasl()
async def _maybe_sasl(self) -> bool:
if (self.sasl_state == SASLResult.NONE and
not self.params.sasl is None and
self.cap_agreed(CAP_SASL)):
res = await SASLContext(self).from_params(self.params.sasl)
self.sasl_state = res
return True
else:
return False
# /CAP-related # /CAP-related