add basic code for using labeled-responses as wait_for matches

This commit is contained in:
jesopo 2020-05-24 01:05:51 +01:00
parent 899c9c0b49
commit 33bcba8001
3 changed files with 51 additions and 11 deletions

View file

@ -1,8 +1,10 @@
from asyncio import Future from asyncio import Future
from irctokens import Line from irctokens import Line
from typing import Any, Awaitable, Callable, Generator, Generic, TypeVar from typing import (Any, Awaitable, Callable, Generator, Generic, Optional,
from .matching import IMatchResponse TypeVar)
from .matching import IMatchResponse
from .interface import IServer from .interface import IServer
from .ircv3 import TAG_LABEL
TEvent = TypeVar("TEvent") TEvent = TypeVar("TEvent")
class MaybeAwait(Generic[TEvent]): class MaybeAwait(Generic[TEvent]):
@ -16,9 +18,11 @@ class MaybeAwait(Generic[TEvent]):
class WaitFor(object): class WaitFor(object):
def __init__(self, def __init__(self,
wait_fut: "Future[WaitFor]", wait_fut: "Future[WaitFor]",
response: IMatchResponse): response: IMatchResponse,
label: Optional[str]):
self._wait_fut = wait_fut self._wait_fut = wait_fut
self.response = response self.response = response
self._label = label
self.deferred = False self.deferred = False
self._our_fut: "Future[Line]" = Future() self._our_fut: "Future[Line]" = Future()
@ -30,6 +34,12 @@ class WaitFor(object):
return await self return await self
def match(self, server: IServer, line: Line): def match(self, server: IServer, line: Line):
if (self._label is not None and
line.tags is not None):
label = TAG_LABEL.get(line.tags)
if (label is not None and
label == self._label):
return True
return self.response.match(server, line) return self.response.match(server, line)
def resolve(self, line: Line): def resolve(self, line: Line):

View file

@ -40,13 +40,36 @@ class Capability(ICapability):
alias=self.alias, alias=self.alias,
depends_on=self.depends_on[:]) depends_on=self.depends_on[:])
class MessageTag(object):
def __init__(self,
name: Optional[str],
draft_name: Optional[str]=None):
self.name = name
self.draft = draft_name
self._tags = [self.name, self.draft]
def available(self, tags: Iterable[str]) -> Optional[str]:
for tag in self._tags:
if tag is not None and tag in tags:
return tag
else:
return None
def get(self, tags: Dict[str, str]) -> Optional[str]:
name = self.available(tags)
if name is not None:
return tags[name]
else:
return None
CAP_SASL = Capability("sasl") CAP_SASL = Capability("sasl")
CAP_ECHO = Capability("echo-message") CAP_ECHO = Capability("echo-message")
CAP_LABEL = Capability("labeled-response", "draft/labeled-response-0.2")
CAP_STS = Capability("sts", "draft/sts") CAP_STS = Capability("sts", "draft/sts")
CAP_RESUME = Capability(None, "draft/resume-0.5", alias="resume") CAP_RESUME = Capability(None, "draft/resume-0.5", alias="resume")
LABEL_TAG = { CAP_LABEL = Capability("labeled-response", "draft/labeled-response-0.2")
TAG_LABEL = MessageTag("label", "draft/label")
LABEL_TAG_MAP = {
"draft/labeled-response-0.2": "draft/label", "draft/labeled-response-0.2": "draft/label",
"labeled-response": "label" "labeled-response": "label"
} }

View file

@ -13,7 +13,7 @@ from ircstates.server import ServerDisconnectedException
from irctokens import build, Line, tokenise from irctokens import build, Line, tokenise
from .ircv3 import (CAPContext, sts_transmute, CAP_ECHO, CAP_SASL, from .ircv3 import (CAPContext, sts_transmute, CAP_ECHO, CAP_SASL,
CAP_LABEL, LABEL_TAG, resume_transmute) CAP_LABEL, LABEL_TAG_MAP, resume_transmute)
from .sasl import SASLContext, SASLResult from .sasl import SASLContext, SASLResult
from .join_info import WHOContext from .join_info import WHOContext
from .matching import (ResponseOr, Responses, Response, ANY, SELF, MASK_SELF, from .matching import (ResponseOr, Responses, Response, ANY, SELF, MASK_SELF,
@ -85,7 +85,7 @@ class Server(IServer):
label = self.cap_available(CAP_LABEL) label = self.cap_available(CAP_LABEL)
if not label is None: if not label is None:
tag = LABEL_TAG[label] tag = LABEL_TAG_MAP[label]
if line.tags is None or not tag in line.tags: if line.tags is None or not tag in line.tags:
if line.tags is None: if line.tags is None:
line.tags = {} line.tags = {}
@ -259,8 +259,10 @@ class Server(IServer):
break break
def wait_for(self, def wait_for(self,
response: Union[IMatchResponse, Set[IMatchResponse]] response: Union[IMatchResponse, Set[IMatchResponse]],
sent_line: Optional[SentLine]=None
) -> Awaitable[Line]: ) -> Awaitable[Line]:
response_obj: IMatchResponse response_obj: IMatchResponse
if isinstance(response, set): if isinstance(response, set):
response_obj = ResponseOr(*response) response_obj = ResponseOr(*response)
@ -270,7 +272,12 @@ class Server(IServer):
wait_for_fut = self._wait_for_fut wait_for_fut = self._wait_for_fut
if wait_for_fut is not None: if wait_for_fut is not None:
self._wait_for_fut = None self._wait_for_fut = None
our_wait_for = WaitFor(wait_for_fut, response_obj)
label: Optional[str] = None
if sent_line is not None:
label = str(sent_line.id)
our_wait_for = WaitFor(wait_for_fut, response_obj, label)
return our_wait_for return our_wait_for
raise Exception() raise Exception()