simplify reading lines, simplify wait_for, wait_for from outside reads

This commit is contained in:
jesopo 2020-06-13 00:21:39 +01:00
parent db851e0ba2
commit b9a543031a
4 changed files with 78 additions and 93 deletions

View file

@ -17,21 +17,16 @@ class MaybeAwait(Generic[TEvent]):
class WaitFor(object): class WaitFor(object):
def __init__(self, def __init__(self,
wait_fut: "Future[WaitFor]", response: IMatchResponse):
response: IMatchResponse,
label: Optional[str]):
self._wait_fut = wait_fut
self.response = response self.response = response
self._label = label self._label: Optional[str] = None
self.deferred = False
self._our_fut: "Future[Line]" = Future() self._our_fut: "Future[Line]" = Future()
def __await__(self) -> Generator[Any, None, Line]: def __await__(self) -> Generator[Any, None, Line]:
self._wait_fut.set_result(self)
return self._our_fut.__await__() return self._our_fut.__await__()
async def defer(self):
self.deferred = True def with_label(self, label: str):
return await self self._label = label
def match(self, server: IServer, line: Line): def match(self, server: IServer, line: Line):
if (self._label is not None and if (self._label is not None and

View file

@ -37,7 +37,10 @@ class Bot(IBot):
async def _run_server(self, server: Server): async def _run_server(self, server: Server):
async with anyio.create_task_group() as tg: async with anyio.create_task_group() as tg:
await tg.spawn(server._read_lines) async def _read():
while True:
await server._read_lines()
await tg.spawn(_read)
await tg.spawn(server._send_lines) await tg.spawn(server._send_lines)
await self.disconnected(server) await self.disconnected(server)

View file

@ -46,7 +46,7 @@ class SentLine(object):
self.id = id self.id = id
self.priority = priority self.priority = priority
self.line = line self.line = line
self.future: Future = Future() self.future: "Future[SentLine]" = Future()
def __lt__(self, other: "SentLine") -> bool: def __lt__(self, other: "SentLine") -> bool:
return self.priority < other.priority return self.priority < other.priority

View file

@ -64,6 +64,7 @@ class Server(IServer):
self._read_queue: Deque[Tuple[Line, Optional[Emit]]] = deque() self._read_queue: Deque[Tuple[Line, Optional[Emit]]] = deque()
self._wait_fors: List[Tuple[WaitFor, Optional[Awaitable]]] = []
self._wait_for_fut: Optional[Future[WaitFor]] = None self._wait_for_fut: Optional[Future[WaitFor]] = None
def hostmask(self) -> str: def hostmask(self) -> str:
@ -77,8 +78,11 @@ class Server(IServer):
def send_raw(self, line: str, priority=SendPriority.DEFAULT def send_raw(self, line: str, priority=SendPriority.DEFAULT
) -> Awaitable[SentLine]: ) -> Awaitable[SentLine]:
return self.send(tokenise(line), priority) return self.send(tokenise(line), priority)
def send(self, line: Line, priority=SendPriority.DEFAULT def send(self,
line: Line,
priority=SendPriority.DEFAULT
) -> Awaitable[SentLine]: ) -> Awaitable[SentLine]:
self.line_presend(line) self.line_presend(line)
sent_line = SentLine(self._sent_count, priority, line) sent_line = SentLine(self._sent_count, priority, line)
self._sent_count += 1 self._sent_count += 1
@ -93,10 +97,7 @@ class Server(IServer):
self._send_queue.put_nowait(sent_line) self._send_queue.put_nowait(sent_line)
async def _assure() -> SentLine: return sent_line.future
await sent_line.future
return sent_line
return MaybeAwait(_assure)
def set_throttle(self, rate: int, time: float): def set_throttle(self, rate: int, time: float):
self.throttle.rate_limit = rate self.throttle.rate_limit = rate
@ -182,39 +183,34 @@ class Server(IServer):
await self.line_read(line) await self.line_read(line)
async def _next_line(self) -> Tuple[Line, Optional[Emit]]: async def _next_lines(self) -> List[Tuple[Line, Optional[Emit]]]:
if self._read_queue: ping_sent = False
both = self._read_queue.popleft() while True:
else: try:
ping_sent = False async with timeout(PING_TIMEOUT):
while True: data = await self._reader.read(1024)
try: except asyncio.TimeoutError:
async with timeout(PING_TIMEOUT): if ping_sent:
data = await self._reader.read(1024) data = b"" # empty data means the socket disconnected
except asyncio.TimeoutError: else:
if ping_sent: ping_sent = True
data = b"" # empty data means the socket disconnected await self.send(build("PING", ["hello"]))
else: continue
ping_sent = True
await self.send(build("PING", ["hello"]))
continue
self.last_read = monotonic() self.last_read = monotonic()
ping_sent = False ping_sent = False
try: try:
lines = self.recv(data) lines = self.recv(data)
except ServerDisconnectedException: except ServerDisconnectedException:
self.disconnected = True self.disconnected = True
raise raise
if lines: return lines
self._read_queue.extend(lines[1:])
both = lines[0]
break
return both
async def _line_or_wait(self, line_aw: Awaitable): async def _line_or_wait(self,
line_aw: Awaitable
) -> Optional[Tuple[WaitFor, Awaitable]]:
wait_for_fut: Future[WaitFor] = Future() wait_for_fut: Future[WaitFor] = Future()
self._wait_for_fut = wait_for_fut self._wait_for_fut = wait_for_fut
@ -225,43 +221,34 @@ class Server(IServer):
new_line_aw = list(pend)[0] new_line_aw = list(pend)[0]
return (await wait_for_fut), new_line_aw return (await wait_for_fut), new_line_aw
else: else:
return None, None return None
async def _read_lines(self): async def _read_lines(self) -> List[Tuple[Line, Optional[Emit]]]:
waited_reads: Deque[Tuple[Line, Optional[Emit]]] = deque() lines = await self._next_lines()
wait_for: Optional[WaitFor] = None for line, emit in lines:
wait_for_aw: Optional[Awaitable] = None
async def _line() -> Tuple[Line, Optional[Emit]]:
both = await self._next_line()
waited_reads.append(both)
line, emit = both
self.line_preread(line) self.line_preread(line)
return both
while True: for i, (wait_for, aw) in enumerate(self._wait_fors):
if wait_for is not None: if wait_for.match(self, line):
line, emit = await _line()
if wait_for.response.match(self, line):
wait_for.resolve(line) wait_for.resolve(line)
wait_for, wait_for_aw = await self._line_or_wait( if aw is not None:
wait_for_aw) new_wait_for = await self._line_or_wait(aw)
else: if new_wait_for is not None:
if not waited_reads: self._wait_fors.append(new_wait_for)
await _line() self._wait_fors.pop(i)
break
while waited_reads: line_aw = self._on_read(line, emit)
new_line, new_emit = waited_reads.popleft() new_wait_for = await self._line_or_wait(line_aw)
line_aw = self._on_read(new_line, new_emit) if new_wait_for is not None:
wait_for, wait_for_aw = await self._line_or_wait(line_aw) self._wait_fors.append(new_wait_for)
if wait_for is not None: return lines
break
def wait_for(self, async def wait_for(self,
response: Union[IMatchResponse, Set[IMatchResponse]], response: Union[IMatchResponse, Set[IMatchResponse]],
sent_line: Optional[SentLine]=None sent_aw: Optional[Awaitable[SentLine]]=None
) -> Awaitable[Line]: ) -> Line:
response_obj: IMatchResponse response_obj: IMatchResponse
if isinstance(response, set): if isinstance(response, set):
@ -269,17 +256,21 @@ class Server(IServer):
else: else:
response_obj = response response_obj = response
our_wait_for = WaitFor(response_obj)
wait_for_fut = self._wait_for_fut wait_for_fut = self._wait_for_fut
if wait_for_fut is not None: if wait_for_fut is not None:
self._wait_for_fut = None self._wait_for_fut = None
wait_for_fut.set_result(our_wait_for)
else:
self._wait_fors.append((our_wait_for, None))
label: Optional[str] = None if sent_aw is not None:
if sent_line is not None: sent_line = await sent_aw
label = str(sent_line.id) label = str(sent_line.id)
our_wait_for.with_label(label)
our_wait_for = WaitFor(wait_for_fut, response_obj, label) return (await our_wait_for)
return our_wait_for
raise Exception()
async def _on_send_line(self, line: Line): async def _on_send_line(self, line: Line):
if (line.command == "PRIVMSG" and if (line.command == "PRIVMSG" and
@ -337,7 +328,6 @@ class Server(IServer):
def send_nick(self, new_nick: str) -> Awaitable[bool]: def send_nick(self, new_nick: str) -> Awaitable[bool]:
fut = self.send(build("NICK", [new_nick])) fut = self.send(build("NICK", [new_nick]))
async def _assure() -> bool: async def _assure() -> bool:
await fut
line = await self.wait_for({ line = await self.wait_for({
Response("NICK", [Folded(new_nick)], source=MASK_SELF), Response("NICK", [Folded(new_nick)], source=MASK_SELF),
Responses([ Responses([
@ -350,7 +340,7 @@ class Server(IServer):
ERR_ERRONEUSNICKNAME, ERR_ERRONEUSNICKNAME,
ERR_UNAVAILRESOURCE ERR_UNAVAILRESOURCE
], [ANY, Folded(new_nick)]) ], [ANY, Folded(new_nick)])
}) }, fut)
return line.command == "NICK" return line.command == "NICK"
return MaybeAwait(_assure) return MaybeAwait(_assure)
@ -368,9 +358,9 @@ class Server(IServer):
fut = self.send(build("PART", [name])) fut = self.send(build("PART", [name]))
async def _assure(): async def _assure():
await fut
line = await self.wait_for( line = await self.wait_for(
Response("PART", [Folded(name)], source=MASK_SELF) Response("PART", [Folded(name)], source=MASK_SELF),
fut
) )
return return
return MaybeAwait(_assure) return MaybeAwait(_assure)
@ -388,8 +378,6 @@ class Server(IServer):
fut = self.send(build("JOIN", [",".join(names)]+keys)) fut = self.send(build("JOIN", [",".join(names)]+keys))
async def _assure(): async def _assure():
await fut
channels: List[Channel] = [] channels: List[Channel] = []
while folded_names: while folded_names:
@ -398,7 +386,7 @@ class Server(IServer):
Responses(JOIN_ERR_FIRST, [ANY, ANY]), Responses(JOIN_ERR_FIRST, [ANY, ANY]),
Response(ERR_USERONCHANNEL, [ANY, SELF, ANY]), Response(ERR_USERONCHANNEL, [ANY, SELF, ANY]),
Response(ERR_LINKCHANNEL, [ANY, ANY, ANY]) Response(ERR_LINKCHANNEL, [ANY, ANY, ANY])
}) }, fut)
chan: Optional[str] = None chan: Optional[str] = None
if line.command == RPL_CHANNELMODEIS: if line.command == RPL_CHANNELMODEIS:
@ -429,9 +417,9 @@ class Server(IServer):
) -> Awaitable[Optional[str]]: ) -> Awaitable[Optional[str]]:
fut = self.send(build("PRIVMSG", [target, message])) fut = self.send(build("PRIVMSG", [target, message]))
async def _assure(): async def _assure():
await fut
line = await self.wait_for( line = await self.wait_for(
Response("PRIVMSG", [Folded(target), ANY], source=MASK_SELF) Response("PRIVMSG", [Folded(target), ANY], source=MASK_SELF),
fut
) )
if line.command == "PRIVMSG": if line.command == "PRIVMSG":
return line.params[1] return line.params[1]
@ -449,7 +437,6 @@ class Server(IServer):
fut = self.send(build("WHOIS", args)) fut = self.send(build("WHOIS", args))
async def _assure() -> Optional[Whois]: async def _assure() -> Optional[Whois]:
await fut
params = [ANY, Folded(self.casefold(target))] params = [ANY, Folded(self.casefold(target))]
obj = Whois() obj = Whois()
while True: while True:
@ -465,7 +452,7 @@ class Server(IServer):
RPL_WHOISACCOUNT, RPL_WHOISACCOUNT,
RPL_WHOISSECURE, RPL_WHOISSECURE,
RPL_ENDOFWHOIS RPL_ENDOFWHOIS
], params)) ], params), fut)
if line.command in [ERR_NOSUCHNICK, ERR_NOSUCHSERVER]: if line.command in [ERR_NOSUCHNICK, ERR_NOSUCHSERVER]:
return None return None
elif line.command == RPL_WHOISUSER: elif line.command == RPL_WHOISUSER: