simplify reading lines, simplify wait_for, wait_for from outside reads
This commit is contained in:
parent
db851e0ba2
commit
b9a543031a
4 changed files with 78 additions and 93 deletions
|
@ -17,21 +17,16 @@ class MaybeAwait(Generic[TEvent]):
|
||||||
|
|
||||||
class WaitFor(object):
|
class WaitFor(object):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
wait_fut: "Future[WaitFor]",
|
response: IMatchResponse):
|
||||||
response: IMatchResponse,
|
|
||||||
label: Optional[str]):
|
|
||||||
self._wait_fut = wait_fut
|
|
||||||
self.response = response
|
self.response = response
|
||||||
self._label = label
|
self._label: Optional[str] = None
|
||||||
self.deferred = False
|
|
||||||
self._our_fut: "Future[Line]" = Future()
|
self._our_fut: "Future[Line]" = Future()
|
||||||
|
|
||||||
def __await__(self) -> Generator[Any, None, Line]:
|
def __await__(self) -> Generator[Any, None, Line]:
|
||||||
self._wait_fut.set_result(self)
|
|
||||||
return self._our_fut.__await__()
|
return self._our_fut.__await__()
|
||||||
async def defer(self):
|
|
||||||
self.deferred = True
|
def with_label(self, label: str):
|
||||||
return await self
|
self._label = label
|
||||||
|
|
||||||
def match(self, server: IServer, line: Line):
|
def match(self, server: IServer, line: Line):
|
||||||
if (self._label is not None and
|
if (self._label is not None and
|
||||||
|
|
|
@ -37,7 +37,10 @@ class Bot(IBot):
|
||||||
|
|
||||||
async def _run_server(self, server: Server):
|
async def _run_server(self, server: Server):
|
||||||
async with anyio.create_task_group() as tg:
|
async with anyio.create_task_group() as tg:
|
||||||
await tg.spawn(server._read_lines)
|
async def _read():
|
||||||
|
while True:
|
||||||
|
await server._read_lines()
|
||||||
|
await tg.spawn(_read)
|
||||||
await tg.spawn(server._send_lines)
|
await tg.spawn(server._send_lines)
|
||||||
|
|
||||||
await self.disconnected(server)
|
await self.disconnected(server)
|
||||||
|
|
|
@ -46,7 +46,7 @@ class SentLine(object):
|
||||||
self.id = id
|
self.id = id
|
||||||
self.priority = priority
|
self.priority = priority
|
||||||
self.line = line
|
self.line = line
|
||||||
self.future: Future = Future()
|
self.future: "Future[SentLine]" = Future()
|
||||||
|
|
||||||
def __lt__(self, other: "SentLine") -> bool:
|
def __lt__(self, other: "SentLine") -> bool:
|
||||||
return self.priority < other.priority
|
return self.priority < other.priority
|
||||||
|
|
|
@ -64,6 +64,7 @@ class Server(IServer):
|
||||||
|
|
||||||
self._read_queue: Deque[Tuple[Line, Optional[Emit]]] = deque()
|
self._read_queue: Deque[Tuple[Line, Optional[Emit]]] = deque()
|
||||||
|
|
||||||
|
self._wait_fors: List[Tuple[WaitFor, Optional[Awaitable]]] = []
|
||||||
self._wait_for_fut: Optional[Future[WaitFor]] = None
|
self._wait_for_fut: Optional[Future[WaitFor]] = None
|
||||||
|
|
||||||
def hostmask(self) -> str:
|
def hostmask(self) -> str:
|
||||||
|
@ -77,8 +78,11 @@ class Server(IServer):
|
||||||
def send_raw(self, line: str, priority=SendPriority.DEFAULT
|
def send_raw(self, line: str, priority=SendPriority.DEFAULT
|
||||||
) -> Awaitable[SentLine]:
|
) -> Awaitable[SentLine]:
|
||||||
return self.send(tokenise(line), priority)
|
return self.send(tokenise(line), priority)
|
||||||
def send(self, line: Line, priority=SendPriority.DEFAULT
|
def send(self,
|
||||||
|
line: Line,
|
||||||
|
priority=SendPriority.DEFAULT
|
||||||
) -> Awaitable[SentLine]:
|
) -> Awaitable[SentLine]:
|
||||||
|
|
||||||
self.line_presend(line)
|
self.line_presend(line)
|
||||||
sent_line = SentLine(self._sent_count, priority, line)
|
sent_line = SentLine(self._sent_count, priority, line)
|
||||||
self._sent_count += 1
|
self._sent_count += 1
|
||||||
|
@ -93,10 +97,7 @@ class Server(IServer):
|
||||||
|
|
||||||
self._send_queue.put_nowait(sent_line)
|
self._send_queue.put_nowait(sent_line)
|
||||||
|
|
||||||
async def _assure() -> SentLine:
|
return sent_line.future
|
||||||
await sent_line.future
|
|
||||||
return sent_line
|
|
||||||
return MaybeAwait(_assure)
|
|
||||||
|
|
||||||
def set_throttle(self, rate: int, time: float):
|
def set_throttle(self, rate: int, time: float):
|
||||||
self.throttle.rate_limit = rate
|
self.throttle.rate_limit = rate
|
||||||
|
@ -182,39 +183,34 @@ class Server(IServer):
|
||||||
|
|
||||||
await self.line_read(line)
|
await self.line_read(line)
|
||||||
|
|
||||||
async def _next_line(self) -> Tuple[Line, Optional[Emit]]:
|
async def _next_lines(self) -> List[Tuple[Line, Optional[Emit]]]:
|
||||||
if self._read_queue:
|
ping_sent = False
|
||||||
both = self._read_queue.popleft()
|
while True:
|
||||||
else:
|
try:
|
||||||
ping_sent = False
|
async with timeout(PING_TIMEOUT):
|
||||||
while True:
|
data = await self._reader.read(1024)
|
||||||
try:
|
except asyncio.TimeoutError:
|
||||||
async with timeout(PING_TIMEOUT):
|
if ping_sent:
|
||||||
data = await self._reader.read(1024)
|
data = b"" # empty data means the socket disconnected
|
||||||
except asyncio.TimeoutError:
|
else:
|
||||||
if ping_sent:
|
ping_sent = True
|
||||||
data = b"" # empty data means the socket disconnected
|
await self.send(build("PING", ["hello"]))
|
||||||
else:
|
continue
|
||||||
ping_sent = True
|
|
||||||
await self.send(build("PING", ["hello"]))
|
|
||||||
continue
|
|
||||||
|
|
||||||
self.last_read = monotonic()
|
self.last_read = monotonic()
|
||||||
ping_sent = False
|
ping_sent = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
lines = self.recv(data)
|
lines = self.recv(data)
|
||||||
except ServerDisconnectedException:
|
except ServerDisconnectedException:
|
||||||
self.disconnected = True
|
self.disconnected = True
|
||||||
raise
|
raise
|
||||||
|
|
||||||
if lines:
|
return lines
|
||||||
self._read_queue.extend(lines[1:])
|
|
||||||
both = lines[0]
|
|
||||||
break
|
|
||||||
return both
|
|
||||||
|
|
||||||
async def _line_or_wait(self, line_aw: Awaitable):
|
async def _line_or_wait(self,
|
||||||
|
line_aw: Awaitable
|
||||||
|
) -> Optional[Tuple[WaitFor, Awaitable]]:
|
||||||
wait_for_fut: Future[WaitFor] = Future()
|
wait_for_fut: Future[WaitFor] = Future()
|
||||||
self._wait_for_fut = wait_for_fut
|
self._wait_for_fut = wait_for_fut
|
||||||
|
|
||||||
|
@ -225,43 +221,34 @@ class Server(IServer):
|
||||||
new_line_aw = list(pend)[0]
|
new_line_aw = list(pend)[0]
|
||||||
return (await wait_for_fut), new_line_aw
|
return (await wait_for_fut), new_line_aw
|
||||||
else:
|
else:
|
||||||
return None, None
|
return None
|
||||||
|
|
||||||
async def _read_lines(self):
|
async def _read_lines(self) -> List[Tuple[Line, Optional[Emit]]]:
|
||||||
waited_reads: Deque[Tuple[Line, Optional[Emit]]] = deque()
|
lines = await self._next_lines()
|
||||||
wait_for: Optional[WaitFor] = None
|
for line, emit in lines:
|
||||||
wait_for_aw: Optional[Awaitable] = None
|
|
||||||
|
|
||||||
async def _line() -> Tuple[Line, Optional[Emit]]:
|
|
||||||
both = await self._next_line()
|
|
||||||
waited_reads.append(both)
|
|
||||||
line, emit = both
|
|
||||||
self.line_preread(line)
|
self.line_preread(line)
|
||||||
return both
|
|
||||||
|
|
||||||
while True:
|
for i, (wait_for, aw) in enumerate(self._wait_fors):
|
||||||
if wait_for is not None:
|
if wait_for.match(self, line):
|
||||||
line, emit = await _line()
|
|
||||||
if wait_for.response.match(self, line):
|
|
||||||
wait_for.resolve(line)
|
wait_for.resolve(line)
|
||||||
|
|
||||||
wait_for, wait_for_aw = await self._line_or_wait(
|
if aw is not None:
|
||||||
wait_for_aw)
|
new_wait_for = await self._line_or_wait(aw)
|
||||||
else:
|
if new_wait_for is not None:
|
||||||
if not waited_reads:
|
self._wait_fors.append(new_wait_for)
|
||||||
await _line()
|
self._wait_fors.pop(i)
|
||||||
|
break
|
||||||
|
|
||||||
while waited_reads:
|
line_aw = self._on_read(line, emit)
|
||||||
new_line, new_emit = waited_reads.popleft()
|
new_wait_for = await self._line_or_wait(line_aw)
|
||||||
line_aw = self._on_read(new_line, new_emit)
|
if new_wait_for is not None:
|
||||||
wait_for, wait_for_aw = await self._line_or_wait(line_aw)
|
self._wait_fors.append(new_wait_for)
|
||||||
if wait_for is not None:
|
return lines
|
||||||
break
|
|
||||||
|
|
||||||
def wait_for(self,
|
async def wait_for(self,
|
||||||
response: Union[IMatchResponse, Set[IMatchResponse]],
|
response: Union[IMatchResponse, Set[IMatchResponse]],
|
||||||
sent_line: Optional[SentLine]=None
|
sent_aw: Optional[Awaitable[SentLine]]=None
|
||||||
) -> Awaitable[Line]:
|
) -> Line:
|
||||||
|
|
||||||
response_obj: IMatchResponse
|
response_obj: IMatchResponse
|
||||||
if isinstance(response, set):
|
if isinstance(response, set):
|
||||||
|
@ -269,17 +256,21 @@ class Server(IServer):
|
||||||
else:
|
else:
|
||||||
response_obj = response
|
response_obj = response
|
||||||
|
|
||||||
|
our_wait_for = WaitFor(response_obj)
|
||||||
|
|
||||||
wait_for_fut = self._wait_for_fut
|
wait_for_fut = self._wait_for_fut
|
||||||
if wait_for_fut is not None:
|
if wait_for_fut is not None:
|
||||||
self._wait_for_fut = None
|
self._wait_for_fut = None
|
||||||
|
wait_for_fut.set_result(our_wait_for)
|
||||||
|
else:
|
||||||
|
self._wait_fors.append((our_wait_for, None))
|
||||||
|
|
||||||
label: Optional[str] = None
|
if sent_aw is not None:
|
||||||
if sent_line is not None:
|
sent_line = await sent_aw
|
||||||
label = str(sent_line.id)
|
label = str(sent_line.id)
|
||||||
|
our_wait_for.with_label(label)
|
||||||
|
|
||||||
our_wait_for = WaitFor(wait_for_fut, response_obj, label)
|
return (await our_wait_for)
|
||||||
return our_wait_for
|
|
||||||
raise Exception()
|
|
||||||
|
|
||||||
async def _on_send_line(self, line: Line):
|
async def _on_send_line(self, line: Line):
|
||||||
if (line.command == "PRIVMSG" and
|
if (line.command == "PRIVMSG" and
|
||||||
|
@ -337,7 +328,6 @@ class Server(IServer):
|
||||||
def send_nick(self, new_nick: str) -> Awaitable[bool]:
|
def send_nick(self, new_nick: str) -> Awaitable[bool]:
|
||||||
fut = self.send(build("NICK", [new_nick]))
|
fut = self.send(build("NICK", [new_nick]))
|
||||||
async def _assure() -> bool:
|
async def _assure() -> bool:
|
||||||
await fut
|
|
||||||
line = await self.wait_for({
|
line = await self.wait_for({
|
||||||
Response("NICK", [Folded(new_nick)], source=MASK_SELF),
|
Response("NICK", [Folded(new_nick)], source=MASK_SELF),
|
||||||
Responses([
|
Responses([
|
||||||
|
@ -350,7 +340,7 @@ class Server(IServer):
|
||||||
ERR_ERRONEUSNICKNAME,
|
ERR_ERRONEUSNICKNAME,
|
||||||
ERR_UNAVAILRESOURCE
|
ERR_UNAVAILRESOURCE
|
||||||
], [ANY, Folded(new_nick)])
|
], [ANY, Folded(new_nick)])
|
||||||
})
|
}, fut)
|
||||||
return line.command == "NICK"
|
return line.command == "NICK"
|
||||||
return MaybeAwait(_assure)
|
return MaybeAwait(_assure)
|
||||||
|
|
||||||
|
@ -368,9 +358,9 @@ class Server(IServer):
|
||||||
fut = self.send(build("PART", [name]))
|
fut = self.send(build("PART", [name]))
|
||||||
|
|
||||||
async def _assure():
|
async def _assure():
|
||||||
await fut
|
|
||||||
line = await self.wait_for(
|
line = await self.wait_for(
|
||||||
Response("PART", [Folded(name)], source=MASK_SELF)
|
Response("PART", [Folded(name)], source=MASK_SELF),
|
||||||
|
fut
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
return MaybeAwait(_assure)
|
return MaybeAwait(_assure)
|
||||||
|
@ -388,8 +378,6 @@ class Server(IServer):
|
||||||
fut = self.send(build("JOIN", [",".join(names)]+keys))
|
fut = self.send(build("JOIN", [",".join(names)]+keys))
|
||||||
|
|
||||||
async def _assure():
|
async def _assure():
|
||||||
await fut
|
|
||||||
|
|
||||||
channels: List[Channel] = []
|
channels: List[Channel] = []
|
||||||
|
|
||||||
while folded_names:
|
while folded_names:
|
||||||
|
@ -398,7 +386,7 @@ class Server(IServer):
|
||||||
Responses(JOIN_ERR_FIRST, [ANY, ANY]),
|
Responses(JOIN_ERR_FIRST, [ANY, ANY]),
|
||||||
Response(ERR_USERONCHANNEL, [ANY, SELF, ANY]),
|
Response(ERR_USERONCHANNEL, [ANY, SELF, ANY]),
|
||||||
Response(ERR_LINKCHANNEL, [ANY, ANY, ANY])
|
Response(ERR_LINKCHANNEL, [ANY, ANY, ANY])
|
||||||
})
|
}, fut)
|
||||||
|
|
||||||
chan: Optional[str] = None
|
chan: Optional[str] = None
|
||||||
if line.command == RPL_CHANNELMODEIS:
|
if line.command == RPL_CHANNELMODEIS:
|
||||||
|
@ -429,9 +417,9 @@ class Server(IServer):
|
||||||
) -> Awaitable[Optional[str]]:
|
) -> Awaitable[Optional[str]]:
|
||||||
fut = self.send(build("PRIVMSG", [target, message]))
|
fut = self.send(build("PRIVMSG", [target, message]))
|
||||||
async def _assure():
|
async def _assure():
|
||||||
await fut
|
|
||||||
line = await self.wait_for(
|
line = await self.wait_for(
|
||||||
Response("PRIVMSG", [Folded(target), ANY], source=MASK_SELF)
|
Response("PRIVMSG", [Folded(target), ANY], source=MASK_SELF),
|
||||||
|
fut
|
||||||
)
|
)
|
||||||
if line.command == "PRIVMSG":
|
if line.command == "PRIVMSG":
|
||||||
return line.params[1]
|
return line.params[1]
|
||||||
|
@ -449,7 +437,6 @@ class Server(IServer):
|
||||||
|
|
||||||
fut = self.send(build("WHOIS", args))
|
fut = self.send(build("WHOIS", args))
|
||||||
async def _assure() -> Optional[Whois]:
|
async def _assure() -> Optional[Whois]:
|
||||||
await fut
|
|
||||||
params = [ANY, Folded(self.casefold(target))]
|
params = [ANY, Folded(self.casefold(target))]
|
||||||
obj = Whois()
|
obj = Whois()
|
||||||
while True:
|
while True:
|
||||||
|
@ -465,7 +452,7 @@ class Server(IServer):
|
||||||
RPL_WHOISACCOUNT,
|
RPL_WHOISACCOUNT,
|
||||||
RPL_WHOISSECURE,
|
RPL_WHOISSECURE,
|
||||||
RPL_ENDOFWHOIS
|
RPL_ENDOFWHOIS
|
||||||
], params))
|
], params), fut)
|
||||||
if line.command in [ERR_NOSUCHNICK, ERR_NOSUCHSERVER]:
|
if line.command in [ERR_NOSUCHNICK, ERR_NOSUCHSERVER]:
|
||||||
return None
|
return None
|
||||||
elif line.command == RPL_WHOISUSER:
|
elif line.command == RPL_WHOISUSER:
|
||||||
|
|
Loading…
Reference in a new issue