don't try SASL twice, try SASL on CAP NEW
This commit is contained in:
parent
c139879670
commit
61f1cdba9d
2 changed files with 21 additions and 3 deletions
|
@ -15,6 +15,7 @@ SASL_USERPASS_MECHANISMS = [
|
||||||
]
|
]
|
||||||
|
|
||||||
class SASLResult(Enum):
|
class SASLResult(Enum):
|
||||||
|
NONE = 0
|
||||||
SUCCESS = 1
|
SUCCESS = 1
|
||||||
FAILURE = 2
|
FAILURE = 2
|
||||||
ALREADY = 3
|
ALREADY = 3
|
||||||
|
|
|
@ -9,7 +9,7 @@ from irctokens import build, Line, tokenise
|
||||||
from .ircv3 import Capability, CAPS, CAP_SASL
|
from .ircv3 import Capability, CAPS, CAP_SASL
|
||||||
from .interface import ConnectionParams, IServer, PriorityLine, SendPriority
|
from .interface import ConnectionParams, IServer, PriorityLine, SendPriority
|
||||||
from .matching import BaseResponse
|
from .matching import BaseResponse
|
||||||
from .sasl import SASLContext
|
from .sasl import SASLContext, SASLResult
|
||||||
|
|
||||||
sc = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
|
sc = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
|
||||||
|
|
||||||
|
@ -26,6 +26,7 @@ class Server(IServer):
|
||||||
|
|
||||||
self.throttle = Throttler(
|
self.throttle = Throttler(
|
||||||
rate_limit=THROTTLE_RATE, period=THROTTLE_TIME)
|
rate_limit=THROTTLE_RATE, period=THROTTLE_TIME)
|
||||||
|
self.sasl_state = SASLResult.NONE
|
||||||
|
|
||||||
self._write_queue: PriorityQueue[PriorityLine] = PriorityQueue()
|
self._write_queue: PriorityQueue[PriorityLine] = PriorityQueue()
|
||||||
|
|
||||||
|
@ -66,6 +67,8 @@ class Server(IServer):
|
||||||
await self._cap_ls_done()
|
await self._cap_ls_done()
|
||||||
elif emit.subcommand in ["ACK", "NAK"]:
|
elif emit.subcommand in ["ACK", "NAK"]:
|
||||||
await self._cap_ack(emit)
|
await self._cap_ack(emit)
|
||||||
|
elif emit.subcommand == "NEW":
|
||||||
|
await self._cap_new()
|
||||||
|
|
||||||
async def _on_read_line(self, line: Line):
|
async def _on_read_line(self, line: Line):
|
||||||
for i, (response, future) in enumerate(self._wait_for):
|
for i, (response, future) in enumerate(self._wait_for):
|
||||||
|
@ -124,13 +127,27 @@ class Server(IServer):
|
||||||
if matches:
|
if matches:
|
||||||
self._requested_caps = matches
|
self._requested_caps = matches
|
||||||
await self.send(build("CAP", ["REQ", " ".join(matches)]))
|
await self.send(build("CAP", ["REQ", " ".join(matches)]))
|
||||||
|
|
||||||
async def _cap_ack(self, emit: Emit):
|
async def _cap_ack(self, emit: Emit):
|
||||||
if not self.params.sasl is None and self.cap_agreed(CAP_SASL):
|
await self._maybe_sasl()
|
||||||
await SASLContext(self).from_params(self.params.sasl)
|
|
||||||
|
|
||||||
for cap in (emit.tokens or []):
|
for cap in (emit.tokens or []):
|
||||||
if cap in self._requested_caps:
|
if cap in self._requested_caps:
|
||||||
self._requested_caps.remove(cap)
|
self._requested_caps.remove(cap)
|
||||||
if not self._requested_caps:
|
if not self._requested_caps:
|
||||||
await self.send(build("CAP", ["END"]))
|
await self.send(build("CAP", ["END"]))
|
||||||
|
|
||||||
|
async def _cap_new(self):
|
||||||
|
await self._maybe_sasl()
|
||||||
|
|
||||||
|
async def _maybe_sasl(self) -> bool:
|
||||||
|
if (self.sasl_state == SASLResult.NONE and
|
||||||
|
not self.params.sasl is None and
|
||||||
|
self.cap_agreed(CAP_SASL)):
|
||||||
|
res = await SASLContext(self).from_params(self.params.sasl)
|
||||||
|
self.sasl_state = res
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
# /CAP-related
|
# /CAP-related
|
||||||
|
|
Loading…
Add table
Reference in a new issue