allow ResponseOr to be shorthanded as a Set[IMatchResponse]

This commit is contained in:
jesopo 2020-04-27 01:28:46 +01:00
parent 769390baf7
commit 0921cb8086
5 changed files with 27 additions and 16 deletions

View file

@ -1,5 +1,5 @@
from asyncio import Future from asyncio import Future
from typing import Awaitable, Iterable, List, Optional, Set, Tuple from typing import Awaitable, Iterable, List, Optional, Set, Tuple, Union
from enum import IntEnum from enum import IntEnum
from ircstates import Server, Emit from ircstates import Server, Emit
@ -85,7 +85,9 @@ class IServer(Server):
) -> Awaitable[SentLine]: ) -> Awaitable[SentLine]:
pass pass
def wait_for(self, response: IMatchResponse) -> Awaitable[Line]: def wait_for(self,
response: Union[IMatchResponse, Set[IMatchResponse]]
) -> Awaitable[Line]:
pass pass
def set_throttle(self, rate: int, time: float): def set_throttle(self, rate: int, time: float):

View file

@ -5,7 +5,7 @@ from irctokens import build
from ircstates.server import ServerDisconnectedException from ircstates.server import ServerDisconnectedException
from .contexts import ServerContext from .contexts import ServerContext
from .matching import Response, ResponseOr, ANY from .matching import Response, ANY
from .interface import ICapability from .interface import ICapability
from .params import ConnectionParams, STSPolicy, ResumePolicy from .params import ConnectionParams, STSPolicy, ResumePolicy
@ -108,10 +108,10 @@ class CAPContext(ServerContext):
await self.server.send(build("CAP", ["REQ", " ".join(cap_names)])) await self.server.send(build("CAP", ["REQ", " ".join(cap_names)]))
while cap_names: while cap_names:
line = await self.server.wait_for(ResponseOr( line = await self.server.wait_for({
Response("CAP", [ANY, "ACK"]), Response("CAP", [ANY, "ACK"]),
Response("CAP", [ANY, "NAK"]) Response("CAP", [ANY, "NAK"])
)) })
current_caps = line.params[2].split(" ") current_caps = line.params[2].split(" ")
for cap in current_caps: for cap in current_caps:
@ -136,10 +136,10 @@ class CAPContext(ServerContext):
if previous_policy is not None and not self.server.registered: if previous_policy is not None and not self.server.registered:
await self.server.send(build("RESUME", [previous_policy.token])) await self.server.send(build("RESUME", [previous_policy.token]))
line = await self.server.wait_for(ResponseOr( line = await self.server.wait_for({
Response("RESUME", ["SUCCESS"]), Response("RESUME", ["SUCCESS"]),
Response("FAIL", ["RESUME"]) Response("FAIL", ["RESUME"])
)) })
if line.command == "RESUME": if line.command == "RESUME":
raise HandshakeCancel() raise HandshakeCancel()

View file

@ -3,7 +3,7 @@ from irctokens import build
from ircstates.numerics import * from ircstates.numerics import *
from .contexts import ServerContext from .contexts import ServerContext
from .matching import Response, ResponseOr, ANY, Folded from .matching import Response, ANY, Folded
class WHOContext(ServerContext): class WHOContext(ServerContext):
async def ensure(self, channel: str): async def ensure(self, channel: str):

View file

@ -4,7 +4,7 @@ from base64 import b64decode, b64encode
from irctokens import build from irctokens import build
from ircstates.numerics import * from ircstates.numerics import *
from .matching import ResponseOr, Responses, Response, ANY from .matching import Responses, Response, ANY
from .contexts import ServerContext from .contexts import ServerContext
from .params import SASLParams, SASLUserPass, SASLSCRAM, SASLExternal from .params import SASLParams, SASLUserPass, SASLSCRAM, SASLExternal
from .scram import SCRAMContext, SCRAMAlgorithm from .scram import SCRAMContext, SCRAMAlgorithm
@ -60,10 +60,10 @@ class SASLContext(ServerContext):
async def external(self) -> SASLResult: async def external(self) -> SASLResult:
await self.server.send(build("AUTHENTICATE", ["EXTERNAL"])) await self.server.send(build("AUTHENTICATE", ["EXTERNAL"]))
line = await self.server.wait_for(ResponseOr( line = await self.server.wait_for({
AUTHENTICATE_ANY, AUTHENTICATE_ANY,
NUMERICS_INITIAL NUMERICS_INITIAL
)) })
if line.command == "907": if line.command == "907":
# we've done SASL already. cleanly abort # we've done SASL already. cleanly abort
@ -117,10 +117,10 @@ class SASLContext(ServerContext):
while match: while match:
await self.server.send(build("AUTHENTICATE", [match[0]])) await self.server.send(build("AUTHENTICATE", [match[0]]))
line = await self.server.wait_for(ResponseOr( line = await self.server.wait_for({
AUTHENTICATE_ANY, AUTHENTICATE_ANY,
NUMERICS_INITIAL NUMERICS_INITIAL
)) })
if line.command == "907": if line.command == "907":
# we've done SASL already. cleanly abort # we've done SASL already. cleanly abort

View file

@ -1,6 +1,7 @@
import asyncio import asyncio
from asyncio import Future, PriorityQueue from asyncio import Future, PriorityQueue
from typing import Awaitable, Deque, Dict, List, Optional, Set, Tuple from typing import (Awaitable, Deque, Dict, List, Optional, Set, Tuple,
Union)
from collections import deque from collections import deque
from time import monotonic from time import monotonic
@ -222,12 +223,20 @@ class Server(IServer):
return both return both
async def wait_for(self, response: IMatchResponse) -> Line: async def wait_for(self,
response: Union[IMatchResponse, Set[IMatchResponse]]
) -> Line:
response_obj: IMatchResponse
if isinstance(response, set):
response_obj = ResponseOr(*response)
else:
response_obj = response
wait_for_fut = self._wait_for_fut wait_for_fut = self._wait_for_fut
if wait_for_fut is not None: if wait_for_fut is not None:
self._wait_for_fut = None self._wait_for_fut = None
our_wait_for = WaitFor(response) our_wait_for = WaitFor(response_obj)
wait_for_fut.set_result(our_wait_for) wait_for_fut.set_result(our_wait_for)
return await our_wait_for return await our_wait_for
raise Exception() raise Exception()