diff --git a/ircrobots/__init__.py b/ircrobots/__init__.py index baa110b..5b798ed 100644 --- a/ircrobots/__init__.py +++ b/ircrobots/__init__.py @@ -1,5 +1,5 @@ from .bot import Bot from .server import Server from .params import (ConnectionParams, SASLUserPass, SASLExternal, SASLSCRAM, - STSPolicy) + STSPolicy, ResumePolicy) from .ircv3 import Capability diff --git a/ircrobots/ircv3.py b/ircrobots/ircv3.py index 4477714..279cf79 100644 --- a/ircrobots/ircv3.py +++ b/ircrobots/ircv3.py @@ -88,6 +88,9 @@ async def resume_transmute(params: ConnectionParams): if params.resume is not None: params.host = params.resume.address +class HandshakeCancel(Exception): + pass + class CAPContext(ServerContext): async def on_ls(self, tokens: Dict[str, str]): await self._sts(tokens) @@ -125,14 +128,28 @@ class CAPContext(ServerContext): line = await self.server.wait_for(Response("RESUME", ["TOKEN", ANY])) token = line.params[1] address, port = self.server.server_address() - resume_policy = ResumePolicy(address, token) + + previous_policy = self.server.params.resume self.server.params.resume = resume_policy await self.server.resume_policy(resume_policy) + if previous_policy is not None and not self.server.registered: + await self.server.send(build("RESUME", [previous_policy.token])) + line = await self.server.wait_for(ResponseOr( + Response("RESUME", ["SUCCESS"]), + Response("FAIL", ["RESUME"]) + )) + if line.command == "RESUME": + raise HandshakeCancel() + async def handshake(self): - await self.on_ls(self.server.available_caps) - await self.server.send(build("CAP", ["END"])) + try: + await self.on_ls(self.server.available_caps) + except HandshakeCancel: + return + else: + await self.server.send(build("CAP", ["END"])) async def _sts(self, tokens: Dict[str, str]): cap_sts = CAP_STS.available(tokens)