add basic code for using labeled-responses as wait_for matches
This commit is contained in:
parent
899c9c0b49
commit
33bcba8001
3 changed files with 51 additions and 11 deletions
|
@ -1,8 +1,10 @@
|
|||
from asyncio import Future
|
||||
from irctokens import Line
|
||||
from typing import Any, Awaitable, Callable, Generator, Generic, TypeVar
|
||||
from .matching import IMatchResponse
|
||||
from asyncio import Future
|
||||
from irctokens import Line
|
||||
from typing import (Any, Awaitable, Callable, Generator, Generic, Optional,
|
||||
TypeVar)
|
||||
from .matching import IMatchResponse
|
||||
from .interface import IServer
|
||||
from .ircv3 import TAG_LABEL
|
||||
|
||||
TEvent = TypeVar("TEvent")
|
||||
class MaybeAwait(Generic[TEvent]):
|
||||
|
@ -16,9 +18,11 @@ class MaybeAwait(Generic[TEvent]):
|
|||
class WaitFor(object):
|
||||
def __init__(self,
|
||||
wait_fut: "Future[WaitFor]",
|
||||
response: IMatchResponse):
|
||||
response: IMatchResponse,
|
||||
label: Optional[str]):
|
||||
self._wait_fut = wait_fut
|
||||
self.response = response
|
||||
self._label = label
|
||||
self.deferred = False
|
||||
self._our_fut: "Future[Line]" = Future()
|
||||
|
||||
|
@ -30,6 +34,12 @@ class WaitFor(object):
|
|||
return await self
|
||||
|
||||
def match(self, server: IServer, line: Line):
|
||||
if (self._label is not None and
|
||||
line.tags is not None):
|
||||
label = TAG_LABEL.get(line.tags)
|
||||
if (label is not None and
|
||||
label == self._label):
|
||||
return True
|
||||
return self.response.match(server, line)
|
||||
|
||||
def resolve(self, line: Line):
|
||||
|
|
|
@ -40,13 +40,36 @@ class Capability(ICapability):
|
|||
alias=self.alias,
|
||||
depends_on=self.depends_on[:])
|
||||
|
||||
class MessageTag(object):
|
||||
def __init__(self,
|
||||
name: Optional[str],
|
||||
draft_name: Optional[str]=None):
|
||||
self.name = name
|
||||
self.draft = draft_name
|
||||
self._tags = [self.name, self.draft]
|
||||
|
||||
def available(self, tags: Iterable[str]) -> Optional[str]:
|
||||
for tag in self._tags:
|
||||
if tag is not None and tag in tags:
|
||||
return tag
|
||||
else:
|
||||
return None
|
||||
|
||||
def get(self, tags: Dict[str, str]) -> Optional[str]:
|
||||
name = self.available(tags)
|
||||
if name is not None:
|
||||
return tags[name]
|
||||
else:
|
||||
return None
|
||||
|
||||
CAP_SASL = Capability("sasl")
|
||||
CAP_ECHO = Capability("echo-message")
|
||||
CAP_LABEL = Capability("labeled-response", "draft/labeled-response-0.2")
|
||||
CAP_STS = Capability("sts", "draft/sts")
|
||||
CAP_RESUME = Capability(None, "draft/resume-0.5", alias="resume")
|
||||
|
||||
LABEL_TAG = {
|
||||
CAP_LABEL = Capability("labeled-response", "draft/labeled-response-0.2")
|
||||
TAG_LABEL = MessageTag("label", "draft/label")
|
||||
LABEL_TAG_MAP = {
|
||||
"draft/labeled-response-0.2": "draft/label",
|
||||
"labeled-response": "label"
|
||||
}
|
||||
|
|
|
@ -13,7 +13,7 @@ from ircstates.server import ServerDisconnectedException
|
|||
from irctokens import build, Line, tokenise
|
||||
|
||||
from .ircv3 import (CAPContext, sts_transmute, CAP_ECHO, CAP_SASL,
|
||||
CAP_LABEL, LABEL_TAG, resume_transmute)
|
||||
CAP_LABEL, LABEL_TAG_MAP, resume_transmute)
|
||||
from .sasl import SASLContext, SASLResult
|
||||
from .join_info import WHOContext
|
||||
from .matching import (ResponseOr, Responses, Response, ANY, SELF, MASK_SELF,
|
||||
|
@ -85,7 +85,7 @@ class Server(IServer):
|
|||
|
||||
label = self.cap_available(CAP_LABEL)
|
||||
if not label is None:
|
||||
tag = LABEL_TAG[label]
|
||||
tag = LABEL_TAG_MAP[label]
|
||||
if line.tags is None or not tag in line.tags:
|
||||
if line.tags is None:
|
||||
line.tags = {}
|
||||
|
@ -259,8 +259,10 @@ class Server(IServer):
|
|||
break
|
||||
|
||||
def wait_for(self,
|
||||
response: Union[IMatchResponse, Set[IMatchResponse]]
|
||||
response: Union[IMatchResponse, Set[IMatchResponse]],
|
||||
sent_line: Optional[SentLine]=None
|
||||
) -> Awaitable[Line]:
|
||||
|
||||
response_obj: IMatchResponse
|
||||
if isinstance(response, set):
|
||||
response_obj = ResponseOr(*response)
|
||||
|
@ -270,7 +272,12 @@ class Server(IServer):
|
|||
wait_for_fut = self._wait_for_fut
|
||||
if wait_for_fut is not None:
|
||||
self._wait_for_fut = None
|
||||
our_wait_for = WaitFor(wait_for_fut, response_obj)
|
||||
|
||||
label: Optional[str] = None
|
||||
if sent_line is not None:
|
||||
label = str(sent_line.id)
|
||||
|
||||
our_wait_for = WaitFor(wait_for_fut, response_obj, label)
|
||||
return our_wait_for
|
||||
raise Exception()
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue