implement draft/resume-0.5

This commit is contained in:
jesopo 2020-04-25 19:30:36 +01:00
parent 064c786db7
commit 15b97ab3da
5 changed files with 54 additions and 10 deletions

View file

@ -5,7 +5,7 @@ from enum import IntEnum
from ircstates import Server, Emit from ircstates import Server, Emit
from irctokens import Line, Hostmask from irctokens import Line, Hostmask
from .params import ConnectionParams, SASLParams, STSPolicy from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy
class ITCPReader(object): class ITCPReader(object):
async def read(self, byte_count: int): async def read(self, byte_count: int):
@ -13,8 +13,14 @@ class ITCPReader(object):
class ITCPWriter(object): class ITCPWriter(object):
def write(self, data: bytes): def write(self, data: bytes):
pass pass
def get_peer(self) -> Tuple[str, int]:
pass
async def drain(self): async def drain(self):
pass pass
async def close(self):
pass
class ITCPTransport(object): class ITCPTransport(object):
async def connect(self, async def connect(self,
@ -84,6 +90,9 @@ class IServer(Server):
def set_throttle(self, rate: int, time: float): def set_throttle(self, rate: int, time: float):
pass pass
def server_address(self) -> Tuple[str, int]:
pass
async def connect(self, async def connect(self,
transport: ITCPTransport, transport: ITCPTransport,
params: ConnectionParams): params: ConnectionParams):
@ -97,6 +106,8 @@ class IServer(Server):
pass pass
async def sts_policy(self, sts: STSPolicy): async def sts_policy(self, sts: STSPolicy):
pass pass
async def resume_policy(self, resume: ResumePolicy):
pass
async def next_line(self) -> Optional[Tuple[Line, Optional[Emit]]]: async def next_line(self) -> Optional[Tuple[Line, Optional[Emit]]]:
pass pass

View file

@ -7,7 +7,7 @@ from ircstates.server import ServerDisconnectedException
from .contexts import ServerContext from .contexts import ServerContext
from .matching import Response, ResponseOr, ANY from .matching import Response, ResponseOr, ANY
from .interface import ICapability from .interface import ICapability
from .params import ConnectionParams, STSPolicy from .params import ConnectionParams, STSPolicy, ResumePolicy
class Capability(ICapability): class Capability(ICapability):
def __init__(self, def __init__(self,
@ -44,6 +44,7 @@ CAP_SASL = Capability("sasl")
CAP_ECHO = Capability("echo-message") CAP_ECHO = Capability("echo-message")
CAP_LABEL = Capability("labeled-response", "draft/labeled-response-0.2") CAP_LABEL = Capability("labeled-response", "draft/labeled-response-0.2")
CAP_STS = Capability("sts", "draft/sts") CAP_STS = Capability("sts", "draft/sts")
CAP_RESUME = Capability(None, "draft/resume-0.5", alias="resume")
LABEL_TAG = { LABEL_TAG = {
"draft/labeled-response-0.2": "draft/label", "draft/labeled-response-0.2": "draft/label",
@ -65,7 +66,8 @@ CAPS: List[ICapability] = [
Capability("batch"), Capability("batch"),
Capability(None, "draft/rename", alias="rename"), Capability(None, "draft/rename", alias="rename"),
Capability("setname", "draft/setname") Capability("setname", "draft/setname"),
CAP_RESUME
] ]
def _cap_dict(s: str) -> Dict[str, str]: def _cap_dict(s: str) -> Dict[str, str]:
@ -82,6 +84,9 @@ async def sts_transmute(params: ConnectionParams):
if since <= params.sts.duration: if since <= params.sts.duration:
params.port = params.sts.port params.port = params.sts.port
params.tls = True params.tls = True
async def resume_transmute(params: ConnectionParams):
if params.resume is not None:
params.host = params.resume.address
class CAPContext(ServerContext): class CAPContext(ServerContext):
async def on_ls(self, tokens: Dict[str, str]): async def on_ls(self, tokens: Dict[str, str]):
@ -109,10 +114,22 @@ class CAPContext(ServerContext):
for cap in current_caps: for cap in current_caps:
if cap in cap_names: if cap in cap_names:
cap_names.remove(cap) cap_names.remove(cap)
if CAP_RESUME.available(current_caps):
await self.resume_token()
if (self.server.cap_agreed(CAP_SASL) and if (self.server.cap_agreed(CAP_SASL) and
not self.server.params.sasl is None): not self.server.params.sasl is None):
await self.server.sasl_auth(self.server.params.sasl) await self.server.sasl_auth(self.server.params.sasl)
async def resume_token(self):
line = await self.server.wait_for(Response("RESUME", ["TOKEN", ANY]))
token = line.params[1]
address, port = self.server.server_address()
resume_policy = ResumePolicy(address, token)
self.server.params.resume = resume_policy
await self.server.resume_policy(resume_policy)
async def handshake(self): async def handshake(self):
await self.on_ls(self.server.available_caps) await self.on_ls(self.server.available_caps)
await self.server.send(build("CAP", ["END"])) await self.server.send(build("CAP", ["END"]))

View file

@ -27,6 +27,11 @@ class STSPolicy(object):
duration: int duration: int
preload: bool preload: bool
@dataclass
class ResumePolicy(object):
address: str
token: str
@dataclass @dataclass
class ConnectionParams(object): class ConnectionParams(object):
nickname: str nickname: str
@ -43,3 +48,4 @@ class ConnectionParams(object):
sasl: Optional[SASLParams] = None sasl: Optional[SASLParams] = None
sts: Optional[STSPolicy] = None sts: Optional[STSPolicy] = None
resume: Optional[ResumePolicy] = None

View file

@ -10,13 +10,13 @@ from ircstates.server import ServerDisconnectedException
from irctokens import build, Line, tokenise from irctokens import build, Line, tokenise
from .ircv3 import (CAPContext, sts_transmute, CAP_ECHO, CAP_SASL, from .ircv3 import (CAPContext, sts_transmute, CAP_ECHO, CAP_SASL,
CAP_LABEL, LABEL_TAG) CAP_LABEL, LABEL_TAG, resume_transmute)
from .sasl import SASLContext, SASLResult from .sasl import SASLContext, SASLResult
from .join_info import WHOContext from .join_info import WHOContext
from .matching import ResponseOr, Responses, Response, ANY, Folded, Nickname from .matching import ResponseOr, Responses, Response, ANY, Folded, Nickname
from .asyncs import MaybeAwait, WaitFor from .asyncs import MaybeAwait, WaitFor
from .struct import Whois from .struct import Whois
from .params import ConnectionParams, SASLParams, STSPolicy from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy
from .interface import (IBot, ICapability, IServer, SentLine, SendPriority, from .interface import (IBot, ICapability, IServer, SentLine, SendPriority,
IMatchResponse) IMatchResponse)
from .interface import ITCPTransport, ITCPReader, ITCPWriter from .interface import ITCPTransport, ITCPReader, ITCPWriter
@ -84,10 +84,14 @@ class Server(IServer):
self.throttle.rate_limit = rate self.throttle.rate_limit = rate
self.throttle.period = time self.throttle.period = time
def server_address(self) -> Tuple[str, int]:
return self._writer.get_peer()
async def connect(self, async def connect(self,
transport: ITCPTransport, transport: ITCPTransport,
params: ConnectionParams): params: ConnectionParams):
await sts_transmute(params) await sts_transmute(params)
await resume_transmute(params)
reader, writer = await transport.connect( reader, writer = await transport.connect(
params.host, params.host,
@ -126,6 +130,8 @@ class Server(IServer):
pass pass
async def sts_policy(self, sts: STSPolicy): async def sts_policy(self, sts: STSPolicy):
pass pass
async def resume_policy(self, resume: ResumePolicy):
pass
# /to be overriden # /to be overriden
async def _on_read_emit(self, line: Line, emit: Emit): async def _on_read_emit(self, line: Line, emit: Emit):

View file

@ -16,6 +16,10 @@ class TCPWriter(ITCPWriter):
def __init__(self, writer: StreamWriter): def __init__(self, writer: StreamWriter):
self._writer = writer self._writer = writer
def get_peer(self) -> Tuple[str, int]:
address, port, *_ = self._writer.transport.get_extra_info("peername")
return (address, port)
def write(self, data: bytes): def write(self, data: bytes):
self._writer.write(data) self._writer.write(data)