allow ResponseOr to be shorthanded as a Set[IMatchResponse]
This commit is contained in:
parent
769390baf7
commit
0921cb8086
5 changed files with 27 additions and 16 deletions
|
@ -1,5 +1,5 @@
|
|||
from asyncio import Future
|
||||
from typing import Awaitable, Iterable, List, Optional, Set, Tuple
|
||||
from typing import Awaitable, Iterable, List, Optional, Set, Tuple, Union
|
||||
from enum import IntEnum
|
||||
|
||||
from ircstates import Server, Emit
|
||||
|
@ -85,7 +85,9 @@ class IServer(Server):
|
|||
) -> Awaitable[SentLine]:
|
||||
pass
|
||||
|
||||
def wait_for(self, response: IMatchResponse) -> Awaitable[Line]:
|
||||
def wait_for(self,
|
||||
response: Union[IMatchResponse, Set[IMatchResponse]]
|
||||
) -> Awaitable[Line]:
|
||||
pass
|
||||
|
||||
def set_throttle(self, rate: int, time: float):
|
||||
|
|
|
@ -5,7 +5,7 @@ from irctokens import build
|
|||
from ircstates.server import ServerDisconnectedException
|
||||
|
||||
from .contexts import ServerContext
|
||||
from .matching import Response, ResponseOr, ANY
|
||||
from .matching import Response, ANY
|
||||
from .interface import ICapability
|
||||
from .params import ConnectionParams, STSPolicy, ResumePolicy
|
||||
|
||||
|
@ -108,10 +108,10 @@ class CAPContext(ServerContext):
|
|||
await self.server.send(build("CAP", ["REQ", " ".join(cap_names)]))
|
||||
|
||||
while cap_names:
|
||||
line = await self.server.wait_for(ResponseOr(
|
||||
line = await self.server.wait_for({
|
||||
Response("CAP", [ANY, "ACK"]),
|
||||
Response("CAP", [ANY, "NAK"])
|
||||
))
|
||||
})
|
||||
|
||||
current_caps = line.params[2].split(" ")
|
||||
for cap in current_caps:
|
||||
|
@ -136,10 +136,10 @@ class CAPContext(ServerContext):
|
|||
|
||||
if previous_policy is not None and not self.server.registered:
|
||||
await self.server.send(build("RESUME", [previous_policy.token]))
|
||||
line = await self.server.wait_for(ResponseOr(
|
||||
line = await self.server.wait_for({
|
||||
Response("RESUME", ["SUCCESS"]),
|
||||
Response("FAIL", ["RESUME"])
|
||||
))
|
||||
})
|
||||
if line.command == "RESUME":
|
||||
raise HandshakeCancel()
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ from irctokens import build
|
|||
from ircstates.numerics import *
|
||||
|
||||
from .contexts import ServerContext
|
||||
from .matching import Response, ResponseOr, ANY, Folded
|
||||
from .matching import Response, ANY, Folded
|
||||
|
||||
class WHOContext(ServerContext):
|
||||
async def ensure(self, channel: str):
|
||||
|
|
|
@ -4,7 +4,7 @@ from base64 import b64decode, b64encode
|
|||
from irctokens import build
|
||||
from ircstates.numerics import *
|
||||
|
||||
from .matching import ResponseOr, Responses, Response, ANY
|
||||
from .matching import Responses, Response, ANY
|
||||
from .contexts import ServerContext
|
||||
from .params import SASLParams, SASLUserPass, SASLSCRAM, SASLExternal
|
||||
from .scram import SCRAMContext, SCRAMAlgorithm
|
||||
|
@ -60,10 +60,10 @@ class SASLContext(ServerContext):
|
|||
|
||||
async def external(self) -> SASLResult:
|
||||
await self.server.send(build("AUTHENTICATE", ["EXTERNAL"]))
|
||||
line = await self.server.wait_for(ResponseOr(
|
||||
line = await self.server.wait_for({
|
||||
AUTHENTICATE_ANY,
|
||||
NUMERICS_INITIAL
|
||||
))
|
||||
})
|
||||
|
||||
if line.command == "907":
|
||||
# we've done SASL already. cleanly abort
|
||||
|
@ -117,10 +117,10 @@ class SASLContext(ServerContext):
|
|||
|
||||
while match:
|
||||
await self.server.send(build("AUTHENTICATE", [match[0]]))
|
||||
line = await self.server.wait_for(ResponseOr(
|
||||
line = await self.server.wait_for({
|
||||
AUTHENTICATE_ANY,
|
||||
NUMERICS_INITIAL
|
||||
))
|
||||
})
|
||||
|
||||
if line.command == "907":
|
||||
# we've done SASL already. cleanly abort
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
from asyncio import Future, PriorityQueue
|
||||
from typing import Awaitable, Deque, Dict, List, Optional, Set, Tuple
|
||||
from typing import (Awaitable, Deque, Dict, List, Optional, Set, Tuple,
|
||||
Union)
|
||||
from collections import deque
|
||||
from time import monotonic
|
||||
|
||||
|
@ -222,12 +223,20 @@ class Server(IServer):
|
|||
|
||||
return both
|
||||
|
||||
async def wait_for(self, response: IMatchResponse) -> Line:
|
||||
async def wait_for(self,
|
||||
response: Union[IMatchResponse, Set[IMatchResponse]]
|
||||
) -> Line:
|
||||
response_obj: IMatchResponse
|
||||
if isinstance(response, set):
|
||||
response_obj = ResponseOr(*response)
|
||||
else:
|
||||
response_obj = response
|
||||
|
||||
wait_for_fut = self._wait_for_fut
|
||||
if wait_for_fut is not None:
|
||||
self._wait_for_fut = None
|
||||
|
||||
our_wait_for = WaitFor(response)
|
||||
our_wait_for = WaitFor(response_obj)
|
||||
wait_for_fut.set_result(our_wait_for)
|
||||
return await our_wait_for
|
||||
raise Exception()
|
||||
|
|
Loading…
Reference in a new issue