Universally quantified type variables

Implement universally quantified type variables, both explicitly given
by the user and inferred by the type inference algorithm.
This commit is contained in:
Griffin Smith 2021-03-14 16:43:47 -04:00
parent 7960c3270e
commit ecb4c0f803
17 changed files with 634 additions and 111 deletions

14
Cargo.lock generated
View file

@ -5,7 +5,9 @@ name = "achilles"
version = "0.1.0"
dependencies = [
"anyhow",
"bimap",
"clap",
"crate-root",
"derive_more",
"inkwell",
"itertools",
@ -57,6 +59,12 @@ version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a"
[[package]]
name = "bimap"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f92b72b8f03128773278bf74418b9205f3d2a12c39a61f92395f47af390c32bf"
[[package]]
name = "bit-set"
version = "0.5.2"
@ -140,6 +148,12 @@ dependencies = [
"syn",
]
[[package]]
name = "crate-root"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59c6fe4622b269032d2c5140a592d67a9c409031d286174fcde172fbed86f0d3"
[[package]]
name = "derive_more"
version = "0.99.11"

View file

@ -6,6 +6,7 @@ edition = "2018"
[dependencies]
anyhow = "1.0.38"
bimap = "0.6.0"
clap = "3.0.0-beta.2"
derive_more = "0.99.11"
inkwell = { git = "https://github.com/TheDan64/inkwell", branch = "master", features = ["llvm11-0"] }
@ -18,3 +19,6 @@ pratt = "0.3.0"
proptest = "1.0.0"
test-strategy = "0.1.1"
thiserror = "1.0.24"
[dev-dependencies]
crate-root = "0.1.3"

View file

@ -1,3 +1,3 @@
fn id x = x
fn plus x y = x + y
fn plus (x: int) (y: int) = x + y
fn main = plus (id 2) 7

View file

@ -222,6 +222,12 @@ pub enum Decl<'a, T> {
}
impl<'a, T> Decl<'a, T> {
pub fn type_(&self) -> &T {
match self {
Decl::Fun { type_, .. } => type_,
}
}
pub fn traverse_type<F, U, E>(self, f: F) -> Result<Decl<'a, U>, E>
where
F: Fn(T) -> Result<U, E> + Clone,

View file

@ -1,6 +1,7 @@
pub(crate) mod hir;
use std::borrow::Cow;
use std::collections::HashMap;
use std::convert::TryFrom;
use std::fmt::{self, Display, Formatter};
@ -126,7 +127,7 @@ impl<'a> Literal<'a> {
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct Binding<'a> {
pub ident: Ident<'a>,
pub type_: Option<Type>,
pub type_: Option<Type<'a>>,
pub body: Expr<'a>,
}
@ -134,7 +135,7 @@ impl<'a> Binding<'a> {
fn to_owned(&self) -> Binding<'static> {
Binding {
ident: self.ident.to_owned(),
type_: self.type_.clone(),
type_: self.type_.as_ref().map(|t| t.to_owned()),
body: self.body.to_owned(),
}
}
@ -177,7 +178,7 @@ pub enum Expr<'a> {
Ascription {
expr: Box<Expr<'a>>,
type_: Type,
type_: Type<'a>,
},
}
@ -215,20 +216,46 @@ impl<'a> Expr<'a> {
},
Expr::Ascription { expr, type_ } => Expr::Ascription {
expr: Box::new((**expr).to_owned()),
type_: type_.clone(),
type_: type_.to_owned(),
},
}
}
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct Arg<'a> {
pub ident: Ident<'a>,
pub type_: Option<Type<'a>>,
}
impl<'a> Arg<'a> {
pub fn to_owned(&self) -> Arg<'static> {
Arg {
ident: self.ident.to_owned(),
type_: self.type_.as_ref().map(Type::to_owned),
}
}
}
impl<'a> TryFrom<&'a str> for Arg<'a> {
type Error = <Ident<'a> as TryFrom<&'a str>>::Error;
fn try_from(value: &'a str) -> Result<Self, Self::Error> {
Ok(Arg {
ident: Ident::try_from(value)?,
type_: None,
})
}
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct Fun<'a> {
pub args: Vec<Ident<'a>>,
pub args: Vec<Arg<'a>>,
pub body: Expr<'a>,
}
impl<'a> Fun<'a> {
fn to_owned(&self) -> Fun<'static> {
pub fn to_owned(&self) -> Fun<'static> {
Fun {
args: self.args.iter().map(|arg| arg.to_owned()).collect(),
body: self.body.to_owned(),
@ -236,40 +263,147 @@ impl<'a> Fun<'a> {
}
}
#[derive(Debug, PartialEq, Eq)]
#[derive(Debug, PartialEq, Eq, Clone)]
pub enum Decl<'a> {
Fun { name: Ident<'a>, body: Fun<'a> },
}
////
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct FunctionType {
pub args: Vec<Type>,
pub ret: Box<Type>,
pub struct FunctionType<'a> {
pub args: Vec<Type<'a>>,
pub ret: Box<Type<'a>>,
}
impl Display for FunctionType {
impl<'a> FunctionType<'a> {
pub fn to_owned(&self) -> FunctionType<'static> {
FunctionType {
args: self.args.iter().map(|a| a.to_owned()).collect(),
ret: Box::new((*self.ret).to_owned()),
}
}
}
impl<'a> Display for FunctionType<'a> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "fn {} -> {}", self.args.iter().join(", "), self.ret)
}
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub enum Type {
pub enum Type<'a> {
Int,
Float,
Bool,
CString,
Function(FunctionType),
Var(Ident<'a>),
Function(FunctionType<'a>),
}
impl Display for Type {
impl<'a> Type<'a> {
pub fn to_owned(&self) -> Type<'static> {
match self {
Type::Int => Type::Int,
Type::Float => Type::Float,
Type::Bool => Type::Bool,
Type::CString => Type::CString,
Type::Var(v) => Type::Var(v.to_owned()),
Type::Function(f) => Type::Function(f.to_owned()),
}
}
pub fn alpha_equiv(&self, other: &Self) -> bool {
fn do_alpha_equiv<'a>(
substs: &mut HashMap<&'a Ident<'a>, &'a Ident<'a>>,
lhs: &'a Type,
rhs: &'a Type,
) -> bool {
match (lhs, rhs) {
(Type::Var(v1), Type::Var(v2)) => substs.entry(v1).or_insert(v2) == &v2,
(
Type::Function(FunctionType {
args: args1,
ret: ret1,
}),
Type::Function(FunctionType {
args: args2,
ret: ret2,
}),
) => {
args1.len() == args2.len()
&& args1
.iter()
.zip(args2)
.all(|(a1, a2)| do_alpha_equiv(substs, a1, a2))
&& do_alpha_equiv(substs, ret1, ret2)
}
_ => lhs == rhs,
}
}
let mut substs = HashMap::new();
do_alpha_equiv(&mut substs, self, other)
}
}
impl<'a> Display for Type<'a> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Type::Int => f.write_str("int"),
Type::Float => f.write_str("float"),
Type::Bool => f.write_str("bool"),
Type::CString => f.write_str("cstring"),
Type::Var(v) => v.fmt(f),
Type::Function(ft) => ft.fmt(f),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn type_var(n: &str) -> Type<'static> {
Type::Var(Ident::try_from(n.to_owned()).unwrap())
}
mod alpha_equiv {
use super::*;
#[test]
fn trivial() {
assert!(Type::Int.alpha_equiv(&Type::Int));
assert!(!Type::Int.alpha_equiv(&Type::Bool));
}
#[test]
fn simple_type_var() {
assert!(type_var("a").alpha_equiv(&type_var("b")));
}
#[test]
fn function_with_type_vars_equiv() {
assert!(Type::Function(FunctionType {
args: vec![type_var("a")],
ret: Box::new(type_var("b")),
})
.alpha_equiv(&Type::Function(FunctionType {
args: vec![type_var("b")],
ret: Box::new(type_var("a")),
})))
}
#[test]
fn function_with_type_vars_non_equiv() {
assert!(!Type::Function(FunctionType {
args: vec![type_var("a")],
ret: Box::new(type_var("a")),
})
.alpha_equiv(&Type::Function(FunctionType {
args: vec![type_var("b")],
ret: Box::new(type_var("a")),
})))
}
}
}

