diff --git a/ircrobots/interface.py b/ircrobots/interface.py index 64cc005..0120642 100644 --- a/ircrobots/interface.py +++ b/ircrobots/interface.py @@ -1,5 +1,5 @@ from asyncio import Future -from typing import Awaitable, Iterable, List, Optional, Set, Tuple +from typing import Awaitable, Iterable, List, Optional, Set, Tuple, Union from enum import IntEnum from ircstates import Server, Emit @@ -85,7 +85,9 @@ class IServer(Server): ) -> Awaitable[SentLine]: pass - def wait_for(self, response: IMatchResponse) -> Awaitable[Line]: + def wait_for(self, + response: Union[IMatchResponse, Set[IMatchResponse]] + ) -> Awaitable[Line]: pass def set_throttle(self, rate: int, time: float): diff --git a/ircrobots/ircv3.py b/ircrobots/ircv3.py index 279cf79..f7ba42d 100644 --- a/ircrobots/ircv3.py +++ b/ircrobots/ircv3.py @@ -5,7 +5,7 @@ from irctokens import build from ircstates.server import ServerDisconnectedException from .contexts import ServerContext -from .matching import Response, ResponseOr, ANY +from .matching import Response, ANY from .interface import ICapability from .params import ConnectionParams, STSPolicy, ResumePolicy @@ -108,10 +108,10 @@ class CAPContext(ServerContext): await self.server.send(build("CAP", ["REQ", " ".join(cap_names)])) while cap_names: - line = await self.server.wait_for(ResponseOr( + line = await self.server.wait_for({ Response("CAP", [ANY, "ACK"]), Response("CAP", [ANY, "NAK"]) - )) + }) current_caps = line.params[2].split(" ") for cap in current_caps: @@ -136,10 +136,10 @@ class CAPContext(ServerContext): if previous_policy is not None and not self.server.registered: await self.server.send(build("RESUME", [previous_policy.token])) - line = await self.server.wait_for(ResponseOr( + line = await self.server.wait_for({ Response("RESUME", ["SUCCESS"]), Response("FAIL", ["RESUME"]) - )) + }) if line.command == "RESUME": raise HandshakeCancel() diff --git a/ircrobots/join_info.py b/ircrobots/join_info.py index 2d95abe..2e68c6b 100644 --- a/ircrobots/join_info.py +++ b/ircrobots/join_info.py @@ -3,7 +3,7 @@ from irctokens import build from ircstates.numerics import * from .contexts import ServerContext -from .matching import Response, ResponseOr, ANY, Folded +from .matching import Response, ANY, Folded class WHOContext(ServerContext): async def ensure(self, channel: str): diff --git a/ircrobots/sasl.py b/ircrobots/sasl.py index 301b42c..026ac5c 100644 --- a/ircrobots/sasl.py +++ b/ircrobots/sasl.py @@ -4,7 +4,7 @@ from base64 import b64decode, b64encode from irctokens import build from ircstates.numerics import * -from .matching import ResponseOr, Responses, Response, ANY +from .matching import Responses, Response, ANY from .contexts import ServerContext from .params import SASLParams, SASLUserPass, SASLSCRAM, SASLExternal from .scram import SCRAMContext, SCRAMAlgorithm @@ -60,10 +60,10 @@ class SASLContext(ServerContext): async def external(self) -> SASLResult: await self.server.send(build("AUTHENTICATE", ["EXTERNAL"])) - line = await self.server.wait_for(ResponseOr( + line = await self.server.wait_for({ AUTHENTICATE_ANY, NUMERICS_INITIAL - )) + }) if line.command == "907": # we've done SASL already. cleanly abort @@ -117,10 +117,10 @@ class SASLContext(ServerContext): while match: await self.server.send(build("AUTHENTICATE", [match[0]])) - line = await self.server.wait_for(ResponseOr( + line = await self.server.wait_for({ AUTHENTICATE_ANY, NUMERICS_INITIAL - )) + }) if line.command == "907": # we've done SASL already. cleanly abort diff --git a/ircrobots/server.py b/ircrobots/server.py index 646c6f5..7e08f9d 100644 --- a/ircrobots/server.py +++ b/ircrobots/server.py @@ -1,6 +1,7 @@ import asyncio from asyncio import Future, PriorityQueue -from typing import Awaitable, Deque, Dict, List, Optional, Set, Tuple +from typing import (Awaitable, Deque, Dict, List, Optional, Set, Tuple, + Union) from collections import deque from time import monotonic @@ -222,12 +223,20 @@ class Server(IServer): return both - async def wait_for(self, response: IMatchResponse) -> Line: + async def wait_for(self, + response: Union[IMatchResponse, Set[IMatchResponse]] + ) -> Line: + response_obj: IMatchResponse + if isinstance(response, set): + response_obj = ResponseOr(*response) + else: + response_obj = response + wait_for_fut = self._wait_for_fut if wait_for_fut is not None: self._wait_for_fut = None - our_wait_for = WaitFor(response) + our_wait_for = WaitFor(response_obj) wait_for_fut.set_result(our_wait_for) return await our_wait_for raise Exception()