implement draft/resume-0.5
This commit is contained in:
parent
064c786db7
commit
15b97ab3da
5 changed files with 54 additions and 10 deletions
|
@ -5,7 +5,7 @@ from enum import IntEnum
|
||||||
from ircstates import Server, Emit
|
from ircstates import Server, Emit
|
||||||
from irctokens import Line, Hostmask
|
from irctokens import Line, Hostmask
|
||||||
|
|
||||||
from .params import ConnectionParams, SASLParams, STSPolicy
|
from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy
|
||||||
|
|
||||||
class ITCPReader(object):
|
class ITCPReader(object):
|
||||||
async def read(self, byte_count: int):
|
async def read(self, byte_count: int):
|
||||||
|
@ -13,8 +13,14 @@ class ITCPReader(object):
|
||||||
class ITCPWriter(object):
|
class ITCPWriter(object):
|
||||||
def write(self, data: bytes):
|
def write(self, data: bytes):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def get_peer(self) -> Tuple[str, int]:
|
||||||
|
pass
|
||||||
|
|
||||||
async def drain(self):
|
async def drain(self):
|
||||||
pass
|
pass
|
||||||
|
async def close(self):
|
||||||
|
pass
|
||||||
|
|
||||||
class ITCPTransport(object):
|
class ITCPTransport(object):
|
||||||
async def connect(self,
|
async def connect(self,
|
||||||
|
@ -84,6 +90,9 @@ class IServer(Server):
|
||||||
def set_throttle(self, rate: int, time: float):
|
def set_throttle(self, rate: int, time: float):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def server_address(self) -> Tuple[str, int]:
|
||||||
|
pass
|
||||||
|
|
||||||
async def connect(self,
|
async def connect(self,
|
||||||
transport: ITCPTransport,
|
transport: ITCPTransport,
|
||||||
params: ConnectionParams):
|
params: ConnectionParams):
|
||||||
|
@ -97,6 +106,8 @@ class IServer(Server):
|
||||||
pass
|
pass
|
||||||
async def sts_policy(self, sts: STSPolicy):
|
async def sts_policy(self, sts: STSPolicy):
|
||||||
pass
|
pass
|
||||||
|
async def resume_policy(self, resume: ResumePolicy):
|
||||||
|
pass
|
||||||
|
|
||||||
async def next_line(self) -> Optional[Tuple[Line, Optional[Emit]]]:
|
async def next_line(self) -> Optional[Tuple[Line, Optional[Emit]]]:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -7,7 +7,7 @@ from ircstates.server import ServerDisconnectedException
|
||||||
from .contexts import ServerContext
|
from .contexts import ServerContext
|
||||||
from .matching import Response, ResponseOr, ANY
|
from .matching import Response, ResponseOr, ANY
|
||||||
from .interface import ICapability
|
from .interface import ICapability
|
||||||
from .params import ConnectionParams, STSPolicy
|
from .params import ConnectionParams, STSPolicy, ResumePolicy
|
||||||
|
|
||||||
class Capability(ICapability):
|
class Capability(ICapability):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -44,6 +44,7 @@ CAP_SASL = Capability("sasl")
|
||||||
CAP_ECHO = Capability("echo-message")
|
CAP_ECHO = Capability("echo-message")
|
||||||
CAP_LABEL = Capability("labeled-response", "draft/labeled-response-0.2")
|
CAP_LABEL = Capability("labeled-response", "draft/labeled-response-0.2")
|
||||||
CAP_STS = Capability("sts", "draft/sts")
|
CAP_STS = Capability("sts", "draft/sts")
|
||||||
|
CAP_RESUME = Capability(None, "draft/resume-0.5", alias="resume")
|
||||||
|
|
||||||
LABEL_TAG = {
|
LABEL_TAG = {
|
||||||
"draft/labeled-response-0.2": "draft/label",
|
"draft/labeled-response-0.2": "draft/label",
|
||||||
|
@ -65,7 +66,8 @@ CAPS: List[ICapability] = [
|
||||||
Capability("batch"),
|
Capability("batch"),
|
||||||
|
|
||||||
Capability(None, "draft/rename", alias="rename"),
|
Capability(None, "draft/rename", alias="rename"),
|
||||||
Capability("setname", "draft/setname")
|
Capability("setname", "draft/setname"),
|
||||||
|
CAP_RESUME
|
||||||
]
|
]
|
||||||
|
|
||||||
def _cap_dict(s: str) -> Dict[str, str]:
|
def _cap_dict(s: str) -> Dict[str, str]:
|
||||||
|
@ -82,6 +84,9 @@ async def sts_transmute(params: ConnectionParams):
|
||||||
if since <= params.sts.duration:
|
if since <= params.sts.duration:
|
||||||
params.port = params.sts.port
|
params.port = params.sts.port
|
||||||
params.tls = True
|
params.tls = True
|
||||||
|
async def resume_transmute(params: ConnectionParams):
|
||||||
|
if params.resume is not None:
|
||||||
|
params.host = params.resume.address
|
||||||
|
|
||||||
class CAPContext(ServerContext):
|
class CAPContext(ServerContext):
|
||||||
async def on_ls(self, tokens: Dict[str, str]):
|
async def on_ls(self, tokens: Dict[str, str]):
|
||||||
|
@ -109,10 +114,22 @@ class CAPContext(ServerContext):
|
||||||
for cap in current_caps:
|
for cap in current_caps:
|
||||||
if cap in cap_names:
|
if cap in cap_names:
|
||||||
cap_names.remove(cap)
|
cap_names.remove(cap)
|
||||||
|
if CAP_RESUME.available(current_caps):
|
||||||
|
await self.resume_token()
|
||||||
|
|
||||||
if (self.server.cap_agreed(CAP_SASL) and
|
if (self.server.cap_agreed(CAP_SASL) and
|
||||||
not self.server.params.sasl is None):
|
not self.server.params.sasl is None):
|
||||||
await self.server.sasl_auth(self.server.params.sasl)
|
await self.server.sasl_auth(self.server.params.sasl)
|
||||||
|
|
||||||
|
async def resume_token(self):
|
||||||
|
line = await self.server.wait_for(Response("RESUME", ["TOKEN", ANY]))
|
||||||
|
token = line.params[1]
|
||||||
|
address, port = self.server.server_address()
|
||||||
|
|
||||||
|
resume_policy = ResumePolicy(address, token)
|
||||||
|
self.server.params.resume = resume_policy
|
||||||
|
await self.server.resume_policy(resume_policy)
|
||||||
|
|
||||||
async def handshake(self):
|
async def handshake(self):
|
||||||
await self.on_ls(self.server.available_caps)
|
await self.on_ls(self.server.available_caps)
|
||||||
await self.server.send(build("CAP", ["END"]))
|
await self.server.send(build("CAP", ["END"]))
|
||||||
|
|
|
@ -27,6 +27,11 @@ class STSPolicy(object):
|
||||||
duration: int
|
duration: int
|
||||||
preload: bool
|
preload: bool
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ResumePolicy(object):
|
||||||
|
address: str
|
||||||
|
token: str
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ConnectionParams(object):
|
class ConnectionParams(object):
|
||||||
nickname: str
|
nickname: str
|
||||||
|
@ -43,3 +48,4 @@ class ConnectionParams(object):
|
||||||
sasl: Optional[SASLParams] = None
|
sasl: Optional[SASLParams] = None
|
||||||
|
|
||||||
sts: Optional[STSPolicy] = None
|
sts: Optional[STSPolicy] = None
|
||||||
|
resume: Optional[ResumePolicy] = None
|
||||||
|
|
|
@ -10,13 +10,13 @@ from ircstates.server import ServerDisconnectedException
|
||||||
from irctokens import build, Line, tokenise
|
from irctokens import build, Line, tokenise
|
||||||
|
|
||||||
from .ircv3 import (CAPContext, sts_transmute, CAP_ECHO, CAP_SASL,
|
from .ircv3 import (CAPContext, sts_transmute, CAP_ECHO, CAP_SASL,
|
||||||
CAP_LABEL, LABEL_TAG)
|
CAP_LABEL, LABEL_TAG, resume_transmute)
|
||||||
from .sasl import SASLContext, SASLResult
|
from .sasl import SASLContext, SASLResult
|
||||||
from .join_info import WHOContext
|
from .join_info import WHOContext
|
||||||
from .matching import ResponseOr, Responses, Response, ANY, Folded, Nickname
|
from .matching import ResponseOr, Responses, Response, ANY, Folded, Nickname
|
||||||
from .asyncs import MaybeAwait, WaitFor
|
from .asyncs import MaybeAwait, WaitFor
|
||||||
from .struct import Whois
|
from .struct import Whois
|
||||||
from .params import ConnectionParams, SASLParams, STSPolicy
|
from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy
|
||||||
from .interface import (IBot, ICapability, IServer, SentLine, SendPriority,
|
from .interface import (IBot, ICapability, IServer, SentLine, SendPriority,
|
||||||
IMatchResponse)
|
IMatchResponse)
|
||||||
from .interface import ITCPTransport, ITCPReader, ITCPWriter
|
from .interface import ITCPTransport, ITCPReader, ITCPWriter
|
||||||
|
@ -84,10 +84,14 @@ class Server(IServer):
|
||||||
self.throttle.rate_limit = rate
|
self.throttle.rate_limit = rate
|
||||||
self.throttle.period = time
|
self.throttle.period = time
|
||||||
|
|
||||||
|
def server_address(self) -> Tuple[str, int]:
|
||||||
|
return self._writer.get_peer()
|
||||||
|
|
||||||
async def connect(self,
|
async def connect(self,
|
||||||
transport: ITCPTransport,
|
transport: ITCPTransport,
|
||||||
params: ConnectionParams):
|
params: ConnectionParams):
|
||||||
await sts_transmute(params)
|
await sts_transmute(params)
|
||||||
|
await resume_transmute(params)
|
||||||
|
|
||||||
reader, writer = await transport.connect(
|
reader, writer = await transport.connect(
|
||||||
params.host,
|
params.host,
|
||||||
|
@ -126,6 +130,8 @@ class Server(IServer):
|
||||||
pass
|
pass
|
||||||
async def sts_policy(self, sts: STSPolicy):
|
async def sts_policy(self, sts: STSPolicy):
|
||||||
pass
|
pass
|
||||||
|
async def resume_policy(self, resume: ResumePolicy):
|
||||||
|
pass
|
||||||
# /to be overriden
|
# /to be overriden
|
||||||
|
|
||||||
async def _on_read_emit(self, line: Line, emit: Emit):
|
async def _on_read_emit(self, line: Line, emit: Emit):
|
||||||
|
|
|
@ -16,6 +16,10 @@ class TCPWriter(ITCPWriter):
|
||||||
def __init__(self, writer: StreamWriter):
|
def __init__(self, writer: StreamWriter):
|
||||||
self._writer = writer
|
self._writer = writer
|
||||||
|
|
||||||
|
def get_peer(self) -> Tuple[str, int]:
|
||||||
|
address, port, *_ = self._writer.transport.get_extra_info("peername")
|
||||||
|
return (address, port)
|
||||||
|
|
||||||
def write(self, data: bytes):
|
def write(self, data: bytes):
|
||||||
self._writer.write(data)
|
self._writer.write(data)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue