enact provided ResumePolicy (incl. cancelling handshake)
This commit is contained in:
parent
15b97ab3da
commit
5b927beb25
2 changed files with 21 additions and 4 deletions
|
@ -1,5 +1,5 @@
|
||||||
from .bot import Bot
|
from .bot import Bot
|
||||||
from .server import Server
|
from .server import Server
|
||||||
from .params import (ConnectionParams, SASLUserPass, SASLExternal, SASLSCRAM,
|
from .params import (ConnectionParams, SASLUserPass, SASLExternal, SASLSCRAM,
|
||||||
STSPolicy)
|
STSPolicy, ResumePolicy)
|
||||||
from .ircv3 import Capability
|
from .ircv3 import Capability
|
||||||
|
|
|
@ -88,6 +88,9 @@ async def resume_transmute(params: ConnectionParams):
|
||||||
if params.resume is not None:
|
if params.resume is not None:
|
||||||
params.host = params.resume.address
|
params.host = params.resume.address
|
||||||
|
|
||||||
|
class HandshakeCancel(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
class CAPContext(ServerContext):
|
class CAPContext(ServerContext):
|
||||||
async def on_ls(self, tokens: Dict[str, str]):
|
async def on_ls(self, tokens: Dict[str, str]):
|
||||||
await self._sts(tokens)
|
await self._sts(tokens)
|
||||||
|
@ -125,14 +128,28 @@ class CAPContext(ServerContext):
|
||||||
line = await self.server.wait_for(Response("RESUME", ["TOKEN", ANY]))
|
line = await self.server.wait_for(Response("RESUME", ["TOKEN", ANY]))
|
||||||
token = line.params[1]
|
token = line.params[1]
|
||||||
address, port = self.server.server_address()
|
address, port = self.server.server_address()
|
||||||
|
|
||||||
resume_policy = ResumePolicy(address, token)
|
resume_policy = ResumePolicy(address, token)
|
||||||
|
|
||||||
|
previous_policy = self.server.params.resume
|
||||||
self.server.params.resume = resume_policy
|
self.server.params.resume = resume_policy
|
||||||
await self.server.resume_policy(resume_policy)
|
await self.server.resume_policy(resume_policy)
|
||||||
|
|
||||||
|
if previous_policy is not None and not self.server.registered:
|
||||||
|
await self.server.send(build("RESUME", [previous_policy.token]))
|
||||||
|
line = await self.server.wait_for(ResponseOr(
|
||||||
|
Response("RESUME", ["SUCCESS"]),
|
||||||
|
Response("FAIL", ["RESUME"])
|
||||||
|
))
|
||||||
|
if line.command == "RESUME":
|
||||||
|
raise HandshakeCancel()
|
||||||
|
|
||||||
async def handshake(self):
|
async def handshake(self):
|
||||||
await self.on_ls(self.server.available_caps)
|
try:
|
||||||
await self.server.send(build("CAP", ["END"]))
|
await self.on_ls(self.server.available_caps)
|
||||||
|
except HandshakeCancel:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
await self.server.send(build("CAP", ["END"]))
|
||||||
|
|
||||||
async def _sts(self, tokens: Dict[str, str]):
|
async def _sts(self, tokens: Dict[str, str]):
|
||||||
cap_sts = CAP_STS.available(tokens)
|
cap_sts = CAP_STS.available(tokens)
|
||||||
|
|
Loading…
Reference in a new issue