implement draft/resume-0.5

This commit is contained in:
jesopo 2020-04-25 19:30:36 +01:00
parent 064c786db7
commit 15b97ab3da
5 changed files with 54 additions and 10 deletions

View file

@ -5,7 +5,7 @@ from enum import IntEnum
from ircstates import Server, Emit
from irctokens import Line, Hostmask
from .params import ConnectionParams, SASLParams, STSPolicy
from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy
class ITCPReader(object):
async def read(self, byte_count: int):
@ -13,8 +13,14 @@ class ITCPReader(object):
class ITCPWriter(object):
def write(self, data: bytes):
pass
def get_peer(self) -> Tuple[str, int]:
pass
async def drain(self):
pass
async def close(self):
pass
class ITCPTransport(object):
async def connect(self,
@ -84,6 +90,9 @@ class IServer(Server):
def set_throttle(self, rate: int, time: float):
pass
def server_address(self) -> Tuple[str, int]:
pass
async def connect(self,
transport: ITCPTransport,
params: ConnectionParams):
@ -97,6 +106,8 @@ class IServer(Server):
pass
async def sts_policy(self, sts: STSPolicy):
pass
async def resume_policy(self, resume: ResumePolicy):
pass
async def next_line(self) -> Optional[Tuple[Line, Optional[Emit]]]:
pass

View file

@ -7,7 +7,7 @@ from ircstates.server import ServerDisconnectedException
from .contexts import ServerContext
from .matching import Response, ResponseOr, ANY
from .interface import ICapability
from .params import ConnectionParams, STSPolicy
from .params import ConnectionParams, STSPolicy, ResumePolicy
class Capability(ICapability):
def __init__(self,
@ -40,10 +40,11 @@ class Capability(ICapability):
alias=self.alias,
depends_on=self.depends_on[:])
CAP_SASL = Capability("sasl")
CAP_ECHO = Capability("echo-message")
CAP_LABEL = Capability("labeled-response", "draft/labeled-response-0.2")
CAP_STS = Capability("sts", "draft/sts")
CAP_SASL = Capability("sasl")
CAP_ECHO = Capability("echo-message")
CAP_LABEL = Capability("labeled-response", "draft/labeled-response-0.2")
CAP_STS = Capability("sts", "draft/sts")
CAP_RESUME = Capability(None, "draft/resume-0.5", alias="resume")
LABEL_TAG = {
"draft/labeled-response-0.2": "draft/label",
@ -65,7 +66,8 @@ CAPS: List[ICapability] = [
Capability("batch"),
Capability(None, "draft/rename", alias="rename"),
Capability("setname", "draft/setname")
Capability("setname", "draft/setname"),
CAP_RESUME
]
def _cap_dict(s: str) -> Dict[str, str]:
@ -82,6 +84,9 @@ async def sts_transmute(params: ConnectionParams):
if since <= params.sts.duration:
params.port = params.sts.port
params.tls = True
async def resume_transmute(params: ConnectionParams):
if params.resume is not None:
params.host = params.resume.address
class CAPContext(ServerContext):
async def on_ls(self, tokens: Dict[str, str]):
@ -109,10 +114,22 @@ class CAPContext(ServerContext):
for cap in current_caps:
if cap in cap_names:
cap_names.remove(cap)
if CAP_RESUME.available(current_caps):
await self.resume_token()
if (self.server.cap_agreed(CAP_SASL) and
not self.server.params.sasl is None):
await self.server.sasl_auth(self.server.params.sasl)
async def resume_token(self):
line = await self.server.wait_for(Response("RESUME", ["TOKEN", ANY]))
token = line.params[1]
address, port = self.server.server_address()
resume_policy = ResumePolicy(address, token)
self.server.params.resume = resume_policy
await self.server.resume_policy(resume_policy)
async def handshake(self):
await self.on_ls(self.server.available_caps)
await self.server.send(build("CAP", ["END"]))

View file

@ -27,6 +27,11 @@ class STSPolicy(object):
duration: int
preload: bool
@dataclass
class ResumePolicy(object):
address: str
token: str
@dataclass
class ConnectionParams(object):
nickname: str
@ -42,4 +47,5 @@ class ConnectionParams(object):
tls_verify: bool = True
sasl: Optional[SASLParams] = None
sts: Optional[STSPolicy] = None
sts: Optional[STSPolicy] = None
resume: Optional[ResumePolicy] = None

View file

@ -10,13 +10,13 @@ from ircstates.server import ServerDisconnectedException
from irctokens import build, Line, tokenise
from .ircv3 import (CAPContext, sts_transmute, CAP_ECHO, CAP_SASL,
CAP_LABEL, LABEL_TAG)
CAP_LABEL, LABEL_TAG, resume_transmute)
from .sasl import SASLContext, SASLResult
from .join_info import WHOContext
from .matching import ResponseOr, Responses, Response, ANY, Folded, Nickname
from .asyncs import MaybeAwait, WaitFor
from .struct import Whois
from .params import ConnectionParams, SASLParams, STSPolicy
from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy
from .interface import (IBot, ICapability, IServer, SentLine, SendPriority,
IMatchResponse)
from .interface import ITCPTransport, ITCPReader, ITCPWriter
@ -84,10 +84,14 @@ class Server(IServer):
self.throttle.rate_limit = rate
self.throttle.period = time
def server_address(self) -> Tuple[str, int]:
return self._writer.get_peer()
async def connect(self,
transport: ITCPTransport,
params: ConnectionParams):
await sts_transmute(params)
await resume_transmute(params)
reader, writer = await transport.connect(
params.host,
@ -126,6 +130,8 @@ class Server(IServer):
pass
async def sts_policy(self, sts: STSPolicy):
pass
async def resume_policy(self, resume: ResumePolicy):
pass
# /to be overriden
async def _on_read_emit(self, line: Line, emit: Emit):

View file

@ -16,6 +16,10 @@ class TCPWriter(ITCPWriter):
def __init__(self, writer: StreamWriter):
self._writer = writer
def get_peer(self) -> Tuple[str, int]:
address, port, *_ = self._writer.transport.get_extra_info("peername")
return (address, port)
def write(self, data: bytes):
self._writer.write(data)