simplify reading lines, simplify wait_for, wait_for from outside reads

This commit is contained in:
jesopo 2020-06-13 00:21:39 +01:00
parent db851e0ba2
commit b9a543031a
4 changed files with 78 additions and 93 deletions

View file

@ -17,21 +17,16 @@ class MaybeAwait(Generic[TEvent]):
class WaitFor(object):
def __init__(self,
wait_fut: "Future[WaitFor]",
response: IMatchResponse,
label: Optional[str]):
self._wait_fut = wait_fut
response: IMatchResponse):
self.response = response
self._label = label
self.deferred = False
self._label: Optional[str] = None
self._our_fut: "Future[Line]" = Future()
def __await__(self) -> Generator[Any, None, Line]:
self._wait_fut.set_result(self)
return self._our_fut.__await__()
async def defer(self):
self.deferred = True
return await self
def with_label(self, label: str):
self._label = label
def match(self, server: IServer, line: Line):
if (self._label is not None and

View file

@ -37,7 +37,10 @@ class Bot(IBot):
async def _run_server(self, server: Server):
async with anyio.create_task_group() as tg:
await tg.spawn(server._read_lines)
async def _read():
while True:
await server._read_lines()
await tg.spawn(_read)
await tg.spawn(server._send_lines)
await self.disconnected(server)

View file

@ -46,7 +46,7 @@ class SentLine(object):
self.id = id
self.priority = priority
self.line = line
self.future: Future = Future()
self.future: "Future[SentLine]" = Future()
def __lt__(self, other: "SentLine") -> bool:
return self.priority < other.priority

View file

@ -64,6 +64,7 @@ class Server(IServer):
self._read_queue: Deque[Tuple[Line, Optional[Emit]]] = deque()
self._wait_fors: List[Tuple[WaitFor, Optional[Awaitable]]] = []
self._wait_for_fut: Optional[Future[WaitFor]] = None
def hostmask(self) -> str:
@ -77,8 +78,11 @@ class Server(IServer):
def send_raw(self, line: str, priority=SendPriority.DEFAULT
) -> Awaitable[SentLine]:
return self.send(tokenise(line), priority)
def send(self, line: Line, priority=SendPriority.DEFAULT
def send(self,
line: Line,
priority=SendPriority.DEFAULT
) -> Awaitable[SentLine]:
self.line_presend(line)
sent_line = SentLine(self._sent_count, priority, line)
self._sent_count += 1
@ -93,10 +97,7 @@ class Server(IServer):
self._send_queue.put_nowait(sent_line)
async def _assure() -> SentLine:
await sent_line.future
return sent_line
return MaybeAwait(_assure)
return sent_line.future
def set_throttle(self, rate: int, time: float):
self.throttle.rate_limit = rate
@ -182,39 +183,34 @@ class Server(IServer):
await self.line_read(line)
async def _next_line(self) -> Tuple[Line, Optional[Emit]]:
if self._read_queue:
both = self._read_queue.popleft()
else:
ping_sent = False
while True:
try:
async with timeout(PING_TIMEOUT):
data = await self._reader.read(1024)
except asyncio.TimeoutError:
if ping_sent:
data = b"" # empty data means the socket disconnected
else:
ping_sent = True
await self.send(build("PING", ["hello"]))
continue
async def _next_lines(self) -> List[Tuple[Line, Optional[Emit]]]:
ping_sent = False
while True:
try:
async with timeout(PING_TIMEOUT):
data = await self._reader.read(1024)
except asyncio.TimeoutError:
if ping_sent:
data = b"" # empty data means the socket disconnected
else:
ping_sent = True
await self.send(build("PING", ["hello"]))
continue
self.last_read = monotonic()
ping_sent = False
self.last_read = monotonic()
ping_sent = False
try:
lines = self.recv(data)
except ServerDisconnectedException:
self.disconnected = True
raise
try:
lines = self.recv(data)
except ServerDisconnectedException:
self.disconnected = True
raise
if lines:
self._read_queue.extend(lines[1:])
both = lines[0]
break
return both
return lines
async def _line_or_wait(self, line_aw: Awaitable):
async def _line_or_wait(self,
line_aw: Awaitable
) -> Optional[Tuple[WaitFor, Awaitable]]:
wait_for_fut: Future[WaitFor] = Future()
self._wait_for_fut = wait_for_fut
@ -225,43 +221,34 @@ class Server(IServer):
new_line_aw = list(pend)[0]
return (await wait_for_fut), new_line_aw
else:
return None, None
return None
async def _read_lines(self):
waited_reads: Deque[Tuple[Line, Optional[Emit]]] = deque()
wait_for: Optional[WaitFor] = None
wait_for_aw: Optional[Awaitable] = None
async def _line() -> Tuple[Line, Optional[Emit]]:
both = await self._next_line()
waited_reads.append(both)
line, emit = both
async def _read_lines(self) -> List[Tuple[Line, Optional[Emit]]]:
lines = await self._next_lines()
for line, emit in lines:
self.line_preread(line)
return both
while True:
if wait_for is not None:
line, emit = await _line()
if wait_for.response.match(self, line):
for i, (wait_for, aw) in enumerate(self._wait_fors):
if wait_for.match(self, line):
wait_for.resolve(line)
wait_for, wait_for_aw = await self._line_or_wait(
wait_for_aw)
else:
if not waited_reads:
await _line()
if aw is not None:
new_wait_for = await self._line_or_wait(aw)
if new_wait_for is not None:
self._wait_fors.append(new_wait_for)
self._wait_fors.pop(i)
break
while waited_reads:
new_line, new_emit = waited_reads.popleft()
line_aw = self._on_read(new_line, new_emit)
wait_for, wait_for_aw = await self._line_or_wait(line_aw)
if wait_for is not None:
break
line_aw = self._on_read(line, emit)
new_wait_for = await self._line_or_wait(line_aw)
if new_wait_for is not None:
self._wait_fors.append(new_wait_for)
return lines
def wait_for(self,
response: Union[IMatchResponse, Set[IMatchResponse]],
sent_line: Optional[SentLine]=None
) -> Awaitable[Line]:
async def wait_for(self,
response: Union[IMatchResponse, Set[IMatchResponse]],
sent_aw: Optional[Awaitable[SentLine]]=None
) -> Line:
response_obj: IMatchResponse
if isinstance(response, set):
@ -269,17 +256,21 @@ class Server(IServer):
else:
response_obj = response
our_wait_for = WaitFor(response_obj)
wait_for_fut = self._wait_for_fut
if wait_for_fut is not None:
self._wait_for_fut = None
wait_for_fut.set_result(our_wait_for)
else:
self._wait_fors.append((our_wait_for, None))
label: Optional[str] = None
if sent_line is not None:
label = str(sent_line.id)
if sent_aw is not None:
sent_line = await sent_aw
label = str(sent_line.id)
our_wait_for.with_label(label)
our_wait_for = WaitFor(wait_for_fut, response_obj, label)
return our_wait_for
raise Exception()
return (await our_wait_for)
async def _on_send_line(self, line: Line):
if (line.command == "PRIVMSG" and
@ -337,7 +328,6 @@ class Server(IServer):
def send_nick(self, new_nick: str) -> Awaitable[bool]:
fut = self.send(build("NICK", [new_nick]))
async def _assure() -> bool:
await fut
line = await self.wait_for({
Response("NICK", [Folded(new_nick)], source=MASK_SELF),
Responses([
@ -350,7 +340,7 @@ class Server(IServer):
ERR_ERRONEUSNICKNAME,
ERR_UNAVAILRESOURCE
], [ANY, Folded(new_nick)])
})
}, fut)
return line.command == "NICK"
return MaybeAwait(_assure)
@ -368,9 +358,9 @@ class Server(IServer):
fut = self.send(build("PART", [name]))
async def _assure():
await fut
line = await self.wait_for(
Response("PART", [Folded(name)], source=MASK_SELF)
Response("PART", [Folded(name)], source=MASK_SELF),
fut
)
return
return MaybeAwait(_assure)
@ -388,8 +378,6 @@ class Server(IServer):
fut = self.send(build("JOIN", [",".join(names)]+keys))
async def _assure():
await fut
channels: List[Channel] = []
while folded_names:
@ -398,7 +386,7 @@ class Server(IServer):
Responses(JOIN_ERR_FIRST, [ANY, ANY]),
Response(ERR_USERONCHANNEL, [ANY, SELF, ANY]),
Response(ERR_LINKCHANNEL, [ANY, ANY, ANY])
})
}, fut)
chan: Optional[str] = None
if line.command == RPL_CHANNELMODEIS:
@ -429,9 +417,9 @@ class Server(IServer):
) -> Awaitable[Optional[str]]:
fut = self.send(build("PRIVMSG", [target, message]))
async def _assure():
await fut
line = await self.wait_for(
Response("PRIVMSG", [Folded(target), ANY], source=MASK_SELF)
Response("PRIVMSG", [Folded(target), ANY], source=MASK_SELF),
fut
)
if line.command == "PRIVMSG":
return line.params[1]
@ -449,7 +437,6 @@ class Server(IServer):
fut = self.send(build("WHOIS", args))
async def _assure() -> Optional[Whois]:
await fut
params = [ANY, Folded(self.casefold(target))]
obj = Whois()
while True:
@ -465,7 +452,7 @@ class Server(IServer):
RPL_WHOISACCOUNT,
RPL_WHOISSECURE,
RPL_ENDOFWHOIS
], params))
], params), fut)
if line.command in [ERR_NOSUCHNICK, ERR_NOSUCHSERVER]:
return None
elif line.command == RPL_WHOISUSER: