feat(grfn/achilles): Implement tuples, and tuple patterns

Implement tuple expressions, types, and patterns, all the way through
the parser down to the typechecker. In LLVM, these are implemented as
anonymous structs, using an `extract` instruction when they're pattern
matched on to get out the individual fields.

Currently the only limitation here is patterns aren't supported in
function argument position, but you can still do something like

    fn xy = let (x, y) = xy in x + y

Change-Id: I357f17e9d4052e741eda8605b6662822f331efde
Reviewed-on: https://cl.tvl.fyi/c/depot/+/3027
Reviewed-by: grfn <grfn@gws.fyi>
Tested-by: BuildkiteCI
This commit is contained in:
Griffin Smith 2021-04-17 08:28:24 +02:00 committed by grfn
parent e1c45be3f5
commit 48098f83c1
12 changed files with 413 additions and 54 deletions

View file

@ -4,10 +4,43 @@ use itertools::Itertools;
use super::{BinaryOperator, Ident, Literal, UnaryOperator};
#[derive(Debug, PartialEq, Eq, Clone)]
pub enum Pattern<'a, T> {
Id(Ident<'a>, T),
Tuple(Vec<Pattern<'a, T>>),
}
impl<'a, T> Pattern<'a, T> {
pub fn to_owned(&self) -> Pattern<'static, T>
where
T: Clone,
{
match self {
Pattern::Id(id, t) => Pattern::Id(id.to_owned(), t.clone()),
Pattern::Tuple(pats) => {
Pattern::Tuple(pats.into_iter().map(Pattern::to_owned).collect())
}
}
}
pub fn traverse_type<F, U, E>(self, f: F) -> Result<Pattern<'a, U>, E>
where
F: Fn(T) -> Result<U, E> + Clone,
{
match self {
Pattern::Id(id, t) => Ok(Pattern::Id(id, f(t)?)),
Pattern::Tuple(pats) => Ok(Pattern::Tuple(
pats.into_iter()
.map(|pat| pat.traverse_type(f.clone()))
.collect::<Result<Vec<_>, _>>()?,
)),
}
}
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct Binding<'a, T> {
pub ident: Ident<'a>,
pub type_: T,
pub pat: Pattern<'a, T>,
pub body: Expr<'a, T>,
}
@ -17,8 +50,7 @@ impl<'a, T> Binding<'a, T> {
T: Clone,
{
Binding {
ident: self.ident.to_owned(),
type_: self.type_.clone(),
pat: self.pat.to_owned(),
body: self.body.to_owned(),
}
}
@ -30,6 +62,8 @@ pub enum Expr<'a, T> {
Literal(Literal<'a>, T),
Tuple(Vec<Expr<'a, T>>, T),
UnaryOp {
op: UnaryOperator,
rhs: Box<Expr<'a, T>>,
@ -76,6 +110,7 @@ impl<'a, T> Expr<'a, T> {
match self {
Expr::Ident(_, t) => t,
Expr::Literal(_, t) => t,
Expr::Tuple(_, t) => t,
Expr::UnaryOp { type_, .. } => type_,
Expr::BinaryOp { type_, .. } => type_,
Expr::Let { type_, .. } => type_,
@ -115,10 +150,9 @@ impl<'a, T> Expr<'a, T> {
} => Ok(Expr::Let {
bindings: bindings
.into_iter()
.map(|Binding { ident, type_, body }| {
.map(|Binding { pat, body }| {
Ok(Binding {
ident,
type_: f(type_)?,
pat: pat.traverse_type(f.clone())?,
body: body.traverse_type(f.clone())?,
})
})
@ -168,6 +202,13 @@ impl<'a, T> Expr<'a, T> {
.collect::<Result<Vec<_>, E>>()?,
type_: f(type_)?,
}),
Expr::Tuple(members, t) => Ok(Expr::Tuple(
members
.into_iter()
.map(|t| t.traverse_type(f.clone()))
.try_collect()?,
f(t)?,
)),
}
}
@ -242,6 +283,9 @@ impl<'a, T> Expr<'a, T> {
args: args.iter().map(|e| e.to_owned()).collect(),
type_: type_.clone(),
},
Expr::Tuple(members, t) => {
Expr::Tuple(members.into_iter().map(Expr::to_owned).collect(), t.clone())
}
}
}
}

View file

@ -127,9 +127,24 @@ impl<'a> Literal<'a> {
}
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub enum Pattern<'a> {
Id(Ident<'a>),
Tuple(Vec<Pattern<'a>>),
}
impl<'a> Pattern<'a> {
pub fn to_owned(&self) -> Pattern<'static> {
match self {
Pattern::Id(id) => Pattern::Id(id.to_owned()),
Pattern::Tuple(pats) => Pattern::Tuple(pats.iter().map(Pattern::to_owned).collect()),
}
}
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct Binding<'a> {
pub ident: Ident<'a>,
pub pat: Pattern<'a>,
pub type_: Option<Type<'a>>,
pub body: Expr<'a>,
}
@ -137,7 +152,7 @@ pub struct Binding<'a> {
impl<'a> Binding<'a> {
fn to_owned(&self) -> Binding<'static> {
Binding {
ident: self.ident.to_owned(),
pat: self.pat.to_owned(),
type_: self.type_.as_ref().map(|t| t.to_owned()),
body: self.body.to_owned(),
}
@ -179,6 +194,8 @@ pub enum Expr<'a> {
args: Vec<Expr<'a>>,
},
Tuple(Vec<Expr<'a>>),
Ascription {
expr: Box<Expr<'a>>,
type_: Type<'a>,
@ -190,6 +207,9 @@ impl<'a> Expr<'a> {
match self {
Expr::Ident(ref id) => Expr::Ident(id.to_owned()),
Expr::Literal(ref lit) => Expr::Literal(lit.to_owned()),
Expr::Tuple(ref members) => {
Expr::Tuple(members.into_iter().map(Expr::to_owned).collect())
}
Expr::UnaryOp { op, rhs } => Expr::UnaryOp {
op: *op,
rhs: Box::new((**rhs).to_owned()),
@ -312,6 +332,7 @@ pub enum Type<'a> {
Bool,
CString,
Unit,
Tuple(Vec<Type<'a>>),
Var(Ident<'a>),
Function(FunctionType<'a>),
}
@ -326,6 +347,7 @@ impl<'a> Type<'a> {
Type::Unit => Type::Unit,
Type::Var(v) => Type::Var(v.to_owned()),
Type::Function(f) => Type::Function(f.to_owned()),
Type::Tuple(members) => Type::Tuple(members.iter().map(Type::to_owned).collect()),
}
}
@ -379,9 +401,23 @@ impl<'a> Type<'a> {
Type::Float => Type::Float,
Type::Bool => Type::Bool,
Type::CString => Type::CString,
Type::Tuple(members) => Type::Tuple(
members
.into_iter()
.map(|t| t.traverse_type_vars(f.clone()))
.collect(),
),
Type::Unit => Type::Unit,
}
}
pub fn as_tuple(&self) -> Option<&Vec<Type<'a>>> {
if let Self::Tuple(v) = self {
Some(v)
} else {
None
}
}
}
impl<'a> Display for Type<'a> {
@ -394,6 +430,7 @@ impl<'a> Display for Type<'a> {
Type::Unit => f.write_str("()"),
Type::Var(v) => v.fmt(f),
Type::Function(ft) => ft.fmt(f),
Type::Tuple(ms) => write!(f, "({})", ms.iter().join(", ")),
}
}
}

View file

@ -7,12 +7,13 @@ use inkwell::builder::Builder;
pub use inkwell::context::Context;
use inkwell::module::Module;
use inkwell::support::LLVMString;
use inkwell::types::{BasicType, BasicTypeEnum, FunctionType, IntType};
use inkwell::values::{AnyValueEnum, BasicValueEnum, FunctionValue};
use inkwell::types::{BasicType, BasicTypeEnum, FunctionType, IntType, StructType};
use inkwell::values::{AnyValueEnum, BasicValueEnum, FunctionValue, StructValue};
use inkwell::{AddressSpace, IntPredicate};
use itertools::Itertools;
use thiserror::Error;
use crate::ast::hir::{Binding, Decl, Expr};
use crate::ast::hir::{Binding, Decl, Expr, Pattern};
use crate::ast::{BinaryOperator, Ident, Literal, Type, UnaryOperator};
use crate::common::env::Env;
@ -82,6 +83,25 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
.append_basic_block(*self.function_stack.last().unwrap(), name)
}
fn bind_pattern(&mut self, pat: &'ast Pattern<'ast, Type>, val: AnyValueEnum<'ctx>) {
match pat {
Pattern::Id(id, _) => self.env.set(id, val),
Pattern::Tuple(pats) => {
for (i, pat) in pats.iter().enumerate() {
let member = self
.builder
.build_extract_value(
StructValue::try_from(val).unwrap(),
i as _,
"pat_bind",
)
.unwrap();
self.bind_pattern(pat, member.into());
}
}
}
}
pub fn codegen_expr(
&mut self,
expr: &'ast Expr<'ast, Type>,
@ -164,9 +184,9 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
}
Expr::Let { bindings, body, .. } => {
self.env.push();
for Binding { ident, body, .. } in bindings {
for Binding { pat, body, .. } in bindings {
if let Some(val) = self.codegen_expr(body)? {
self.env.set(ident, val);
self.bind_pattern(pat, val);
}
}
let res = self.codegen_expr(body);
@ -244,6 +264,19 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
self.env.restore(env);
Ok(Some(function.into()))
}
Expr::Tuple(members, ty) => {
let values = members
.into_iter()
.map(|expr| self.codegen_expr(expr))
.collect::<Result<Vec<_>>>()?
.into_iter()
.filter_map(|x| x)
.map(|x| x.try_into().unwrap())
.collect_vec();
let field_types = ty.as_tuple().unwrap();
let tuple_type = self.codegen_tuple_type(field_types);
Ok(Some(tuple_type.const_named_struct(&values).into()))
}
}
}
@ -341,9 +374,20 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
Type::Function(_) => todo!(),
Type::Var(_) => unreachable!(),
Type::Unit => None,
Type::Tuple(ts) => Some(self.codegen_tuple_type(ts).into()),
}
}
fn codegen_tuple_type(&self, ts: &'ast [Type]) -> StructType<'ctx> {
self.context.struct_type(
ts.iter()
.filter_map(|t| self.codegen_type(t))
.collect_vec()
.as_slice(),
false,
)
}
fn codegen_int_type(&self, type_: &'ast Type) -> IntType<'ctx> {
// TODO
self.context.i64_type()
@ -433,4 +477,10 @@ mod tests {
let res = jit_eval::<i64>("let id = fn x = x in id 1").unwrap();
assert_eq!(res, 1);
}
#[test]
fn bind_tuple_pattern() {
let res = jit_eval::<i64>("let (x, y) = (1, 2) in x + y").unwrap();
assert_eq!(res, 3);
}
}

View file

@ -1,9 +1,12 @@
mod error;
mod value;
use itertools::Itertools;
use value::Val;
pub use self::error::{Error, Result};
pub use self::value::{Function, Value};
use crate::ast::hir::{Binding, Expr};
use crate::ast::hir::{Binding, Expr, Pattern};
use crate::ast::{BinaryOperator, FunctionType, Ident, Literal, Type, UnaryOperator};
use crate::common::env::Env;
@ -24,6 +27,17 @@ impl<'a> Interpreter<'a> {
.ok_or_else(|| Error::UndefinedVariable(var.to_owned()))
}
fn bind_pattern(&mut self, pattern: &'a Pattern<'a, Type>, value: Value<'a>) {
match pattern {
Pattern::Id(id, _) => self.env.set(id, value),
Pattern::Tuple(pats) => {
for (pat, val) in pats.iter().zip(value.as_tuple().unwrap().clone()) {
self.bind_pattern(pat, val);
}
}
}
}
pub fn eval(&mut self, expr: &'a Expr<'a, Type>) -> Result<Value<'a>> {
let res = match expr {
Expr::Ident(id, _) => self.resolve(id),
@ -53,9 +67,9 @@ impl<'a> Interpreter<'a> {
}
Expr::Let { bindings, body, .. } => {
self.env.push();
for Binding { ident, body, .. } in bindings {
for Binding { pat, body, .. } in bindings {
let val = self.eval(body)?;
self.env.set(ident, val);
self.bind_pattern(pat, val);
}
let res = self.eval(body)?;
self.env.pop();
@ -115,6 +129,13 @@ impl<'a> Interpreter<'a> {
body: (**body).to_owned(),
}))
}
Expr::Tuple(members, _) => Ok(Val::Tuple(
members
.into_iter()
.map(|expr| self.eval(expr))
.try_collect()?,
)
.into()),
}?;
debug_assert_eq!(&res.type_(), expr.type_());
Ok(res)

View file

@ -6,6 +6,7 @@ use std::rc::Rc;
use std::result;
use derive_more::{Deref, From, TryInto};
use itertools::Itertools;
use super::{Error, Result};
use crate::ast::hir::Expr;
@ -25,6 +26,7 @@ pub enum Val<'a> {
Float(f64),
Bool(bool),
String(Cow<'a, str>),
Tuple(Vec<Value<'a>>),
Function(Function<'a>),
}
@ -49,6 +51,7 @@ impl<'a> fmt::Debug for Val<'a> {
Val::Function(Function { type_, .. }) => {
f.debug_struct("Function").field("type_", type_).finish()
}
Val::Tuple(members) => f.debug_tuple("Tuple").field(members).finish(),
}
}
}
@ -79,6 +82,7 @@ impl<'a> Display for Val<'a> {
Val::Bool(x) => x.fmt(f),
Val::String(s) => write!(f, "{:?}", s),
Val::Function(Function { type_, .. }) => write!(f, "<{}>", type_),
Val::Tuple(members) => write!(f, "({})", members.iter().join(", ")),
}
}
}
@ -91,6 +95,7 @@ impl<'a> Val<'a> {
Val::Bool(_) => Type::Bool,
Val::String(_) => Type::CString,
Val::Function(Function { type_, .. }) => Type::Function(type_.clone()),
Val::Tuple(members) => Type::Tuple(members.iter().map(|expr| expr.type_()).collect()),
}
}
@ -114,6 +119,22 @@ impl<'a> Val<'a> {
}),
}
}
pub fn as_tuple(&self) -> Option<&Vec<Value<'a>>> {
if let Self::Tuple(v) = self {
Some(v)
} else {
None
}
}
pub fn try_into_tuple(self) -> result::Result<Vec<Value<'a>>, Self> {
if let Self::Tuple(v) = self {
Ok(v)
} else {
Err(self)
}
}
}
#[derive(Debug, PartialEq, Clone, Deref)]

View file

@ -8,7 +8,8 @@ use nom::{
};
use pratt::{Affix, Associativity, PrattParser, Precedence};
use crate::ast::{BinaryOperator, Binding, Expr, Fun, Literal, UnaryOperator};
use super::util::comma;
use crate::ast::{BinaryOperator, Binding, Expr, Fun, Literal, Pattern, UnaryOperator};
use crate::parser::{arg, ident, type_};
#[derive(Debug)]
@ -192,9 +193,45 @@ named!(literal(&str) -> Literal, alt!(int | bool_ | string | unit));
named!(literal_expr(&str) -> Expr, map!(literal, Expr::Literal));
named!(tuple(&str) -> Expr, do_parse!(
complete!(tag!("("))
>> multispace0
>> fst: expr
>> comma
>> rest: separated_list0!(
comma,
expr
)
>> multispace0
>> tag!(")")
>> ({
let mut members = Vec::with_capacity(rest.len() + 1);
members.push(fst);
members.append(&mut rest.clone());
Expr::Tuple(members)
})
));
named!(tuple_pattern(&str) -> Pattern, do_parse!(
complete!(tag!("("))
>> multispace0
>> pats: separated_list0!(
comma,
pattern
)
>> multispace0
>> tag!(")")
>> (Pattern::Tuple(pats))
));
named!(pattern(&str) -> Pattern, alt!(
ident => { |id| Pattern::Id(id) } |
tuple_pattern
));
named!(binding(&str) -> Binding, do_parse!(
multispace0
>> ident: ident
>> pat: pattern
>> multispace0
>> type_: opt!(preceded!(tuple!(tag!(":"), multispace0), type_))
>> multispace0
@ -202,7 +239,7 @@ named!(binding(&str) -> Binding, do_parse!(
>> multispace0
>> body: expr
>> (Binding {
ident,
pat,
type_,
body
})
@ -267,6 +304,7 @@ named!(paren_expr(&str) -> Expr,
named!(funcref(&str) -> Expr, alt!(
ident_expr |
tuple |
paren_expr
));
@ -296,6 +334,7 @@ named!(fun_expr(&str) -> Expr, do_parse!(
named!(fn_arg(&str) -> Expr, alt!(
ident_expr |
literal_expr |
tuple |
paren_expr
));
@ -314,7 +353,8 @@ named!(simple_expr_unascripted(&str) -> Expr, alt!(
if_ |
fun_expr |
literal_expr |
ident_expr
ident_expr |
tuple
));
named!(simple_expr(&str) -> Expr, alt!(
@ -334,7 +374,7 @@ named!(pub expr(&str) -> Expr, alt!(
#[cfg(test)]
pub(crate) mod tests {
use super::*;
use crate::ast::{Arg, Ident, Type};
use crate::ast::{Arg, Ident, Pattern, Type};
use std::convert::TryFrom;
use BinaryOperator::*;
use Expr::{BinaryOp, If, Let, UnaryOp};
@ -449,6 +489,17 @@ pub(crate) mod tests {
);
}
#[test]
fn tuple() {
assert_eq!(
test_parse!(expr, "(1, \"seven\")"),
Expr::Tuple(vec![
Expr::Literal(Literal::Int(1)),
Expr::Literal(Literal::String(Cow::Borrowed("seven")))
])
)
}
#[test]
fn simple_string_lit() {
assert_eq!(
@ -465,12 +516,12 @@ pub(crate) mod tests {
Let {
bindings: vec![
Binding {
ident: Ident::try_from("x").unwrap(),
pat: Pattern::Id(Ident::try_from("x").unwrap()),
type_: None,
body: Expr::Literal(Literal::Int(1))
},
Binding {
ident: Ident::try_from("y").unwrap(),
pat: Pattern::Id(Ident::try_from("y").unwrap()),
type_: None,
body: Expr::BinaryOp {
lhs: ident_expr("x"),
@ -553,7 +604,7 @@ pub(crate) mod tests {
Expr::Call {
fun: Box::new(Expr::Let {
bindings: vec![Binding {
ident: Ident::try_from("x").unwrap(),
pat: Pattern::Id(Ident::try_from("x").unwrap()),
type_: None,
body: Expr::Literal(Literal::Int(1))
}],
@ -571,7 +622,7 @@ pub(crate) mod tests {
res,
Expr::Let {
bindings: vec![Binding {
ident: Ident::try_from("id").unwrap(),
pat: Pattern::Id(Ident::try_from("id").unwrap()),
type_: None,
body: Expr::Fun(Box::new(Fun {
args: vec![Arg::try_from("x").unwrap()],
@ -586,6 +637,28 @@ pub(crate) mod tests {
);
}
#[test]
fn tuple_binding() {
let res = test_parse!(expr, "let (x, y) = (1, 2) in x");
assert_eq!(
res,
Expr::Let {
bindings: vec![Binding {
pat: Pattern::Tuple(vec![
Pattern::Id(Ident::from_str_unchecked("x")),
Pattern::Id(Ident::from_str_unchecked("y"))
]),
body: Expr::Tuple(vec![
Expr::Literal(Literal::Int(1)),
Expr::Literal(Literal::Int(2))
]),
type_: None
}],
body: Box::new(Expr::Ident(Ident::from_str_unchecked("x")))
}
)
}
mod ascriptions {
use super::*;
@ -608,7 +681,7 @@ pub(crate) mod tests {
res,
Expr::Let {
bindings: vec![Binding {
ident: Ident::try_from("const_1").unwrap(),
pat: Pattern::Id(Ident::try_from("const_1").unwrap()),
type_: None,
body: Expr::Fun(Box::new(Fun {
args: vec![Arg::try_from("x").unwrap()],
@ -633,7 +706,7 @@ pub(crate) mod tests {
res,
Expr::Let {
bindings: vec![Binding {
ident: Ident::try_from("x").unwrap(),
pat: Pattern::Id(Ident::try_from("x").unwrap()),
type_: Some(Type::Int),
body: Expr::Literal(Literal::Int(1))
}],

View file

@ -6,6 +6,7 @@ use nom::{alt, char, complete, do_parse, eof, many0, named, separated_list0, tag
pub(crate) mod macros;
mod expr;
mod type_;
mod util;
use crate::ast::{Arg, Decl, Fun, Ident};
pub use expr::expr;

View file

@ -2,17 +2,14 @@ use nom::character::complete::{multispace0, multispace1};
use nom::{alt, delimited, do_parse, map, named, opt, separated_list0, tag, terminated, tuple};
use super::ident;
use super::util::comma;
use crate::ast::{FunctionType, Type};
named!(pub function_type(&str) -> FunctionType, do_parse!(
tag!("fn")
>> multispace1
>> args: map!(opt!(terminated!(separated_list0!(
tuple!(
multispace0,
tag!(","),
multispace0
),
comma,
type_
), multispace1)), |args| args.unwrap_or_default())
>> tag!("->")
@ -24,12 +21,32 @@ named!(pub function_type(&str) -> FunctionType, do_parse!(
})
));
named!(tuple_type(&str) -> Type, do_parse!(
tag!("(")
>> multispace0
>> fst: type_
>> comma
>> rest: separated_list0!(
comma,
type_
)
>> multispace0
>> tag!(")")
>> ({
let mut members = Vec::with_capacity(rest.len() + 1);
members.push(fst);
members.append(&mut rest.clone());
Type::Tuple(members)
})
));
named!(pub type_(&str) -> Type, alt!(
tag!("int") => { |_| Type::Int } |
tag!("float") => { |_| Type::Float } |
tag!("bool") => { |_| Type::Bool } |
tag!("cstring") => { |_| Type::CString } |
tag!("()") => { |_| Type::Unit } |
tuple_type |
function_type => { |ft| Type::Function(ft) }|
ident => { |id| Type::Var(id) } |
delimited!(
@ -111,6 +128,14 @@ mod tests {
)
}
#[test]
fn tuple() {
assert_eq!(
test_parse!(type_, "(int, int)"),
Type::Tuple(vec![Type::Int, Type::Int])
)
}
#[test]
fn type_vars() {
assert_eq!(

View file

@ -0,0 +1,8 @@
use nom::character::complete::multispace0;
use nom::{complete, map, named, tag, tuple};
named!(pub(crate) comma(&str) -> (), map!(tuple!(
multispace0,
complete!(tag!(",")),
multispace0
) ,|_| ()));

View file

@ -1,6 +1,6 @@
use std::collections::HashMap;
use crate::ast::hir::{Binding, Decl, Expr};
use crate::ast::hir::{Binding, Decl, Expr, Pattern};
use crate::ast::{BinaryOperator, Ident, Literal, UnaryOperator};
pub(crate) mod monomorphize;
@ -29,9 +29,12 @@ pub(crate) trait Visitor<'a, 'ast, T: 'ast>: Sized + 'a {
Ok(())
}
fn visit_pattern(&mut self, _pat: &mut Pattern<'ast, T>) -> Result<(), Self::Error> {
Ok(())
}
fn visit_binding(&mut self, binding: &mut Binding<'ast, T>) -> Result<(), Self::Error> {
self.visit_ident(&mut binding.ident)?;
self.visit_type(&mut binding.type_)?;
self.visit_pattern(&mut binding.pat)?;
self.visit_expr(&mut binding.body)?;
Ok(())
}
@ -54,6 +57,13 @@ pub(crate) trait Visitor<'a, 'ast, T: 'ast>: Sized + 'a {
Ok(())
}
fn visit_tuple(&mut self, members: &mut Vec<Expr<'ast, T>>) -> Result<(), Self::Error> {
for expr in members {
self.visit_expr(expr)?;
}
Ok(())
}
fn pre_visit_expr(&mut self, _expr: &mut Expr<'ast, T>) -> Result<(), Self::Error> {
Ok(())
}
@ -137,12 +147,16 @@ pub(crate) trait Visitor<'a, 'ast, T: 'ast>: Sized + 'a {
self.visit_type(type_)?;
self.post_visit_call(fun, type_args, args)?;
}
Expr::Tuple(tup, type_) => {
self.visit_tuple(tup)?;
self.visit_type(type_)?;
}
}
Ok(())
}
fn post_visit_decl(&mut self, decl: &'a Decl<'ast, T>) -> Result<(), Self::Error> {
fn post_visit_decl(&mut self, _decl: &'a Decl<'ast, T>) -> Result<(), Self::Error> {
Ok(())
}

View file

@ -1,7 +1,7 @@
use std::collections::HashMap;
use std::mem;
use ast::hir::Binding;
use ast::hir::{Binding, Pattern};
use ast::Literal;
use void::{ResultVoidExt, Void};
@ -42,8 +42,10 @@ impl<'a, 'ast> Visitor<'a, 'ast, ast::Type<'ast>> for StripPositiveUnits {
bindings: extracted
.into_iter()
.map(|expr| Binding {
ident: Ident::from_str_unchecked("___discarded"),
type_: expr.type_().clone(),
pat: Pattern::Id(
Ident::from_str_unchecked("___discarded"),
expr.type_().clone(),
),
body: expr,
})
.collect(),

View file

@ -8,7 +8,7 @@ use std::fmt::{self, Display};
use std::{mem, result};
use thiserror::Error;
use crate::ast::{self, hir, Arg, BinaryOperator, Ident, Literal};
use crate::ast::{self, hir, Arg, BinaryOperator, Ident, Literal, Pattern};
use crate::common::env::Env;
use crate::common::{Namer, NamerOf};
@ -85,6 +85,7 @@ pub enum Type {
Exist(TyVar),
Nullary(NullaryType),
Prim(PrimType),
Tuple(Vec<Type>),
Unit,
Fun {
args: Vec<Type>,
@ -102,6 +103,9 @@ impl<'a> TryFrom<Type> for ast::Type<'a> {
Type::Exist(_) => Err(value),
Type::Nullary(_) => todo!(),
Type::Prim(p) => Ok(p.into()),
Type::Tuple(members) => Ok(ast::Type::Tuple(
members.into_iter().map(|ty| ty.try_into()).try_collect()?,
)),
Type::Fun { ref args, ref ret } => Ok(ast::Type::Function(ast::FunctionType {
args: args
.clone()
@ -128,6 +132,7 @@ impl Display for Type {
Type::Univ(TyVar(n)) => write!(f, "∀{}", n),
Type::Exist(TyVar(n)) => write!(f, "∃{}", n),
Type::Fun { args, ret } => write!(f, "fn {} -> {}", args.iter().join(", "), ret),
Type::Tuple(members) => write!(f, "({})", members.iter().join(", ")),
Type::Unit => write!(f, "()"),
}
}
@ -159,6 +164,31 @@ impl<'ast> Typechecker<'ast> {
}
}
fn bind_pattern(
&mut self,
pat: Pattern<'ast>,
type_: Type,
) -> Result<hir::Pattern<'ast, Type>> {
match pat {
Pattern::Id(ident) => {
self.env.set(ident.clone(), type_.clone());
Ok(hir::Pattern::Id(ident, type_))
}
Pattern::Tuple(members) => {
let mut tys = Vec::with_capacity(members.len());
let mut hir_members = Vec::with_capacity(members.len());
for pat in members {
let ty = self.fresh_ex();
hir_members.push(self.bind_pattern(pat, ty.clone())?);
tys.push(ty);
}
let tuple_type = Type::Tuple(tys);
self.unify(&tuple_type, &type_)?;
Ok(hir::Pattern::Tuple(hir_members))
}
}
}
pub(crate) fn tc_expr(&mut self, expr: ast::Expr<'ast>) -> Result<hir::Expr<'ast, Type>> {
match expr {
ast::Expr::Ident(ident) => {
@ -178,6 +208,14 @@ impl<'ast> Typechecker<'ast> {
};
Ok(hir::Expr::Literal(lit.to_owned(), type_))
}
ast::Expr::Tuple(members) => {
let members = members
.into_iter()
.map(|expr| self.tc_expr(expr))
.collect::<Result<Vec<_>>>()?;
let type_ = Type::Tuple(members.iter().map(|expr| expr.type_().clone()).collect());
Ok(hir::Expr::Tuple(members, type_))
}
ast::Expr::UnaryOp { op, rhs } => todo!(),
ast::Expr::BinaryOp { lhs, op, rhs } => {
let lhs = self.tc_expr(*lhs)?;
@ -209,18 +247,14 @@ impl<'ast> Typechecker<'ast> {
let bindings = bindings
.into_iter()
.map(
|ast::Binding { ident, type_, body }| -> Result<hir::Binding<Type>> {
|ast::Binding { pat, type_, body }| -> Result<hir::Binding<Type>> {
let body = self.tc_expr(body)?;
if let Some(type_) = type_ {
let type_ = self.type_from_ast_type(type_);
self.unify(body.type_(), &type_)?;
}
self.env.set(ident.clone(), body.type_().clone());
Ok(hir::Binding {
ident,
type_: body.type_().clone(),
body,
})
let pat = self.bind_pattern(pat, body.type_().clone())?;
Ok(hir::Binding { pat, body })
},
)
.collect::<Result<Vec<hir::Binding<Type>>>>()?;
@ -382,7 +416,7 @@ impl<'ast> Typechecker<'ast> {
fn unify(&mut self, ty1: &Type, ty2: &Type) -> Result<Type> {
match (ty1, ty2) {
(Type::Unit, Type::Unit) => Ok(Type::Unit),
(Type::Exist(tv), ty) | (ty, Type::Exist(tv)) => match self.resolve_tv(*tv) {
(Type::Exist(tv), ty) | (ty, Type::Exist(tv)) => match self.resolve_tv(*tv)? {
Some(existing_ty) if self.types_match(ty, &existing_ty) => Ok(ty.clone()),
Some(var @ ast::Type::Var(_)) => {
let var = self.type_from_ast_type(var);
@ -419,6 +453,14 @@ impl<'ast> Typechecker<'ast> {
}
}
(Type::Prim(p1), Type::Prim(p2)) if p1 == p2 => Ok(ty2.clone()),
(Type::Tuple(t1), Type::Tuple(t2)) if t1.len() == t2.len() => {
let ts = t1
.iter()
.zip(t2.iter())
.map(|(t1, t2)| self.unify(t1, t2))
.try_collect()?;
Ok(Type::Tuple(ts))
}
(
Type::Fun {
args: args1,
@ -469,11 +511,17 @@ impl<'ast> Typechecker<'ast> {
fn finalize_type(&self, ty: Type) -> Result<ast::Type<'static>> {
let ret = match ty {
Type::Exist(tv) => self.resolve_tv(tv).ok_or(Error::AmbiguousType(tv)),
Type::Exist(tv) => self.resolve_tv(tv)?.ok_or(Error::AmbiguousType(tv)),
Type::Univ(tv) => Ok(ast::Type::Var(self.name_univ(tv))),
Type::Unit => Ok(ast::Type::Unit),
Type::Nullary(_) => todo!(),
Type::Prim(pr) => Ok(pr.into()),
Type::Tuple(members) => Ok(ast::Type::Tuple(
members
.into_iter()
.map(|ty| self.finalize_type(ty))
.try_collect()?,
)),
Type::Fun { args, ret } => Ok(ast::Type::Function(ast::FunctionType {
args: args
.into_iter()
@ -485,12 +533,15 @@ impl<'ast> Typechecker<'ast> {
ret
}
fn resolve_tv(&self, tv: TyVar) -> Option<ast::Type<'static>> {
fn resolve_tv(&self, tv: TyVar) -> Result<Option<ast::Type<'static>>> {
let mut res = &Type::Exist(tv);
loop {
Ok(loop {
match res {
Type::Exist(tv) => {
res = self.ctx.get(tv)?;
res = match self.ctx.get(tv) {
Some(r) => r,
None => return Ok(None),
};
}
Type::Univ(tv) => {
let ident = self.name_univ(*tv);
@ -504,8 +555,9 @@ impl<'ast> Typechecker<'ast> {
Type::Prim(pr) => break Some((*pr).into()),
Type::Unit => break Some(ast::Type::Unit),
Type::Fun { args, ret } => todo!(),
Type::Tuple(_) => break Some(self.finalize_type(res.clone())?),
}
}
})
}
fn type_from_ast_type(&mut self, ast_type: ast::Type<'ast>) -> Type {
@ -515,6 +567,12 @@ impl<'ast> Typechecker<'ast> {
ast::Type::Float => FLOAT,
ast::Type::Bool => BOOL,
ast::Type::CString => CSTRING,
ast::Type::Tuple(members) => Type::Tuple(
members
.into_iter()
.map(|ty| self.type_from_ast_type(ty))
.collect(),
),
ast::Type::Function(ast::FunctionType { args, ret }) => Type::Fun {
args: args
.into_iter()
@ -582,6 +640,11 @@ impl<'ast> Typechecker<'ast> {
(Type::Unit, _) => false,
(Type::Nullary(_), _) => todo!(),
(Type::Prim(pr), ty) => ast::Type::from(*pr) == *ty,
(Type::Tuple(members), ast::Type::Tuple(members2)) => members
.iter()
.zip(members2.iter())
.all(|(t1, t2)| self.types_match(t1, t2)),
(Type::Tuple(members), _) => false,
(Type::Fun { args, ret }, ast::Type::Function(ft)) => {
args.len() == ft.args.len()
&& args