don't try SASL twice, try SASL on CAP NEW
This commit is contained in:
parent
c139879670
commit
61f1cdba9d
2 changed files with 21 additions and 3 deletions
|
@ -15,6 +15,7 @@ SASL_USERPASS_MECHANISMS = [
|
|||
]
|
||||
|
||||
class SASLResult(Enum):
|
||||
NONE = 0
|
||||
SUCCESS = 1
|
||||
FAILURE = 2
|
||||
ALREADY = 3
|
||||
|
|
|
@ -9,7 +9,7 @@ from irctokens import build, Line, tokenise
|
|||
from .ircv3 import Capability, CAPS, CAP_SASL
|
||||
from .interface import ConnectionParams, IServer, PriorityLine, SendPriority
|
||||
from .matching import BaseResponse
|
||||
from .sasl import SASLContext
|
||||
from .sasl import SASLContext, SASLResult
|
||||
|
||||
sc = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
|
||||
|
||||
|
@ -26,6 +26,7 @@ class Server(IServer):
|
|||
|
||||
self.throttle = Throttler(
|
||||
rate_limit=THROTTLE_RATE, period=THROTTLE_TIME)
|
||||
self.sasl_state = SASLResult.NONE
|
||||
|
||||
self._write_queue: PriorityQueue[PriorityLine] = PriorityQueue()
|
||||
|
||||
|
@ -66,6 +67,8 @@ class Server(IServer):
|
|||
await self._cap_ls_done()
|
||||
elif emit.subcommand in ["ACK", "NAK"]:
|
||||
await self._cap_ack(emit)
|
||||
elif emit.subcommand == "NEW":
|
||||
await self._cap_new()
|
||||
|
||||
async def _on_read_line(self, line: Line):
|
||||
for i, (response, future) in enumerate(self._wait_for):
|
||||
|
@ -124,13 +127,27 @@ class Server(IServer):
|
|||
if matches:
|
||||
self._requested_caps = matches
|
||||
await self.send(build("CAP", ["REQ", " ".join(matches)]))
|
||||
|
||||
async def _cap_ack(self, emit: Emit):
|
||||
if not self.params.sasl is None and self.cap_agreed(CAP_SASL):
|
||||
await SASLContext(self).from_params(self.params.sasl)
|
||||
await self._maybe_sasl()
|
||||
|
||||
for cap in (emit.tokens or []):
|
||||
if cap in self._requested_caps:
|
||||
self._requested_caps.remove(cap)
|
||||
if not self._requested_caps:
|
||||
await self.send(build("CAP", ["END"]))
|
||||
|
||||
async def _cap_new(self):
|
||||
await self._maybe_sasl()
|
||||
|
||||
async def _maybe_sasl(self) -> bool:
|
||||
if (self.sasl_state == SASLResult.NONE and
|
||||
not self.params.sasl is None and
|
||||
self.cap_agreed(CAP_SASL)):
|
||||
res = await SASLContext(self).from_params(self.params.sasl)
|
||||
self.sasl_state = res
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
# /CAP-related
|
||||
|
|
Loading…
Reference in a new issue