View file

@ -15,13 +15,13 @@ pub struct Check {
expr: Option<String>,
}
fn run_expr(expr: String) -> Result<Type> {
fn run_expr(expr: String) -> Result<Type<'static>> {
let (_, parsed) = parser::expr(&expr)?;
let hir_expr = tc::typecheck_expr(parsed)?;
Ok(hir_expr.type_().clone())
Ok(hir_expr.type_().to_owned())
}
fn run_path(path: PathBuf) -> Result<Type> {
fn run_path(path: PathBuf) -> Result<Type<'static>> {
todo!()
}

View file

@ -1,4 +1,6 @@
pub(crate) mod env;
pub(crate) mod error;
pub(crate) mod namer;
pub use error::{Error, Result};
pub use namer::{Namer, NamerOf};

122
src/common/namer.rs Normal file
View file

@ -0,0 +1,122 @@
use std::fmt::Display;
use std::marker::PhantomData;
pub struct Namer<T, F> {
make_name: F,
counter: u64,
_phantom: PhantomData<T>,
}
impl<T, F> Namer<T, F> {
pub fn new(make_name: F) -> Self {
Namer {
make_name,
counter: 0,
_phantom: PhantomData,
}
}
}
impl Namer<String, Box<dyn Fn(u64) -> String>> {
pub fn with_prefix<T>(prefix: T) -> Self
where
T: Display + 'static,
{
Namer::new(move |i| format!("{}{}", prefix, i)).boxed()
}
pub fn with_suffix<T>(suffix: T) -> Self
where
T: Display + 'static,
{
Namer::new(move |i| format!("{}{}", i, suffix)).boxed()
}
pub fn alphabetic() -> Self {
Namer::new(|i| {
if i <= 26 {
std::char::from_u32((i + 96) as u32).unwrap().to_string()
} else {
format!(
"{}{}",
std::char::from_u32(((i % 26) + 96) as u32).unwrap(),
i - 26
)
}
})
.boxed()
}
}
impl<T, F> Namer<T, F>
where
F: Fn(u64) -> T,
{
pub fn make_name(&mut self) -> T {
self.counter += 1;
(self.make_name)(self.counter)
}
pub fn boxed(self) -> NamerOf<T>
where
F: 'static,
{
Namer {
make_name: Box::new(self.make_name),
counter: self.counter,
_phantom: self._phantom,
}
}
pub fn map<G, U>(self, f: G) -> NamerOf<U>
where
G: Fn(T) -> U + 'static,
T: 'static,
F: 'static,
{
Namer {
counter: self.counter,
make_name: Box::new(move |x| f((self.make_name)(x))),
_phantom: PhantomData,
}
}
}
pub type NamerOf<T> = Namer<T, Box<dyn Fn(u64) -> T>>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn prefix() {
let mut namer = Namer::with_prefix("t");
assert_eq!(namer.make_name(), "t1");
assert_eq!(namer.make_name(), "t2");
}
#[test]
fn suffix() {
let mut namer = Namer::with_suffix("t");
assert_eq!(namer.make_name(), "1t");
assert_eq!(namer.make_name(), "2t");
}
#[test]
fn alphabetic() {
let mut namer = Namer::alphabetic();
assert_eq!(namer.make_name(), "a");
assert_eq!(namer.make_name(), "b");
(0..25).for_each(|_| {
namer.make_name();
});
assert_eq!(namer.make_name(), "b2");
}
#[test]
fn custom_callback() {
let mut namer = Namer::new(|n| n + 1);
assert_eq!(namer.make_name(), 2);
assert_eq!(namer.make_name(), 3);
}
}

