allow ResponseOr to be shorthanded as a Set[IMatchResponse]
This commit is contained in:
parent
769390baf7
commit
0921cb8086
5 changed files with 27 additions and 16 deletions
|
@ -1,5 +1,5 @@
|
||||||
from asyncio import Future
|
from asyncio import Future
|
||||||
from typing import Awaitable, Iterable, List, Optional, Set, Tuple
|
from typing import Awaitable, Iterable, List, Optional, Set, Tuple, Union
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
|
|
||||||
from ircstates import Server, Emit
|
from ircstates import Server, Emit
|
||||||
|
@ -85,7 +85,9 @@ class IServer(Server):
|
||||||
) -> Awaitable[SentLine]:
|
) -> Awaitable[SentLine]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def wait_for(self, response: IMatchResponse) -> Awaitable[Line]:
|
def wait_for(self,
|
||||||
|
response: Union[IMatchResponse, Set[IMatchResponse]]
|
||||||
|
) -> Awaitable[Line]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def set_throttle(self, rate: int, time: float):
|
def set_throttle(self, rate: int, time: float):
|
||||||
|
|
|
@ -5,7 +5,7 @@ from irctokens import build
|
||||||
from ircstates.server import ServerDisconnectedException
|
from ircstates.server import ServerDisconnectedException
|
||||||
|
|
||||||
from .contexts import ServerContext
|
from .contexts import ServerContext
|
||||||
from .matching import Response, ResponseOr, ANY
|
from .matching import Response, ANY
|
||||||
from .interface import ICapability
|
from .interface import ICapability
|
||||||
from .params import ConnectionParams, STSPolicy, ResumePolicy
|
from .params import ConnectionParams, STSPolicy, ResumePolicy
|
||||||
|
|
||||||
|
@ -108,10 +108,10 @@ class CAPContext(ServerContext):
|
||||||
await self.server.send(build("CAP", ["REQ", " ".join(cap_names)]))
|
await self.server.send(build("CAP", ["REQ", " ".join(cap_names)]))
|
||||||
|
|
||||||
while cap_names:
|
while cap_names:
|
||||||
line = await self.server.wait_for(ResponseOr(
|
line = await self.server.wait_for({
|
||||||
Response("CAP", [ANY, "ACK"]),
|
Response("CAP", [ANY, "ACK"]),
|
||||||
Response("CAP", [ANY, "NAK"])
|
Response("CAP", [ANY, "NAK"])
|
||||||
))
|
})
|
||||||
|
|
||||||
current_caps = line.params[2].split(" ")
|
current_caps = line.params[2].split(" ")
|
||||||
for cap in current_caps:
|
for cap in current_caps:
|
||||||
|
@ -136,10 +136,10 @@ class CAPContext(ServerContext):
|
||||||
|
|
||||||
if previous_policy is not None and not self.server.registered:
|
if previous_policy is not None and not self.server.registered:
|
||||||
await self.server.send(build("RESUME", [previous_policy.token]))
|
await self.server.send(build("RESUME", [previous_policy.token]))
|
||||||
line = await self.server.wait_for(ResponseOr(
|
line = await self.server.wait_for({
|
||||||
Response("RESUME", ["SUCCESS"]),
|
Response("RESUME", ["SUCCESS"]),
|
||||||
Response("FAIL", ["RESUME"])
|
Response("FAIL", ["RESUME"])
|
||||||
))
|
})
|
||||||
if line.command == "RESUME":
|
if line.command == "RESUME":
|
||||||
raise HandshakeCancel()
|
raise HandshakeCancel()
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ from irctokens import build
|
||||||
from ircstates.numerics import *
|
from ircstates.numerics import *
|
||||||
|
|
||||||
from .contexts import ServerContext
|
from .contexts import ServerContext
|
||||||
from .matching import Response, ResponseOr, ANY, Folded
|
from .matching import Response, ANY, Folded
|
||||||
|
|
||||||
class WHOContext(ServerContext):
|
class WHOContext(ServerContext):
|
||||||
async def ensure(self, channel: str):
|
async def ensure(self, channel: str):
|
||||||
|
|
|
@ -4,7 +4,7 @@ from base64 import b64decode, b64encode
|
||||||
from irctokens import build
|
from irctokens import build
|
||||||
from ircstates.numerics import *
|
from ircstates.numerics import *
|
||||||
|
|
||||||
from .matching import ResponseOr, Responses, Response, ANY
|
from .matching import Responses, Response, ANY
|
||||||
from .contexts import ServerContext
|
from .contexts import ServerContext
|
||||||
from .params import SASLParams, SASLUserPass, SASLSCRAM, SASLExternal
|
from .params import SASLParams, SASLUserPass, SASLSCRAM, SASLExternal
|
||||||
from .scram import SCRAMContext, SCRAMAlgorithm
|
from .scram import SCRAMContext, SCRAMAlgorithm
|
||||||
|
@ -60,10 +60,10 @@ class SASLContext(ServerContext):
|
||||||
|
|
||||||
async def external(self) -> SASLResult:
|
async def external(self) -> SASLResult:
|
||||||
await self.server.send(build("AUTHENTICATE", ["EXTERNAL"]))
|
await self.server.send(build("AUTHENTICATE", ["EXTERNAL"]))
|
||||||
line = await self.server.wait_for(ResponseOr(
|
line = await self.server.wait_for({
|
||||||
AUTHENTICATE_ANY,
|
AUTHENTICATE_ANY,
|
||||||
NUMERICS_INITIAL
|
NUMERICS_INITIAL
|
||||||
))
|
})
|
||||||
|
|
||||||
if line.command == "907":
|
if line.command == "907":
|
||||||
# we've done SASL already. cleanly abort
|
# we've done SASL already. cleanly abort
|
||||||
|
@ -117,10 +117,10 @@ class SASLContext(ServerContext):
|
||||||
|
|
||||||
while match:
|
while match:
|
||||||
await self.server.send(build("AUTHENTICATE", [match[0]]))
|
await self.server.send(build("AUTHENTICATE", [match[0]]))
|
||||||
line = await self.server.wait_for(ResponseOr(
|
line = await self.server.wait_for({
|
||||||
AUTHENTICATE_ANY,
|
AUTHENTICATE_ANY,
|
||||||
NUMERICS_INITIAL
|
NUMERICS_INITIAL
|
||||||
))
|
})
|
||||||
|
|
||||||
if line.command == "907":
|
if line.command == "907":
|
||||||
# we've done SASL already. cleanly abort
|
# we've done SASL already. cleanly abort
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from asyncio import Future, PriorityQueue
|
from asyncio import Future, PriorityQueue
|
||||||
from typing import Awaitable, Deque, Dict, List, Optional, Set, Tuple
|
from typing import (Awaitable, Deque, Dict, List, Optional, Set, Tuple,
|
||||||
|
Union)
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from time import monotonic
|
from time import monotonic
|
||||||
|
|
||||||
|
@ -222,12 +223,20 @@ class Server(IServer):
|
||||||
|
|
||||||
return both
|
return both
|
||||||
|
|
||||||
async def wait_for(self, response: IMatchResponse) -> Line:
|
async def wait_for(self,
|
||||||
|
response: Union[IMatchResponse, Set[IMatchResponse]]
|
||||||
|
) -> Line:
|
||||||
|
response_obj: IMatchResponse
|
||||||
|
if isinstance(response, set):
|
||||||
|
response_obj = ResponseOr(*response)
|
||||||
|
else:
|
||||||
|
response_obj = response
|
||||||
|
|
||||||
wait_for_fut = self._wait_for_fut
|
wait_for_fut = self._wait_for_fut
|
||||||
if wait_for_fut is not None:
|
if wait_for_fut is not None:
|
||||||
self._wait_for_fut = None
|
self._wait_for_fut = None
|
||||||
|
|
||||||
our_wait_for = WaitFor(response)
|
our_wait_for = WaitFor(response_obj)
|
||||||
wait_for_fut.set_result(our_wait_for)
|
wait_for_fut.set_result(our_wait_for)
|
||||||
return await our_wait_for
|
return await our_wait_for
|
||||||
raise Exception()
|
raise Exception()
|
||||||
|
|
Loading…
Reference in a new issue