enact provided ResumePolicy (incl. cancelling handshake)

This commit is contained in:
jesopo 2020-04-25 20:13:46 +01:00
parent 15b97ab3da
commit 5b927beb25
2 changed files with 21 additions and 4 deletions

View file

@ -1,5 +1,5 @@
from .bot import Bot from .bot import Bot
from .server import Server from .server import Server
from .params import (ConnectionParams, SASLUserPass, SASLExternal, SASLSCRAM, from .params import (ConnectionParams, SASLUserPass, SASLExternal, SASLSCRAM,
STSPolicy) STSPolicy, ResumePolicy)
from .ircv3 import Capability from .ircv3 import Capability

View file

@ -88,6 +88,9 @@ async def resume_transmute(params: ConnectionParams):
if params.resume is not None: if params.resume is not None:
params.host = params.resume.address params.host = params.resume.address
class HandshakeCancel(Exception):
pass
class CAPContext(ServerContext): class CAPContext(ServerContext):
async def on_ls(self, tokens: Dict[str, str]): async def on_ls(self, tokens: Dict[str, str]):
await self._sts(tokens) await self._sts(tokens)
@ -125,14 +128,28 @@ class CAPContext(ServerContext):
line = await self.server.wait_for(Response("RESUME", ["TOKEN", ANY])) line = await self.server.wait_for(Response("RESUME", ["TOKEN", ANY]))
token = line.params[1] token = line.params[1]
address, port = self.server.server_address() address, port = self.server.server_address()
resume_policy = ResumePolicy(address, token) resume_policy = ResumePolicy(address, token)
previous_policy = self.server.params.resume
self.server.params.resume = resume_policy self.server.params.resume = resume_policy
await self.server.resume_policy(resume_policy) await self.server.resume_policy(resume_policy)
if previous_policy is not None and not self.server.registered:
await self.server.send(build("RESUME", [previous_policy.token]))
line = await self.server.wait_for(ResponseOr(
Response("RESUME", ["SUCCESS"]),
Response("FAIL", ["RESUME"])
))
if line.command == "RESUME":
raise HandshakeCancel()
async def handshake(self): async def handshake(self):
await self.on_ls(self.server.available_caps) try:
await self.server.send(build("CAP", ["END"])) await self.on_ls(self.server.available_caps)
except HandshakeCancel:
return
else:
await self.server.send(build("CAP", ["END"]))
async def _sts(self, tokens: Dict[str, str]): async def _sts(self, tokens: Dict[str, str]):
cap_sts = CAP_STS.available(tokens) cap_sts = CAP_STS.available(tokens)