tvl-depot/users/Profpatsch/netencode/netencode.rs
Profpatsch 7d9c30ab3d feat(users/Profpatsch/netencode): add dec::RecordDot
`dec::RecordDot` accesses a specific field of a netencode record.

In order to implement this, either we’d have to introduce a type-level
string, but in all honesty this kind of typelevel circlejerking never
leads anywhere, so let’s change the trait to use `&self` after all.
Usage is pretty much the same, except actually more like you’d expect.

Change-Id: I5a7f1a3f587256c50df1b65c2969e5a7194bba70
Reviewed-on: https://cl.tvl.fyi/c/depot/+/2494
Tested-by: BuildkiteCI
Reviewed-by: Profpatsch <mail@profpatsch.de>
2021-02-09 01:36:27 +00:00

689 lines
20 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

extern crate nom;
extern crate exec_helpers;
use std::collections::HashMap;
use std::io::{Write, Read};
#[derive(Debug, PartialEq, Eq, Clone)]
pub enum T {
// Unit
Unit,
// Boolean
N1(bool),
// Naturals
N3(u8),
N6(u64),
N7(u128),
// Integers
I3(i8),
I6(i64),
I7(i128),
// Text
// TODO: make into &str
Text(String),
Binary(Vec<u8>),
// Tags
// TODO: make into &str
Sum(Tag<String, T>),
// TODO: make into &str
Record(HashMap<String, T>),
List(Vec<T>),
}
impl T {
fn to_u<'a>(&'a self) -> U<'a> {
match self {
T::Unit => U::Unit,
T::N1(b) => U::N1(*b),
T::N3(u) => U::N3(*u),
T::N6(u) => U::N6(*u),
T::N7(u) => U::N7(*u),
T::I3(i) => U::I3(*i),
T::I6(i) => U::I6(*i),
T::I7(i) => U::I7(*i),
T::Text(t) => U::Text(t.as_str()),
T::Binary(v) => U::Binary(v),
T::Sum(Tag { tag, val }) => U::Sum(
Tag { tag: tag.as_str(), val: Box::new(val.to_u()) }
),
T::Record(map) => U::Record(
map.iter().map(|(k, v)| (k.as_str(), v.to_u())).collect()
),
T::List(l) => U::List(
l.iter().map(|v| v.to_u()).collect::<Vec<U<'a>>>()
),
}
}
pub fn encode<'a>(&'a self) -> Vec<u8> {
match self {
// TODO: dont go via U, inefficient
o => o.to_u().encode()
}
}
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub enum U<'a> {
Unit,
// Boolean
N1(bool),
// Naturals
N3(u8),
N6(u64),
N7(u128),
// Integers
I3(i8),
I6(i64),
I7(i128),
// Text
Text(&'a str),
Binary(&'a [u8]),
// Tags
Sum(Tag<&'a str, U<'a>>),
Record(HashMap<&'a str, U<'a>>),
List(Vec<U<'a>>),
}
impl<'a> U<'a> {
pub fn encode(&self) -> Vec<u8> {
let mut c = std::io::Cursor::new(vec![]);
encode(&mut c, self);
c.into_inner()
}
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct Tag<S, A> {
// TODO: make into &str
pub tag: S,
pub val: Box<A>
}
impl<S, A> Tag<S, A> {
fn map<F, B>(self, f: F) -> Tag<S, B>
where F: Fn(A) -> B {
Tag {
tag: self.tag,
val: Box::new(f(*self.val))
}
}
}
fn encode_tag<W: Write>(w: &mut W, tag: &str, val: &U) -> std::io::Result<()> {
write!(w, "<{}:{}|", tag.len(), tag)?;
encode(w, val)?;
Ok(())
}
pub fn encode<W: Write>(w: &mut W, u: &U) -> std::io::Result<()> {
match u {
U::Unit => write!(w, "u,"),
U::N1(b) => if *b { write!(w, "n1:1,") } else { write!(w, "n1:0,") },
U::N3(n) => write!(w, "n3:{},", n),
U::N6(n) => write!(w, "n6:{},", n),
U::N7(n) => write!(w, "n7:{},", n),
U::I3(i) => write!(w, "i3:{},", i),
U::I6(i) => write!(w, "i6:{},", i),
U::I7(i) => write!(w, "i7:{},", i),
U::Text(s) => {
write!(w, "t{}:", s.len());
w.write(s.as_bytes());
write!(w, ",")
}
U::Binary(s) => {
write!(w, "b{}:", s.len());
w.write(&s);
write!(w, ",")
},
U::Sum(Tag{tag, val}) => encode_tag(w, tag, val),
U::Record(m) => {
let mut c = std::io::Cursor::new(vec![]);
for (k, v) in m {
encode_tag(&mut c, k, v)?;
}
write!(w, "{{{}:", c.get_ref().len())?;
w.write(c.get_ref())?;
write!(w, "}}")
},
U::List(l) => {
let mut c = std::io::Cursor::new(vec![]);
for u in l {
encode(&mut c, u)?;
}
write!(w, "[{}:", c.get_ref().len())?;
w.write(c.get_ref())?;
write!(w, "]")
}
}
}
pub fn text(s: String) -> T {
T::Text(s)
}
pub fn u_from_stdin_or_die_user_error<'a>(prog_name: &'_ str, stdin_buf: &'a mut Vec<u8>) -> U<'a> {
std::io::stdin().lock().read_to_end(stdin_buf);
let u = match parse::u_u(stdin_buf) {
Ok((rest, u)) => match rest {
b"" => u,
_ => exec_helpers::die_user_error(prog_name, format!("stdin contained some soup after netencode value: {:?}", rest))
},
Err(err) => exec_helpers::die_user_error(prog_name, format!("unable to parse netencode from stdin: {:?}", err))
};
u
}
pub mod parse {
use super::{T, Tag, U};
use std::str::FromStr;
use std::ops::Neg;
use std::collections::HashMap;
use nom::{IResult};
use nom::bytes::complete::{tag, take};
use nom::branch::{alt};
use nom::character::complete::{digit1, char};
use nom::sequence::{tuple};
use nom::combinator::{map, map_res, flat_map, map_parser, opt};
use nom::error::{context, ErrorKind, ParseError};
fn unit_t(s: &[u8]) -> IResult<&[u8], ()> {
let (s, _) = context("unit", tag("u,"))(s)?;
Ok((s, ()))
}
fn usize_t(s: &[u8]) -> IResult<&[u8], usize> {
context(
"usize",
map_res(
map_res(digit1, |n| std::str::from_utf8(n)),
|s| s.parse::<usize>())
)(s)
}
fn sized(begin: char, end: char) -> impl Fn(&[u8]) -> IResult<&[u8], &[u8]> {
move |s: &[u8]| {
let (s, (_, len, _)) = tuple((
char(begin),
usize_t,
char(':')
))(s)?;
let (s, (res, _)) = tuple((
take(len),
char(end)
))(s)?;
Ok((s, res))
}
}
fn uint_t<'a, I: FromStr + 'a>(t: &'static str) -> impl Fn(&'a [u8]) -> IResult<&'a [u8], I> {
move |s: &'a [u8]| {
let (s, (_, _, int, _)) = tuple((
tag(t.as_bytes()),
char(':'),
map_res(
map_res(digit1, |n: &[u8]| std::str::from_utf8(n)),
|s| s.parse::<I>()
),
char(',')
))(s)?;
Ok((s, int))
}
}
fn bool_t<'a>() -> impl Fn(&'a [u8]) -> IResult<&'a [u8], bool> {
context("bool", alt((
map(tag("n1:0,"), |_| false),
map(tag("n1:1,"), |_| true),
)))
}
fn int_t<'a, I: FromStr + Neg<Output=I>>(t: &'static str) -> impl Fn(&'a [u8]) -> IResult<&[u8], I> {
context(
t,
move |s: &'a [u8]| {
let (s, (_, _, neg, int, _)) = tuple((
tag(t.as_bytes()),
char(':'),
opt(char('-')),
map_res(
map_res(digit1, |n: &[u8]| std::str::from_utf8(n)),
|s| s.parse::<I>()
),
char(',')
))(s)?;
let res = match neg {
Some(_) => -int,
None => int,
};
Ok((s, res))
}
)
}
fn tag_t(s: &[u8]) -> IResult<&[u8], Tag<String, T>> {
// recurses into the main parser
map(tag_g(t_t),
|Tag {tag, val}|
Tag {
tag: tag.to_string(),
val
})(s)
}
fn tag_g<'a, P, O>(inner: P) -> impl Fn(&'a [u8]) -> IResult<&'a [u8], Tag<&'a str, O>>
where
P: Fn(&'a [u8]) -> IResult<&'a [u8], O>
{
move |s: &[u8]| {
let (s, tag) = sized('<', '|')(s)?;
let (s, val) = inner(s)?;
Ok((s, Tag {
tag: std::str::from_utf8(tag)
.map_err(|_| nom::Err::Failure((s, ErrorKind::Char)))?,
val: Box::new(val)
}))
}
}
/// parse text scalar (`t5:hello,`)
fn text(s: &[u8]) -> IResult<&[u8], T> {
let (s, res) = text_g(s)?;
Ok((s, T::Text(res.to_string())))
}
fn text_g(s: &[u8]) -> IResult<&[u8], &str> {
let (s, res) = sized('t', ',')(s)?;
Ok((s,
std::str::from_utf8(res)
.map_err(|_| nom::Err::Failure((s, ErrorKind::Char)))?,
))
}
fn binary<'a>() -> impl Fn(&'a [u8]) -> IResult<&'a [u8], T> {
map(binary_g(), |b| T::Binary(b.to_owned()))
}
fn binary_g() -> impl Fn(&[u8]) -> IResult<&[u8], &[u8]> {
sized('b', ',')
}
fn list_t(s: &[u8]) -> IResult<&[u8], Vec<T>> {
list_g(t_t)(s)
}
fn list_g<'a, P, O>(inner: P) -> impl Fn(&'a [u8]) -> IResult<&'a [u8], Vec<O>>
where
O: Clone,
P: Fn(&'a [u8]) -> IResult<&'a [u8], O>
{
map_parser(
sized('[', ']'),
nom::multi::many0(inner)
)
}
fn record_t<'a>(s: &'a [u8]) -> IResult<&'a [u8], HashMap<String, T>> {
let (s, r) = record_g(t_t)(s)?;
Ok((s,
r.into_iter()
.map(|(k, v)| (k.to_string(), v))
.collect::<HashMap<_,_>>()))
}
fn record_g<'a, P, O>(inner: P) -> impl Fn(&'a [u8]) -> IResult<&'a [u8], HashMap<&'a str, O>>
where
O: Clone,
P: Fn(&'a [u8]) -> IResult<&'a [u8], O>
{
map_parser(
sized('{', '}'),
nom::multi::fold_many1(
tag_g(inner),
HashMap::new(),
|mut acc: HashMap<_,_>, Tag { tag, mut val }| {
// ignore duplicated tag names that appear later
// according to netencode spec
if ! acc.contains_key(tag) {
acc.insert(tag, *val);
}
acc
}
)
)
}
pub fn u_u(s: &[u8]) -> IResult<&[u8], U> {
alt((
map(text_g, U::Text),
map(binary_g(), U::Binary),
map(unit_t, |()| U::Unit),
map(tag_g(u_u), |t| U::Sum(t)),
map(list_g(u_u), U::List),
map(record_g(u_u), U::Record),
map(bool_t(), |u| U::N1(u)),
map(uint_t("n3"), |u| U::N3(u)),
map(uint_t("n6"), |u| U::N6(u)),
map(uint_t("n7"), |u| U::N7(u)),
map(int_t("i3"), |u| U::I3(u)),
map(int_t("i6"), |u| U::I6(u)),
map(int_t("i7"), |u| U::I7(u)),
// less common
map(uint_t("n2"), |u| U::N3(u)),
map(uint_t("n4"), |u| U::N6(u)),
map(uint_t("n5"), |u| U::N6(u)),
map(int_t("i1"), |u| U::I3(u)),
map(int_t("i2"), |u| U::I3(u)),
map(int_t("i4"), |u| U::I6(u)),
map(int_t("i5"), |u| U::I6(u)),
// TODO: 8, 9 not supported
))(s)
}
pub fn t_t(s: &[u8]) -> IResult<&[u8], T> {
alt((
text,
binary(),
map(unit_t, |_| T::Unit),
map(tag_t, |t| T::Sum(t)),
map(list_t, |l| T::List(l)),
map(record_t, |p| T::Record(p)),
map(bool_t(), |u| T::N1(u)),
// 8, 64 and 128 bit
map(uint_t("n3"), |u| T::N3(u)),
map(uint_t("n6"), |u| T::N6(u)),
map(uint_t("n7"), |u| T::N7(u)),
map(int_t("i3"), |u| T::I3(u)),
map(int_t("i6"), |u| T::I6(u)),
map(int_t("i7"), |u| T::I7(u)),
// less common
map(uint_t("n2"), |u| T::N3(u)),
map(uint_t("n4"), |u| T::N6(u)),
map(uint_t("n5"), |u| T::N6(u)),
map(int_t("i1"), |u| T::I3(u)),
map(int_t("i2"), |u| T::I3(u)),
map(int_t("i4"), |u| T::I6(u)),
map(int_t("i5"), |u| T::I6(u)),
// TODO: 8, 9 not supported
))(s)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_unit_t() {
assert_eq!(
unit_t("u,".as_bytes()),
Ok(("".as_bytes(), ()))
);
}
#[test]
fn test_parse_bool_t() {
assert_eq!(
bool_t()("n1:0,".as_bytes()),
Ok(("".as_bytes(), false))
);
assert_eq!(
bool_t()("n1:1,".as_bytes()),
Ok(("".as_bytes(), true))
);
}
#[test]
fn test_parse_usize_t() {
assert_eq!(
usize_t("32foo".as_bytes()),
Ok(("foo".as_bytes(), 32))
);
}
#[test]
fn test_parse_int_t() {
assert_eq!(
uint_t::<u8>("n3")("n3:42,abc".as_bytes()),
Ok(("abc".as_bytes(), 42))
);
assert_eq!(
uint_t::<u8>("n3")("n3:1024,abc".as_bytes()),
Err(nom::Err::Error(("1024,abc".as_bytes(), nom::error::ErrorKind::MapRes)))
);
assert_eq!(
int_t::<i64>("i6")("i6:-23,abc".as_bytes()),
Ok(("abc".as_bytes(), -23))
);
assert_eq!(
int_t::<i128>("i3")("i3:0,:abc".as_bytes()),
Ok((":abc".as_bytes(), 0))
);
assert_eq!(
uint_t::<u8>("n7")("n7:09,".as_bytes()),
Ok(("".as_bytes(), 9))
);
// assert_eq!(
// length("c"),
// Err(nom::Err::Error(("c", nom::error::ErrorKind::Digit)))
// );
// assert_eq!(
// length(":"),
// Err(nom::Err::Error((":", nom::error::ErrorKind::Digit)))
// );
}
#[test]
fn test_parse_text() {
assert_eq!(
text("t5:hello,".as_bytes()),
Ok(("".as_bytes(), T::Text("hello".to_owned())))
);
assert_eq!(
text("t4:fo,".as_bytes()),
// TODO: way better parse error messages
Err(nom::Err::Error(("fo,".as_bytes(), nom::error::ErrorKind::Eof)))
);
assert_eq!(
text("t9:今日は,".as_bytes()),
Ok(("".as_bytes(), T::Text("今日は".to_owned())))
);
}
#[test]
fn test_parse_binary() {
assert_eq!(
binary()("b5:hello,".as_bytes()),
Ok(("".as_bytes(), T::Binary(Vec::from("hello".to_owned()))))
);
assert_eq!(
binary()("b4:fo,".as_bytes()),
// TODO: way better parse error messages
Err(nom::Err::Error(("fo,".as_bytes(), nom::error::ErrorKind::Eof)))
);
assert_eq!(
binary()("b9:今日は,".as_bytes()),
Ok(("".as_bytes(), T::Binary(Vec::from("今日は".as_bytes()))))
);
}
#[test]
fn test_list() {
assert_eq!(
list_t("[0:]".as_bytes()),
Ok(("".as_bytes(), vec![]))
);
assert_eq!(
list_t("[6:u,u,u,]".as_bytes()),
Ok(("".as_bytes(), vec![
T::Unit,
T::Unit,
T::Unit,
]))
);
assert_eq!(
list_t("[15:u,[7:t3:foo,]u,]".as_bytes()),
Ok(("".as_bytes(), vec![
T::Unit,
T::List(vec![T::Text("foo".to_owned())]),
T::Unit,
]))
);
}
#[test]
fn test_record() {
assert_eq!(
record_t("{21:<1:a|u,<1:b|u,<1:c|u,}".as_bytes()),
Ok(("".as_bytes(), vec![
("a".to_owned(), T::Unit),
("b".to_owned(), T::Unit),
("c".to_owned(), T::Unit),
].into_iter().collect::<HashMap<String, T>>()))
);
// duplicated keys are ignored (first is taken)
assert_eq!(
record_t("{25:<1:a|u,<1:b|u,<1:a|i1:-1,}".as_bytes()),
Ok(("".as_bytes(), vec![
("a".to_owned(), T::Unit),
("b".to_owned(), T::Unit),
].into_iter().collect::<HashMap<_,_>>()))
);
}
#[test]
fn test_parse() {
assert_eq!(
t_t("n3:255,".as_bytes()),
Ok(("".as_bytes(), T::N3(255)))
);
assert_eq!(
t_t("t6:halloo,".as_bytes()),
Ok(("".as_bytes(), T::Text("halloo".to_owned())))
);
assert_eq!(
t_t("<3:foo|t6:halloo,".as_bytes()),
Ok(("".as_bytes(), T::Sum (Tag {
tag: "foo".to_owned(),
val: Box::new(T::Text("halloo".to_owned()))
})))
);
// { a: Unit
// , foo: List <A: Unit | B: List i3> }
assert_eq!(
t_t("{52:<1:a|u,<3:foo|[33:<1:A|u,<1:A|n1:1,<1:B|[7:i3:127,]]}".as_bytes()),
Ok(("".as_bytes(), T::Record(vec![
("a".to_owned(), T::Unit),
("foo".to_owned(), T::List(vec![
T::Sum(Tag { tag: "A".to_owned(), val: Box::new(T::Unit) }),
T::Sum(Tag { tag: "A".to_owned(), val: Box::new(T::N1(true)) }),
T::Sum(Tag { tag: "B".to_owned(), val: Box::new(T::List(vec![T::I3(127)])) }),
]))
].into_iter().collect::<HashMap<String, T>>())))
);
}
}
}
pub mod dec {
use super::*;
use std::collections::HashMap;
pub struct DecodeError(pub String);
pub trait Decoder<'a> {
type A;
fn dec(&self, u: U<'a>) -> Result<Self::A, DecodeError>;
}
#[derive(Clone, Copy)]
pub struct AnyT;
#[derive(Clone, Copy)]
pub struct AnyU;
// impl Decoder for AnyT {
// type A = T;
// fn dec(u: U) -> Result<Self::A, DecodeError> {
// // TODO: implement
// parse::u_into_t(u)
// }
// }
impl<'a> Decoder<'a> for AnyU {
type A = U<'a>;
fn dec(&self, u: U<'a>) -> Result<Self::A, DecodeError> {
Ok(u)
}
}
#[derive(Clone, Copy)]
pub struct ScalarAsBytes;
impl<'a> Decoder<'a> for ScalarAsBytes {
type A = Vec<u8>;
fn dec(&self, u: U<'a>) -> Result<Self::A, DecodeError> {
match u {
U::N3(u) => Ok(format!("{}", u).into_bytes()),
U::N6(u) => Ok(format!("{}", u).into_bytes()),
U::N7(u) => Ok(format!("{}", u).into_bytes()),
U::I3(i) => Ok(format!("{}", i).into_bytes()),
U::I6(i) => Ok(format!("{}", i).into_bytes()),
U::I7(i) => Ok(format!("{}", i).into_bytes()),
U::Text(t) => Ok(t.as_bytes().to_owned()),
U::Binary(b) => Ok(b.to_owned()),
o => Err(DecodeError(format!("Cannot decode {:?} into scalar", o))),
}
}
}
#[derive(Clone, Copy)]
pub struct Record<T>(pub T);
impl<'a, Inner: Decoder<'a>> Decoder<'a> for Record<Inner> {
type A = HashMap<&'a str, Inner::A>;
fn dec(&self, u: U<'a>) -> Result<Self::A, DecodeError> {
match u {
U::Record(map) =>
map.into_iter()
.map(|(k, v)| self.0.dec(v).map(|v2| (k, v2)))
.collect::<Result<Self::A, _>>(),
o => Err(DecodeError(format!("Cannot decode {:?} into record", o)))
}
}
}
#[derive(Clone, Copy)]
pub struct RecordDot<'a, T> {
field: &'a str,
inner: T
}
impl <'a, Inner: Decoder<'a> + Copy> Decoder<'a> for RecordDot<'_, Inner> {
type A = Inner::A;
fn dec(&self, u: U<'a>) -> Result<Self::A, DecodeError> {
match Record(self.inner).dec(u) {
Ok(mut map) => match map.remove(self.field) {
Some(inner) => Ok(inner),
None => Err(DecodeError(format!("Cannot find `{}` in record map", self.field))),
},
Err(err) => Err(err),
}
}
}
fn dec_u(b: &[u8]) -> Result<U, DecodeError> {
match parse::u_u(b) {
Ok((b"", u)) => Ok(u),
Ok((rest, _)) => Err(DecodeError(format!("Cannot decode nested U, it contains trailing bytes"))),
Err(err) => Err(DecodeError(format!("Cannot decode nested U bytes: {:?}", err))),
}
}
}