View file

@ -10,7 +10,10 @@ pub enum Error {
UndefinedVariable(Ident<'static>),
#[error("Unexpected type {actual}, expected type {expected}")]
InvalidType { actual: Type, expected: Type },
InvalidType {
actual: Type<'static>,
expected: Type<'static>,
},
}
pub type Result<T> = result::Result<T, Error>;

View file

@ -115,7 +115,7 @@ impl<'a> Interpreter<'a> {
}
}
pub fn eval<'a>(expr: &'a Expr<'a, Type>) -> Result<Value> {
pub fn eval<'a>(expr: &'a Expr<'a, Type>) -> Result<Value<'a>> {
let mut interpreter = Interpreter::new();
interpreter.eval(expr)
}
@ -128,7 +128,7 @@ mod tests {
use super::*;
use BinaryOperator::*;
fn int_lit(i: u64) -> Box<Expr<'static, Type>> {
fn int_lit(i: u64) -> Box<Expr<'static, Type<'static>>> {
Box::new(Expr::Literal(Literal::Int(i), Type::Int))
}
@ -168,6 +168,7 @@ mod tests {
}
#[test]
#[ignore]
fn function_call() {
let res = do_eval::<i64>("let id = fn x = x in id 1");
assert_eq!(res, 1);

View file

@ -13,9 +13,9 @@ use crate::ast::{FunctionType, Ident, Type};
#[derive(Debug, Clone)]
pub struct Function<'a> {
pub type_: FunctionType,
pub type_: FunctionType<'a>,
pub args: Vec<Ident<'a>>,
pub body: Expr<'a, Type>,
pub body: Expr<'a, Type<'a>>,
}
#[derive(From, TryInto)]
@ -100,7 +100,7 @@ impl<'a> Val<'a> {
&'b T: TryFrom<&'b Self>,
{
<&T>::try_from(self).map_err(|_| Error::InvalidType {
actual: self.type_(),
actual: self.type_().to_owned(),
expected: <T as TypeOf>::type_of(),
})
}
@ -109,8 +109,8 @@ impl<'a> Val<'a> {
match self {
Val::Function(f) if f.type_ == function_type => Ok(&f),
_ => Err(Error::InvalidType {
actual: self.type_(),
expected: Type::Function(function_type),
actual: self.type_().to_owned(),
expected: Type::Function(function_type.to_owned()),
}),
}
}
@ -175,29 +175,29 @@ impl<'a> Div for Value<'a> {
}
pub trait TypeOf {
fn type_of() -> Type;
fn type_of() -> Type<'static>;
}
impl TypeOf for i64 {
fn type_of() -> Type {
fn type_of() -> Type<'static> {
Type::Int
}
}
impl TypeOf for bool {
fn type_of() -> Type {
fn type_of() -> Type<'static> {
Type::Bool
}
}
impl TypeOf for f64 {
fn type_of() -> Type {
fn type_of() -> Type<'static> {
Type::Float
}
}
impl TypeOf for String {
fn type_of() -> Type {
fn type_of() -> Type<'static> {
Type::CString
}
}

View file

@ -1,4 +1,5 @@
#![feature(str_split_once)]
#![feature(or_insert_with_key)]
use clap::Clap;

View file

@ -9,7 +9,7 @@ use nom::{
use pratt::{Affix, Associativity, PrattParser, Precedence};
use crate::ast::{BinaryOperator, Binding, Expr, Fun, Literal, UnaryOperator};
use crate::parser::{ident, type_};
use crate::parser::{arg, ident, type_};
#[derive(Debug)]
enum TokenTree<'a> {
@ -274,7 +274,7 @@ named!(no_arg_call(&str) -> Expr, do_parse!(
named!(fun_expr(&str) -> Expr, do_parse!(
tag!("fn")
>> multispace1
>> args: separated_list0!(multispace1, ident)
>> args: separated_list0!(multispace1, arg)
>> multispace0
>> char!('=')
>> multispace0
@ -285,7 +285,7 @@ named!(fun_expr(&str) -> Expr, do_parse!(
})))
));
named!(arg(&str) -> Expr, alt!(
named!(fn_arg(&str) -> Expr, alt!(
ident_expr |
literal_expr |
paren_expr
@ -294,7 +294,7 @@ named!(arg(&str) -> Expr, alt!(
named!(call_with_args(&str) -> Expr, do_parse!(
fun: funcref
>> multispace1
>> args: separated_list1!(multispace1, arg)
>> args: separated_list1!(multispace1, fn_arg)
>> (Expr::Call {
fun: Box::new(fun),
args
@ -326,7 +326,7 @@ named!(pub expr(&str) -> Expr, alt!(
#[cfg(test)]
pub(crate) mod tests {
use super::*;
use crate::ast::{Ident, Type};
use crate::ast::{Arg, Ident, Type};
use std::convert::TryFrom;
use BinaryOperator::*;
use Expr::{BinaryOp, If, Let, UnaryOp};
@ -549,7 +549,7 @@ pub(crate) mod tests {
ident: Ident::try_from("id").unwrap(),
type_: None,
body: Expr::Fun(Box::new(Fun {
args: vec![Ident::try_from("x").unwrap()],
args: vec![Arg::try_from("x").unwrap()],
body: *ident_expr("x")
}))
}],
@ -586,7 +586,7 @@ pub(crate) mod tests {
ident: Ident::try_from("const_1").unwrap(),
type_: None,
body: Expr::Fun(Box::new(Fun {
args: vec![Ident::try_from("x").unwrap()],
args: vec![Arg::try_from("x").unwrap()],
body: Expr::Ascription {
expr: Box::new(Expr::Literal(Literal::Int(1))),
type_: Type::Int,

View file

@ -7,7 +7,7 @@ mod macros;
mod expr;
mod type_;
use crate::ast::{Decl, Fun, Ident};
use crate::ast::{Arg, Decl, Fun, Ident};
pub use expr::expr;
pub use type_::type_;
@ -58,12 +58,33 @@ where
}
}
named!(ascripted_arg(&str) -> Arg, do_parse!(
complete!(char!('(')) >>
multispace0 >>
ident: ident >>
multispace0 >>
complete!(char!(':')) >>
multispace0 >>
type_: type_ >>
multispace0 >>
complete!(char!(')')) >>
(Arg {
ident,
type_: Some(type_)
})
));
named!(arg(&str) -> Arg, alt!(
ident => { |ident| Arg {ident, type_: None}} |
ascripted_arg
));
named!(fun_decl(&str) -> Decl, do_parse!(
complete!(tag!("fn"))
>> multispace0
>> name: ident
>> multispace1
>> args: separated_list0!(multispace1, ident)
>> args: separated_list0!(multispace1, arg)
>> multispace0
>> char!('=')
>> multispace0
@ -87,6 +108,8 @@ named!(pub toplevel(&str) -> Vec<Decl>, terminated!(many0!(decl), multispace0));
mod tests {
use std::convert::TryInto;
use crate::ast::{BinaryOperator, Expr, Literal, Type};
use super::*;
use expr::tests::ident_expr;
@ -105,6 +128,29 @@ mod tests {
)
}
#[test]
fn ascripted_fn_args() {
test_parse!(ascripted_arg, "(x : int)");
let res = test_parse!(decl, "fn plus1 (x : int) = x + 1");
assert_eq!(
res,
Decl::Fun {
name: "plus1".try_into().unwrap(),
body: Fun {
args: vec![Arg {
ident: "x".try_into().unwrap(),
type_: Some(Type::Int),
}],
body: Expr::BinaryOp {
lhs: ident_expr("x"),
op: BinaryOperator::Add,
rhs: Box::new(Expr::Literal(Literal::Int(1))),
}
}
}
);
}
#[test]
fn multiple_decls() {
let res = test_parse!(

View file

@ -1,6 +1,7 @@
use nom::character::complete::{multispace0, multispace1};
use nom::{alt, delimited, do_parse, map, named, opt, separated_list0, tag, terminated, tuple};
use super::ident;
use crate::ast::{FunctionType, Type};
named!(function_type(&str) -> Type, do_parse!(
@ -29,6 +30,7 @@ named!(pub type_(&str) -> Type, alt!(
tag!("bool") => { |_| Type::Bool } |
tag!("cstring") => { |_| Type::CString } |
function_type |
ident => { |id| Type::Var(id) } |
delimited!(
tuple!(tag!("("), multispace0),
type_,
@ -38,7 +40,10 @@ named!(pub type_(&str) -> Type, alt!(
#[cfg(test)]
mod tests {
use std::convert::TryFrom;
use super::*;
use crate::ast::Ident;
#[test]
fn simple_types() {
@ -103,4 +108,18 @@ mod tests {
})
)
}
#[test]
fn type_vars() {
assert_eq!(
test_parse!(type_, "fn x, y -> x"),
Type::Function(FunctionType {
args: vec![
Type::Var(Ident::try_from("x").unwrap()),
Type::Var(Ident::try_from("y").unwrap()),
],
ret: Box::new(Type::Var(Ident::try_from("x").unwrap())),
})
)
}
}

View file

@ -1,13 +1,16 @@
use bimap::BiMap;
use derive_more::From;
use itertools::Itertools;
use std::cell::RefCell;
use std::collections::HashMap;
use std::convert::{TryFrom, TryInto};
use std::fmt::{self, Display};
use std::result;
use std::{mem, result};
use thiserror::Error;
use crate::ast::{self, hir, BinaryOperator, Ident, Literal};
use crate::ast::{self, hir, Arg, BinaryOperator, Ident, Literal};
use crate::common::env::Env;
use crate::common::{Namer, NamerOf};
#[derive(Debug, Error)]
pub enum Error {
@ -52,7 +55,7 @@ pub enum PrimType {
CString,
}
impl From<PrimType> for ast::Type {
impl<'a> From<PrimType> for ast::Type<'a> {
fn from(pr: PrimType) -> Self {
match pr {
PrimType::Int => ast::Type::Int,
@ -88,22 +91,7 @@ pub enum Type {
},
}
impl PartialEq<ast::Type> for Type {
fn eq(&self, other: &ast::Type) -> bool {
match (self, other) {
(Type::Univ(_), _) => todo!(),
(Type::Exist(_), _) => false,
(Type::Nullary(_), _) => todo!(),
(Type::Prim(pr), ty) => ast::Type::from(*pr) == *ty,
(Type::Fun { args, ret }, ast::Type::Function(ft)) => {
*args == ft.args && (**ret).eq(&*ft.ret)
}
(Type::Fun { .. }, _) => false,
}
}
}
impl TryFrom<Type> for ast::Type {
impl<'a> TryFrom<Type> for ast::Type<'a> {
type Error = Type;
fn try_from(value: Type) -> result::Result<Self, Self::Error> {
@ -142,33 +130,29 @@ impl Display for Type {
}
}
impl From<ast::Type> for Type {
fn from(type_: ast::Type) -> Self {
match type_ {
ast::Type::Int => INT,
ast::Type::Float => FLOAT,
ast::Type::Bool => BOOL,
ast::Type::CString => CSTRING,
ast::Type::Function(ast::FunctionType { args, ret }) => Type::Fun {
args: args.into_iter().map(Self::from).collect(),
ret: Box::new(Self::from(*ret)),
},
}
}
}
struct Typechecker<'ast> {
ty_var_counter: u64,
ty_var_namer: NamerOf<TyVar>,
ctx: HashMap<TyVar, Type>,
env: Env<Ident<'ast>, Type>,
/// AST type var -> type
instantiations: Env<Ident<'ast>, Type>,
/// AST type-var -> universal TyVar
type_vars: RefCell<(BiMap<Ident<'ast>, TyVar>, NamerOf<Ident<'static>>)>,
}
impl<'ast> Typechecker<'ast> {
fn new() -> Self {
Self {
ty_var_counter: 0,
ty_var_namer: Namer::new(TyVar).boxed(),
type_vars: RefCell::new((
Default::default(),
Namer::alphabetic().map(|n| Ident::try_from(n).unwrap()),
)),
ctx: Default::default(),
env: Default::default(),
instantiations: Default::default(),
}
}
@ -224,7 +208,8 @@ impl<'ast> Typechecker<'ast> {
|ast::Binding { ident, type_, body }| -> Result<hir::Binding<Type>> {
let body = self.tc_expr(body)?;
if let Some(type_) = type_ {
self.unify(body.type_(), &type_.into())?;
let type_ = self.type_from_ast_type(type_);
self.unify(body.type_(), &type_)?;
}
self.env.set(ident.clone(), body.type_().clone());
Ok(hir::Binding {
@ -265,19 +250,22 @@ impl<'ast> Typechecker<'ast> {
self.env.push();
let args: Vec<_> = args
.into_iter()
.map(|id| {
let ty = self.fresh_ex();
self.env.set(id.clone(), ty.clone());
(id, ty)
.map(|Arg { ident, type_ }| {
let ty = match type_ {
Some(t) => self.type_from_ast_type(t),
None => self.fresh_ex(),
};
self.env.set(ident.clone(), ty.clone());
(ident, ty)
})
.collect();
let body = self.tc_expr(body)?;
self.env.pop();
Ok(hir::Expr::Fun {
type_: Type::Fun {
args: args.iter().map(|(_, ty)| ty.clone()).collect(),
ret: Box::new(body.type_().clone()),
},
type_: self.universalize(
args.iter().map(|(_, ty)| ty.clone()).collect(),
body.type_().clone(),
),
args,
body: Box::new(body),
})
@ -290,6 +278,7 @@ impl<'ast> Typechecker<'ast> {
ret: Box::new(ret_ty.clone()),
};
let fun = self.tc_expr(*fun)?;
self.instantiations.push();
self.unify(&ft, fun.type_())?;
let args = args
.into_iter()
@ -300,6 +289,7 @@ impl<'ast> Typechecker<'ast> {
Ok(arg)
})
.try_collect()?;
self.commit_instantiations();
Ok(hir::Expr::Call {
fun: Box::new(fun),
args,
@ -308,7 +298,8 @@ impl<'ast> Typechecker<'ast> {
}
ast::Expr::Ascription { expr, type_ } => {
let expr = self.tc_expr(*expr)?;
self.unify(expr.type_(), &type_.into())?;
let type_ = self.type_from_ast_type(type_);
self.unify(expr.type_(), &type_)?;
Ok(expr)
}
}
@ -334,8 +325,7 @@ impl<'ast> Typechecker<'ast> {
}
fn fresh_tv(&mut self) -> TyVar {
self.ty_var_counter += 1;
TyVar(self.ty_var_counter)
self.ty_var_namer.make_name()
}
fn fresh_ex(&mut self) -> Type {
@ -343,29 +333,69 @@ impl<'ast> Typechecker<'ast> {
}
fn fresh_univ(&mut self) -> Type {
Type::Exist(self.fresh_tv())
Type::Univ(self.fresh_tv())
}
fn universalize<'a>(&mut self, expr: hir::Expr<'a, Type>) -> hir::Expr<'a, Type> {
// TODO
expr
#[allow(clippy::redundant_closure)] // https://github.com/rust-lang/rust-clippy/issues/6903
fn universalize(&mut self, args: Vec<Type>, ret: Type) -> Type {
let mut vars = HashMap::new();
let mut universalize_type = move |ty| match ty {
Type::Exist(tv) if self.resolve_tv(tv).is_none() => vars
.entry(tv)
.or_insert_with_key(|tv| {
let ty = self.fresh_univ();
self.ctx.insert(*tv, ty.clone());
ty
})
.clone(),
_ => ty,
};
Type::Fun {
args: args.into_iter().map(|t| universalize_type(t)).collect(),
ret: Box::new(universalize_type(ret)),
}
}
fn unify(&mut self, ty1: &Type, ty2: &Type) -> Result<Type> {
match (ty1, ty2) {
(Type::Prim(p1), Type::Prim(p2)) if p1 == p2 => Ok(ty2.clone()),
(Type::Exist(tv), ty) | (ty, Type::Exist(tv)) => match self.resolve_tv(*tv) {
Some(existing_ty) if *ty == existing_ty => Ok(ty.clone()),
Some(existing_ty) => Err(Error::TypeMismatch {
expected: ty.clone(),
actual: existing_ty.into(),
}),
Some(existing_ty) if self.types_match(ty, &existing_ty) => Ok(ty.clone()),
Some(var @ ast::Type::Var(_)) => {
let var = self.type_from_ast_type(var);
self.unify(&var, ty)
}
Some(existing_ty) => match ty {
Type::Exist(_) => {
let rhs = self.type_from_ast_type(existing_ty);
self.unify(ty, &rhs)
}
_ => Err(Error::TypeMismatch {
expected: ty.clone(),
actual: self.type_from_ast_type(existing_ty),
}),
},
None => match self.ctx.insert(*tv, ty.clone()) {
Some(existing) => self.unify(&existing, ty),
None => Ok(ty.clone()),
},
},
(Type::Univ(u1), Type::Univ(u2)) if u1 == u2 => Ok(ty2.clone()),
(Type::Univ(u), ty) | (ty, Type::Univ(u)) => {
let ident = self.name_univ(*u);
match self.instantiations.resolve(&ident) {
Some(existing_ty) if ty == existing_ty => Ok(ty.clone()),
Some(existing_ty) => Err(Error::TypeMismatch {
expected: ty.clone(),
actual: existing_ty.clone(),
}),
None => {
self.instantiations.set(ident, ty.clone());
Ok(ty.clone())
}
}
}
(Type::Prim(p1), Type::Prim(p2)) if p1 == p2 => Ok(ty2.clone()),
(
Type::Fun {
args: args1,
@ -395,18 +425,24 @@ impl<'ast> Typechecker<'ast> {
}
}
fn finalize_expr(&self, expr: hir::Expr<'ast, Type>) -> Result<hir::Expr<'ast, ast::Type>> {
fn finalize_expr(
&self,
expr: hir::Expr<'ast, Type>,
) -> Result<hir::Expr<'ast, ast::Type<'ast>>> {
expr.traverse_type(|ty| self.finalize_type(ty))
}
fn finalize_decl(&self, decl: hir::Decl<'ast, Type>) -> Result<hir::Decl<'ast, ast::Type>> {
fn finalize_decl(
&self,
decl: hir::Decl<'ast, Type>,
) -> Result<hir::Decl<'ast, ast::Type<'ast>>> {
decl.traverse_type(|ty| self.finalize_type(ty))
}
fn finalize_type(&self, ty: Type) -> Result<ast::Type> {
match ty {
fn finalize_type(&self, ty: Type) -> Result<ast::Type<'static>> {
let ret = match ty {
Type::Exist(tv) => self.resolve_tv(tv).ok_or(Error::AmbiguousType(tv)),
Type::Univ(tv) => todo!(),
Type::Univ(tv) => Ok(ast::Type::Var(self.name_univ(tv))),
Type::Nullary(_) => todo!(),
Type::Prim(pr) => Ok(pr.into()),
Type::Fun { args, ret } => Ok(ast::Type::Function(ast::FunctionType {
@ -416,23 +452,105 @@ impl<'ast> Typechecker<'ast> {
.try_collect()?,
ret: Box::new(self.finalize_type(*ret)?),
})),
}
};
ret
}
fn resolve_tv(&self, tv: TyVar) -> Option<ast::Type> {
fn resolve_tv(&self, tv: TyVar) -> Option<ast::Type<'static>> {
let mut res = &Type::Exist(tv);
loop {
match res {
Type::Exist(tv) => {
res = self.ctx.get(tv)?;
}
Type::Univ(_) => todo!(),
Type::Univ(tv) => {
let ident = self.name_univ(*tv);
if let Some(r) = self.instantiations.resolve(&ident) {
res = r;
} else {
break Some(ast::Type::Var(ident));
}
}
Type::Nullary(_) => todo!(),
Type::Prim(pr) => break Some((*pr).into()),
Type::Fun { args, ret } => todo!(),
}
}
}
fn type_from_ast_type(&mut self, ast_type: ast::Type<'ast>) -> Type {
match ast_type {
ast::Type::Int => INT,
ast::Type::Float => FLOAT,
ast::Type::Bool => BOOL,
ast::Type::CString => CSTRING,
ast::Type::Function(ast::FunctionType { args, ret }) => Type::Fun {
args: args
.into_iter()
.map(|t| self.type_from_ast_type(t))
.collect(),
ret: Box::new(self.type_from_ast_type(*ret)),
},
ast::Type::Var(id) => Type::Univ({
let opt_tv = { self.type_vars.borrow_mut().0.get_by_left(&id).copied() };
opt_tv.unwrap_or_else(|| {
let tv = self.fresh_tv();
self.type_vars
.borrow_mut()
.0
.insert_no_overwrite(id, tv)
.unwrap();
tv
})
}),
}
}
fn name_univ(&self, tv: TyVar) -> Ident<'static> {
let mut vars = self.type_vars.borrow_mut();
vars.0
.get_by_right(&tv)
.map(Ident::to_owned)
.unwrap_or_else(|| {
let name = vars.1.make_name();
vars.0.insert_no_overwrite(name.clone(), tv).unwrap();
name
})
}
fn commit_instantiations(&mut self) {
let mut ctx = mem::take(&mut self.ctx);
for (_, v) in ctx.iter_mut() {
if let Type::Univ(tv) = v {
if let Some(concrete) = self.instantiations.resolve(&self.name_univ(*tv)) {
*v = concrete.clone();
}
}
}
self.ctx = ctx;
self.instantiations.pop();
}
fn types_match(&self, type_: &Type, ast_type: &ast::Type<'ast>) -> bool {
match (type_, ast_type) {
(Type::Univ(u), ast::Type::Var(v)) => {
Some(u) == self.type_vars.borrow().0.get_by_left(v)
}
(Type::Univ(_), _) => false,
(Type::Exist(_), _) => false,
(Type::Nullary(_), _) => todo!(),
(Type::Prim(pr), ty) => ast::Type::from(*pr) == *ty,
(Type::Fun { args, ret }, ast::Type::Function(ft)) => {
args.len() == ft.args.len()
&& args
.iter()
.zip(&ft.args)
.all(|(a1, a2)| self.types_match(a1, &a2))
&& self.types_match(&*ret, &*ft.ret)
}
(Type::Fun { .. }, _) => false,
}
}
}
pub fn typecheck_expr(expr: ast::Expr) -> Result<hir::Expr<ast::Type>> {
@ -446,8 +564,10 @@ pub fn typecheck_toplevel(decls: Vec<ast::Decl>) -> Result<Vec<hir::Decl<ast::Ty
decls
.into_iter()
.map(|decl| {
let decl = typechecker.tc_decl(decl)?;
typechecker.finalize_decl(decl)
let hir_decl = typechecker.tc_decl(decl)?;
let res = typechecker.finalize_decl(hir_decl)?;
typechecker.ctx.clear();
Ok(res)
})
.try_collect()
}
@ -462,7 +582,13 @@ mod tests {
let parsed_expr = test_parse!(expr, $expr);
let parsed_type = test_parse!(type_, $type);
let res = typecheck_expr(parsed_expr).unwrap_or_else(|e| panic!("{}", e));
assert_eq!(res.type_(), &parsed_type);
assert!(
res.type_().alpha_equiv(&parsed_type),
"{} inferred type {}, but expected {}",
$expr,
res.type_(),
$type
);
};
}
@ -501,9 +627,8 @@ mod tests {
}
#[test]
#[ignore]
fn generic_function() {
assert_type!("fn x = x", "fn x, y -> x");
assert_type!("fn x = x", "fn x -> x");
}
#[test]
@ -517,6 +642,11 @@ mod tests {
assert_type!("fn x = x + 1", "fn int -> int");
}
#[test]
fn arg_ascriptions() {
assert_type!("fn (x: int) = x", "fn int -> int");
}
#[test]
fn call_concrete_function() {
assert_type!("(fn x = x + 1) 2", "int");

41
tests/compile.rs Normal file
View file

@ -0,0 +1,41 @@
use std::process::Command;
use crate_root::root;
const FIXTURES: &[(&str, i32)] = &[("simple", 5), ("functions", 9)];
#[test]
fn compile_and_run_files() {
let ach = root().unwrap().join("ach");
for (fixture, exit_code) in FIXTURES {
println!(">>> Testing: {}", fixture);
println!(" Running: `make {}`", fixture);
assert!(
Command::new("make")
.arg(fixture)
.current_dir(&ach)
.spawn()
.unwrap()
.wait()
.unwrap()
.success(),
"make failed"
);
let out_path = ach.join(fixture);
println!(" Running: `{}`", out_path.to_str().unwrap());
assert_eq!(
Command::new(out_path)
.spawn()
.unwrap()
.wait()
.unwrap()
.code()
.unwrap(),
*exit_code,
);
println!(" OK");
}
}