implement draft/resume-0.5
This commit is contained in:
parent
064c786db7
commit
15b97ab3da
5 changed files with 54 additions and 10 deletions
|
@ -5,7 +5,7 @@ from enum import IntEnum
|
|||
from ircstates import Server, Emit
|
||||
from irctokens import Line, Hostmask
|
||||
|
||||
from .params import ConnectionParams, SASLParams, STSPolicy
|
||||
from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy
|
||||
|
||||
class ITCPReader(object):
|
||||
async def read(self, byte_count: int):
|
||||
|
@ -13,8 +13,14 @@ class ITCPReader(object):
|
|||
class ITCPWriter(object):
|
||||
def write(self, data: bytes):
|
||||
pass
|
||||
|
||||
def get_peer(self) -> Tuple[str, int]:
|
||||
pass
|
||||
|
||||
async def drain(self):
|
||||
pass
|
||||
async def close(self):
|
||||
pass
|
||||
|
||||
class ITCPTransport(object):
|
||||
async def connect(self,
|
||||
|
@ -84,6 +90,9 @@ class IServer(Server):
|
|||
def set_throttle(self, rate: int, time: float):
|
||||
pass
|
||||
|
||||
def server_address(self) -> Tuple[str, int]:
|
||||
pass
|
||||
|
||||
async def connect(self,
|
||||
transport: ITCPTransport,
|
||||
params: ConnectionParams):
|
||||
|
@ -97,6 +106,8 @@ class IServer(Server):
|
|||
pass
|
||||
async def sts_policy(self, sts: STSPolicy):
|
||||
pass
|
||||
async def resume_policy(self, resume: ResumePolicy):
|
||||
pass
|
||||
|
||||
async def next_line(self) -> Optional[Tuple[Line, Optional[Emit]]]:
|
||||
pass
|
||||
|
|
|
@ -7,7 +7,7 @@ from ircstates.server import ServerDisconnectedException
|
|||
from .contexts import ServerContext
|
||||
from .matching import Response, ResponseOr, ANY
|
||||
from .interface import ICapability
|
||||
from .params import ConnectionParams, STSPolicy
|
||||
from .params import ConnectionParams, STSPolicy, ResumePolicy
|
||||
|
||||
class Capability(ICapability):
|
||||
def __init__(self,
|
||||
|
@ -40,10 +40,11 @@ class Capability(ICapability):
|
|||
alias=self.alias,
|
||||
depends_on=self.depends_on[:])
|
||||
|
||||
CAP_SASL = Capability("sasl")
|
||||
CAP_ECHO = Capability("echo-message")
|
||||
CAP_LABEL = Capability("labeled-response", "draft/labeled-response-0.2")
|
||||
CAP_STS = Capability("sts", "draft/sts")
|
||||
CAP_SASL = Capability("sasl")
|
||||
CAP_ECHO = Capability("echo-message")
|
||||
CAP_LABEL = Capability("labeled-response", "draft/labeled-response-0.2")
|
||||
CAP_STS = Capability("sts", "draft/sts")
|
||||
CAP_RESUME = Capability(None, "draft/resume-0.5", alias="resume")
|
||||
|
||||
LABEL_TAG = {
|
||||
"draft/labeled-response-0.2": "draft/label",
|
||||
|
@ -65,7 +66,8 @@ CAPS: List[ICapability] = [
|
|||
Capability("batch"),
|
||||
|
||||
Capability(None, "draft/rename", alias="rename"),
|
||||
Capability("setname", "draft/setname")
|
||||
Capability("setname", "draft/setname"),
|
||||
CAP_RESUME
|
||||
]
|
||||
|
||||
def _cap_dict(s: str) -> Dict[str, str]:
|
||||
|
@ -82,6 +84,9 @@ async def sts_transmute(params: ConnectionParams):
|
|||
if since <= params.sts.duration:
|
||||
params.port = params.sts.port
|
||||
params.tls = True
|
||||
async def resume_transmute(params: ConnectionParams):
|
||||
if params.resume is not None:
|
||||
params.host = params.resume.address
|
||||
|
||||
class CAPContext(ServerContext):
|
||||
async def on_ls(self, tokens: Dict[str, str]):
|
||||
|
@ -109,10 +114,22 @@ class CAPContext(ServerContext):
|
|||
for cap in current_caps:
|
||||
if cap in cap_names:
|
||||
cap_names.remove(cap)
|
||||
if CAP_RESUME.available(current_caps):
|
||||
await self.resume_token()
|
||||
|
||||
if (self.server.cap_agreed(CAP_SASL) and
|
||||
not self.server.params.sasl is None):
|
||||
await self.server.sasl_auth(self.server.params.sasl)
|
||||
|
||||
async def resume_token(self):
|
||||
line = await self.server.wait_for(Response("RESUME", ["TOKEN", ANY]))
|
||||
token = line.params[1]
|
||||
address, port = self.server.server_address()
|
||||
|
||||
resume_policy = ResumePolicy(address, token)
|
||||
self.server.params.resume = resume_policy
|
||||
await self.server.resume_policy(resume_policy)
|
||||
|
||||
async def handshake(self):
|
||||
await self.on_ls(self.server.available_caps)
|
||||
await self.server.send(build("CAP", ["END"]))
|
||||
|
|
|
@ -27,6 +27,11 @@ class STSPolicy(object):
|
|||
duration: int
|
||||
preload: bool
|
||||
|
||||
@dataclass
|
||||
class ResumePolicy(object):
|
||||
address: str
|
||||
token: str
|
||||
|
||||
@dataclass
|
||||
class ConnectionParams(object):
|
||||
nickname: str
|
||||
|
@ -42,4 +47,5 @@ class ConnectionParams(object):
|
|||
tls_verify: bool = True
|
||||
sasl: Optional[SASLParams] = None
|
||||
|
||||
sts: Optional[STSPolicy] = None
|
||||
sts: Optional[STSPolicy] = None
|
||||
resume: Optional[ResumePolicy] = None
|
||||
|
|
|
@ -10,13 +10,13 @@ from ircstates.server import ServerDisconnectedException
|
|||
from irctokens import build, Line, tokenise
|
||||
|
||||
from .ircv3 import (CAPContext, sts_transmute, CAP_ECHO, CAP_SASL,
|
||||
CAP_LABEL, LABEL_TAG)
|
||||
CAP_LABEL, LABEL_TAG, resume_transmute)
|
||||
from .sasl import SASLContext, SASLResult
|
||||
from .join_info import WHOContext
|
||||
from .matching import ResponseOr, Responses, Response, ANY, Folded, Nickname
|
||||
from .asyncs import MaybeAwait, WaitFor
|
||||
from .struct import Whois
|
||||
from .params import ConnectionParams, SASLParams, STSPolicy
|
||||
from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy
|
||||
from .interface import (IBot, ICapability, IServer, SentLine, SendPriority,
|
||||
IMatchResponse)
|
||||
from .interface import ITCPTransport, ITCPReader, ITCPWriter
|
||||
|
@ -84,10 +84,14 @@ class Server(IServer):
|
|||
self.throttle.rate_limit = rate
|
||||
self.throttle.period = time
|
||||
|
||||
def server_address(self) -> Tuple[str, int]:
|
||||
return self._writer.get_peer()
|
||||
|
||||
async def connect(self,
|
||||
transport: ITCPTransport,
|
||||
params: ConnectionParams):
|
||||
await sts_transmute(params)
|
||||
await resume_transmute(params)
|
||||
|
||||
reader, writer = await transport.connect(
|
||||
params.host,
|
||||
|
@ -126,6 +130,8 @@ class Server(IServer):
|
|||
pass
|
||||
async def sts_policy(self, sts: STSPolicy):
|
||||
pass
|
||||
async def resume_policy(self, resume: ResumePolicy):
|
||||
pass
|
||||
# /to be overriden
|
||||
|
||||
async def _on_read_emit(self, line: Line, emit: Emit):
|
||||
|
|
|
@ -16,6 +16,10 @@ class TCPWriter(ITCPWriter):
|
|||
def __init__(self, writer: StreamWriter):
|
||||
self._writer = writer
|
||||
|
||||
def get_peer(self) -> Tuple[str, int]:
|
||||
address, port, *_ = self._writer.transport.get_extra_info("peername")
|
||||
return (address, port)
|
||||
|
||||
def write(self, data: bytes):
|
||||
self._writer.write(data)
|
||||
|
||||
|
|
Loading…
Reference in a new issue