Add the start of a hindley-milner typechecker

The beginning of a parse-don't-validate-based hindley-milner
typechecker, which returns on success an IR where every AST node
trivially knows its own type, and using those types to determine LLVM
types in codegen.
This commit is contained in:
Griffin Smith 2021-03-13 21:57:27 -05:00
parent f8beda81fb
commit 32a5c0ff0f
20 changed files with 980 additions and 78 deletions

1
Cargo.lock generated
View file

@ -9,6 +9,7 @@ dependencies = [
"derive_more", "derive_more",
"inkwell", "inkwell",
"itertools", "itertools",
"lazy_static",
"llvm-sys", "llvm-sys",
"nom", "nom",
"nom-trace", "nom-trace",

View file

@ -10,6 +10,7 @@ clap = "3.0.0-beta.2"
derive_more = "0.99.11" derive_more = "0.99.11"
inkwell = { git = "https://github.com/TheDan64/inkwell", branch = "master", features = ["llvm11-0"] } inkwell = { git = "https://github.com/TheDan64/inkwell", branch = "master", features = ["llvm11-0"] }
itertools = "0.10.0" itertools = "0.10.0"
lazy_static = "1.4.0"
llvm-sys = "110.0.1" llvm-sys = "110.0.1"
nom = "6.1.2" nom = "6.1.2"
nom-trace = { git = "https://github.com/glittershark/nom-trace", branch = "nom-6" } nom-trace = { git = "https://github.com/glittershark/nom-trace", branch = "nom-6" }

3
ach/.gitignore vendored
View file

@ -1,2 +1,5 @@
*.ll *.ll
*.o *.o
functions
simple

246
src/ast/hir.rs Normal file
View file

@ -0,0 +1,246 @@
use itertools::Itertools;
use super::{BinaryOperator, Ident, Literal, UnaryOperator};
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct Binding<'a, T> {
pub ident: Ident<'a>,
pub type_: T,
pub body: Expr<'a, T>,
}
impl<'a, T> Binding<'a, T> {
fn to_owned(&self) -> Binding<'static, T>
where
T: Clone,
{
Binding {
ident: self.ident.to_owned(),
type_: self.type_.clone(),
body: self.body.to_owned(),
}
}
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub enum Expr<'a, T> {
Ident(Ident<'a>, T),
Literal(Literal, T),
UnaryOp {
op: UnaryOperator,
rhs: Box<Expr<'a, T>>,
type_: T,
},
BinaryOp {
lhs: Box<Expr<'a, T>>,
op: BinaryOperator,
rhs: Box<Expr<'a, T>>,
type_: T,
},
Let {
bindings: Vec<Binding<'a, T>>,
body: Box<Expr<'a, T>>,
type_: T,
},
If {
condition: Box<Expr<'a, T>>,
then: Box<Expr<'a, T>>,
else_: Box<Expr<'a, T>>,
type_: T,
},
Fun {
args: Vec<(Ident<'a>, T)>,
body: Box<Expr<'a, T>>,
type_: T,
},
Call {
fun: Box<Expr<'a, T>>,
args: Vec<Expr<'a, T>>,
type_: T,
},
}
impl<'a, T> Expr<'a, T> {
pub fn type_(&self) -> &T {
match self {
Expr::Ident(_, t) => t,
Expr::Literal(_, t) => t,
Expr::UnaryOp { type_, .. } => type_,
Expr::BinaryOp { type_, .. } => type_,
Expr::Let { type_, .. } => type_,
Expr::If { type_, .. } => type_,
Expr::Fun { type_, .. } => type_,
Expr::Call { type_, .. } => type_,
}
}
pub fn traverse_type<F, U, E>(self, f: F) -> Result<Expr<'a, U>, E>
where
F: Fn(T) -> Result<U, E> + Clone,
{
match self {
Expr::Ident(id, t) => Ok(Expr::Ident(id, f(t)?)),
Expr::Literal(lit, t) => Ok(Expr::Literal(lit, f(t)?)),
Expr::UnaryOp { op, rhs, type_ } => Ok(Expr::UnaryOp {
op,
rhs: Box::new(rhs.traverse_type(f.clone())?),
type_: f(type_)?,
}),
Expr::BinaryOp {
lhs,
op,
rhs,
type_,
} => Ok(Expr::BinaryOp {
lhs: Box::new(lhs.traverse_type(f.clone())?),
op,
rhs: Box::new(rhs.traverse_type(f.clone())?),
type_: f(type_)?,
}),
Expr::Let {
bindings,
body,
type_,
} => Ok(Expr::Let {
bindings: bindings
.into_iter()
.map(|Binding { ident, type_, body }| {
Ok(Binding {
ident,
type_: f(type_)?,
body: body.traverse_type(f.clone())?,
})
})
.collect::<Result<Vec<_>, E>>()?,
body: Box::new(body.traverse_type(f.clone())?),
type_: f(type_)?,
}),
Expr::If {
condition,
then,
else_,
type_,
} => Ok(Expr::If {
condition: Box::new(condition.traverse_type(f.clone())?),
then: Box::new(then.traverse_type(f.clone())?),
else_: Box::new(else_.traverse_type(f.clone())?),
type_: f(type_)?,
}),
Expr::Fun { args, body, type_ } => Ok(Expr::Fun {
args: args
.into_iter()
.map(|(id, t)| Ok((id, f.clone()(t)?)))
.collect::<Result<Vec<_>, E>>()?,
body: Box::new(body.traverse_type(f.clone())?),
type_: f(type_)?,
}),
Expr::Call { fun, args, type_ } => Ok(Expr::Call {
fun: Box::new(fun.traverse_type(f.clone())?),
args: args
.into_iter()
.map(|e| e.traverse_type(f.clone()))
.collect::<Result<Vec<_>, E>>()?,
type_: f(type_)?,
}),
}
}
pub fn to_owned(&self) -> Expr<'static, T>
where
T: Clone,
{
match self {
Expr::Ident(id, t) => Expr::Ident(id.to_owned(), t.clone()),
Expr::Literal(lit, t) => Expr::Literal(lit.clone(), t.clone()),
Expr::UnaryOp { op, rhs, type_ } => Expr::UnaryOp {
op: *op,
rhs: Box::new((**rhs).to_owned()),
type_: type_.clone(),
},
Expr::BinaryOp {
lhs,
op,
rhs,
type_,
} => Expr::BinaryOp {
lhs: Box::new((**lhs).to_owned()),
op: *op,
rhs: Box::new((**rhs).to_owned()),
type_: type_.clone(),
},
Expr::Let {
bindings,
body,
type_,
} => Expr::Let {
bindings: bindings.into_iter().map(|b| b.to_owned()).collect(),
body: Box::new((**body).to_owned()),
type_: type_.clone(),
},
Expr::If {
condition,
then,
else_,
type_,
} => Expr::If {
condition: Box::new((**condition).to_owned()),
then: Box::new((**then).to_owned()),
else_: Box::new((**else_).to_owned()),
type_: type_.clone(),
},
Expr::Fun { args, body, type_ } => Expr::Fun {
args: args
.into_iter()
.map(|(id, t)| (id.to_owned(), t.clone()))
.collect(),
body: Box::new((**body).to_owned()),
type_: type_.clone(),
},
Expr::Call { fun, args, type_ } => Expr::Call {
fun: Box::new((**fun).to_owned()),
args: args.into_iter().map(|e| e.to_owned()).collect(),
type_: type_.clone(),
},
}
}
}
pub enum Decl<'a, T> {
Fun {
name: Ident<'a>,
args: Vec<(Ident<'a>, T)>,
body: Box<Expr<'a, T>>,
type_: T,
},
}
impl<'a, T> Decl<'a, T> {
pub fn traverse_type<F, U, E>(self, f: F) -> Result<Decl<'a, U>, E>
where
F: Fn(T) -> Result<U, E> + Clone,
{
match self {
Decl::Fun {
name,
args,
body,
type_,
} => Ok(Decl::Fun {
name,
args: args
.into_iter()
.map(|(id, t)| Ok((id, f(t)?)))
.try_collect()?,
body: Box::new(body.traverse_type(f.clone())?),
type_: f(type_)?,
}),
}
}
}

View file

@ -1,3 +1,5 @@
pub(crate) mod hir;
use std::borrow::Cow; use std::borrow::Cow;
use std::convert::TryFrom; use std::convert::TryFrom;
use std::fmt::{self, Display, Formatter}; use std::fmt::{self, Display, Formatter};
@ -107,6 +109,7 @@ pub enum UnaryOperator {
#[derive(Debug, PartialEq, Eq, Clone)] #[derive(Debug, PartialEq, Eq, Clone)]
pub enum Literal { pub enum Literal {
Int(u64), Int(u64),
Bool(bool),
} }
#[derive(Debug, PartialEq, Eq, Clone)] #[derive(Debug, PartialEq, Eq, Clone)]

View file

@ -7,12 +7,13 @@ use inkwell::builder::Builder;
pub use inkwell::context::Context; pub use inkwell::context::Context;
use inkwell::module::Module; use inkwell::module::Module;
use inkwell::support::LLVMString; use inkwell::support::LLVMString;
use inkwell::types::FunctionType; use inkwell::types::{BasicType, BasicTypeEnum, FunctionType, IntType};
use inkwell::values::{AnyValueEnum, BasicValueEnum, FunctionValue}; use inkwell::values::{AnyValueEnum, BasicValueEnum, FunctionValue};
use inkwell::IntPredicate; use inkwell::IntPredicate;
use thiserror::Error; use thiserror::Error;
use crate::ast::{BinaryOperator, Binding, Decl, Expr, Fun, Ident, Literal, UnaryOperator}; use crate::ast::hir::{Binding, Decl, Expr};
use crate::ast::{BinaryOperator, Ident, Literal, Type, UnaryOperator};
use crate::common::env::Env; use crate::common::env::Env;
#[derive(Debug, PartialEq, Eq, Error)] #[derive(Debug, PartialEq, Eq, Error)]
@ -36,7 +37,7 @@ pub struct Codegen<'ctx, 'ast> {
context: &'ctx Context, context: &'ctx Context,
pub module: Module<'ctx>, pub module: Module<'ctx>,
builder: Builder<'ctx>, builder: Builder<'ctx>,
env: Env<'ast, AnyValueEnum<'ctx>>, env: Env<&'ast Ident<'ast>, AnyValueEnum<'ctx>>,
function_stack: Vec<FunctionValue<'ctx>>, function_stack: Vec<FunctionValue<'ctx>>,
identifier_counter: u32, identifier_counter: u32,
} }
@ -77,18 +78,23 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
.append_basic_block(*self.function_stack.last().unwrap(), name) .append_basic_block(*self.function_stack.last().unwrap(), name)
} }
pub fn codegen_expr(&mut self, expr: &'ast Expr<'ast>) -> Result<AnyValueEnum<'ctx>> { pub fn codegen_expr(&mut self, expr: &'ast Expr<'ast, Type>) -> Result<AnyValueEnum<'ctx>> {
match expr { match expr {
Expr::Ident(id) => self Expr::Ident(id, _) => self
.env .env
.resolve(id) .resolve(id)
.cloned() .cloned()
.ok_or_else(|| Error::UndefinedVariable(id.to_owned())), .ok_or_else(|| Error::UndefinedVariable(id.to_owned())),
Expr::Literal(Literal::Int(i)) => { Expr::Literal(lit, ty) => {
let ty = self.context.i64_type(); let ty = self.codegen_int_type(ty);
Ok(AnyValueEnum::IntValue(ty.const_int(*i, false))) match lit {
Literal::Int(i) => Ok(AnyValueEnum::IntValue(ty.const_int(*i, false))),
Literal::Bool(b) => Ok(AnyValueEnum::IntValue(
ty.const_int(if *b { 1 } else { 0 }, false),
)),
} }
Expr::UnaryOp { op, rhs } => { }
Expr::UnaryOp { op, rhs, .. } => {
let rhs = self.codegen_expr(rhs)?; let rhs = self.codegen_expr(rhs)?;
match op { match op {
UnaryOperator::Not => unimplemented!(), UnaryOperator::Not => unimplemented!(),
@ -97,7 +103,7 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
)), )),
} }
} }
Expr::BinaryOp { lhs, op, rhs } => { Expr::BinaryOp { lhs, op, rhs, .. } => {
let lhs = self.codegen_expr(lhs)?; let lhs = self.codegen_expr(lhs)?;
let rhs = self.codegen_expr(rhs)?; let rhs = self.codegen_expr(rhs)?;
match op { match op {
@ -135,7 +141,7 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
BinaryOperator::Neq => todo!(), BinaryOperator::Neq => todo!(),
} }
} }
Expr::Let { bindings, body } => { Expr::Let { bindings, body, .. } => {
self.env.push(); self.env.push();
for Binding { ident, body, .. } in bindings { for Binding { ident, body, .. } in bindings {
let val = self.codegen_expr(body)?; let val = self.codegen_expr(body)?;
@ -149,6 +155,7 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
condition, condition,
then, then,
else_, else_,
type_,
} => { } => {
let then_block = self.append_basic_block("then"); let then_block = self.append_basic_block("then");
let else_block = self.append_basic_block("else"); let else_block = self.append_basic_block("else");
@ -168,15 +175,15 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
self.builder.build_unconditional_branch(join_block); self.builder.build_unconditional_branch(join_block);
self.builder.position_at_end(join_block); self.builder.position_at_end(join_block);
let phi = self.builder.build_phi(self.context.i64_type(), "join"); let phi = self.builder.build_phi(self.codegen_type(type_), "join");
phi.add_incoming(&[ phi.add_incoming(&[
(&BasicValueEnum::try_from(then_res).unwrap(), then_block), (&BasicValueEnum::try_from(then_res).unwrap(), then_block),
(&BasicValueEnum::try_from(else_res).unwrap(), else_block), (&BasicValueEnum::try_from(else_res).unwrap(), else_block),
]); ]);
Ok(phi.as_basic_value().into()) Ok(phi.as_basic_value().into())
} }
Expr::Call { fun, args } => { Expr::Call { fun, args, .. } => {
if let Expr::Ident(id) = &**fun { if let Expr::Ident(id, _) = &**fun {
let function = self let function = self
.module .module
.get_function(id.into()) .get_function(id.into())
@ -197,8 +204,7 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
todo!() todo!()
} }
} }
Expr::Fun(fun) => { Expr::Fun { args, body, .. } => {
let Fun { args, body } = &**fun;
let fname = self.fresh_ident("f"); let fname = self.fresh_ident("f");
let cur_block = self.builder.get_insert_block().unwrap(); let cur_block = self.builder.get_insert_block().unwrap();
let env = self.env.save(); // TODO: closures let env = self.env.save(); // TODO: closures
@ -207,29 +213,27 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
self.env.restore(env); self.env.restore(env);
Ok(function.into()) Ok(function.into())
} }
Expr::Ascription { expr, .. } => self.codegen_expr(expr),
} }
} }
pub fn codegen_function( pub fn codegen_function(
&mut self, &mut self,
name: &str, name: &str,
args: &'ast [Ident<'ast>], args: &'ast [(Ident<'ast>, Type)],
body: &'ast Expr<'ast>, body: &'ast Expr<'ast, Type>,
) -> Result<FunctionValue<'ctx>> { ) -> Result<FunctionValue<'ctx>> {
let i64_type = self.context.i64_type();
self.new_function( self.new_function(
name, name,
i64_type.fn_type( self.codegen_type(body.type_()).fn_type(
args.iter() args.iter()
.map(|_| i64_type.into()) .map(|(_, at)| self.codegen_type(at))
.collect::<Vec<_>>() .collect::<Vec<_>>()
.as_slice(), .as_slice(),
false, false,
), ),
); );
self.env.push(); self.env.push();
for (i, arg) in args.iter().enumerate() { for (i, (arg, _)) in args.iter().enumerate() {
self.env.set( self.env.set(
arg, arg,
self.cur_function().get_nth_param(i as u32).unwrap().into(), self.cur_function().get_nth_param(i as u32).unwrap().into(),
@ -240,11 +244,10 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
Ok(self.finish_function(&res)) Ok(self.finish_function(&res))
} }
pub fn codegen_decl(&mut self, decl: &'ast Decl<'ast>) -> Result<()> { pub fn codegen_decl(&mut self, decl: &'ast Decl<'ast, Type>) -> Result<()> {
match decl { match decl {
Decl::Fun { Decl::Fun {
name, name, args, body, ..
body: Fun { args, body },
} => { } => {
self.codegen_function(name.into(), args, body)?; self.codegen_function(name.into(), args, body)?;
Ok(()) Ok(())
@ -252,13 +255,28 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
} }
} }
pub fn codegen_main(&mut self, expr: &'ast Expr<'ast>) -> Result<()> { pub fn codegen_main(&mut self, expr: &'ast Expr<'ast, Type>) -> Result<()> {
self.new_function("main", self.context.i64_type().fn_type(&[], false)); self.new_function("main", self.context.i64_type().fn_type(&[], false));
let res = self.codegen_expr(expr)?.try_into().unwrap(); let res = self.codegen_expr(expr)?.try_into().unwrap();
if *expr.type_() != Type::Int {
self.builder
.build_return(Some(&self.context.i64_type().const_int(0, false)));
} else {
self.finish_function(&res); self.finish_function(&res);
}
Ok(()) Ok(())
} }
fn codegen_type(&self, type_: &'ast Type) -> BasicTypeEnum<'ctx> {
// TODO
self.context.i64_type().into()
}
fn codegen_int_type(&self, type_: &'ast Type) -> IntType<'ctx> {
// TODO
self.context.i64_type()
}
pub fn print_to_file<P>(&self, path: P) -> Result<()> pub fn print_to_file<P>(&self, path: P) -> Result<()>
where where
P: AsRef<Path>, P: AsRef<Path>,
@ -299,6 +317,8 @@ mod tests {
fn jit_eval<T>(expr: &str) -> anyhow::Result<T> { fn jit_eval<T>(expr: &str) -> anyhow::Result<T> {
let expr = crate::parser::expr(expr).unwrap().1; let expr = crate::parser::expr(expr).unwrap().1;
let expr = crate::tc::typecheck_expr(expr).unwrap();
let context = Context::create(); let context = Context::create();
let mut codegen = Codegen::new(&context, "test"); let mut codegen = Codegen::new(&context, "test");
let execution_engine = codegen let execution_engine = codegen

View file

@ -4,10 +4,11 @@ use inkwell::execution_engine::JitFunction;
use inkwell::OptimizationLevel; use inkwell::OptimizationLevel;
pub use llvm::*; pub use llvm::*;
use crate::ast::Expr; use crate::ast::hir::Expr;
use crate::ast::Type;
use crate::common::Result; use crate::common::Result;
pub fn jit_eval<T>(expr: &Expr) -> Result<T> { pub fn jit_eval<T>(expr: &Expr<Type>) -> Result<T> {
let context = Context::create(); let context = Context::create();
let mut codegen = Codegen::new(&context, "eval"); let mut codegen = Codegen::new(&context, "eval");
let execution_engine = codegen let execution_engine = codegen

39
src/commands/check.rs Normal file
View file

@ -0,0 +1,39 @@
use clap::Clap;
use std::path::PathBuf;
use crate::ast::Type;
use crate::{parser, tc, Result};
/// Typecheck a file or expression
#[derive(Clap)]
pub struct Check {
/// File to check
path: Option<PathBuf>,
/// Expression to check
#[clap(long, short = 'e')]
expr: Option<String>,
}
fn run_expr(expr: String) -> Result<Type> {
let (_, parsed) = parser::expr(&expr)?;
let hir_expr = tc::typecheck_expr(parsed)?;
Ok(hir_expr.type_().clone())
}
fn run_path(path: PathBuf) -> Result<Type> {
todo!()
}
impl Check {
pub fn run(self) -> Result<()> {
let type_ = match (self.path, self.expr) {
(None, None) => Err("Must specify either a file or expression to check".into()),
(Some(_), Some(_)) => Err("Cannot specify both a file and expression to check".into()),
(None, Some(expr)) => run_expr(expr),
(Some(path), None) => run_path(path),
}?;
println!("type: {}", type_);
Ok(())
}
}

View file

@ -3,6 +3,7 @@ use clap::Clap;
use crate::codegen; use crate::codegen;
use crate::interpreter; use crate::interpreter;
use crate::parser; use crate::parser;
use crate::tc;
use crate::Result; use crate::Result;
/// Evaluate an expression and print its result /// Evaluate an expression and print its result
@ -19,10 +20,11 @@ pub struct Eval {
impl Eval { impl Eval {
pub fn run(self) -> Result<()> { pub fn run(self) -> Result<()> {
let (_, parsed) = parser::expr(&self.expr)?; let (_, parsed) = parser::expr(&self.expr)?;
let hir = tc::typecheck_expr(parsed)?;
let result = if self.jit { let result = if self.jit {
codegen::jit_eval::<i64>(&parsed)?.into() codegen::jit_eval::<i64>(&hir)?.into()
} else { } else {
interpreter::eval(&parsed)? interpreter::eval(&hir)?
}; };
println!("{}", result); println!("{}", result);
Ok(()) Ok(())

View file

@ -1,5 +1,7 @@
pub mod check;
pub mod compile; pub mod compile;
pub mod eval; pub mod eval;
pub use check::Check;
pub use compile::Compile; pub use compile::Compile;
pub use eval::Eval; pub use eval::Eval;

View file

@ -1,19 +1,25 @@
use std::borrow::Borrow;
use std::collections::HashMap; use std::collections::HashMap;
use std::hash::Hash;
use std::mem; use std::mem;
use crate::ast::Ident;
/// A lexical environment /// A lexical environment
#[derive(Debug, PartialEq, Eq)] #[derive(Debug, PartialEq, Eq)]
pub struct Env<'ast, V>(Vec<HashMap<&'ast Ident<'ast>, V>>); pub struct Env<K: Eq + Hash, V>(Vec<HashMap<K, V>>);
impl<'ast, V> Default for Env<'ast, V> { impl<K, V> Default for Env<K, V>
where
K: Eq + Hash,
{
fn default() -> Self { fn default() -> Self {
Self::new() Self::new()
} }
} }
impl<'ast, V> Env<'ast, V> { impl<K, V> Env<K, V>
where
K: Eq + Hash,
{
pub fn new() -> Self { pub fn new() -> Self {
Self(vec![Default::default()]) Self(vec![Default::default()])
} }
@ -34,11 +40,15 @@ impl<'ast, V> Env<'ast, V> {
*self = saved; *self = saved;
} }
pub fn set(&mut self, k: &'ast Ident<'ast>, v: V) { pub fn set(&mut self, k: K, v: V) {
self.0.last_mut().unwrap().insert(k, v); self.0.last_mut().unwrap().insert(k, v);
} }
pub fn resolve<'a>(&'a self, k: &'ast Ident<'ast>) -> Option<&'a V> { pub fn resolve<'a, Q>(&'a self, k: &Q) -> Option<&'a V>
where
K: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
for ctx in self.0.iter().rev() { for ctx in self.0.iter().rev() {
if let Some(res) = ctx.get(k) { if let Some(res) = ctx.get(k) {
return Some(res); return Some(res);

View file

@ -2,7 +2,7 @@ use std::{io, result};
use thiserror::Error; use thiserror::Error;
use crate::{codegen, interpreter, parser}; use crate::{codegen, interpreter, parser, tc};
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum Error { pub enum Error {
@ -18,6 +18,9 @@ pub enum Error {
#[error("Compile error: {0}")] #[error("Compile error: {0}")]
CodegenError(#[from] codegen::Error), CodegenError(#[from] codegen::Error),
#[error("Type error: {0}")]
TypeError(#[from] tc::Error),
#[error("{0}")] #[error("{0}")]
Message(String), Message(String),
} }
@ -28,6 +31,12 @@ impl From<String> for Error {
} }
} }
impl<'a> From<&'a str> for Error {
fn from(s: &'a str) -> Self {
Self::Message(s.to_owned())
}
}
impl<'a> From<nom::Err<nom::error::Error<&'a str>>> for Error { impl<'a> From<nom::Err<nom::error::Error<&'a str>>> for Error {
fn from(e: nom::Err<nom::error::Error<&'a str>>) -> Self { fn from(e: nom::Err<nom::error::Error<&'a str>>) -> Self {
use nom::error::Error as NomError; use nom::error::Error as NomError;

View file

@ -8,7 +8,7 @@ use test_strategy::Arbitrary;
use crate::codegen::{self, Codegen}; use crate::codegen::{self, Codegen};
use crate::common::Result; use crate::common::Result;
use crate::parser; use crate::{parser, tc};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Arbitrary)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Arbitrary)]
pub enum OutputFormat { pub enum OutputFormat {
@ -55,6 +55,8 @@ pub struct CompilerOptions {
pub fn compile_file(input: &Path, output: &Path, options: &CompilerOptions) -> Result<()> { pub fn compile_file(input: &Path, output: &Path, options: &CompilerOptions) -> Result<()> {
let src = fs::read_to_string(input)?; let src = fs::read_to_string(input)?;
let (_, decls) = parser::toplevel(&src)?; // TODO: statements let (_, decls) = parser::toplevel(&src)?; // TODO: statements
let decls = tc::typecheck_toplevel(decls)?;
let context = codegen::Context::create(); let context = codegen::Context::create();
let mut codegen = Codegen::new( let mut codegen = Codegen::new(
&context, &context,

View file

@ -3,14 +3,13 @@ mod value;
pub use self::error::{Error, Result}; pub use self::error::{Error, Result};
pub use self::value::{Function, Value}; pub use self::value::{Function, Value};
use crate::ast::{ use crate::ast::hir::{Binding, Expr};
BinaryOperator, Binding, Expr, FunctionType, Ident, Literal, Type, UnaryOperator, use crate::ast::{BinaryOperator, FunctionType, Ident, Literal, Type, UnaryOperator};
};
use crate::common::env::Env; use crate::common::env::Env;
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub struct Interpreter<'a> { pub struct Interpreter<'a> {
env: Env<'a, Value<'a>>, env: Env<&'a Ident<'a>, Value<'a>>,
} }
impl<'a> Interpreter<'a> { impl<'a> Interpreter<'a> {
@ -25,18 +24,19 @@ impl<'a> Interpreter<'a> {
.ok_or_else(|| Error::UndefinedVariable(var.to_owned())) .ok_or_else(|| Error::UndefinedVariable(var.to_owned()))
} }
pub fn eval(&mut self, expr: &'a Expr<'a>) -> Result<Value<'a>> { pub fn eval(&mut self, expr: &'a Expr<'a, Type>) -> Result<Value<'a>> {
match expr { let res = match expr {
Expr::Ident(id) => self.resolve(id), Expr::Ident(id, _) => self.resolve(id),
Expr::Literal(Literal::Int(i)) => Ok((*i).into()), Expr::Literal(Literal::Int(i), _) => Ok((*i).into()),
Expr::UnaryOp { op, rhs } => { Expr::Literal(Literal::Bool(b), _) => Ok((*b).into()),
Expr::UnaryOp { op, rhs, .. } => {
let rhs = self.eval(rhs)?; let rhs = self.eval(rhs)?;
match op { match op {
UnaryOperator::Neg => -rhs, UnaryOperator::Neg => -rhs,
_ => unimplemented!(), _ => unimplemented!(),
} }
} }
Expr::BinaryOp { lhs, op, rhs } => { Expr::BinaryOp { lhs, op, rhs, .. } => {
let lhs = self.eval(lhs)?; let lhs = self.eval(lhs)?;
let rhs = self.eval(rhs)?; let rhs = self.eval(rhs)?;
match op { match op {
@ -49,7 +49,7 @@ impl<'a> Interpreter<'a> {
BinaryOperator::Neq => todo!(), BinaryOperator::Neq => todo!(),
} }
} }
Expr::Let { bindings, body } => { Expr::Let { bindings, body, .. } => {
self.env.push(); self.env.push();
for Binding { ident, body, .. } in bindings { for Binding { ident, body, .. } in bindings {
let val = self.eval(body)?; let val = self.eval(body)?;
@ -63,6 +63,7 @@ impl<'a> Interpreter<'a> {
condition, condition,
then, then,
else_, else_,
..
} => { } => {
let condition = self.eval(condition)?; let condition = self.eval(condition)?;
if *(condition.as_type::<bool>()?) { if *(condition.as_type::<bool>()?) {
@ -71,7 +72,7 @@ impl<'a> Interpreter<'a> {
self.eval(else_) self.eval(else_)
} }
} }
Expr::Call { ref fun, args } => { Expr::Call { ref fun, args, .. } => {
let fun = self.eval(fun)?; let fun = self.eval(fun)?;
let expected_type = FunctionType { let expected_type = FunctionType {
args: args.iter().map(|_| Type::Int).collect(), args: args.iter().map(|_| Type::Int).collect(),
@ -94,21 +95,26 @@ impl<'a> Interpreter<'a> {
} }
Ok(Value::from(*interpreter.eval(body)?.as_type::<i64>()?)) Ok(Value::from(*interpreter.eval(body)?.as_type::<i64>()?))
} }
Expr::Fun(fun) => Ok(Value::from(value::Function { Expr::Fun { args, body, type_ } => {
let type_ = match type_ {
Type::Function(ft) => ft.clone(),
_ => unreachable!("Function expression without function type"),
};
Ok(Value::from(value::Function {
// TODO // TODO
type_: FunctionType { type_,
args: fun.args.iter().map(|_| Type::Int).collect(), args: args.iter().map(|(arg, _)| arg.to_owned()).collect(),
ret: Box::new(Type::Int), body: (**body).to_owned(),
}, }))
args: fun.args.iter().map(|arg| arg.to_owned()).collect(),
body: fun.body.to_owned(),
})),
Expr::Ascription { expr, .. } => self.eval(expr),
} }
}?;
debug_assert_eq!(&res.type_(), expr.type_());
Ok(res)
} }
} }
pub fn eval<'a>(expr: &'a Expr<'a>) -> Result<Value> { pub fn eval<'a>(expr: &'a Expr<'a, Type>) -> Result<Value> {
let mut interpreter = Interpreter::new(); let mut interpreter = Interpreter::new();
interpreter.eval(expr) interpreter.eval(expr)
} }
@ -121,17 +127,18 @@ mod tests {
use super::*; use super::*;
use BinaryOperator::*; use BinaryOperator::*;
fn int_lit(i: u64) -> Box<Expr<'static>> { fn int_lit(i: u64) -> Box<Expr<'static, Type>> {
Box::new(Expr::Literal(Literal::Int(i))) Box::new(Expr::Literal(Literal::Int(i), Type::Int))
} }
fn parse_eval<T>(src: &str) -> T fn do_eval<T>(src: &str) -> T
where where
for<'a> &'a T: TryFrom<&'a Val<'a>>, for<'a> &'a T: TryFrom<&'a Val<'a>>,
T: Clone + TypeOf, T: Clone + TypeOf,
{ {
let expr = crate::parser::expr(src).unwrap().1; let expr = crate::parser::expr(src).unwrap().1;
let res = eval(&expr).unwrap(); let hir = crate::tc::typecheck_expr(expr).unwrap();
let res = eval(&hir).unwrap();
res.as_type::<T>().unwrap().clone() res.as_type::<T>().unwrap().clone()
} }
@ -141,6 +148,7 @@ mod tests {
lhs: int_lit(1), lhs: int_lit(1),
op: Mul, op: Mul,
rhs: int_lit(2), rhs: int_lit(2),
type_: Type::Int,
}; };
let res = eval(&expr).unwrap(); let res = eval(&expr).unwrap();
assert_eq!(*res.as_type::<i64>().unwrap(), 2); assert_eq!(*res.as_type::<i64>().unwrap(), 2);
@ -148,19 +156,19 @@ mod tests {
#[test] #[test]
fn variable_shadowing() { fn variable_shadowing() {
let res = parse_eval::<i64>("let x = 1 in (let x = 2 in x) + x"); let res = do_eval::<i64>("let x = 1 in (let x = 2 in x) + x");
assert_eq!(res, 3); assert_eq!(res, 3);
} }
#[test] #[test]
fn conditional_with_equals() { fn conditional_with_equals() {
let res = parse_eval::<i64>("let x = 1 in if x == 1 then 2 else 4"); let res = do_eval::<i64>("let x = 1 in if x == 1 then 2 else 4");
assert_eq!(res, 2); assert_eq!(res, 2);
} }
#[test] #[test]
fn function_call() { fn function_call() {
let res = parse_eval::<i64>("let id = fn x = x in id 1"); let res = do_eval::<i64>("let id = fn x = x in id 1");
assert_eq!(res, 1); assert_eq!(res, 1);
} }
} }

View file

@ -6,13 +6,14 @@ use std::rc::Rc;
use derive_more::{Deref, From, TryInto}; use derive_more::{Deref, From, TryInto};
use super::{Error, Result}; use super::{Error, Result};
use crate::ast::{Expr, FunctionType, Ident, Type}; use crate::ast::hir::Expr;
use crate::ast::{FunctionType, Ident, Type};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Function<'a> { pub struct Function<'a> {
pub type_: FunctionType, pub type_: FunctionType,
pub args: Vec<Ident<'a>>, pub args: Vec<Ident<'a>>,
pub body: Expr<'a>, pub body: Expr<'a, Type>,
} }
#[derive(From, TryInto)] #[derive(From, TryInto)]

View file

@ -8,6 +8,7 @@ pub mod compiler;
pub mod interpreter; pub mod interpreter;
#[macro_use] #[macro_use]
pub mod parser; pub mod parser;
pub mod tc;
pub use common::{Error, Result}; pub use common::{Error, Result};
@ -21,6 +22,7 @@ struct Opts {
enum Command { enum Command {
Eval(commands::Eval), Eval(commands::Eval),
Compile(commands::Compile), Compile(commands::Compile),
Check(commands::Check),
} }
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
@ -28,5 +30,6 @@ fn main() -> anyhow::Result<()> {
match opts.subcommand { match opts.subcommand {
Command::Eval(eval) => Ok(eval.run()?), Command::Eval(eval) => Ok(eval.run()?),
Command::Compile(compile) => Ok(compile.run()?), Command::Compile(compile) => Ok(compile.run()?),
Command::Check(check) => Ok(check.run()?),
} }
} }

View file

@ -156,7 +156,14 @@ where
named!(int(&str) -> Literal, map!(flat_map!(digit1, parse_to!(u64)), Literal::Int)); named!(int(&str) -> Literal, map!(flat_map!(digit1, parse_to!(u64)), Literal::Int));
named!(literal(&str) -> Expr, map!(alt!(int), Expr::Literal)); named!(bool_(&str) -> Literal, alt!(
tag!("true") => { |_| Literal::Bool(true) } |
tag!("false") => { |_| Literal::Bool(false) }
));
named!(literal(&str) -> Literal, alt!(int | bool_));
named!(literal_expr(&str) -> Expr, map!(literal, Expr::Literal));
named!(binding(&str) -> Binding, do_parse!( named!(binding(&str) -> Binding, do_parse!(
multispace0 multispace0
@ -262,7 +269,7 @@ named!(fun_expr(&str) -> Expr, do_parse!(
named!(arg(&str) -> Expr, alt!( named!(arg(&str) -> Expr, alt!(
ident_expr | ident_expr |
literal | literal_expr |
paren_expr paren_expr
)); ));
@ -280,7 +287,7 @@ named!(simple_expr_unascripted(&str) -> Expr, alt!(
let_ | let_ |
if_ | if_ |
fun_expr | fun_expr |
literal | literal_expr |
ident_expr ident_expr
)); ));
@ -399,6 +406,18 @@ pub(crate) mod tests {
} }
} }
#[test]
fn bools() {
assert_eq!(
test_parse!(expr, "true"),
Expr::Literal(Literal::Bool(true))
);
assert_eq!(
test_parse!(expr, "false"),
Expr::Literal(Literal::Bool(false))
);
}
#[test] #[test]
fn let_complex() { fn let_complex() {
let res = test_parse!(expr, "let x = 1; y = x * 7 in (x + y) * 4"); let res = test_parse!(expr, "let x = 1; y = x * 7 in (x + y) * 4");

View file

@ -1,3 +1,4 @@
#[cfg(test)]
#[macro_use] #[macro_use]
macro_rules! test_parse { macro_rules! test_parse {
($parser: ident, $src: expr) => {{ ($parser: ident, $src: expr) => {{

View file

@ -14,7 +14,10 @@ pub use type_::type_;
pub type Error = nom::Err<nom::error::Error<String>>; pub type Error = nom::Err<nom::error::Error<String>>;
pub(crate) fn is_reserved(s: &str) -> bool { pub(crate) fn is_reserved(s: &str) -> bool {
matches!(s, "if" | "then" | "else" | "let" | "in" | "fn") matches!(
s,
"if" | "then" | "else" | "let" | "in" | "fn" | "int" | "float" | "bool" | "true" | "false"
)
} }
pub(crate) fn ident<'a, E>(i: &'a str) -> nom::IResult<&'a str, Ident, E> pub(crate) fn ident<'a, E>(i: &'a str) -> nom::IResult<&'a str, Ident, E>

528
src/tc/mod.rs Normal file
View file

@ -0,0 +1,528 @@
use derive_more::From;
use itertools::Itertools;
use std::collections::HashMap;
use std::convert::{TryFrom, TryInto};
use std::fmt::{self, Display};
use std::result;
use thiserror::Error;
use crate::ast::{self, hir, BinaryOperator, Ident, Literal};
use crate::common::env::Env;
#[derive(Debug, Error)]
pub enum Error {
#[error("Undefined variable {0}")]
UndefinedVariable(Ident<'static>),
#[error("Mismatched types: expected {expected}, but got {actual}")]
TypeMismatch { expected: Type, actual: Type },
#[error("Mismatched types, expected numeric type, but got {0}")]
NonNumeric(Type),
#[error("Ambiguous type {0}")]
AmbiguousType(TyVar),
}
pub type Result<T> = result::Result<T, Error>;
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
pub struct TyVar(u64);
impl Display for TyVar {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "t{}", self.0)
}
}
#[derive(Debug, PartialEq, Eq, Clone, Hash)]
pub struct NullaryType(String);
impl Display for NullaryType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.0)
}
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum PrimType {
Int,
Float,
Bool,
}
impl From<PrimType> for ast::Type {
fn from(pr: PrimType) -> Self {
match pr {
PrimType::Int => ast::Type::Int,
PrimType::Float => ast::Type::Float,
PrimType::Bool => ast::Type::Bool,
}
}
}
impl Display for PrimType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
PrimType::Int => f.write_str("int"),
PrimType::Float => f.write_str("float"),
PrimType::Bool => f.write_str("bool"),
}
}
}
#[derive(Debug, PartialEq, Eq, Clone, From)]
pub enum Type {
#[from(ignore)]
Univ(TyVar),
#[from(ignore)]
Exist(TyVar),
Nullary(NullaryType),
Prim(PrimType),
Fun {
args: Vec<Type>,
ret: Box<Type>,
},
}
impl PartialEq<ast::Type> for Type {
fn eq(&self, other: &ast::Type) -> bool {
match (self, other) {
(Type::Univ(_), _) => todo!(),
(Type::Exist(_), _) => false,
(Type::Nullary(_), _) => todo!(),
(Type::Prim(pr), ty) => ast::Type::from(*pr) == *ty,
(Type::Fun { args, ret }, ast::Type::Function(ft)) => {
*args == ft.args && (**ret).eq(&*ft.ret)
}
(Type::Fun { .. }, _) => false,
}
}
}
impl TryFrom<Type> for ast::Type {
type Error = Type;
fn try_from(value: Type) -> result::Result<Self, Self::Error> {
match value {
Type::Univ(_) => todo!(),
Type::Exist(_) => Err(value),
Type::Nullary(_) => todo!(),
Type::Prim(p) => Ok(p.into()),
Type::Fun { ref args, ref ret } => Ok(ast::Type::Function(ast::FunctionType {
args: args
.clone()
.into_iter()
.map(Self::try_from)
.try_collect()
.map_err(|_| value.clone())?,
ret: Box::new((*ret.clone()).try_into().map_err(|_| value.clone())?),
})),
}
}
}
const INT: Type = Type::Prim(PrimType::Int);
const FLOAT: Type = Type::Prim(PrimType::Float);
const BOOL: Type = Type::Prim(PrimType::Bool);
impl Display for Type {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Type::Nullary(nt) => nt.fmt(f),
Type::Prim(p) => p.fmt(f),
Type::Univ(TyVar(n)) => write!(f, "∀{}", n),
Type::Exist(TyVar(n)) => write!(f, "∃{}", n),
Type::Fun { args, ret } => write!(f, "fn {} -> {}", args.iter().join(", "), ret),
}
}
}
impl From<ast::Type> for Type {
fn from(type_: ast::Type) -> Self {
match type_ {
ast::Type::Int => INT,
ast::Type::Float => FLOAT,
ast::Type::Bool => BOOL,
ast::Type::Function(ast::FunctionType { args, ret }) => Type::Fun {
args: args.into_iter().map(Self::from).collect(),
ret: Box::new(Self::from(*ret)),
},
}
}
}
struct Typechecker<'ast> {
ty_var_counter: u64,
ctx: HashMap<TyVar, Type>,
env: Env<Ident<'ast>, Type>,
}
impl<'ast> Typechecker<'ast> {
fn new() -> Self {
Self {
ty_var_counter: 0,
ctx: Default::default(),
env: Default::default(),
}
}
pub(crate) fn tc_expr(&mut self, expr: ast::Expr<'ast>) -> Result<hir::Expr<'ast, Type>> {
match expr {
ast::Expr::Ident(ident) => {
let type_ = self
.env
.resolve(&ident)
.ok_or_else(|| Error::UndefinedVariable(ident.to_owned()))?
.clone();
Ok(hir::Expr::Ident(ident, type_))
}
ast::Expr::Literal(lit) => {
let type_ = match lit {
Literal::Int(_) => Type::Prim(PrimType::Int),
Literal::Bool(_) => Type::Prim(PrimType::Bool),
};
Ok(hir::Expr::Literal(lit, type_))
}
ast::Expr::UnaryOp { op, rhs } => todo!(),
ast::Expr::BinaryOp { lhs, op, rhs } => {
let lhs = self.tc_expr(*lhs)?;
let rhs = self.tc_expr(*rhs)?;
let type_ = match op {
BinaryOperator::Equ | BinaryOperator::Neq => {
self.unify(lhs.type_(), rhs.type_())?;
Type::Prim(PrimType::Bool)
}
BinaryOperator::Add | BinaryOperator::Sub | BinaryOperator::Mul => {
let ty = self.unify(lhs.type_(), rhs.type_())?;
// if !matches!(ty, Type::Int | Type::Float) {
// return Err(Error::NonNumeric(ty));
// }
ty
}
BinaryOperator::Div => todo!(),
BinaryOperator::Pow => todo!(),
};
Ok(hir::Expr::BinaryOp {
lhs: Box::new(lhs),
op,
rhs: Box::new(rhs),
type_,
})
}
ast::Expr::Let { bindings, body } => {
self.env.push();
let bindings = bindings
.into_iter()
.map(
|ast::Binding { ident, type_, body }| -> Result<hir::Binding<Type>> {
let body = self.tc_expr(body)?;
if let Some(type_) = type_ {
self.unify(body.type_(), &type_.into())?;
}
self.env.set(ident.clone(), body.type_().clone());
Ok(hir::Binding {
ident,
type_: body.type_().clone(),
body,
})
},
)
.collect::<Result<Vec<hir::Binding<Type>>>>()?;
let body = self.tc_expr(*body)?;
self.env.pop();
Ok(hir::Expr::Let {
bindings,
type_: body.type_().clone(),
body: Box::new(body),
})
}
ast::Expr::If {
condition,
then,
else_,
} => {
let condition = self.tc_expr(*condition)?;
self.unify(&Type::Prim(PrimType::Bool), condition.type_())?;
let then = self.tc_expr(*then)?;
let else_ = self.tc_expr(*else_)?;
let type_ = self.unify(then.type_(), else_.type_())?;
Ok(hir::Expr::If {
condition: Box::new(condition),
then: Box::new(then),
else_: Box::new(else_),
type_,
})
}
ast::Expr::Fun(f) => {
let ast::Fun { args, body } = *f;
self.env.push();
let args: Vec<_> = args
.into_iter()
.map(|id| {
let ty = self.fresh_ex();
self.env.set(id.clone(), ty.clone());
(id, ty)
})
.collect();
let body = self.tc_expr(body)?;
self.env.pop();
Ok(hir::Expr::Fun {
type_: Type::Fun {
args: args.iter().map(|(_, ty)| ty.clone()).collect(),
ret: Box::new(body.type_().clone()),
},
args,
body: Box::new(body),
})
}
ast::Expr::Call { fun, args } => {
let ret_ty = self.fresh_ex();
let arg_tys = args.iter().map(|_| self.fresh_ex()).collect::<Vec<_>>();
let ft = Type::Fun {
args: arg_tys.clone(),
ret: Box::new(ret_ty.clone()),
};
let fun = self.tc_expr(*fun)?;
self.unify(&ft, fun.type_())?;
let args = args
.into_iter()
.zip(arg_tys)
.map(|(arg, ty)| {
let arg = self.tc_expr(arg)?;
self.unify(&ty, arg.type_())?;
Ok(arg)
})
.try_collect()?;
Ok(hir::Expr::Call {
fun: Box::new(fun),
args,
type_: ret_ty,
})
}
ast::Expr::Ascription { expr, type_ } => {
let expr = self.tc_expr(*expr)?;
self.unify(expr.type_(), &type_.into())?;
Ok(expr)
}
}
}
pub(crate) fn tc_decl(&mut self, decl: ast::Decl<'ast>) -> Result<hir::Decl<'ast, Type>> {
match decl {
ast::Decl::Fun { name, body } => {
let body = self.tc_expr(ast::Expr::Fun(Box::new(body)))?;
let type_ = body.type_().clone();
self.env.set(name.clone(), type_);
match body {
hir::Expr::Fun { args, body, type_ } => Ok(hir::Decl::Fun {
name,
args,
body,
type_,
}),
_ => unreachable!(),
}
}
}
}
fn fresh_tv(&mut self) -> TyVar {
self.ty_var_counter += 1;
TyVar(self.ty_var_counter)
}
fn fresh_ex(&mut self) -> Type {
Type::Exist(self.fresh_tv())
}
fn fresh_univ(&mut self) -> Type {
Type::Exist(self.fresh_tv())
}
fn universalize<'a>(&mut self, expr: hir::Expr<'a, Type>) -> hir::Expr<'a, Type> {
// TODO
expr
}
fn unify(&mut self, ty1: &Type, ty2: &Type) -> Result<Type> {
match (ty1, ty2) {
(Type::Prim(p1), Type::Prim(p2)) if p1 == p2 => Ok(ty2.clone()),
(Type::Exist(tv), ty) | (ty, Type::Exist(tv)) => match self.resolve_tv(*tv) {
Some(existing_ty) if *ty == existing_ty => Ok(ty.clone()),
Some(existing_ty) => Err(Error::TypeMismatch {
expected: ty.clone(),
actual: existing_ty.into(),
}),
None => match self.ctx.insert(*tv, ty.clone()) {
Some(existing) => self.unify(&existing, ty),
None => Ok(ty.clone()),
},
},
(Type::Univ(u1), Type::Univ(u2)) if u1 == u2 => Ok(ty2.clone()),
(
Type::Fun {
args: args1,
ret: ret1,
},
Type::Fun {
args: args2,
ret: ret2,
},
) => {
let args = args1
.iter()
.zip(args2)
.map(|(t1, t2)| self.unify(t1, t2))
.try_collect()?;
let ret = self.unify(ret1, ret2)?;
Ok(Type::Fun {
args,
ret: Box::new(ret),
})
}
(Type::Nullary(_), _) | (_, Type::Nullary(_)) => todo!(),
_ => Err(Error::TypeMismatch {
expected: ty1.clone(),
actual: ty2.clone(),
}),
}
}
fn finalize_expr(&self, expr: hir::Expr<'ast, Type>) -> Result<hir::Expr<'ast, ast::Type>> {
expr.traverse_type(|ty| self.finalize_type(ty))
}
fn finalize_decl(&self, decl: hir::Decl<'ast, Type>) -> Result<hir::Decl<'ast, ast::Type>> {
decl.traverse_type(|ty| self.finalize_type(ty))
}
fn finalize_type(&self, ty: Type) -> Result<ast::Type> {
match ty {
Type::Exist(tv) => self.resolve_tv(tv).ok_or(Error::AmbiguousType(tv)),
Type::Univ(tv) => todo!(),
Type::Nullary(_) => todo!(),
Type::Prim(pr) => Ok(pr.into()),
Type::Fun { args, ret } => Ok(ast::Type::Function(ast::FunctionType {
args: args
.into_iter()
.map(|ty| self.finalize_type(ty))
.try_collect()?,
ret: Box::new(self.finalize_type(*ret)?),
})),
}
}
fn resolve_tv(&self, tv: TyVar) -> Option<ast::Type> {
let mut res = &Type::Exist(tv);
loop {
match res {
Type::Exist(tv) => {
res = self.ctx.get(tv)?;
}
Type::Univ(_) => todo!(),
Type::Nullary(_) => todo!(),
Type::Prim(pr) => break Some((*pr).into()),
Type::Fun { args, ret } => todo!(),
}
}
}
}
pub fn typecheck_expr(expr: ast::Expr) -> Result<hir::Expr<ast::Type>> {
let mut typechecker = Typechecker::new();
let typechecked = typechecker.tc_expr(expr)?;
typechecker.finalize_expr(typechecked)
}
pub fn typecheck_toplevel(decls: Vec<ast::Decl>) -> Result<Vec<hir::Decl<ast::Type>>> {
let mut typechecker = Typechecker::new();
decls
.into_iter()
.map(|decl| {
let decl = typechecker.tc_decl(decl)?;
typechecker.finalize_decl(decl)
})
.try_collect()
}
#[cfg(test)]
mod tests {
use super::*;
macro_rules! assert_type {
($expr: expr, $type: expr) => {
use crate::parser::{expr, type_};
let parsed_expr = test_parse!(expr, $expr);
let parsed_type = test_parse!(type_, $type);
let res = typecheck_expr(parsed_expr).unwrap_or_else(|e| panic!("{}", e));
assert_eq!(res.type_(), &parsed_type);
};
}
macro_rules! assert_type_error {
($expr: expr) => {
use crate::parser::expr;
let parsed_expr = test_parse!(expr, $expr);
let res = typecheck_expr(parsed_expr);
assert!(
res.is_err(),
"Expected type error, but got type: {}",
res.unwrap().type_()
);
};
}
#[test]
fn literal_int() {
assert_type!("1", "int");
}
#[test]
fn conditional() {
assert_type!("if 1 == 2 then 3 else 4", "int");
}
#[test]
#[ignore]
fn add_bools() {
assert_type_error!("true + false");
}
#[test]
fn call_generic_function() {
assert_type!("(fn x = x) 1", "int");
}
#[test]
#[ignore]
fn generic_function() {
assert_type!("fn x = x", "fn x, y -> x");
}
#[test]
#[ignore]
fn let_generalization() {
assert_type!("let id = fn x = x in if id true then id 1 else 2", "int");
}
#[test]
fn concrete_function() {
assert_type!("fn x = x + 1", "fn int -> int");
}
#[test]
fn call_concrete_function() {
assert_type!("(fn x = x + 1) 2", "int");
}
#[test]
fn conditional_non_bool() {
assert_type_error!("if 3 then true else false");
}
#[test]
fn let_int() {
assert_type!("let x = 1 in x", "int");
}
}