simplify reading lines, simplify wait_for, wait_for from outside reads
This commit is contained in:
parent
db851e0ba2
commit
b9a543031a
4 changed files with 78 additions and 93 deletions
|
@ -17,21 +17,16 @@ class MaybeAwait(Generic[TEvent]):
|
|||
|
||||
class WaitFor(object):
|
||||
def __init__(self,
|
||||
wait_fut: "Future[WaitFor]",
|
||||
response: IMatchResponse,
|
||||
label: Optional[str]):
|
||||
self._wait_fut = wait_fut
|
||||
response: IMatchResponse):
|
||||
self.response = response
|
||||
self._label = label
|
||||
self.deferred = False
|
||||
self._label: Optional[str] = None
|
||||
self._our_fut: "Future[Line]" = Future()
|
||||
|
||||
def __await__(self) -> Generator[Any, None, Line]:
|
||||
self._wait_fut.set_result(self)
|
||||
return self._our_fut.__await__()
|
||||
async def defer(self):
|
||||
self.deferred = True
|
||||
return await self
|
||||
|
||||
def with_label(self, label: str):
|
||||
self._label = label
|
||||
|
||||
def match(self, server: IServer, line: Line):
|
||||
if (self._label is not None and
|
||||
|
|
|
@ -37,7 +37,10 @@ class Bot(IBot):
|
|||
|
||||
async def _run_server(self, server: Server):
|
||||
async with anyio.create_task_group() as tg:
|
||||
await tg.spawn(server._read_lines)
|
||||
async def _read():
|
||||
while True:
|
||||
await server._read_lines()
|
||||
await tg.spawn(_read)
|
||||
await tg.spawn(server._send_lines)
|
||||
|
||||
await self.disconnected(server)
|
||||
|
|
|
@ -46,7 +46,7 @@ class SentLine(object):
|
|||
self.id = id
|
||||
self.priority = priority
|
||||
self.line = line
|
||||
self.future: Future = Future()
|
||||
self.future: "Future[SentLine]" = Future()
|
||||
|
||||
def __lt__(self, other: "SentLine") -> bool:
|
||||
return self.priority < other.priority
|
||||
|
|
|
@ -64,6 +64,7 @@ class Server(IServer):
|
|||
|
||||
self._read_queue: Deque[Tuple[Line, Optional[Emit]]] = deque()
|
||||
|
||||
self._wait_fors: List[Tuple[WaitFor, Optional[Awaitable]]] = []
|
||||
self._wait_for_fut: Optional[Future[WaitFor]] = None
|
||||
|
||||
def hostmask(self) -> str:
|
||||
|
@ -77,8 +78,11 @@ class Server(IServer):
|
|||
def send_raw(self, line: str, priority=SendPriority.DEFAULT
|
||||
) -> Awaitable[SentLine]:
|
||||
return self.send(tokenise(line), priority)
|
||||
def send(self, line: Line, priority=SendPriority.DEFAULT
|
||||
def send(self,
|
||||
line: Line,
|
||||
priority=SendPriority.DEFAULT
|
||||
) -> Awaitable[SentLine]:
|
||||
|
||||
self.line_presend(line)
|
||||
sent_line = SentLine(self._sent_count, priority, line)
|
||||
self._sent_count += 1
|
||||
|
@ -93,10 +97,7 @@ class Server(IServer):
|
|||
|
||||
self._send_queue.put_nowait(sent_line)
|
||||
|
||||
async def _assure() -> SentLine:
|
||||
await sent_line.future
|
||||
return sent_line
|
||||
return MaybeAwait(_assure)
|
||||
return sent_line.future
|
||||
|
||||
def set_throttle(self, rate: int, time: float):
|
||||
self.throttle.rate_limit = rate
|
||||
|
@ -182,39 +183,34 @@ class Server(IServer):
|
|||
|
||||
await self.line_read(line)
|
||||
|
||||
async def _next_line(self) -> Tuple[Line, Optional[Emit]]:
|
||||
if self._read_queue:
|
||||
both = self._read_queue.popleft()
|
||||
else:
|
||||
ping_sent = False
|
||||
while True:
|
||||
try:
|
||||
async with timeout(PING_TIMEOUT):
|
||||
data = await self._reader.read(1024)
|
||||
except asyncio.TimeoutError:
|
||||
if ping_sent:
|
||||
data = b"" # empty data means the socket disconnected
|
||||
else:
|
||||
ping_sent = True
|
||||
await self.send(build("PING", ["hello"]))
|
||||
continue
|
||||
async def _next_lines(self) -> List[Tuple[Line, Optional[Emit]]]:
|
||||
ping_sent = False
|
||||
while True:
|
||||
try:
|
||||
async with timeout(PING_TIMEOUT):
|
||||
data = await self._reader.read(1024)
|
||||
except asyncio.TimeoutError:
|
||||
if ping_sent:
|
||||
data = b"" # empty data means the socket disconnected
|
||||
else:
|
||||
ping_sent = True
|
||||
await self.send(build("PING", ["hello"]))
|
||||
continue
|
||||
|
||||
self.last_read = monotonic()
|
||||
ping_sent = False
|
||||
self.last_read = monotonic()
|
||||
ping_sent = False
|
||||
|
||||
try:
|
||||
lines = self.recv(data)
|
||||
except ServerDisconnectedException:
|
||||
self.disconnected = True
|
||||
raise
|
||||
try:
|
||||
lines = self.recv(data)
|
||||
except ServerDisconnectedException:
|
||||
self.disconnected = True
|
||||
raise
|
||||
|
||||
if lines:
|
||||
self._read_queue.extend(lines[1:])
|
||||
both = lines[0]
|
||||
break
|
||||
return both
|
||||
return lines
|
||||
|
||||
async def _line_or_wait(self, line_aw: Awaitable):
|
||||
async def _line_or_wait(self,
|
||||
line_aw: Awaitable
|
||||
) -> Optional[Tuple[WaitFor, Awaitable]]:
|
||||
wait_for_fut: Future[WaitFor] = Future()
|
||||
self._wait_for_fut = wait_for_fut
|
||||
|
||||
|
@ -225,43 +221,34 @@ class Server(IServer):
|
|||
new_line_aw = list(pend)[0]
|
||||
return (await wait_for_fut), new_line_aw
|
||||
else:
|
||||
return None, None
|
||||
return None
|
||||
|
||||
async def _read_lines(self):
|
||||
waited_reads: Deque[Tuple[Line, Optional[Emit]]] = deque()
|
||||
wait_for: Optional[WaitFor] = None
|
||||
wait_for_aw: Optional[Awaitable] = None
|
||||
|
||||
async def _line() -> Tuple[Line, Optional[Emit]]:
|
||||
both = await self._next_line()
|
||||
waited_reads.append(both)
|
||||
line, emit = both
|
||||
async def _read_lines(self) -> List[Tuple[Line, Optional[Emit]]]:
|
||||
lines = await self._next_lines()
|
||||
for line, emit in lines:
|
||||
self.line_preread(line)
|
||||
return both
|
||||
|
||||
while True:
|
||||
if wait_for is not None:
|
||||
line, emit = await _line()
|
||||
if wait_for.response.match(self, line):
|
||||
for i, (wait_for, aw) in enumerate(self._wait_fors):
|
||||
if wait_for.match(self, line):
|
||||
wait_for.resolve(line)
|
||||
|
||||
wait_for, wait_for_aw = await self._line_or_wait(
|
||||
wait_for_aw)
|
||||
else:
|
||||
if not waited_reads:
|
||||
await _line()
|
||||
if aw is not None:
|
||||
new_wait_for = await self._line_or_wait(aw)
|
||||
if new_wait_for is not None:
|
||||
self._wait_fors.append(new_wait_for)
|
||||
self._wait_fors.pop(i)
|
||||
break
|
||||
|
||||
while waited_reads:
|
||||
new_line, new_emit = waited_reads.popleft()
|
||||
line_aw = self._on_read(new_line, new_emit)
|
||||
wait_for, wait_for_aw = await self._line_or_wait(line_aw)
|
||||
if wait_for is not None:
|
||||
break
|
||||
line_aw = self._on_read(line, emit)
|
||||
new_wait_for = await self._line_or_wait(line_aw)
|
||||
if new_wait_for is not None:
|
||||
self._wait_fors.append(new_wait_for)
|
||||
return lines
|
||||
|
||||
def wait_for(self,
|
||||
response: Union[IMatchResponse, Set[IMatchResponse]],
|
||||
sent_line: Optional[SentLine]=None
|
||||
) -> Awaitable[Line]:
|
||||
async def wait_for(self,
|
||||
response: Union[IMatchResponse, Set[IMatchResponse]],
|
||||
sent_aw: Optional[Awaitable[SentLine]]=None
|
||||
) -> Line:
|
||||
|
||||
response_obj: IMatchResponse
|
||||
if isinstance(response, set):
|
||||
|
@ -269,17 +256,21 @@ class Server(IServer):
|
|||
else:
|
||||
response_obj = response
|
||||
|
||||
our_wait_for = WaitFor(response_obj)
|
||||
|
||||
wait_for_fut = self._wait_for_fut
|
||||
if wait_for_fut is not None:
|
||||
self._wait_for_fut = None
|
||||
wait_for_fut.set_result(our_wait_for)
|
||||
else:
|
||||
self._wait_fors.append((our_wait_for, None))
|
||||
|
||||
label: Optional[str] = None
|
||||
if sent_line is not None:
|
||||
label = str(sent_line.id)
|
||||
if sent_aw is not None:
|
||||
sent_line = await sent_aw
|
||||
label = str(sent_line.id)
|
||||
our_wait_for.with_label(label)
|
||||
|
||||
our_wait_for = WaitFor(wait_for_fut, response_obj, label)
|
||||
return our_wait_for
|
||||
raise Exception()
|
||||
return (await our_wait_for)
|
||||
|
||||
async def _on_send_line(self, line: Line):
|
||||
if (line.command == "PRIVMSG" and
|
||||
|
@ -337,7 +328,6 @@ class Server(IServer):
|
|||
def send_nick(self, new_nick: str) -> Awaitable[bool]:
|
||||
fut = self.send(build("NICK", [new_nick]))
|
||||
async def _assure() -> bool:
|
||||
await fut
|
||||
line = await self.wait_for({
|
||||
Response("NICK", [Folded(new_nick)], source=MASK_SELF),
|
||||
Responses([
|
||||
|
@ -350,7 +340,7 @@ class Server(IServer):
|
|||
ERR_ERRONEUSNICKNAME,
|
||||
ERR_UNAVAILRESOURCE
|
||||
], [ANY, Folded(new_nick)])
|
||||
})
|
||||
}, fut)
|
||||
return line.command == "NICK"
|
||||
return MaybeAwait(_assure)
|
||||
|
||||
|
@ -368,9 +358,9 @@ class Server(IServer):
|
|||
fut = self.send(build("PART", [name]))
|
||||
|
||||
async def _assure():
|
||||
await fut
|
||||
line = await self.wait_for(
|
||||
Response("PART", [Folded(name)], source=MASK_SELF)
|
||||
Response("PART", [Folded(name)], source=MASK_SELF),
|
||||
fut
|
||||
)
|
||||
return
|
||||
return MaybeAwait(_assure)
|
||||
|
@ -388,8 +378,6 @@ class Server(IServer):
|
|||
fut = self.send(build("JOIN", [",".join(names)]+keys))
|
||||
|
||||
async def _assure():
|
||||
await fut
|
||||
|
||||
channels: List[Channel] = []
|
||||
|
||||
while folded_names:
|
||||
|
@ -398,7 +386,7 @@ class Server(IServer):
|
|||
Responses(JOIN_ERR_FIRST, [ANY, ANY]),
|
||||
Response(ERR_USERONCHANNEL, [ANY, SELF, ANY]),
|
||||
Response(ERR_LINKCHANNEL, [ANY, ANY, ANY])
|
||||
})
|
||||
}, fut)
|
||||
|
||||
chan: Optional[str] = None
|
||||
if line.command == RPL_CHANNELMODEIS:
|
||||
|
@ -429,9 +417,9 @@ class Server(IServer):
|
|||
) -> Awaitable[Optional[str]]:
|
||||
fut = self.send(build("PRIVMSG", [target, message]))
|
||||
async def _assure():
|
||||
await fut
|
||||
line = await self.wait_for(
|
||||
Response("PRIVMSG", [Folded(target), ANY], source=MASK_SELF)
|
||||
Response("PRIVMSG", [Folded(target), ANY], source=MASK_SELF),
|
||||
fut
|
||||
)
|
||||
if line.command == "PRIVMSG":
|
||||
return line.params[1]
|
||||
|
@ -449,7 +437,6 @@ class Server(IServer):
|
|||
|
||||
fut = self.send(build("WHOIS", args))
|
||||
async def _assure() -> Optional[Whois]:
|
||||
await fut
|
||||
params = [ANY, Folded(self.casefold(target))]
|
||||
obj = Whois()
|
||||
while True:
|
||||
|
@ -465,7 +452,7 @@ class Server(IServer):
|
|||
RPL_WHOISACCOUNT,
|
||||
RPL_WHOISSECURE,
|
||||
RPL_ENDOFWHOIS
|
||||
], params))
|
||||
], params), fut)
|
||||
if line.command in [ERR_NOSUCHNICK, ERR_NOSUCHSERVER]:
|
||||
return None
|
||||
elif line.command == RPL_WHOISUSER:
|
||||
|
|
Loading…
Reference in a new issue