Merge 066f819e9c
into 0ce3b9b0b0
This commit is contained in:
commit
87cf545db5
23 changed files with 674 additions and 491 deletions
|
@ -1,32 +1,36 @@
|
|||
import asyncio, re
|
||||
from argparse import ArgumentParser
|
||||
from typing import Dict, List, Optional
|
||||
import asyncio
|
||||
import re
|
||||
from argparse import ArgumentParser
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from irctokens import build, Line
|
||||
from ircrobots import Bot as BaseBot
|
||||
from ircrobots import Server as BaseServer
|
||||
from irctokens import Line, build
|
||||
|
||||
from ircrobots import Bot as BaseBot
|
||||
from ircrobots import ConnectionParams
|
||||
from ircrobots import Server as BaseServer
|
||||
|
||||
TRIGGER = "!"
|
||||
|
||||
|
||||
def _delims(s: str, delim: str):
|
||||
s_copy = list(s)
|
||||
while s_copy:
|
||||
char = s_copy.pop(0)
|
||||
if char == delim:
|
||||
if not s_copy:
|
||||
yield len(s)-(len(s_copy)+1)
|
||||
yield len(s) - (len(s_copy) + 1)
|
||||
elif not s_copy.pop(0) == delim:
|
||||
yield len(s)-(len(s_copy)+2)
|
||||
yield len(s) - (len(s_copy) + 2)
|
||||
|
||||
|
||||
def _sed(sed: str, s: str) -> Optional[str]:
|
||||
if len(sed) > 1:
|
||||
delim = sed[1]
|
||||
last = 0
|
||||
delim = sed[1]
|
||||
last = 0
|
||||
parts: List[str] = []
|
||||
for i in _delims(sed, delim):
|
||||
parts.append(sed[last:i])
|
||||
last = i+1
|
||||
last = i + 1
|
||||
if len(parts) == 4:
|
||||
break
|
||||
if last < (len(sed)):
|
||||
|
@ -36,10 +40,10 @@ def _sed(sed: str, s: str) -> Optional[str]:
|
|||
flags_s = (args or [""])[0]
|
||||
|
||||
flags = re.I if "i" in flags_s else 0
|
||||
count = 0 if "g" in flags_s else 1
|
||||
count = 0 if "g" in flags_s else 1
|
||||
|
||||
for i in reversed(list(_delims(replace, "&"))):
|
||||
replace = replace[:i] + "\\g<0>" + replace[i+1:]
|
||||
replace = replace[:i] + "\\g<0>" + replace[i + 1 :]
|
||||
|
||||
try:
|
||||
compiled = re.compile(pattern, flags)
|
||||
|
@ -49,18 +53,22 @@ def _sed(sed: str, s: str) -> Optional[str]:
|
|||
else:
|
||||
return None
|
||||
|
||||
|
||||
class Database:
|
||||
def __init__(self):
|
||||
self._settings: Dict[str, str] = {}
|
||||
|
||||
async def get(self, context: str, setting: str) -> Optional[str]:
|
||||
return self._settings.get(setting, None)
|
||||
|
||||
async def set(self, context: str, setting: str, value: str):
|
||||
self._settings[setting] = value
|
||||
|
||||
async def rem(self, context: str, setting: str):
|
||||
if setting in self._settings:
|
||||
del self._settings[setting]
|
||||
|
||||
|
||||
class Server(BaseServer):
|
||||
def __init__(self, bot: Bot, name: str, channel: str, database: Database):
|
||||
super().__init__(bot, name)
|
||||
|
@ -78,24 +86,24 @@ class Server(BaseServer):
|
|||
await self.send(build("JOIN", [self._channel]))
|
||||
|
||||
if (
|
||||
line.command == "PRIVMSG" and
|
||||
self.has_channel(line.params[0]) and
|
||||
not line.hostmask is None and
|
||||
not self.casefold(line.hostmask.nickname) == me and
|
||||
self.has_user(line.hostmask.nickname) and
|
||||
line.params[1].startswith(TRIGGER)):
|
||||
line.command == "PRIVMSG"
|
||||
and self.has_channel(line.params[0])
|
||||
and not line.hostmask is None
|
||||
and not self.casefold(line.hostmask.nickname) == me
|
||||
and self.has_user(line.hostmask.nickname)
|
||||
and line.params[1].startswith(TRIGGER)
|
||||
):
|
||||
|
||||
channel = self.channels[self.casefold(line.params[0])]
|
||||
user = self.users[self.casefold(line.hostmask.nickname)]
|
||||
cuser = channel.users[user.nickname_lower]
|
||||
text = line.params[1].replace(TRIGGER, "", 1)
|
||||
user = self.users[self.casefold(line.hostmask.nickname)]
|
||||
cuser = channel.users[user.nickname_lower]
|
||||
text = line.params[1].replace(TRIGGER, "", 1)
|
||||
db_context = f"{self.name}:{channel.name}"
|
||||
|
||||
name, _, text = text.partition(" ")
|
||||
name, _, text = text.partition(" ")
|
||||
action, _, text = text.partition(" ")
|
||||
name = name.lower()
|
||||
key = f"factoid-{name}"
|
||||
|
||||
key = f"factoid-{name}"
|
||||
|
||||
out = ""
|
||||
if not action or action == "@":
|
||||
|
@ -125,10 +133,8 @@ class Server(BaseServer):
|
|||
elif value:
|
||||
changed = _sed(value, current)
|
||||
if not changed is None:
|
||||
await self._database.set(
|
||||
db_context, key, changed)
|
||||
out = (f"{user.nickname}: "
|
||||
f"changed '{name}' factoid")
|
||||
await self._database.set(db_context, key, changed)
|
||||
out = f"{user.nickname}: " f"changed '{name}' factoid"
|
||||
else:
|
||||
out = f"{user.nickname}: invalid sed"
|
||||
else:
|
||||
|
@ -136,29 +142,28 @@ class Server(BaseServer):
|
|||
else:
|
||||
out = f"{user.nickname}: you are not an op"
|
||||
|
||||
|
||||
else:
|
||||
out = f"{user.nickname}: unknown action '{action}'"
|
||||
await self.send(build("PRIVMSG", [line.params[0], out]))
|
||||
|
||||
|
||||
class Bot(BaseBot):
|
||||
def __init__(self, channel: str):
|
||||
super().__init__()
|
||||
self._channel = channel
|
||||
|
||||
def create_server(self, name: str):
|
||||
return Server(self, name, self._channel, Database())
|
||||
|
||||
|
||||
async def main(hostname: str, channel: str, nickname: str):
|
||||
bot = Bot(channel)
|
||||
|
||||
params = ConnectionParams(
|
||||
nickname,
|
||||
hostname,
|
||||
6697
|
||||
)
|
||||
params = ConnectionParams(nickname, hostname, 6697)
|
||||
await bot.add_server("freenode", params)
|
||||
await bot.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = ArgumentParser(description="A simple IRC bot for factoids")
|
||||
parser.add_argument("hostname")
|
||||
|
|
|
@ -1,32 +1,37 @@
|
|||
import asyncio
|
||||
|
||||
from irctokens import build, Line
|
||||
from irctokens import Line, build
|
||||
|
||||
from ircrobots import SASLSCRAM
|
||||
from ircrobots import Bot as BaseBot
|
||||
from ircrobots import ConnectionParams, SASLUserPass
|
||||
from ircrobots import Server as BaseServer
|
||||
from ircrobots import ConnectionParams, SASLUserPass, SASLSCRAM
|
||||
|
||||
|
||||
class Server(BaseServer):
|
||||
async def line_read(self, line: Line):
|
||||
print(f"{self.name} < {line.format()}")
|
||||
|
||||
async def line_send(self, line: Line):
|
||||
print(f"{self.name} > {line.format()}")
|
||||
|
||||
|
||||
class Bot(BaseBot):
|
||||
def create_server(self, name: str):
|
||||
return Server(self, name)
|
||||
|
||||
|
||||
async def main():
|
||||
bot = Bot()
|
||||
|
||||
sasl_params = SASLUserPass("myusername", "invalidpassword")
|
||||
params = ConnectionParams(
|
||||
"MyNickname",
|
||||
host = "chat.freenode.invalid",
|
||||
port = 6697,
|
||||
sasl = sasl_params)
|
||||
params = ConnectionParams(
|
||||
"MyNickname", host="chat.freenode.invalid", port=6697, sasl=sasl_params
|
||||
)
|
||||
|
||||
await bot.add_server("freenode", params)
|
||||
await bot.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
import asyncio
|
||||
|
||||
from irctokens import build, Line
|
||||
from ircrobots import Bot as BaseBot
|
||||
from ircrobots import Server as BaseServer
|
||||
from ircrobots import ConnectionParams
|
||||
from irctokens import Line, build
|
||||
|
||||
from ircrobots import Bot as BaseBot
|
||||
from ircrobots import ConnectionParams
|
||||
from ircrobots import Server as BaseServer
|
||||
|
||||
SERVERS = [("freenode", "chat.freenode.invalid")]
|
||||
|
||||
SERVERS = [
|
||||
("freenode", "chat.freenode.invalid")
|
||||
]
|
||||
|
||||
class Server(BaseServer):
|
||||
async def line_read(self, line: Line):
|
||||
|
@ -15,13 +15,16 @@ class Server(BaseServer):
|
|||
if line.command == "001":
|
||||
print(f"connected to {self.isupport.network}")
|
||||
await self.send(build("JOIN", ["#testchannel"]))
|
||||
|
||||
async def line_send(self, line: Line):
|
||||
print(f"{self.name} > {line.format()}")
|
||||
|
||||
|
||||
class Bot(BaseBot):
|
||||
def create_server(self, name: str):
|
||||
return Server(self, name)
|
||||
|
||||
|
||||
async def main():
|
||||
bot = Bot()
|
||||
for name, host in SERVERS:
|
||||
|
@ -30,5 +33,6 @@ async def main():
|
|||
|
||||
await bot.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
|
@ -1,5 +1,11 @@
|
|||
from .bot import Bot
|
||||
from .bot import Bot
|
||||
from .ircv3 import Capability
|
||||
from .params import (
|
||||
SASLSCRAM,
|
||||
ConnectionParams,
|
||||
ResumePolicy,
|
||||
SASLExternal,
|
||||
SASLUserPass,
|
||||
STSPolicy,
|
||||
)
|
||||
from .server import Server
|
||||
from .params import (ConnectionParams, SASLUserPass, SASLExternal, SASLSCRAM,
|
||||
STSPolicy, ResumePolicy)
|
||||
from .ircv3 import Capability
|
||||
|
|
|
@ -1,13 +1,15 @@
|
|||
from asyncio import Future
|
||||
from typing import (Any, Awaitable, Callable, Generator, Generic, Optional,
|
||||
TypeVar)
|
||||
from asyncio import Future
|
||||
from typing import Any, Awaitable, Callable, Generator, Generic, Optional, TypeVar
|
||||
|
||||
from irctokens import Line
|
||||
|
||||
from irctokens import Line
|
||||
from .matching import IMatchResponse
|
||||
from .interface import IServer
|
||||
from .ircv3 import TAG_LABEL
|
||||
from .ircv3 import TAG_LABEL
|
||||
from .matching import IMatchResponse
|
||||
|
||||
TEvent = TypeVar("TEvent")
|
||||
|
||||
|
||||
class MaybeAwait(Generic[TEvent]):
|
||||
def __init__(self, func: Callable[[], Awaitable[TEvent]]):
|
||||
self._func = func
|
||||
|
@ -16,13 +18,12 @@ class MaybeAwait(Generic[TEvent]):
|
|||
coro = self._func()
|
||||
return coro.__await__()
|
||||
|
||||
|
||||
class WaitFor(object):
|
||||
def __init__(self,
|
||||
response: IMatchResponse,
|
||||
deadline: float):
|
||||
def __init__(self, response: IMatchResponse, deadline: float):
|
||||
self.response = response
|
||||
self.deadline = deadline
|
||||
self._label: Optional[str] = None
|
||||
self._label: Optional[str] = None
|
||||
self._our_fut: "Future[Line]" = Future()
|
||||
|
||||
def __await__(self) -> Generator[Any, None, Line]:
|
||||
|
@ -32,11 +33,9 @@ class WaitFor(object):
|
|||
self._label = label
|
||||
|
||||
def match(self, server: IServer, line: Line):
|
||||
if (self._label is not None and
|
||||
line.tags is not None):
|
||||
if self._label is not None and line.tags is not None:
|
||||
label = TAG_LABEL.get(line.tags)
|
||||
if (label is not None and
|
||||
label == self._label):
|
||||
if label is not None and label == self._label:
|
||||
return True
|
||||
return self.response.match(server, line)
|
||||
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
import asyncio, traceback
|
||||
import anyio
|
||||
import asyncio
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
import anyio
|
||||
from ircstates.server import ServerDisconnectedException
|
||||
|
||||
from .server import ConnectionParams, Server
|
||||
from .transport import TCPTransport
|
||||
from .interface import IBot, IServer, ITCPTransport
|
||||
from .server import ConnectionParams, Server
|
||||
from .transport import TCPTransport
|
||||
|
||||
|
||||
class Bot(IBot):
|
||||
def __init__(self):
|
||||
|
@ -17,9 +19,11 @@ class Bot(IBot):
|
|||
return Server(self, name)
|
||||
|
||||
async def disconnected(self, server: IServer):
|
||||
if (server.name in self.servers and
|
||||
server.params is not None and
|
||||
server.disconnected):
|
||||
if (
|
||||
server.name in self.servers
|
||||
and server.params is not None
|
||||
and server.disconnected
|
||||
):
|
||||
|
||||
reconnect = server.params.reconnect
|
||||
|
||||
|
@ -30,7 +34,7 @@ class Bot(IBot):
|
|||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
# let's try again, exponential backoff up to 5 mins
|
||||
reconnect = min(reconnect*2, 300)
|
||||
reconnect = min(reconnect * 2, 300)
|
||||
else:
|
||||
break
|
||||
|
||||
|
@ -38,10 +42,12 @@ class Bot(IBot):
|
|||
del self.servers[server.name]
|
||||
await server.disconnect()
|
||||
|
||||
async def add_server(self,
|
||||
name: str,
|
||||
params: ConnectionParams,
|
||||
transport: ITCPTransport = TCPTransport()) -> Server:
|
||||
async def add_server(
|
||||
self,
|
||||
name: str,
|
||||
params: ConnectionParams,
|
||||
transport: ITCPTransport = TCPTransport(),
|
||||
) -> Server:
|
||||
server = self.create_server(name)
|
||||
self.servers[name] = server
|
||||
await server.connect(transport, params)
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
from dataclasses import dataclass
|
||||
|
||||
from .interface import IServer
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServerContext(object):
|
||||
server: IServer
|
||||
|
|
|
@ -1,19 +1,14 @@
|
|||
from typing import List
|
||||
|
||||
BOLD = "\x02"
|
||||
COLOR = "\x03"
|
||||
INVERT = "\x16"
|
||||
ITALIC = "\x1D"
|
||||
BOLD = "\x02"
|
||||
COLOR = "\x03"
|
||||
INVERT = "\x16"
|
||||
ITALIC = "\x1D"
|
||||
UNDERLINE = "\x1F"
|
||||
RESET = "\x0F"
|
||||
RESET = "\x0F"
|
||||
|
||||
FORMATTERS = [BOLD, INVERT, ITALIC, UNDERLINE, RESET]
|
||||
|
||||
FORMATTERS = [
|
||||
BOLD,
|
||||
INVERT,
|
||||
ITALIC,
|
||||
UNDERLINE,
|
||||
RESET
|
||||
]
|
||||
|
||||
def tokens(s: str) -> List[str]:
|
||||
tokens: List[str] = []
|
||||
|
@ -25,9 +20,7 @@ def tokens(s: str) -> List[str]:
|
|||
for i in range(2):
|
||||
if s_copy and s_copy[0].isdigit():
|
||||
token += s_copy.pop(0)
|
||||
if (len(s_copy) > 1 and
|
||||
s_copy[0] == "," and
|
||||
s_copy[1].isdigit()):
|
||||
if len(s_copy) > 1 and s_copy[0] == "," and s_copy[1].isdigit():
|
||||
token += s_copy.pop(0)
|
||||
token += s_copy.pop(0)
|
||||
if s_copy and s_copy[0].isdigit():
|
||||
|
@ -38,6 +31,7 @@ def tokens(s: str) -> List[str]:
|
|||
tokens.append(token)
|
||||
return tokens
|
||||
|
||||
|
||||
def strip(s: str):
|
||||
for token in tokens(s):
|
||||
s = s.replace(token, "", 1)
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
def collapse(pattern: str) -> str:
|
||||
out = ""
|
||||
i = 0
|
||||
|
@ -15,9 +14,10 @@ def collapse(pattern: str) -> str:
|
|||
|
||||
if pattern[i:]:
|
||||
out += pattern[i]
|
||||
i += 1
|
||||
i += 1
|
||||
return out
|
||||
|
||||
|
||||
def _match(pattern: str, s: str):
|
||||
i, j = 0, 0
|
||||
|
||||
|
@ -45,10 +45,14 @@ def _match(pattern: str, s: str):
|
|||
|
||||
return i == len(pattern)
|
||||
|
||||
|
||||
class Glob(object):
|
||||
def __init__(self, pattern: str):
|
||||
self._pattern = pattern
|
||||
|
||||
def match(self, s: str) -> bool:
|
||||
return _match(self._pattern, s)
|
||||
|
||||
|
||||
def compile(pattern: str) -> Glob:
|
||||
return Glob(collapse(pattern))
|
||||
|
|
|
@ -1,16 +1,19 @@
|
|||
from asyncio import Future
|
||||
from typing import Awaitable, Iterable, List, Optional, Set, Tuple, Union
|
||||
from enum import IntEnum
|
||||
from enum import IntEnum
|
||||
from typing import Awaitable, Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
from ircstates import Server, Emit
|
||||
from irctokens import Line, Hostmask
|
||||
from ircstates import Emit, Server
|
||||
from irctokens import Hostmask, Line
|
||||
|
||||
from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy
|
||||
from .params import ConnectionParams, ResumePolicy, SASLParams, STSPolicy
|
||||
from .security import TLS
|
||||
|
||||
|
||||
class ITCPReader(object):
|
||||
async def read(self, byte_count: int):
|
||||
pass
|
||||
|
||||
|
||||
class ITCPWriter(object):
|
||||
def write(self, data: bytes):
|
||||
pass
|
||||
|
@ -20,37 +23,40 @@ class ITCPWriter(object):
|
|||
|
||||
async def drain(self):
|
||||
pass
|
||||
|
||||
async def close(self):
|
||||
pass
|
||||
|
||||
|
||||
class ITCPTransport(object):
|
||||
async def connect(self,
|
||||
hostname: str,
|
||||
port: int,
|
||||
tls: Optional[TLS],
|
||||
bindhost: Optional[str]=None
|
||||
) -> Tuple[ITCPReader, ITCPWriter]:
|
||||
async def connect(
|
||||
self,
|
||||
hostname: str,
|
||||
port: int,
|
||||
tls: Optional[TLS],
|
||||
bindhost: Optional[str] = None,
|
||||
) -> Tuple[ITCPReader, ITCPWriter]:
|
||||
pass
|
||||
|
||||
|
||||
class SendPriority(IntEnum):
|
||||
HIGH = 0
|
||||
HIGH = 0
|
||||
MEDIUM = 10
|
||||
LOW = 20
|
||||
LOW = 20
|
||||
DEFAULT = MEDIUM
|
||||
|
||||
|
||||
class SentLine(object):
|
||||
def __init__(self,
|
||||
id: int,
|
||||
priority: int,
|
||||
line: Line):
|
||||
self.id = id
|
||||
self.priority = priority
|
||||
self.line = line
|
||||
def __init__(self, id: int, priority: int, line: Line):
|
||||
self.id = id
|
||||
self.priority = priority
|
||||
self.line = line
|
||||
self.future: "Future[SentLine]" = Future()
|
||||
|
||||
def __lt__(self, other: "SentLine") -> bool:
|
||||
return self.priority < other.priority
|
||||
|
||||
|
||||
class ICapability(object):
|
||||
def available(self, capabilities: Iterable[str]) -> Optional[str]:
|
||||
pass
|
||||
|
@ -61,38 +67,46 @@ class ICapability(object):
|
|||
def copy(self) -> "ICapability":
|
||||
pass
|
||||
|
||||
|
||||
class IMatchResponse(object):
|
||||
def match(self, server: "IServer", line: Line) -> bool:
|
||||
pass
|
||||
|
||||
|
||||
class IMatchResponseParam(object):
|
||||
def match(self, server: "IServer", arg: str) -> bool:
|
||||
pass
|
||||
|
||||
|
||||
class IMatchResponseValueParam(IMatchResponseParam):
|
||||
def value(self, server: "IServer"):
|
||||
pass
|
||||
|
||||
def set_value(self, value: str):
|
||||
pass
|
||||
|
||||
|
||||
class IMatchResponseHostmask(object):
|
||||
def match(self, server: "IServer", hostmask: Hostmask) -> bool:
|
||||
pass
|
||||
|
||||
|
||||
class IServer(Server):
|
||||
bot: "IBot"
|
||||
bot: "IBot"
|
||||
disconnected: bool
|
||||
params: ConnectionParams
|
||||
params: ConnectionParams
|
||||
desired_caps: Set[ICapability]
|
||||
last_read: float
|
||||
last_read: float
|
||||
|
||||
def send_raw(self, line: str, priority=SendPriority.DEFAULT
|
||||
) -> Awaitable[SentLine]:
|
||||
pass
|
||||
def send(self, line: Line, priority=SendPriority.DEFAULT
|
||||
) -> Awaitable[SentLine]:
|
||||
def send_raw(self, line: str, priority=SendPriority.DEFAULT) -> Awaitable[SentLine]:
|
||||
pass
|
||||
|
||||
def wait_for(self,
|
||||
response: Union[IMatchResponse, Set[IMatchResponse]]
|
||||
) -> Awaitable[Line]:
|
||||
def send(self, line: Line, priority=SendPriority.DEFAULT) -> Awaitable[SentLine]:
|
||||
pass
|
||||
|
||||
def wait_for(
|
||||
self, response: Union[IMatchResponse, Set[IMatchResponse]]
|
||||
) -> Awaitable[Line]:
|
||||
pass
|
||||
|
||||
def set_throttle(self, rate: int, time: float):
|
||||
|
@ -101,37 +115,44 @@ class IServer(Server):
|
|||
def server_address(self) -> Tuple[str, int]:
|
||||
pass
|
||||
|
||||
async def connect(self,
|
||||
transport: ITCPTransport,
|
||||
params: ConnectionParams):
|
||||
async def connect(self, transport: ITCPTransport, params: ConnectionParams):
|
||||
pass
|
||||
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
def line_preread(self, line: Line):
|
||||
pass
|
||||
|
||||
def line_presend(self, line: Line):
|
||||
pass
|
||||
|
||||
async def line_read(self, line: Line):
|
||||
pass
|
||||
|
||||
async def line_send(self, line: Line):
|
||||
pass
|
||||
|
||||
async def sts_policy(self, sts: STSPolicy):
|
||||
pass
|
||||
|
||||
async def resume_policy(self, resume: ResumePolicy):
|
||||
pass
|
||||
|
||||
def cap_agreed(self, capability: ICapability) -> bool:
|
||||
pass
|
||||
|
||||
def cap_available(self, capability: ICapability) -> Optional[str]:
|
||||
pass
|
||||
|
||||
async def sasl_auth(self, sasl: SASLParams) -> bool:
|
||||
pass
|
||||
|
||||
|
||||
class IBot(object):
|
||||
def create_server(self, name: str) -> IServer:
|
||||
pass
|
||||
|
||||
async def disconnected(self, server: IServer):
|
||||
pass
|
||||
|
||||
|
|
|
@ -1,22 +1,26 @@
|
|||
from time import time
|
||||
from typing import Dict, Iterable, List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
from irctokens import build
|
||||
from ircstates.server import ServerDisconnectedException
|
||||
from time import time
|
||||
from typing import Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
from .contexts import ServerContext
|
||||
from .matching import Response, ANY
|
||||
from ircstates.server import ServerDisconnectedException
|
||||
from irctokens import build
|
||||
|
||||
from .contexts import ServerContext
|
||||
from .interface import ICapability
|
||||
from .params import ConnectionParams, STSPolicy, ResumePolicy
|
||||
from .security import TLS_VERIFYCHAIN
|
||||
from .matching import ANY, Response
|
||||
from .params import ConnectionParams, ResumePolicy, STSPolicy
|
||||
from .security import TLS_VERIFYCHAIN
|
||||
|
||||
|
||||
class Capability(ICapability):
|
||||
def __init__(self,
|
||||
ratified_name: Optional[str],
|
||||
draft_name: Optional[str]=None,
|
||||
alias: Optional[str]=None,
|
||||
depends_on: List[str]=[]):
|
||||
self.name = ratified_name
|
||||
def __init__(
|
||||
self,
|
||||
ratified_name: Optional[str],
|
||||
draft_name: Optional[str] = None,
|
||||
alias: Optional[str] = None,
|
||||
depends_on: List[str] = [],
|
||||
):
|
||||
self.name = ratified_name
|
||||
self.draft = draft_name
|
||||
self.alias = alias or ratified_name
|
||||
self.depends_on = depends_on.copy()
|
||||
|
@ -26,8 +30,7 @@ class Capability(ICapability):
|
|||
def match(self, capability: str) -> bool:
|
||||
return capability in self._caps
|
||||
|
||||
def available(self, capabilities: Iterable[str]
|
||||
) -> Optional[str]:
|
||||
def available(self, capabilities: Iterable[str]) -> Optional[str]:
|
||||
for cap in self._caps:
|
||||
if not cap is None and cap in capabilities:
|
||||
return cap
|
||||
|
@ -36,16 +39,13 @@ class Capability(ICapability):
|
|||
|
||||
def copy(self):
|
||||
return Capability(
|
||||
self.name,
|
||||
self.draft,
|
||||
alias=self.alias,
|
||||
depends_on=self.depends_on[:])
|
||||
self.name, self.draft, alias=self.alias, depends_on=self.depends_on[:]
|
||||
)
|
||||
|
||||
|
||||
class MessageTag(object):
|
||||
def __init__(self,
|
||||
name: Optional[str],
|
||||
draft_name: Optional[str]=None):
|
||||
self.name = name
|
||||
def __init__(self, name: Optional[str], draft_name: Optional[str] = None):
|
||||
self.name = name
|
||||
self.draft = draft_name
|
||||
self._tags = [self.name, self.draft]
|
||||
|
||||
|
@ -63,37 +63,36 @@ class MessageTag(object):
|
|||
else:
|
||||
return None
|
||||
|
||||
CAP_SASL = Capability("sasl")
|
||||
CAP_ECHO = Capability("echo-message")
|
||||
CAP_STS = Capability("sts", "draft/sts")
|
||||
|
||||
CAP_SASL = Capability("sasl")
|
||||
CAP_ECHO = Capability("echo-message")
|
||||
CAP_STS = Capability("sts", "draft/sts")
|
||||
CAP_RESUME = Capability(None, "draft/resume-0.5", alias="resume")
|
||||
|
||||
CAP_LABEL = Capability("labeled-response", "draft/labeled-response-0.2")
|
||||
TAG_LABEL = MessageTag("label", "draft/label")
|
||||
CAP_LABEL = Capability("labeled-response", "draft/labeled-response-0.2")
|
||||
TAG_LABEL = MessageTag("label", "draft/label")
|
||||
LABEL_TAG_MAP = {
|
||||
"draft/labeled-response-0.2": "draft/label",
|
||||
"labeled-response": "label"
|
||||
"labeled-response": "label",
|
||||
}
|
||||
|
||||
CAPS: List[ICapability] = [
|
||||
Capability("multi-prefix"),
|
||||
Capability("chghost"),
|
||||
Capability("away-notify"),
|
||||
|
||||
Capability("invite-notify"),
|
||||
Capability("account-tag"),
|
||||
Capability("account-notify"),
|
||||
Capability("extended-join"),
|
||||
|
||||
Capability("message-tags", "draft/message-tags-0.2"),
|
||||
Capability("cap-notify"),
|
||||
Capability("batch"),
|
||||
|
||||
Capability(None, "draft/rename", alias="rename"),
|
||||
Capability("setname", "draft/setname"),
|
||||
CAP_RESUME
|
||||
CAP_RESUME,
|
||||
]
|
||||
|
||||
|
||||
def _cap_dict(s: str) -> Dict[str, str]:
|
||||
d: Dict[str, str] = {}
|
||||
for token in s.split(","):
|
||||
|
@ -101,41 +100,44 @@ def _cap_dict(s: str) -> Dict[str, str]:
|
|||
d[key] = value
|
||||
return d
|
||||
|
||||
|
||||
async def sts_transmute(params: ConnectionParams):
|
||||
if not params.sts is None and params.tls is None:
|
||||
now = time()
|
||||
since = (now-params.sts.created)
|
||||
now = time()
|
||||
since = now - params.sts.created
|
||||
if since <= params.sts.duration:
|
||||
params.port = params.sts.port
|
||||
params.tls = TLS_VERIFYCHAIN
|
||||
params.tls = TLS_VERIFYCHAIN
|
||||
|
||||
|
||||
async def resume_transmute(params: ConnectionParams):
|
||||
if params.resume is not None:
|
||||
params.host = params.resume.address
|
||||
|
||||
|
||||
class HandshakeCancel(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class CAPContext(ServerContext):
|
||||
async def on_ls(self, tokens: Dict[str, str]):
|
||||
await self._sts(tokens)
|
||||
|
||||
caps = list(self.server.desired_caps)+CAPS
|
||||
caps = list(self.server.desired_caps) + CAPS
|
||||
|
||||
if (not self.server.params.sasl is None and
|
||||
not CAP_SASL in caps):
|
||||
if not self.server.params.sasl is None and not CAP_SASL in caps:
|
||||
caps.append(CAP_SASL)
|
||||
|
||||
matched = (c.available(tokens) for c in caps)
|
||||
matched = (c.available(tokens) for c in caps)
|
||||
cap_names = [name for name in matched if not name is None]
|
||||
|
||||
if cap_names:
|
||||
await self.server.send(build("CAP", ["REQ", " ".join(cap_names)]))
|
||||
|
||||
while cap_names:
|
||||
line = await self.server.wait_for({
|
||||
Response("CAP", [ANY, "ACK"]),
|
||||
Response("CAP", [ANY, "NAK"])
|
||||
})
|
||||
line = await self.server.wait_for(
|
||||
{Response("CAP", [ANY, "ACK"]), Response("CAP", [ANY, "NAK"])}
|
||||
)
|
||||
|
||||
current_caps = line.params[2].split(" ")
|
||||
for cap in current_caps:
|
||||
|
@ -144,8 +146,7 @@ class CAPContext(ServerContext):
|
|||
if CAP_RESUME.available(current_caps):
|
||||
await self.resume_token()
|
||||
|
||||
if (self.server.cap_agreed(CAP_SASL) and
|
||||
not self.server.params.sasl is None):
|
||||
if self.server.cap_agreed(CAP_SASL) and not self.server.params.sasl is None:
|
||||
await self.server.sasl_auth(self.server.params.sasl)
|
||||
|
||||
async def resume_token(self):
|
||||
|
@ -160,10 +161,9 @@ class CAPContext(ServerContext):
|
|||
|
||||
if previous_policy is not None and not self.server.registered:
|
||||
await self.server.send(build("RESUME", [previous_policy.token]))
|
||||
line = await self.server.wait_for({
|
||||
Response("RESUME", ["SUCCESS"]),
|
||||
Response("FAIL", ["RESUME"])
|
||||
})
|
||||
line = await self.server.wait_for(
|
||||
{Response("RESUME", ["SUCCESS"]), Response("FAIL", ["RESUME"])}
|
||||
)
|
||||
if line.command == "RESUME":
|
||||
raise HandshakeCancel()
|
||||
|
||||
|
@ -179,11 +179,11 @@ class CAPContext(ServerContext):
|
|||
cap_sts = CAP_STS.available(tokens)
|
||||
if not cap_sts is None:
|
||||
sts_dict = _cap_dict(tokens[cap_sts])
|
||||
params = self.server.params
|
||||
params = self.server.params
|
||||
if not params.tls:
|
||||
if "port" in sts_dict:
|
||||
params.port = int(sts_dict["port"])
|
||||
params.tls = TLS_VERIFYCHAIN
|
||||
params.tls = TLS_VERIFYCHAIN
|
||||
|
||||
await self.server.bot.disconnect(self.server)
|
||||
await self.server.bot.add_server(self.server.name, params)
|
||||
|
@ -194,6 +194,6 @@ class CAPContext(ServerContext):
|
|||
int(time()),
|
||||
params.port,
|
||||
int(sts_dict["duration"]),
|
||||
"preload" in sts_dict)
|
||||
"preload" in sts_dict,
|
||||
)
|
||||
await self.server.sts_policy(policy)
|
||||
|
||||
|
|
|
@ -1,3 +1,2 @@
|
|||
|
||||
from .params import *
|
||||
from .responses import *
|
||||
from .params import *
|
||||
|
|
|
@ -1,16 +1,27 @@
|
|||
from re import compile as re_compile
|
||||
from typing import Optional, Pattern, Union
|
||||
from irctokens import Hostmask
|
||||
from ..interface import (IMatchResponseParam, IMatchResponseValueParam,
|
||||
IMatchResponseHostmask, IServer)
|
||||
from ..glob import Glob, compile as glob_compile
|
||||
from re import compile as re_compile
|
||||
from typing import Optional, Pattern, Union
|
||||
|
||||
from irctokens import Hostmask
|
||||
|
||||
from .. import formatting
|
||||
from ..glob import Glob
|
||||
from ..glob import compile as glob_compile
|
||||
from ..interface import (
|
||||
IMatchResponseHostmask,
|
||||
IMatchResponseParam,
|
||||
IMatchResponseValueParam,
|
||||
IServer,
|
||||
)
|
||||
|
||||
|
||||
class Any(IMatchResponseParam):
|
||||
def __repr__(self) -> str:
|
||||
return "Any()"
|
||||
|
||||
def match(self, server: IServer, arg: str) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
ANY = Any()
|
||||
|
||||
# NOT
|
||||
|
@ -18,107 +29,142 @@ ANY = Any()
|
|||
# REGEX
|
||||
# LITERAL
|
||||
|
||||
|
||||
class Literal(IMatchResponseValueParam):
|
||||
def __init__(self, value: str):
|
||||
self._value = value
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self._value!r}"
|
||||
|
||||
def value(self, server: IServer) -> str:
|
||||
return self._value
|
||||
|
||||
def set_value(self, value: str):
|
||||
self._value = value
|
||||
|
||||
def match(self, server: IServer, arg: str) -> bool:
|
||||
return arg == self._value
|
||||
|
||||
TYPE_MAYBELIT = Union[str, IMatchResponseParam]
|
||||
|
||||
TYPE_MAYBELIT = Union[str, IMatchResponseParam]
|
||||
TYPE_MAYBELIT_VALUE = Union[str, IMatchResponseValueParam]
|
||||
|
||||
|
||||
def _assure_lit(value: TYPE_MAYBELIT_VALUE) -> IMatchResponseValueParam:
|
||||
if isinstance(value, str):
|
||||
return Literal(value)
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
class Not(IMatchResponseParam):
|
||||
def __init__(self, param: IMatchResponseParam):
|
||||
self._param = param
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Not({self._param!r})"
|
||||
|
||||
def match(self, server: IServer, arg: str) -> bool:
|
||||
return not self._param.match(server, arg)
|
||||
|
||||
|
||||
class ParamValuePassthrough(IMatchResponseValueParam):
|
||||
_value: IMatchResponseValueParam
|
||||
|
||||
def value(self, server: IServer):
|
||||
return self._value.value(server)
|
||||
|
||||
def set_value(self, value: str):
|
||||
self._value.set_value(value)
|
||||
|
||||
|
||||
class Folded(ParamValuePassthrough):
|
||||
def __init__(self, value: TYPE_MAYBELIT_VALUE):
|
||||
self._value = _assure_lit(value)
|
||||
self._folded = False
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Folded({self._value!r})"
|
||||
|
||||
def match(self, server: IServer, arg: str) -> bool:
|
||||
if not self._folded:
|
||||
value = self.value(server)
|
||||
value = self.value(server)
|
||||
folded = server.casefold(value)
|
||||
self.set_value(folded)
|
||||
self._folded = True
|
||||
|
||||
return self._value.match(server, server.casefold(arg))
|
||||
|
||||
|
||||
class Formatless(IMatchResponseParam):
|
||||
def __init__(self, value: TYPE_MAYBELIT_VALUE):
|
||||
self._value = _assure_lit(value)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
brepr = super().__repr__()
|
||||
return f"Formatless({brepr})"
|
||||
|
||||
def match(self, server: IServer, arg: str) -> bool:
|
||||
strip = formatting.strip(arg)
|
||||
return self._value.match(server, strip)
|
||||
|
||||
|
||||
class Regex(IMatchResponseParam):
|
||||
def __init__(self, value: str):
|
||||
self._value = value
|
||||
self._pattern: Optional[Pattern] = None
|
||||
|
||||
def match(self, server: IServer, arg: str) -> bool:
|
||||
if self._pattern is None:
|
||||
self._pattern = re_compile(self._value)
|
||||
return bool(self._pattern.search(arg))
|
||||
|
||||
|
||||
class Self(IMatchResponseParam):
|
||||
def __repr__(self) -> str:
|
||||
return "Self()"
|
||||
|
||||
def match(self, server: IServer, arg: str) -> bool:
|
||||
return server.casefold(arg) == server.nickname_lower
|
||||
|
||||
|
||||
SELF = Self()
|
||||
|
||||
|
||||
class MaskSelf(IMatchResponseHostmask):
|
||||
def __repr__(self) -> str:
|
||||
return "MaskSelf()"
|
||||
|
||||
def match(self, server: IServer, hostmask: Hostmask):
|
||||
return server.casefold(hostmask.nickname) == server.nickname_lower
|
||||
|
||||
|
||||
MASK_SELF = MaskSelf()
|
||||
|
||||
|
||||
class Nick(IMatchResponseHostmask):
|
||||
def __init__(self, nickname: str):
|
||||
self._nickname = nickname
|
||||
self._folded: Optional[str] = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Nick({self._nickname!r})"
|
||||
|
||||
def match(self, server: IServer, hostmask: Hostmask):
|
||||
if self._folded is None:
|
||||
self._folded = server.casefold(self._nickname)
|
||||
return self._folded == server.casefold(hostmask.nickname)
|
||||
|
||||
|
||||
class Mask(IMatchResponseHostmask):
|
||||
def __init__(self, mask: str):
|
||||
self._mask = mask
|
||||
self._compiled: Optional[Glob]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Mask({self._mask!r})"
|
||||
|
||||
def match(self, server: IServer, hostmask: Hostmask):
|
||||
if self._compiled is None:
|
||||
self._compiled = glob_compile(self._mask)
|
||||
|
|
|
@ -1,17 +1,27 @@
|
|||
from typing import List, Optional, Sequence, Union
|
||||
from irctokens import Line
|
||||
from ..interface import (IServer, IMatchResponse, IMatchResponseParam,
|
||||
IMatchResponseHostmask)
|
||||
from .params import *
|
||||
from typing import List, Optional, Sequence, Union
|
||||
|
||||
from irctokens import Line
|
||||
|
||||
from ..interface import (
|
||||
IMatchResponse,
|
||||
IMatchResponseHostmask,
|
||||
IMatchResponseParam,
|
||||
IServer,
|
||||
)
|
||||
from .params import *
|
||||
|
||||
TYPE_PARAM = Union[str, IMatchResponseParam]
|
||||
|
||||
|
||||
class Responses(IMatchResponse):
|
||||
def __init__(self,
|
||||
commands: Sequence[str],
|
||||
params: Sequence[TYPE_PARAM]=[],
|
||||
source: Optional[IMatchResponseHostmask]=None):
|
||||
def __init__(
|
||||
self,
|
||||
commands: Sequence[str],
|
||||
params: Sequence[TYPE_PARAM] = [],
|
||||
source: Optional[IMatchResponseHostmask] = None,
|
||||
):
|
||||
self._commands = commands
|
||||
self._source = source
|
||||
self._source = source
|
||||
|
||||
self._params: Sequence[IMatchResponseParam] = []
|
||||
for param in params:
|
||||
|
@ -25,36 +35,43 @@ class Responses(IMatchResponse):
|
|||
|
||||
def match(self, server: IServer, line: Line) -> bool:
|
||||
for command in self._commands:
|
||||
if (line.command == command and (
|
||||
self._source is None or (
|
||||
line.hostmask is not None and
|
||||
self._source.match(server, line.hostmask)
|
||||
))):
|
||||
if line.command == command and (
|
||||
self._source is None
|
||||
or (
|
||||
line.hostmask is not None
|
||||
and self._source.match(server, line.hostmask)
|
||||
)
|
||||
):
|
||||
|
||||
for i, param in enumerate(self._params):
|
||||
if (i >= len(line.params) or
|
||||
not param.match(server, line.params[i])):
|
||||
if i >= len(line.params) or not param.match(server, line.params[i]):
|
||||
break
|
||||
else:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class Response(Responses):
|
||||
def __init__(self,
|
||||
command: str,
|
||||
params: Sequence[TYPE_PARAM]=[],
|
||||
source: Optional[IMatchResponseHostmask]=None):
|
||||
def __init__(
|
||||
self,
|
||||
command: str,
|
||||
params: Sequence[TYPE_PARAM] = [],
|
||||
source: Optional[IMatchResponseHostmask] = None,
|
||||
):
|
||||
super().__init__([command], params, source=source)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Response({self._commands[0]}: {self._params!r})"
|
||||
|
||||
|
||||
class ResponseOr(IMatchResponse):
|
||||
def __init__(self, *responses: IMatchResponse):
|
||||
self._responses = responses
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"ResponseOr({self._responses!r})"
|
||||
|
||||
def match(self, server: IServer, line: Line) -> bool:
|
||||
for response in self._responses:
|
||||
if response.match(server, line):
|
||||
|
|
|
@ -1,74 +1,80 @@
|
|||
from re import compile as re_compile
|
||||
from typing import List, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from re import compile as re_compile
|
||||
from typing import List, Optional
|
||||
|
||||
from .security import TLS, TLS_NOVERIFY, TLS_VERIFYCHAIN
|
||||
|
||||
|
||||
class SASLParams(object):
|
||||
mechanism: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class _SASLUserPass(SASLParams):
|
||||
username: str
|
||||
password: str
|
||||
username: str
|
||||
password: str
|
||||
|
||||
|
||||
class SASLUserPass(_SASLUserPass):
|
||||
mechanism = "USERPASS"
|
||||
|
||||
|
||||
class SASLSCRAM(_SASLUserPass):
|
||||
mechanism = "SCRAM"
|
||||
|
||||
|
||||
class SASLExternal(SASLParams):
|
||||
mechanism = "EXTERNAL"
|
||||
|
||||
|
||||
@dataclass
|
||||
class STSPolicy(object):
|
||||
created: int
|
||||
port: int
|
||||
created: int
|
||||
port: int
|
||||
duration: int
|
||||
preload: bool
|
||||
preload: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResumePolicy(object):
|
||||
address: str
|
||||
token: str
|
||||
token: str
|
||||
|
||||
|
||||
RE_IPV6HOST = re_compile("\[([a-fA-F0-9:]+)\]")
|
||||
|
||||
_TLS_TYPES = {
|
||||
"+": TLS_VERIFYCHAIN,
|
||||
"~": TLS_NOVERIFY
|
||||
}
|
||||
_TLS_TYPES = {"+": TLS_VERIFYCHAIN, "~": TLS_NOVERIFY}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectionParams(object):
|
||||
nickname: str
|
||||
host: str
|
||||
port: int
|
||||
tls: Optional[TLS] = TLS_VERIFYCHAIN
|
||||
host: str
|
||||
port: int
|
||||
tls: Optional[TLS] = TLS_VERIFYCHAIN
|
||||
|
||||
username: Optional[str] = None
|
||||
realname: Optional[str] = None
|
||||
bindhost: Optional[str] = None
|
||||
|
||||
password: Optional[str] = None
|
||||
sasl: Optional[SASLParams] = None
|
||||
password: Optional[str] = None
|
||||
sasl: Optional[SASLParams] = None
|
||||
|
||||
sts: Optional[STSPolicy] = None
|
||||
sts: Optional[STSPolicy] = None
|
||||
resume: Optional[ResumePolicy] = None
|
||||
|
||||
reconnect: int = 10 # seconds
|
||||
reconnect: int = 10 # seconds
|
||||
alt_nicknames: List[str] = field(default_factory=list)
|
||||
|
||||
autojoin: List[str] = field(default_factory=list)
|
||||
autojoin: List[str] = field(default_factory=list)
|
||||
|
||||
@staticmethod
|
||||
def from_hoststring(
|
||||
nickname: str,
|
||||
hoststring: str
|
||||
) -> "ConnectionParams":
|
||||
def from_hoststring(nickname: str, hoststring: str) -> "ConnectionParams":
|
||||
|
||||
ipv6host = RE_IPV6HOST.search(hoststring)
|
||||
if ipv6host is not None and ipv6host.start() == 0:
|
||||
host = ipv6host.group(1)
|
||||
port_s = hoststring[ipv6host.end()+1:]
|
||||
port_s = hoststring[ipv6host.end() + 1 :]
|
||||
else:
|
||||
host, _, port_s = hoststring.strip().partition(":")
|
||||
|
||||
|
|
|
@ -1,52 +1,63 @@
|
|||
from typing import List
|
||||
from enum import Enum
|
||||
from base64 import b64decode, b64encode
|
||||
from irctokens import build
|
||||
from ircstates.numerics import *
|
||||
from base64 import b64decode, b64encode
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
from ircstates.numerics import *
|
||||
from irctokens import build
|
||||
|
||||
from .matching import Responses, Response, ANY
|
||||
from .contexts import ServerContext
|
||||
from .params import SASLParams, SASLUserPass, SASLSCRAM, SASLExternal
|
||||
from .scram import SCRAMContext, SCRAMAlgorithm
|
||||
from .matching import ANY, Response, Responses
|
||||
from .params import SASLSCRAM, SASLExternal, SASLParams, SASLUserPass
|
||||
from .scram import SCRAMAlgorithm, SCRAMContext
|
||||
|
||||
SASL_SCRAM_MECHANISMS = [
|
||||
"SCRAM-SHA-512",
|
||||
"SCRAM-SHA-256",
|
||||
"SCRAM-SHA-1",
|
||||
]
|
||||
SASL_USERPASS_MECHANISMS = SASL_SCRAM_MECHANISMS+["PLAIN"]
|
||||
SASL_USERPASS_MECHANISMS = SASL_SCRAM_MECHANISMS + ["PLAIN"]
|
||||
|
||||
|
||||
class SASLResult(Enum):
|
||||
NONE = 0
|
||||
NONE = 0
|
||||
SUCCESS = 1
|
||||
FAILURE = 2
|
||||
ALREADY = 3
|
||||
|
||||
|
||||
class SASLError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class SASLUnknownMechanismError(SASLError):
|
||||
pass
|
||||
|
||||
|
||||
AUTH_BYTE_MAX = 400
|
||||
|
||||
AUTHENTICATE_ANY = Response("AUTHENTICATE", [ANY])
|
||||
|
||||
NUMERICS_FAIL = Response(ERR_SASLFAIL)
|
||||
NUMERICS_INITIAL = Responses([
|
||||
ERR_SASLFAIL, ERR_SASLALREADY, RPL_SASLMECHS, ERR_SASLABORTED
|
||||
])
|
||||
NUMERICS_LAST = Responses([RPL_SASLSUCCESS, ERR_SASLFAIL])
|
||||
NUMERICS_FAIL = Response(ERR_SASLFAIL)
|
||||
NUMERICS_INITIAL = Responses(
|
||||
[ERR_SASLFAIL, ERR_SASLALREADY, RPL_SASLMECHS, ERR_SASLABORTED]
|
||||
)
|
||||
NUMERICS_LAST = Responses([RPL_SASLSUCCESS, ERR_SASLFAIL])
|
||||
|
||||
|
||||
def _b64e(s: str):
|
||||
return b64encode(s.encode("utf8")).decode("ascii")
|
||||
|
||||
|
||||
def _b64eb(s: bytes) -> str:
|
||||
# encode-from-bytes
|
||||
return b64encode(s).decode("ascii")
|
||||
|
||||
|
||||
def _b64db(s: str) -> bytes:
|
||||
# decode-to-bytes
|
||||
return b64decode(s)
|
||||
|
||||
|
||||
class SASLContext(ServerContext):
|
||||
async def from_params(self, params: SASLParams) -> SASLResult:
|
||||
if isinstance(params, SASLUserPass):
|
||||
|
@ -57,15 +68,12 @@ class SASLContext(ServerContext):
|
|||
return await self.external()
|
||||
else:
|
||||
raise SASLUnknownMechanismError(
|
||||
"SASLParams given with unknown mechanism "
|
||||
f"{params.mechanism!r}")
|
||||
"SASLParams given with unknown mechanism " f"{params.mechanism!r}"
|
||||
)
|
||||
|
||||
async def external(self) -> SASLResult:
|
||||
await self.server.send(build("AUTHENTICATE", ["EXTERNAL"]))
|
||||
line = await self.server.wait_for({
|
||||
AUTHENTICATE_ANY,
|
||||
NUMERICS_INITIAL
|
||||
})
|
||||
line = await self.server.wait_for({AUTHENTICATE_ANY, NUMERICS_INITIAL})
|
||||
|
||||
if line.command == "907":
|
||||
# we've done SASL already. cleanly abort
|
||||
|
@ -73,8 +81,8 @@ class SASLContext(ServerContext):
|
|||
elif line.command == "908":
|
||||
available = line.params[1].split(",")
|
||||
raise SASLUnknownMechanismError(
|
||||
"Server does not support SASL EXTERNAL "
|
||||
f"(it supports {available}")
|
||||
"Server does not support SASL EXTERNAL " f"(it supports {available}"
|
||||
)
|
||||
elif line.command == "AUTHENTICATE" and line.params[0] == "+":
|
||||
await self.server.send(build("AUTHENTICATE", ["+"]))
|
||||
|
||||
|
@ -89,11 +97,12 @@ class SASLContext(ServerContext):
|
|||
async def scram(self, username: str, password: str) -> SASLResult:
|
||||
return await self.userpass(username, password, SASL_SCRAM_MECHANISMS)
|
||||
|
||||
async def userpass(self,
|
||||
username: str,
|
||||
password: str,
|
||||
mechanisms: List[str]=SASL_USERPASS_MECHANISMS
|
||||
) -> SASLResult:
|
||||
async def userpass(
|
||||
self,
|
||||
username: str,
|
||||
password: str,
|
||||
mechanisms: List[str] = SASL_USERPASS_MECHANISMS,
|
||||
) -> SASLResult:
|
||||
def _common(server_mechs) -> List[str]:
|
||||
mechs: List[str] = []
|
||||
for our_mech in mechanisms:
|
||||
|
@ -106,23 +115,21 @@ class SASLContext(ServerContext):
|
|||
raise SASLUnknownMechanismError(
|
||||
"No matching SASL mechanims. "
|
||||
f"(we want: {mechanisms} "
|
||||
f"server has: {server_mechs})")
|
||||
f"server has: {server_mechs})"
|
||||
)
|
||||
|
||||
if self.server.available_caps["sasl"]:
|
||||
# CAP v3.2 tells us what mechs it supports
|
||||
available = self.server.available_caps["sasl"].split(",")
|
||||
match = _common(available)
|
||||
match = _common(available)
|
||||
else:
|
||||
# CAP v3.1 does not. pick the pick and wait for 907 to inform us of
|
||||
# what mechanisms are supported
|
||||
match = mechanisms
|
||||
match = mechanisms
|
||||
|
||||
while match:
|
||||
await self.server.send(build("AUTHENTICATE", [match[0]]))
|
||||
line = await self.server.wait_for({
|
||||
AUTHENTICATE_ANY,
|
||||
NUMERICS_INITIAL
|
||||
})
|
||||
line = await self.server.wait_for({AUTHENTICATE_ANY, NUMERICS_INITIAL})
|
||||
|
||||
if line.command == "907":
|
||||
# we've done SASL already. cleanly abort
|
||||
|
@ -130,7 +137,7 @@ class SASLContext(ServerContext):
|
|||
elif line.command == "908":
|
||||
# prior to CAP v3.2 - ERR telling us which mechs are supported
|
||||
available = line.params[1].split(",")
|
||||
match = _common(available)
|
||||
match = _common(available)
|
||||
await self.server.wait_for(NUMERICS_FAIL)
|
||||
elif line.command == "AUTHENTICATE" and line.params[0] == "+":
|
||||
auth_text = ""
|
||||
|
@ -138,8 +145,7 @@ class SASLContext(ServerContext):
|
|||
if match[0] == "PLAIN":
|
||||
auth_text = f"{username}\0{username}\0{password}"
|
||||
elif match[0].startswith("SCRAM-SHA-"):
|
||||
auth_text = await self._scram(
|
||||
match[0], username, password)
|
||||
auth_text = await self._scram(match[0], username, password)
|
||||
|
||||
if not auth_text == "+":
|
||||
auth_text = _b64e(auth_text)
|
||||
|
@ -148,7 +154,7 @@ class SASLContext(ServerContext):
|
|||
await self._send_auth_text(auth_text)
|
||||
|
||||
line = await self.server.wait_for(NUMERICS_LAST)
|
||||
if line.command == "903":
|
||||
if line.command == "903":
|
||||
return SASLResult.SUCCESS
|
||||
elif line.command == "904":
|
||||
match.pop(0)
|
||||
|
@ -157,11 +163,8 @@ class SASLContext(ServerContext):
|
|||
|
||||
return SASLResult.FAILURE
|
||||
|
||||
async def _scram(self, algo_str: str,
|
||||
username: str,
|
||||
password: str) -> str:
|
||||
algo_str_prep = algo_str.replace("SCRAM-", "", 1
|
||||
).replace("-", "").upper()
|
||||
async def _scram(self, algo_str: str, username: str, password: str) -> str:
|
||||
algo_str_prep = algo_str.replace("SCRAM-", "", 1).replace("-", "").upper()
|
||||
try:
|
||||
algo = SCRAMAlgorithm(algo_str_prep)
|
||||
except ValueError:
|
||||
|
@ -179,15 +182,15 @@ class SASLContext(ServerContext):
|
|||
line = await self.server.wait_for(AUTHENTICATE_ANY)
|
||||
|
||||
server_final = _b64db(line.params[0])
|
||||
verified = scram.server_final(server_final)
|
||||
#TODO PANIC if verified is false!
|
||||
verified = scram.server_final(server_final)
|
||||
# TODO PANIC if verified is false!
|
||||
return "+"
|
||||
else:
|
||||
return ""
|
||||
|
||||
async def _send_auth_text(self, text: str):
|
||||
n = AUTH_BYTE_MAX
|
||||
chunks = [text[i:i+n] for i in range(0, len(text), n)]
|
||||
chunks = [text[i : i + n] for i in range(0, len(text), n)]
|
||||
if len(chunks[-1]) == 400:
|
||||
chunks.append("+")
|
||||
|
||||
|
|
|
@ -1,57 +1,70 @@
|
|||
import base64, hashlib, hmac, os
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Dict
|
||||
|
||||
|
||||
# IANA Hash Function Textual Names
|
||||
# https://tools.ietf.org/html/rfc5802#section-4
|
||||
# https://www.iana.org/assignments/hash-function-text-names/
|
||||
# MD2 has been removed as it's unacceptably weak
|
||||
class SCRAMAlgorithm(Enum):
|
||||
MD5 = "MD5"
|
||||
SHA_1 = "SHA1"
|
||||
MD5 = "MD5"
|
||||
SHA_1 = "SHA1"
|
||||
SHA_224 = "SHA224"
|
||||
SHA_256 = "SHA256"
|
||||
SHA_384 = "SHA384"
|
||||
SHA_512 = "SHA512"
|
||||
|
||||
|
||||
SCRAM_ERRORS = [
|
||||
"invalid-encoding",
|
||||
"extensions-not-supported", # unrecognized 'm' value
|
||||
"extensions-not-supported", # unrecognized 'm' value
|
||||
"invalid-proof",
|
||||
"channel-bindings-dont-match",
|
||||
"server-does-support-channel-binding",
|
||||
"channel-binding-not-supported",
|
||||
"unsupported-channel-binding-type",
|
||||
"unknown-user",
|
||||
"invalid-username-encoding", # invalid utf8 or bad SASLprep
|
||||
"no-resources"
|
||||
"invalid-username-encoding", # invalid utf8 or bad SASLprep
|
||||
"no-resources",
|
||||
]
|
||||
|
||||
|
||||
def _scram_nonce() -> bytes:
|
||||
return base64.b64encode(os.urandom(32))
|
||||
|
||||
|
||||
def _scram_escape(s: bytes) -> bytes:
|
||||
return s.replace(b"=", b"=3D").replace(b",", b"=2C")
|
||||
|
||||
|
||||
def _scram_unescape(s: bytes) -> bytes:
|
||||
return s.replace(b"=3D", b"=").replace(b"=2C", b",")
|
||||
|
||||
|
||||
def _scram_xor(s1: bytes, s2: bytes) -> bytes:
|
||||
return bytes(a ^ b for a, b in zip(s1, s2))
|
||||
|
||||
|
||||
class SCRAMState(Enum):
|
||||
NONE = 0
|
||||
CLIENT_FIRST = 1
|
||||
CLIENT_FINAL = 2
|
||||
SUCCESS = 3
|
||||
FAILURE = 4
|
||||
NONE = 0
|
||||
CLIENT_FIRST = 1
|
||||
CLIENT_FINAL = 2
|
||||
SUCCESS = 3
|
||||
FAILURE = 4
|
||||
VERIFY_FAILURE = 5
|
||||
|
||||
|
||||
class SCRAMError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class SCRAMContext(object):
|
||||
def __init__(self, algo: SCRAMAlgorithm,
|
||||
username: str,
|
||||
password: str):
|
||||
self._algo = algo
|
||||
def __init__(self, algo: SCRAMAlgorithm, username: str, password: str):
|
||||
self._algo = algo
|
||||
self._username = username.encode("utf8")
|
||||
self._password = password.encode("utf8")
|
||||
|
||||
|
@ -59,11 +72,11 @@ class SCRAMContext(object):
|
|||
self.error = ""
|
||||
self.raw_error = ""
|
||||
|
||||
self._client_first = b""
|
||||
self._client_nonce = b""
|
||||
self._client_first = b""
|
||||
self._client_nonce = b""
|
||||
|
||||
self._salted_password = b""
|
||||
self._auth_message = b""
|
||||
self._auth_message = b""
|
||||
|
||||
def _get_pieces(self, data: bytes) -> Dict[bytes, bytes]:
|
||||
pieces = (piece.split(b"=", 1) for piece in data.split(b","))
|
||||
|
@ -71,6 +84,7 @@ class SCRAMContext(object):
|
|||
|
||||
def _hmac(self, key: bytes, msg: bytes) -> bytes:
|
||||
return hmac.new(key, msg, self._algo.value).digest()
|
||||
|
||||
def _hash(self, msg: bytes) -> bytes:
|
||||
return hashlib.new(self._algo.value, msg).digest()
|
||||
|
||||
|
@ -89,7 +103,9 @@ class SCRAMContext(object):
|
|||
self.state = SCRAMState.CLIENT_FIRST
|
||||
self._client_nonce = _scram_nonce()
|
||||
self._client_first = b"n=%s,r=%s" % (
|
||||
_scram_escape(self._username), self._client_nonce)
|
||||
_scram_escape(self._username),
|
||||
self._client_nonce,
|
||||
)
|
||||
|
||||
# n,,n=<username>,r=<nonce>
|
||||
return b"n,,%s" % self._client_first
|
||||
|
@ -109,17 +125,17 @@ class SCRAMContext(object):
|
|||
if self._assert_error(pieces):
|
||||
return b""
|
||||
|
||||
nonce = pieces[b"r"] # server combines your nonce with it's own
|
||||
if (not nonce.startswith(self._client_nonce) or
|
||||
nonce == self._client_nonce):
|
||||
nonce = pieces[b"r"] # server combines your nonce with it's own
|
||||
if not nonce.startswith(self._client_nonce) or nonce == self._client_nonce:
|
||||
self._fail("nonce-unacceptable")
|
||||
return b""
|
||||
|
||||
salt = base64.b64decode(pieces[b"s"]) # salt is b64encoded
|
||||
salt = base64.b64decode(pieces[b"s"]) # salt is b64encoded
|
||||
iterations = int(pieces[b"i"])
|
||||
|
||||
salted_password = hashlib.pbkdf2_hmac(self._algo.value,
|
||||
self._password, salt, iterations, dklen=None)
|
||||
salted_password = hashlib.pbkdf2_hmac(
|
||||
self._algo.value, self._password, salt, iterations, dklen=None
|
||||
)
|
||||
self._salted_password = salted_password
|
||||
|
||||
client_key = self._hmac(salted_password, b"Client Key")
|
||||
|
|
|
@ -1,26 +1,35 @@
|
|||
import ssl
|
||||
|
||||
|
||||
class TLS:
|
||||
pass
|
||||
|
||||
|
||||
# tls without verification
|
||||
class TLSNoVerify(TLS):
|
||||
pass
|
||||
|
||||
|
||||
TLS_NOVERIFY = TLSNoVerify()
|
||||
|
||||
# verify via CAs
|
||||
class TLSVerifyChain(TLS):
|
||||
pass
|
||||
|
||||
|
||||
TLS_VERIFYCHAIN = TLSVerifyChain()
|
||||
|
||||
# verify by a pinned hash
|
||||
class TLSVerifyHash(TLSNoVerify):
|
||||
def __init__(self, sum: str):
|
||||
self.sum = sum.lower()
|
||||
|
||||
|
||||
class TLSVerifySHA512(TLSVerifyHash):
|
||||
pass
|
||||
|
||||
def tls_context(verify: bool=True) -> ssl.SSLContext:
|
||||
|
||||
def tls_context(verify: bool = True) -> ssl.SSLContext:
|
||||
ctx = ssl.create_default_context()
|
||||
if not verify:
|
||||
ctx.check_hostname = False
|
||||
|
|
|
@ -1,36 +1,60 @@
|
|||
import asyncio
|
||||
from asyncio import Future, PriorityQueue
|
||||
from typing import (AsyncIterable, Awaitable, Deque, Dict, Iterable, List,
|
||||
Optional, Set, Tuple, Union)
|
||||
from asyncio import Future, PriorityQueue
|
||||
from collections import deque
|
||||
from time import monotonic
|
||||
from time import monotonic
|
||||
from typing import (
|
||||
AsyncIterable,
|
||||
Awaitable,
|
||||
Deque,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import anyio
|
||||
from asyncio_rlock import RLock
|
||||
from asyncio_throttle import Throttler
|
||||
from async_timeout import timeout as timeout_
|
||||
from ircstates import Emit, Channel, ChannelUser
|
||||
from async_timeout import timeout as timeout_
|
||||
from asyncio_rlock import RLock
|
||||
from asyncio_throttle import Throttler
|
||||
from ircstates import Channel, ChannelUser, Emit
|
||||
from ircstates.names import Name
|
||||
from ircstates.numerics import *
|
||||
from ircstates.server import ServerDisconnectedException
|
||||
from ircstates.names import Name
|
||||
from irctokens import build, Line, tokenise
|
||||
from ircstates.server import ServerDisconnectedException
|
||||
from irctokens import Line, build, tokenise
|
||||
|
||||
from .ircv3 import (CAPContext, sts_transmute, CAP_ECHO, CAP_SASL,
|
||||
CAP_LABEL, LABEL_TAG_MAP, resume_transmute)
|
||||
from .sasl import SASLContext, SASLResult
|
||||
from .matching import (ResponseOr, Responses, Response, ANY, SELF, MASK_SELF,
|
||||
Folded)
|
||||
from .asyncs import MaybeAwait, WaitFor
|
||||
from .struct import Whois
|
||||
from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy
|
||||
from .interface import (IBot, ICapability, IServer, SentLine, SendPriority,
|
||||
IMatchResponse)
|
||||
from .interface import ITCPTransport, ITCPReader, ITCPWriter
|
||||
from .asyncs import MaybeAwait, WaitFor
|
||||
from .interface import (
|
||||
IBot,
|
||||
ICapability,
|
||||
IMatchResponse,
|
||||
IServer,
|
||||
ITCPReader,
|
||||
ITCPTransport,
|
||||
ITCPWriter,
|
||||
SendPriority,
|
||||
SentLine,
|
||||
)
|
||||
from .ircv3 import (
|
||||
CAP_ECHO,
|
||||
CAP_LABEL,
|
||||
CAP_SASL,
|
||||
LABEL_TAG_MAP,
|
||||
CAPContext,
|
||||
resume_transmute,
|
||||
sts_transmute,
|
||||
)
|
||||
from .matching import ANY, MASK_SELF, SELF, Folded, Response, ResponseOr, Responses
|
||||
from .params import ConnectionParams, ResumePolicy, SASLParams, STSPolicy
|
||||
from .sasl import SASLContext, SASLResult
|
||||
from .struct import Whois
|
||||
|
||||
THROTTLE_RATE = 4 # lines
|
||||
THROTTLE_TIME = 2 # seconds
|
||||
PING_TIMEOUT = 60 # seconds
|
||||
WAIT_TIMEOUT = 20 # seconds
|
||||
PING_TIMEOUT = 60 # seconds
|
||||
WAIT_TIMEOUT = 20 # seconds
|
||||
|
||||
JOIN_ERR_FIRST = [
|
||||
ERR_NOSUCHCHANNEL,
|
||||
|
@ -41,13 +65,14 @@ JOIN_ERR_FIRST = [
|
|||
ERR_INVITEONLYCHAN,
|
||||
ERR_BADCHANNELKEY,
|
||||
ERR_NEEDREGGEDNICK,
|
||||
ERR_THROTTLE
|
||||
ERR_THROTTLE,
|
||||
]
|
||||
|
||||
|
||||
class Server(IServer):
|
||||
_reader: ITCPReader
|
||||
_writer: ITCPWriter
|
||||
params: ConnectionParams
|
||||
params: ConnectionParams
|
||||
|
||||
def __init__(self, bot: IBot, name: str):
|
||||
super().__init__(name)
|
||||
|
@ -58,23 +83,23 @@ class Server(IServer):
|
|||
self.throttle = Throttler(rate_limit=100, period=1)
|
||||
|
||||
self.sasl_state = SASLResult.NONE
|
||||
self.last_read = monotonic()
|
||||
self.last_read = monotonic()
|
||||
|
||||
self._sent_count: int = 0
|
||||
self._sent_count: int = 0
|
||||
self._send_queue: PriorityQueue[SentLine] = PriorityQueue()
|
||||
self.desired_caps: Set[ICapability] = set([])
|
||||
|
||||
self._read_queue: Deque[Line] = deque()
|
||||
self._read_queue: Deque[Line] = deque()
|
||||
self._process_queue: Deque[Tuple[Line, Optional[Emit]]] = deque()
|
||||
|
||||
self._ping_sent = False
|
||||
self._ping_sent = False
|
||||
self._read_lguard = RLock()
|
||||
self.read_lock = self._read_lguard
|
||||
self._read_lwork = asyncio.Lock()
|
||||
self._wait_for = asyncio.Event()
|
||||
self.read_lock = self._read_lguard
|
||||
self._read_lwork = asyncio.Lock()
|
||||
self._wait_for = asyncio.Event()
|
||||
|
||||
self._pending_who: Deque[str] = deque()
|
||||
self._alt_nicks: List[str] = []
|
||||
self._alt_nicks: List[str] = []
|
||||
|
||||
def hostmask(self) -> str:
|
||||
hostmask = self.nickname
|
||||
|
@ -84,13 +109,10 @@ class Server(IServer):
|
|||
hostmask += f"@{self.hostname}"
|
||||
return hostmask
|
||||
|
||||
def send_raw(self, line: str, priority=SendPriority.DEFAULT
|
||||
) -> Awaitable[SentLine]:
|
||||
def send_raw(self, line: str, priority=SendPriority.DEFAULT) -> Awaitable[SentLine]:
|
||||
return self.send(tokenise(line), priority)
|
||||
def send(self,
|
||||
line: Line,
|
||||
priority=SendPriority.DEFAULT
|
||||
) -> Awaitable[SentLine]:
|
||||
|
||||
def send(self, line: Line, priority=SendPriority.DEFAULT) -> Awaitable[SentLine]:
|
||||
|
||||
self.line_presend(line)
|
||||
sent_line = SentLine(self._sent_count, priority, line)
|
||||
|
@ -110,28 +132,25 @@ class Server(IServer):
|
|||
|
||||
def set_throttle(self, rate: int, time: float):
|
||||
self.throttle.rate_limit = rate
|
||||
self.throttle.period = time
|
||||
self.throttle.period = time
|
||||
|
||||
def server_address(self) -> Tuple[str, int]:
|
||||
return self._writer.get_peer()
|
||||
|
||||
async def connect(self,
|
||||
transport: ITCPTransport,
|
||||
params: ConnectionParams):
|
||||
async def connect(self, transport: ITCPTransport, params: ConnectionParams):
|
||||
await sts_transmute(params)
|
||||
await resume_transmute(params)
|
||||
|
||||
reader, writer = await transport.connect(
|
||||
params.host,
|
||||
params.port,
|
||||
tls =params.tls,
|
||||
bindhost =params.bindhost)
|
||||
params.host, params.port, tls=params.tls, bindhost=params.bindhost
|
||||
)
|
||||
|
||||
self._reader = reader
|
||||
self._writer = writer
|
||||
|
||||
self.params = params
|
||||
await self.handshake()
|
||||
|
||||
async def disconnect(self):
|
||||
if not self._writer is None:
|
||||
await self._writer.close()
|
||||
|
@ -145,29 +164,35 @@ class Server(IServer):
|
|||
|
||||
alt_nicks = self.params.alt_nicknames
|
||||
if not alt_nicks:
|
||||
alt_nicks = [nickname+"_"*i for i in range(1, 4)]
|
||||
self._alt_nicks = alt_nicks
|
||||
alt_nicks = [nickname + "_" * i for i in range(1, 4)]
|
||||
self._alt_nicks = alt_nicks
|
||||
|
||||
# these must remain non-awaited; reading hasn't started yet
|
||||
if not self.params.password is None:
|
||||
self.send(build("PASS", [self.params.password]))
|
||||
self.send(build("CAP", ["LS", "302"]))
|
||||
self.send(build("CAP", ["LS", "302"]))
|
||||
self.send(build("NICK", [nickname]))
|
||||
self.send(build("USER", [username, "0", "*", realname]))
|
||||
|
||||
# to be overridden
|
||||
def line_preread(self, line: Line):
|
||||
pass
|
||||
|
||||
def line_presend(self, line: Line):
|
||||
pass
|
||||
|
||||
async def line_read(self, line: Line):
|
||||
pass
|
||||
|
||||
async def line_send(self, line: Line):
|
||||
pass
|
||||
|
||||
async def sts_policy(self, sts: STSPolicy):
|
||||
pass
|
||||
|
||||
async def resume_policy(self, resume: ResumePolicy):
|
||||
pass
|
||||
|
||||
# /to be overriden
|
||||
|
||||
async def _on_read(self, line: Line, emit: Optional[Emit]):
|
||||
|
@ -176,13 +201,14 @@ class Server(IServer):
|
|||
|
||||
elif line.command == RPL_ENDOFWHO:
|
||||
chan = self.casefold(line.params[1])
|
||||
if (self._pending_who and
|
||||
self._pending_who[0] == chan):
|
||||
if self._pending_who and self._pending_who[0] == chan:
|
||||
self._pending_who.popleft()
|
||||
await self._next_who()
|
||||
elif (line.command in {
|
||||
ERR_NICKNAMEINUSE, ERR_ERRONEUSNICKNAME, ERR_UNAVAILRESOURCE
|
||||
} and not self.registered):
|
||||
elif (
|
||||
line.command
|
||||
in {ERR_NICKNAMEINUSE, ERR_ERRONEUSNICKNAME, ERR_UNAVAILRESOURCE}
|
||||
and not self.registered
|
||||
):
|
||||
if self._alt_nicks:
|
||||
nick = self._alt_nicks.pop(0)
|
||||
await self.send(build("NICK", [nick]))
|
||||
|
@ -203,8 +229,7 @@ class Server(IServer):
|
|||
await self._check_regain([line.params[1]])
|
||||
elif line.command == RPL_MONOFFLINE:
|
||||
await self._check_regain(line.params[1].split(","))
|
||||
elif (line.command in ["NICK", "QUIT"] and
|
||||
line.source is not None):
|
||||
elif line.command in ["NICK", "QUIT"] and line.source is not None:
|
||||
await self._check_regain([line.hostmask.nickname])
|
||||
|
||||
elif emit is not None:
|
||||
|
@ -216,10 +241,9 @@ class Server(IServer):
|
|||
await self._batch_joins(self.params.autojoin)
|
||||
|
||||
elif emit.command == "CAP":
|
||||
if emit.subcommand == "NEW":
|
||||
if emit.subcommand == "NEW":
|
||||
await self._cap_ls(emit)
|
||||
elif (emit.subcommand == "LS" and
|
||||
emit.finished):
|
||||
elif emit.subcommand == "LS" and emit.finished:
|
||||
if not self.registered:
|
||||
await CAPContext(self).handshake()
|
||||
else:
|
||||
|
@ -227,7 +251,7 @@ class Server(IServer):
|
|||
|
||||
elif emit.command == "JOIN":
|
||||
if emit.self and not emit.channel is None:
|
||||
chan = emit.channel.name_lower
|
||||
chan = emit.channel.name_lower
|
||||
await self.send(build("MODE", [chan]))
|
||||
|
||||
modes = "".join(self.isupport.chanmodes.a_modes)
|
||||
|
@ -241,18 +265,18 @@ class Server(IServer):
|
|||
|
||||
async def _check_regain(self, nicks: List[str]):
|
||||
for nick in nicks:
|
||||
if (self.casefold_equals(nick, self.params.nickname) and
|
||||
not self.nickname == self.params.nickname):
|
||||
if (
|
||||
self.casefold_equals(nick, self.params.nickname)
|
||||
and not self.nickname == self.params.nickname
|
||||
):
|
||||
await self.send(build("NICK", [self.params.nickname]))
|
||||
|
||||
async def _batch_joins(self,
|
||||
channels: List[str],
|
||||
batch_n: int=10):
|
||||
#TODO: do as many JOINs in one line as we can fit
|
||||
#TODO: channel keys
|
||||
async def _batch_joins(self, channels: List[str], batch_n: int = 10):
|
||||
# TODO: do as many JOINs in one line as we can fit
|
||||
# TODO: channel keys
|
||||
|
||||
for i in range(0, len(channels), batch_n):
|
||||
batch = channels[i:i+batch_n]
|
||||
batch = channels[i : i + batch_n]
|
||||
await self.send(build("JOIN", [",".join(batch)]))
|
||||
|
||||
async def _next_who(self):
|
||||
|
@ -275,7 +299,7 @@ class Server(IServer):
|
|||
return None
|
||||
|
||||
self.last_read = monotonic()
|
||||
lines = self.recv(data)
|
||||
lines = self.recv(data)
|
||||
for line in lines:
|
||||
self.line_preread(line)
|
||||
self._read_queue.append(line)
|
||||
|
@ -287,10 +311,10 @@ class Server(IServer):
|
|||
|
||||
if not self._process_queue:
|
||||
async with self._read_lwork:
|
||||
read_aw = self._read_line(PING_TIMEOUT)
|
||||
read_aw = self._read_line(PING_TIMEOUT)
|
||||
dones, notdones = await asyncio.wait(
|
||||
[read_aw, self._wait_for.wait()],
|
||||
return_when=asyncio.FIRST_COMPLETED
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
self._wait_for.clear()
|
||||
|
||||
|
@ -314,11 +338,12 @@ class Server(IServer):
|
|||
line, emit = self._process_queue.popleft()
|
||||
await self._on_read(line, emit)
|
||||
|
||||
async def wait_for(self,
|
||||
response: Union[IMatchResponse, Set[IMatchResponse]],
|
||||
sent_aw: Optional[Awaitable[SentLine]]=None,
|
||||
timeout: float=WAIT_TIMEOUT
|
||||
) -> Line:
|
||||
async def wait_for(
|
||||
self,
|
||||
response: Union[IMatchResponse, Set[IMatchResponse]],
|
||||
sent_aw: Optional[Awaitable[SentLine]] = None,
|
||||
timeout: float = WAIT_TIMEOUT,
|
||||
) -> Line:
|
||||
|
||||
response_obj: IMatchResponse
|
||||
if isinstance(response, set):
|
||||
|
@ -340,8 +365,9 @@ class Server(IServer):
|
|||
return line
|
||||
|
||||
async def _on_send_line(self, line: Line):
|
||||
if (line.command in ["PRIVMSG", "NOTICE", "TAGMSG"] and
|
||||
not self.cap_agreed(CAP_ECHO)):
|
||||
if line.command in ["PRIVMSG", "NOTICE", "TAGMSG"] and not self.cap_agreed(
|
||||
CAP_ECHO
|
||||
):
|
||||
new_line = line.with_source(self.hostmask())
|
||||
self._read_queue.append(new_line)
|
||||
|
||||
|
@ -349,15 +375,13 @@ class Server(IServer):
|
|||
while True:
|
||||
lines: List[SentLine] = []
|
||||
|
||||
while (not lines or
|
||||
(len(lines) < 5 and self._send_queue.qsize() > 0)):
|
||||
while not lines or (len(lines) < 5 and self._send_queue.qsize() > 0):
|
||||
prio_line = await self._send_queue.get()
|
||||
lines.append(prio_line)
|
||||
|
||||
for line in lines:
|
||||
async with self.throttle:
|
||||
self._writer.write(
|
||||
f"{line.line.format()}\r\n".encode("utf8"))
|
||||
self._writer.write(f"{line.line.format()}\r\n".encode("utf8"))
|
||||
|
||||
await self._writer.drain()
|
||||
|
||||
|
@ -369,6 +393,7 @@ class Server(IServer):
|
|||
# CAP-related
|
||||
def cap_agreed(self, capability: ICapability) -> bool:
|
||||
return bool(self.cap_available(capability))
|
||||
|
||||
def cap_available(self, capability: ICapability) -> Optional[str]:
|
||||
return capability.available(self.agreed_caps)
|
||||
|
||||
|
@ -381,78 +406,81 @@ class Server(IServer):
|
|||
await CAPContext(self).on_ls(tokens)
|
||||
|
||||
async def sasl_auth(self, params: SASLParams) -> bool:
|
||||
if (self.sasl_state == SASLResult.NONE and
|
||||
self.cap_agreed(CAP_SASL)):
|
||||
if self.sasl_state == SASLResult.NONE and self.cap_agreed(CAP_SASL):
|
||||
|
||||
res = await SASLContext(self).from_params(params)
|
||||
self.sasl_state = res
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
# /CAP-related
|
||||
|
||||
def send_nick(self, new_nick: str) -> Awaitable[bool]:
|
||||
fut = self.send(build("NICK", [new_nick]))
|
||||
|
||||
async def _assure() -> bool:
|
||||
line = await self.wait_for({
|
||||
Response("NICK", [Folded(new_nick)], source=MASK_SELF),
|
||||
Responses([
|
||||
ERR_BANNICKCHANGE,
|
||||
ERR_NICKTOOFAST,
|
||||
ERR_CANTCHANGENICK
|
||||
], [ANY]),
|
||||
Responses([
|
||||
ERR_NICKNAMEINUSE,
|
||||
ERR_ERRONEUSNICKNAME,
|
||||
ERR_UNAVAILRESOURCE
|
||||
], [ANY, Folded(new_nick)])
|
||||
}, fut)
|
||||
line = await self.wait_for(
|
||||
{
|
||||
Response("NICK", [Folded(new_nick)], source=MASK_SELF),
|
||||
Responses(
|
||||
[ERR_BANNICKCHANGE, ERR_NICKTOOFAST, ERR_CANTCHANGENICK], [ANY]
|
||||
),
|
||||
Responses(
|
||||
[ERR_NICKNAMEINUSE, ERR_ERRONEUSNICKNAME, ERR_UNAVAILRESOURCE],
|
||||
[ANY, Folded(new_nick)],
|
||||
),
|
||||
},
|
||||
fut,
|
||||
)
|
||||
return line.command == "NICK"
|
||||
|
||||
return MaybeAwait(_assure)
|
||||
|
||||
def send_join(self,
|
||||
name: str,
|
||||
key: Optional[str]=None
|
||||
) -> Awaitable[Channel]:
|
||||
def send_join(self, name: str, key: Optional[str] = None) -> Awaitable[Channel]:
|
||||
fut = self.send_joins([name], [] if key is None else [key])
|
||||
|
||||
async def _assure():
|
||||
channels = await fut
|
||||
return channels[0]
|
||||
|
||||
return MaybeAwait(_assure)
|
||||
|
||||
def send_part(self, name: str):
|
||||
fut = self.send(build("PART", [name]))
|
||||
|
||||
async def _assure():
|
||||
line = await self.wait_for(
|
||||
Response("PART", [Folded(name)], source=MASK_SELF),
|
||||
fut
|
||||
Response("PART", [Folded(name)], source=MASK_SELF), fut
|
||||
)
|
||||
return
|
||||
|
||||
return MaybeAwait(_assure)
|
||||
|
||||
def send_joins(self,
|
||||
names: List[str],
|
||||
keys: List[str]=[]
|
||||
) -> Awaitable[List[Channel]]:
|
||||
def send_joins(
|
||||
self, names: List[str], keys: List[str] = []
|
||||
) -> Awaitable[List[Channel]]:
|
||||
|
||||
folded_names = [self.casefold(name) for name in names]
|
||||
|
||||
if not keys:
|
||||
fut = self.send(build("JOIN", [",".join(names)]))
|
||||
else:
|
||||
fut = self.send(build("JOIN", [",".join(names)]+keys))
|
||||
fut = self.send(build("JOIN", [",".join(names)] + keys))
|
||||
|
||||
async def _assure():
|
||||
channels: List[Channel] = []
|
||||
|
||||
while folded_names:
|
||||
line = await self.wait_for({
|
||||
Response(RPL_CHANNELMODEIS, [ANY, ANY]),
|
||||
Responses(JOIN_ERR_FIRST, [ANY, ANY]),
|
||||
Response(ERR_USERONCHANNEL, [ANY, SELF, ANY]),
|
||||
Response(ERR_LINKCHANNEL, [ANY, ANY, ANY])
|
||||
}, fut)
|
||||
line = await self.wait_for(
|
||||
{
|
||||
Response(RPL_CHANNELMODEIS, [ANY, ANY]),
|
||||
Responses(JOIN_ERR_FIRST, [ANY, ANY]),
|
||||
Response(ERR_USERONCHANNEL, [ANY, SELF, ANY]),
|
||||
Response(ERR_LINKCHANNEL, [ANY, ANY, ANY]),
|
||||
},
|
||||
fut,
|
||||
)
|
||||
|
||||
chan: Optional[str] = None
|
||||
if line.command == RPL_CHANNELMODEIS:
|
||||
|
@ -462,7 +490,7 @@ class Server(IServer):
|
|||
elif line.command == ERR_USERONCHANNEL:
|
||||
chan = line.params[2]
|
||||
elif line.command == ERR_LINKCHANNEL:
|
||||
#XXX i dont like this
|
||||
# XXX i dont like this
|
||||
chan = line.params[2]
|
||||
await self.wait_for(
|
||||
Response(RPL_CHANNELMODEIS, [ANY, Folded(chan)])
|
||||
|
@ -477,51 +505,58 @@ class Server(IServer):
|
|||
channels.append(self.channels[folded])
|
||||
|
||||
return channels
|
||||
|
||||
return MaybeAwait(_assure)
|
||||
|
||||
def send_message(self, target: str, message: str
|
||||
) -> Awaitable[Optional[str]]:
|
||||
def send_message(self, target: str, message: str) -> Awaitable[Optional[str]]:
|
||||
fut = self.send(build("PRIVMSG", [target, message]))
|
||||
|
||||
async def _assure():
|
||||
line = await self.wait_for(
|
||||
Response("PRIVMSG", [Folded(target), ANY], source=MASK_SELF),
|
||||
fut
|
||||
Response("PRIVMSG", [Folded(target), ANY], source=MASK_SELF), fut
|
||||
)
|
||||
if line.command == "PRIVMSG":
|
||||
return line.params[1]
|
||||
else:
|
||||
return None
|
||||
|
||||
return MaybeAwait(_assure)
|
||||
|
||||
def send_whois(self,
|
||||
target: str,
|
||||
remote: bool=False
|
||||
) -> Awaitable[Optional[Whois]]:
|
||||
def send_whois(
|
||||
self, target: str, remote: bool = False
|
||||
) -> Awaitable[Optional[Whois]]:
|
||||
args = [target]
|
||||
if remote:
|
||||
args.append(target)
|
||||
|
||||
fut = self.send(build("WHOIS", args))
|
||||
|
||||
async def _assure() -> Optional[Whois]:
|
||||
folded = self.casefold(target)
|
||||
params = [ANY, Folded(folded)]
|
||||
|
||||
obj = Whois()
|
||||
while True:
|
||||
line = await self.wait_for(Responses([
|
||||
ERR_NOSUCHNICK,
|
||||
ERR_NOSUCHSERVER,
|
||||
RPL_WHOISUSER,
|
||||
RPL_WHOISSERVER,
|
||||
RPL_WHOISOPERATOR,
|
||||
RPL_WHOISIDLE,
|
||||
RPL_WHOISCHANNELS,
|
||||
RPL_WHOISHOST,
|
||||
RPL_WHOISACCOUNT,
|
||||
RPL_WHOISSECURE,
|
||||
RPL_ENDOFWHOIS
|
||||
], params), fut)
|
||||
if line.command in [ERR_NOSUCHNICK, ERR_NOSUCHSERVER]:
|
||||
line = await self.wait_for(
|
||||
Responses(
|
||||
[
|
||||
ERR_NOSUCHNICK,
|
||||
ERR_NOSUCHSERVER,
|
||||
RPL_WHOISUSER,
|
||||
RPL_WHOISSERVER,
|
||||
RPL_WHOISOPERATOR,
|
||||
RPL_WHOISIDLE,
|
||||
RPL_WHOISCHANNELS,
|
||||
RPL_WHOISHOST,
|
||||
RPL_WHOISACCOUNT,
|
||||
RPL_WHOISSECURE,
|
||||
RPL_ENDOFWHOIS,
|
||||
],
|
||||
params,
|
||||
),
|
||||
fut,
|
||||
)
|
||||
if line.command in [ERR_NOSUCHNICK, ERR_NOSUCHSERVER]:
|
||||
return None
|
||||
elif line.command == RPL_WHOISUSER:
|
||||
nick, user, host, _, real = line.params[1:]
|
||||
|
@ -531,7 +566,7 @@ class Server(IServer):
|
|||
obj.realname = real
|
||||
elif line.command == RPL_WHOISIDLE:
|
||||
idle, signon, _ = line.params[2:]
|
||||
obj.idle = int(idle)
|
||||
obj.idle = int(idle)
|
||||
obj.signon = int(signon)
|
||||
elif line.command == RPL_WHOISACCOUNT:
|
||||
obj.account = line.params[2]
|
||||
|
@ -544,11 +579,11 @@ class Server(IServer):
|
|||
symbols = ""
|
||||
while channel[0] in self.isupport.prefix.prefixes:
|
||||
symbols += channel[0]
|
||||
channel = channel[1:]
|
||||
channel = channel[1:]
|
||||
|
||||
channel_user = ChannelUser(
|
||||
Name(obj.nickname, folded),
|
||||
Name(channel, self.casefold(channel))
|
||||
Name(channel, self.casefold(channel)),
|
||||
)
|
||||
for symbol in symbols:
|
||||
mode = self.isupport.prefix.from_prefix(symbol)
|
||||
|
@ -558,4 +593,5 @@ class Server(IServer):
|
|||
obj.channels.append(channel_user)
|
||||
elif line.command == RPL_ENDOFWHOIS:
|
||||
return obj
|
||||
|
||||
return MaybeAwait(_assure)
|
||||
|
|
|
@ -1,23 +1,23 @@
|
|||
from typing import List, Optional
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
from ircstates import ChannelUser
|
||||
|
||||
|
||||
class Whois(object):
|
||||
server: Optional[str] = None
|
||||
server_info: Optional[str] = None
|
||||
operator: bool = False
|
||||
server: Optional[str] = None
|
||||
server_info: Optional[str] = None
|
||||
operator: bool = False
|
||||
|
||||
secure: bool = False
|
||||
secure: bool = False
|
||||
|
||||
signon: Optional[int] = None
|
||||
idle: Optional[int] = None
|
||||
signon: Optional[int] = None
|
||||
idle: Optional[int] = None
|
||||
|
||||
channels: Optional[List[ChannelUser]] = None
|
||||
channels: Optional[List[ChannelUser]] = None
|
||||
|
||||
nickname: str = ""
|
||||
username: str = ""
|
||||
hostname: str = ""
|
||||
realname: str = ""
|
||||
account: Optional[str] = None
|
||||
|
||||
account: Optional[str] = None
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
from hashlib import sha512
|
||||
from ssl import SSLContext
|
||||
from typing import Optional, Tuple
|
||||
from asyncio import StreamReader, StreamWriter
|
||||
from asyncio import StreamReader, StreamWriter
|
||||
from hashlib import sha512
|
||||
from ssl import SSLContext
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from async_stagger import open_connection
|
||||
|
||||
from .interface import ITCPTransport, ITCPReader, ITCPWriter
|
||||
from .security import (tls_context, TLS, TLSNoVerify, TLSVerifyHash,
|
||||
TLSVerifySHA512)
|
||||
from .interface import ITCPReader, ITCPTransport, ITCPWriter
|
||||
from .security import TLS, TLSNoVerify, TLSVerifyHash, TLSVerifySHA512, tls_context
|
||||
|
||||
|
||||
class TCPReader(ITCPReader):
|
||||
def __init__(self, reader: StreamReader):
|
||||
|
@ -14,6 +15,8 @@ class TCPReader(ITCPReader):
|
|||
|
||||
async def read(self, byte_count: int) -> bytes:
|
||||
return await self._reader.read(byte_count)
|
||||
|
||||
|
||||
class TCPWriter(ITCPWriter):
|
||||
def __init__(self, writer: StreamWriter):
|
||||
self._writer = writer
|
||||
|
@ -32,13 +35,15 @@ class TCPWriter(ITCPWriter):
|
|||
self._writer.close()
|
||||
await self._writer.wait_closed()
|
||||
|
||||
|
||||
class TCPTransport(ITCPTransport):
|
||||
async def connect(self,
|
||||
hostname: str,
|
||||
port: int,
|
||||
tls: Optional[TLS],
|
||||
bindhost: Optional[str]=None
|
||||
) -> Tuple[ITCPReader, ITCPWriter]:
|
||||
async def connect(
|
||||
self,
|
||||
hostname: str,
|
||||
port: int,
|
||||
tls: Optional[TLS],
|
||||
bindhost: Optional[str] = None,
|
||||
) -> Tuple[ITCPReader, ITCPWriter]:
|
||||
|
||||
cur_ssl: Optional[SSLContext] = None
|
||||
if tls is not None:
|
||||
|
@ -54,22 +59,20 @@ class TCPTransport(ITCPTransport):
|
|||
hostname,
|
||||
port,
|
||||
server_hostname=server_hostname,
|
||||
ssl =cur_ssl,
|
||||
local_addr =local_addr)
|
||||
ssl=cur_ssl,
|
||||
local_addr=local_addr,
|
||||
)
|
||||
|
||||
if isinstance(tls, TLSVerifyHash):
|
||||
cert: bytes = writer.transport.get_extra_info(
|
||||
"ssl_object"
|
||||
).getpeercert(True)
|
||||
cert: bytes = writer.transport.get_extra_info("ssl_object").getpeercert(
|
||||
True
|
||||
)
|
||||
if isinstance(tls, TLSVerifySHA512):
|
||||
sum = sha512(cert).hexdigest()
|
||||
else:
|
||||
raise ValueError(f"unknown hash pinning {type(tls)}")
|
||||
|
||||
if not sum == tls.sum:
|
||||
raise ValueError(
|
||||
f"pinned hash for {hostname} does not match ({sum})"
|
||||
)
|
||||
raise ValueError(f"pinned hash for {hostname} does not match ({sum})")
|
||||
|
||||
return (TCPReader(reader), TCPWriter(writer))
|
||||
|
||||
|
|
6
setup.py
6
setup.py
|
@ -24,8 +24,8 @@ setup(
|
|||
"Operating System :: OS Independent",
|
||||
"Operating System :: POSIX",
|
||||
"Operating System :: Microsoft :: Windows",
|
||||
"Topic :: Communications :: Chat :: Internet Relay Chat"
|
||||
"Topic :: Communications :: Chat :: Internet Relay Chat",
|
||||
],
|
||||
python_requires='>=3.7',
|
||||
install_requires=install_requires
|
||||
python_requires=">=3.7",
|
||||
install_requires=install_requires,
|
||||
)
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
import unittest
|
||||
|
||||
from ircrobots import glob
|
||||
|
||||
|
||||
class GlobTestCollapse(unittest.TestCase):
|
||||
def test(self):
|
||||
c1 = glob.collapse("**?*")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue