Add the start of a hindley-milner typechecker
The beginning of a parse-don't-validate-based hindley-milner typechecker, which returns on success an IR where every AST node trivially knows its own type, and using those types to determine LLVM types in codegen.
This commit is contained in:
parent
f8beda81fb
commit
32a5c0ff0f
20 changed files with 980 additions and 78 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -9,6 +9,7 @@ dependencies = [
|
|||
"derive_more",
|
||||
"inkwell",
|
||||
"itertools",
|
||||
"lazy_static",
|
||||
"llvm-sys",
|
||||
"nom",
|
||||
"nom-trace",
|
||||
|
|
|
@ -10,6 +10,7 @@ clap = "3.0.0-beta.2"
|
|||
derive_more = "0.99.11"
|
||||
inkwell = { git = "https://github.com/TheDan64/inkwell", branch = "master", features = ["llvm11-0"] }
|
||||
itertools = "0.10.0"
|
||||
lazy_static = "1.4.0"
|
||||
llvm-sys = "110.0.1"
|
||||
nom = "6.1.2"
|
||||
nom-trace = { git = "https://github.com/glittershark/nom-trace", branch = "nom-6" }
|
||||
|
|
3
ach/.gitignore
vendored
3
ach/.gitignore
vendored
|
@ -1,2 +1,5 @@
|
|||
*.ll
|
||||
*.o
|
||||
|
||||
functions
|
||||
simple
|
||||
|
|
246
src/ast/hir.rs
Normal file
246
src/ast/hir.rs
Normal file
|
@ -0,0 +1,246 @@
|
|||
use itertools::Itertools;
|
||||
|
||||
use super::{BinaryOperator, Ident, Literal, UnaryOperator};
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||
pub struct Binding<'a, T> {
|
||||
pub ident: Ident<'a>,
|
||||
pub type_: T,
|
||||
pub body: Expr<'a, T>,
|
||||
}
|
||||
|
||||
impl<'a, T> Binding<'a, T> {
|
||||
fn to_owned(&self) -> Binding<'static, T>
|
||||
where
|
||||
T: Clone,
|
||||
{
|
||||
Binding {
|
||||
ident: self.ident.to_owned(),
|
||||
type_: self.type_.clone(),
|
||||
body: self.body.to_owned(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||
pub enum Expr<'a, T> {
|
||||
Ident(Ident<'a>, T),
|
||||
|
||||
Literal(Literal, T),
|
||||
|
||||
UnaryOp {
|
||||
op: UnaryOperator,
|
||||
rhs: Box<Expr<'a, T>>,
|
||||
type_: T,
|
||||
},
|
||||
|
||||
BinaryOp {
|
||||
lhs: Box<Expr<'a, T>>,
|
||||
op: BinaryOperator,
|
||||
rhs: Box<Expr<'a, T>>,
|
||||
type_: T,
|
||||
},
|
||||
|
||||
Let {
|
||||
bindings: Vec<Binding<'a, T>>,
|
||||
body: Box<Expr<'a, T>>,
|
||||
type_: T,
|
||||
},
|
||||
|
||||
If {
|
||||
condition: Box<Expr<'a, T>>,
|
||||
then: Box<Expr<'a, T>>,
|
||||
else_: Box<Expr<'a, T>>,
|
||||
type_: T,
|
||||
},
|
||||
|
||||
Fun {
|
||||
args: Vec<(Ident<'a>, T)>,
|
||||
body: Box<Expr<'a, T>>,
|
||||
type_: T,
|
||||
},
|
||||
|
||||
Call {
|
||||
fun: Box<Expr<'a, T>>,
|
||||
args: Vec<Expr<'a, T>>,
|
||||
type_: T,
|
||||
},
|
||||
}
|
||||
|
||||
impl<'a, T> Expr<'a, T> {
|
||||
pub fn type_(&self) -> &T {
|
||||
match self {
|
||||
Expr::Ident(_, t) => t,
|
||||
Expr::Literal(_, t) => t,
|
||||
Expr::UnaryOp { type_, .. } => type_,
|
||||
Expr::BinaryOp { type_, .. } => type_,
|
||||
Expr::Let { type_, .. } => type_,
|
||||
Expr::If { type_, .. } => type_,
|
||||
Expr::Fun { type_, .. } => type_,
|
||||
Expr::Call { type_, .. } => type_,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn traverse_type<F, U, E>(self, f: F) -> Result<Expr<'a, U>, E>
|
||||
where
|
||||
F: Fn(T) -> Result<U, E> + Clone,
|
||||
{
|
||||
match self {
|
||||
Expr::Ident(id, t) => Ok(Expr::Ident(id, f(t)?)),
|
||||
Expr::Literal(lit, t) => Ok(Expr::Literal(lit, f(t)?)),
|
||||
Expr::UnaryOp { op, rhs, type_ } => Ok(Expr::UnaryOp {
|
||||
op,
|
||||
rhs: Box::new(rhs.traverse_type(f.clone())?),
|
||||
type_: f(type_)?,
|
||||
}),
|
||||
Expr::BinaryOp {
|
||||
lhs,
|
||||
op,
|
||||
rhs,
|
||||
type_,
|
||||
} => Ok(Expr::BinaryOp {
|
||||
lhs: Box::new(lhs.traverse_type(f.clone())?),
|
||||
op,
|
||||
rhs: Box::new(rhs.traverse_type(f.clone())?),
|
||||
type_: f(type_)?,
|
||||
}),
|
||||
Expr::Let {
|
||||
bindings,
|
||||
body,
|
||||
type_,
|
||||
} => Ok(Expr::Let {
|
||||
bindings: bindings
|
||||
.into_iter()
|
||||
.map(|Binding { ident, type_, body }| {
|
||||
Ok(Binding {
|
||||
ident,
|
||||
type_: f(type_)?,
|
||||
body: body.traverse_type(f.clone())?,
|
||||
})
|
||||
})
|
||||
.collect::<Result<Vec<_>, E>>()?,
|
||||
body: Box::new(body.traverse_type(f.clone())?),
|
||||
type_: f(type_)?,
|
||||
}),
|
||||
Expr::If {
|
||||
condition,
|
||||
then,
|
||||
else_,
|
||||
type_,
|
||||
} => Ok(Expr::If {
|
||||
condition: Box::new(condition.traverse_type(f.clone())?),
|
||||
then: Box::new(then.traverse_type(f.clone())?),
|
||||
else_: Box::new(else_.traverse_type(f.clone())?),
|
||||
type_: f(type_)?,
|
||||
}),
|
||||
Expr::Fun { args, body, type_ } => Ok(Expr::Fun {
|
||||
args: args
|
||||
.into_iter()
|
||||
.map(|(id, t)| Ok((id, f.clone()(t)?)))
|
||||
.collect::<Result<Vec<_>, E>>()?,
|
||||
body: Box::new(body.traverse_type(f.clone())?),
|
||||
type_: f(type_)?,
|
||||
}),
|
||||
Expr::Call { fun, args, type_ } => Ok(Expr::Call {
|
||||
fun: Box::new(fun.traverse_type(f.clone())?),
|
||||
args: args
|
||||
.into_iter()
|
||||
.map(|e| e.traverse_type(f.clone()))
|
||||
.collect::<Result<Vec<_>, E>>()?,
|
||||
type_: f(type_)?,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_owned(&self) -> Expr<'static, T>
|
||||
where
|
||||
T: Clone,
|
||||
{
|
||||
match self {
|
||||
Expr::Ident(id, t) => Expr::Ident(id.to_owned(), t.clone()),
|
||||
Expr::Literal(lit, t) => Expr::Literal(lit.clone(), t.clone()),
|
||||
Expr::UnaryOp { op, rhs, type_ } => Expr::UnaryOp {
|
||||
op: *op,
|
||||
rhs: Box::new((**rhs).to_owned()),
|
||||
type_: type_.clone(),
|
||||
},
|
||||
Expr::BinaryOp {
|
||||
lhs,
|
||||
op,
|
||||
rhs,
|
||||
type_,
|
||||
} => Expr::BinaryOp {
|
||||
lhs: Box::new((**lhs).to_owned()),
|
||||
op: *op,
|
||||
rhs: Box::new((**rhs).to_owned()),
|
||||
type_: type_.clone(),
|
||||
},
|
||||
Expr::Let {
|
||||
bindings,
|
||||
body,
|
||||
type_,
|
||||
} => Expr::Let {
|
||||
bindings: bindings.into_iter().map(|b| b.to_owned()).collect(),
|
||||
body: Box::new((**body).to_owned()),
|
||||
type_: type_.clone(),
|
||||
},
|
||||
Expr::If {
|
||||
condition,
|
||||
then,
|
||||
else_,
|
||||
type_,
|
||||
} => Expr::If {
|
||||
condition: Box::new((**condition).to_owned()),
|
||||
then: Box::new((**then).to_owned()),
|
||||
else_: Box::new((**else_).to_owned()),
|
||||
type_: type_.clone(),
|
||||
},
|
||||
Expr::Fun { args, body, type_ } => Expr::Fun {
|
||||
args: args
|
||||
.into_iter()
|
||||
.map(|(id, t)| (id.to_owned(), t.clone()))
|
||||
.collect(),
|
||||
body: Box::new((**body).to_owned()),
|
||||
type_: type_.clone(),
|
||||
},
|
||||
Expr::Call { fun, args, type_ } => Expr::Call {
|
||||
fun: Box::new((**fun).to_owned()),
|
||||
args: args.into_iter().map(|e| e.to_owned()).collect(),
|
||||
type_: type_.clone(),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum Decl<'a, T> {
|
||||
Fun {
|
||||
name: Ident<'a>,
|
||||
args: Vec<(Ident<'a>, T)>,
|
||||
body: Box<Expr<'a, T>>,
|
||||
type_: T,
|
||||
},
|
||||
}
|
||||
|
||||
impl<'a, T> Decl<'a, T> {
|
||||
pub fn traverse_type<F, U, E>(self, f: F) -> Result<Decl<'a, U>, E>
|
||||
where
|
||||
F: Fn(T) -> Result<U, E> + Clone,
|
||||
{
|
||||
match self {
|
||||
Decl::Fun {
|
||||
name,
|
||||
args,
|
||||
body,
|
||||
type_,
|
||||
} => Ok(Decl::Fun {
|
||||
name,
|
||||
args: args
|
||||
.into_iter()
|
||||
.map(|(id, t)| Ok((id, f(t)?)))
|
||||
.try_collect()?,
|
||||
body: Box::new(body.traverse_type(f.clone())?),
|
||||
type_: f(type_)?,
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,3 +1,5 @@
|
|||
pub(crate) mod hir;
|
||||
|
||||
use std::borrow::Cow;
|
||||
use std::convert::TryFrom;
|
||||
use std::fmt::{self, Display, Formatter};
|
||||
|
@ -107,6 +109,7 @@ pub enum UnaryOperator {
|
|||
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||
pub enum Literal {
|
||||
Int(u64),
|
||||
Bool(bool),
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||
|
|
|
@ -7,12 +7,13 @@ use inkwell::builder::Builder;
|
|||
pub use inkwell::context::Context;
|
||||
use inkwell::module::Module;
|
||||
use inkwell::support::LLVMString;
|
||||
use inkwell::types::FunctionType;
|
||||
use inkwell::types::{BasicType, BasicTypeEnum, FunctionType, IntType};
|
||||
use inkwell::values::{AnyValueEnum, BasicValueEnum, FunctionValue};
|
||||
use inkwell::IntPredicate;
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::ast::{BinaryOperator, Binding, Decl, Expr, Fun, Ident, Literal, UnaryOperator};
|
||||
use crate::ast::hir::{Binding, Decl, Expr};
|
||||
use crate::ast::{BinaryOperator, Ident, Literal, Type, UnaryOperator};
|
||||
use crate::common::env::Env;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Error)]
|
||||
|
@ -36,7 +37,7 @@ pub struct Codegen<'ctx, 'ast> {
|
|||
context: &'ctx Context,
|
||||
pub module: Module<'ctx>,
|
||||
builder: Builder<'ctx>,
|
||||
env: Env<'ast, AnyValueEnum<'ctx>>,
|
||||
env: Env<&'ast Ident<'ast>, AnyValueEnum<'ctx>>,
|
||||
function_stack: Vec<FunctionValue<'ctx>>,
|
||||
identifier_counter: u32,
|
||||
}
|
||||
|
@ -77,18 +78,23 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
|
|||
.append_basic_block(*self.function_stack.last().unwrap(), name)
|
||||
}
|
||||
|
||||
pub fn codegen_expr(&mut self, expr: &'ast Expr<'ast>) -> Result<AnyValueEnum<'ctx>> {
|
||||
pub fn codegen_expr(&mut self, expr: &'ast Expr<'ast, Type>) -> Result<AnyValueEnum<'ctx>> {
|
||||
match expr {
|
||||
Expr::Ident(id) => self
|
||||
Expr::Ident(id, _) => self
|
||||
.env
|
||||
.resolve(id)
|
||||
.cloned()
|
||||
.ok_or_else(|| Error::UndefinedVariable(id.to_owned())),
|
||||
Expr::Literal(Literal::Int(i)) => {
|
||||
let ty = self.context.i64_type();
|
||||
Ok(AnyValueEnum::IntValue(ty.const_int(*i, false)))
|
||||
Expr::Literal(lit, ty) => {
|
||||
let ty = self.codegen_int_type(ty);
|
||||
match lit {
|
||||
Literal::Int(i) => Ok(AnyValueEnum::IntValue(ty.const_int(*i, false))),
|
||||
Literal::Bool(b) => Ok(AnyValueEnum::IntValue(
|
||||
ty.const_int(if *b { 1 } else { 0 }, false),
|
||||
)),
|
||||
}
|
||||
}
|
||||
Expr::UnaryOp { op, rhs } => {
|
||||
Expr::UnaryOp { op, rhs, .. } => {
|
||||
let rhs = self.codegen_expr(rhs)?;
|
||||
match op {
|
||||
UnaryOperator::Not => unimplemented!(),
|
||||
|
@ -97,7 +103,7 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
|
|||
)),
|
||||
}
|
||||
}
|
||||
Expr::BinaryOp { lhs, op, rhs } => {
|
||||
Expr::BinaryOp { lhs, op, rhs, .. } => {
|
||||
let lhs = self.codegen_expr(lhs)?;
|
||||
let rhs = self.codegen_expr(rhs)?;
|
||||
match op {
|
||||
|
@ -135,7 +141,7 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
|
|||
BinaryOperator::Neq => todo!(),
|
||||
}
|
||||
}
|
||||
Expr::Let { bindings, body } => {
|
||||
Expr::Let { bindings, body, .. } => {
|
||||
self.env.push();
|
||||
for Binding { ident, body, .. } in bindings {
|
||||
let val = self.codegen_expr(body)?;
|
||||
|
@ -149,6 +155,7 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
|
|||
condition,
|
||||
then,
|
||||
else_,
|
||||
type_,
|
||||
} => {
|
||||
let then_block = self.append_basic_block("then");
|
||||
let else_block = self.append_basic_block("else");
|
||||
|
@ -168,15 +175,15 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
|
|||
self.builder.build_unconditional_branch(join_block);
|
||||
|
||||
self.builder.position_at_end(join_block);
|
||||
let phi = self.builder.build_phi(self.context.i64_type(), "join");
|
||||
let phi = self.builder.build_phi(self.codegen_type(type_), "join");
|
||||
phi.add_incoming(&[
|
||||
(&BasicValueEnum::try_from(then_res).unwrap(), then_block),
|
||||
(&BasicValueEnum::try_from(else_res).unwrap(), else_block),
|
||||
]);
|
||||
Ok(phi.as_basic_value().into())
|
||||
}
|
||||
Expr::Call { fun, args } => {
|
||||
if let Expr::Ident(id) = &**fun {
|
||||
Expr::Call { fun, args, .. } => {
|
||||
if let Expr::Ident(id, _) = &**fun {
|
||||
let function = self
|
||||
.module
|
||||
.get_function(id.into())
|
||||
|
@ -197,8 +204,7 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
|
|||
todo!()
|
||||
}
|
||||
}
|
||||
Expr::Fun(fun) => {
|
||||
let Fun { args, body } = &**fun;
|
||||
Expr::Fun { args, body, .. } => {
|
||||
let fname = self.fresh_ident("f");
|
||||
let cur_block = self.builder.get_insert_block().unwrap();
|
||||
let env = self.env.save(); // TODO: closures
|
||||
|
@ -207,29 +213,27 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
|
|||
self.env.restore(env);
|
||||
Ok(function.into())
|
||||
}
|
||||
Expr::Ascription { expr, .. } => self.codegen_expr(expr),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn codegen_function(
|
||||
&mut self,
|
||||
name: &str,
|
||||
args: &'ast [Ident<'ast>],
|
||||
body: &'ast Expr<'ast>,
|
||||
args: &'ast [(Ident<'ast>, Type)],
|
||||
body: &'ast Expr<'ast, Type>,
|
||||
) -> Result<FunctionValue<'ctx>> {
|
||||
let i64_type = self.context.i64_type();
|
||||
self.new_function(
|
||||
name,
|
||||
i64_type.fn_type(
|
||||
self.codegen_type(body.type_()).fn_type(
|
||||
args.iter()
|
||||
.map(|_| i64_type.into())
|
||||
.map(|(_, at)| self.codegen_type(at))
|
||||
.collect::<Vec<_>>()
|
||||
.as_slice(),
|
||||
false,
|
||||
),
|
||||
);
|
||||
self.env.push();
|
||||
for (i, arg) in args.iter().enumerate() {
|
||||
for (i, (arg, _)) in args.iter().enumerate() {
|
||||
self.env.set(
|
||||
arg,
|
||||
self.cur_function().get_nth_param(i as u32).unwrap().into(),
|
||||
|
@ -240,11 +244,10 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
|
|||
Ok(self.finish_function(&res))
|
||||
}
|
||||
|
||||
pub fn codegen_decl(&mut self, decl: &'ast Decl<'ast>) -> Result<()> {
|
||||
pub fn codegen_decl(&mut self, decl: &'ast Decl<'ast, Type>) -> Result<()> {
|
||||
match decl {
|
||||
Decl::Fun {
|
||||
name,
|
||||
body: Fun { args, body },
|
||||
name, args, body, ..
|
||||
} => {
|
||||
self.codegen_function(name.into(), args, body)?;
|
||||
Ok(())
|
||||
|
@ -252,13 +255,28 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn codegen_main(&mut self, expr: &'ast Expr<'ast>) -> Result<()> {
|
||||
pub fn codegen_main(&mut self, expr: &'ast Expr<'ast, Type>) -> Result<()> {
|
||||
self.new_function("main", self.context.i64_type().fn_type(&[], false));
|
||||
let res = self.codegen_expr(expr)?.try_into().unwrap();
|
||||
self.finish_function(&res);
|
||||
if *expr.type_() != Type::Int {
|
||||
self.builder
|
||||
.build_return(Some(&self.context.i64_type().const_int(0, false)));
|
||||
} else {
|
||||
self.finish_function(&res);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn codegen_type(&self, type_: &'ast Type) -> BasicTypeEnum<'ctx> {
|
||||
// TODO
|
||||
self.context.i64_type().into()
|
||||
}
|
||||
|
||||
fn codegen_int_type(&self, type_: &'ast Type) -> IntType<'ctx> {
|
||||
// TODO
|
||||
self.context.i64_type()
|
||||
}
|
||||
|
||||
pub fn print_to_file<P>(&self, path: P) -> Result<()>
|
||||
where
|
||||
P: AsRef<Path>,
|
||||
|
@ -299,6 +317,8 @@ mod tests {
|
|||
fn jit_eval<T>(expr: &str) -> anyhow::Result<T> {
|
||||
let expr = crate::parser::expr(expr).unwrap().1;
|
||||
|
||||
let expr = crate::tc::typecheck_expr(expr).unwrap();
|
||||
|
||||
let context = Context::create();
|
||||
let mut codegen = Codegen::new(&context, "test");
|
||||
let execution_engine = codegen
|
||||
|
|
|
@ -4,10 +4,11 @@ use inkwell::execution_engine::JitFunction;
|
|||
use inkwell::OptimizationLevel;
|
||||
pub use llvm::*;
|
||||
|
||||
use crate::ast::Expr;
|
||||
use crate::ast::hir::Expr;
|
||||
use crate::ast::Type;
|
||||
use crate::common::Result;
|
||||
|
||||
pub fn jit_eval<T>(expr: &Expr) -> Result<T> {
|
||||
pub fn jit_eval<T>(expr: &Expr<Type>) -> Result<T> {
|
||||
let context = Context::create();
|
||||
let mut codegen = Codegen::new(&context, "eval");
|
||||
let execution_engine = codegen
|
||||
|
|
39
src/commands/check.rs
Normal file
39
src/commands/check.rs
Normal file
|
@ -0,0 +1,39 @@
|
|||
use clap::Clap;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use crate::ast::Type;
|
||||
use crate::{parser, tc, Result};
|
||||
|
||||
/// Typecheck a file or expression
|
||||
#[derive(Clap)]
|
||||
pub struct Check {
|
||||
/// File to check
|
||||
path: Option<PathBuf>,
|
||||
|
||||
/// Expression to check
|
||||
#[clap(long, short = 'e')]
|
||||
expr: Option<String>,
|
||||
}
|
||||
|
||||
fn run_expr(expr: String) -> Result<Type> {
|
||||
let (_, parsed) = parser::expr(&expr)?;
|
||||
let hir_expr = tc::typecheck_expr(parsed)?;
|
||||
Ok(hir_expr.type_().clone())
|
||||
}
|
||||
|
||||
fn run_path(path: PathBuf) -> Result<Type> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
impl Check {
|
||||
pub fn run(self) -> Result<()> {
|
||||
let type_ = match (self.path, self.expr) {
|
||||
(None, None) => Err("Must specify either a file or expression to check".into()),
|
||||
(Some(_), Some(_)) => Err("Cannot specify both a file and expression to check".into()),
|
||||
(None, Some(expr)) => run_expr(expr),
|
||||
(Some(path), None) => run_path(path),
|
||||
}?;
|
||||
println!("type: {}", type_);
|
||||
Ok(())
|
||||
}
|
||||
}
|
|
@ -3,6 +3,7 @@ use clap::Clap;
|
|||
use crate::codegen;
|
||||
use crate::interpreter;
|
||||
use crate::parser;
|
||||
use crate::tc;
|
||||
use crate::Result;
|
||||
|
||||
/// Evaluate an expression and print its result
|
||||
|
@ -19,10 +20,11 @@ pub struct Eval {
|
|||
impl Eval {
|
||||
pub fn run(self) -> Result<()> {
|
||||
let (_, parsed) = parser::expr(&self.expr)?;
|
||||
let hir = tc::typecheck_expr(parsed)?;
|
||||
let result = if self.jit {
|
||||
codegen::jit_eval::<i64>(&parsed)?.into()
|
||||
codegen::jit_eval::<i64>(&hir)?.into()
|
||||
} else {
|
||||
interpreter::eval(&parsed)?
|
||||
interpreter::eval(&hir)?
|
||||
};
|
||||
println!("{}", result);
|
||||
Ok(())
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
pub mod check;
|
||||
pub mod compile;
|
||||
pub mod eval;
|
||||
|
||||
pub use check::Check;
|
||||
pub use compile::Compile;
|
||||
pub use eval::Eval;
|
||||
|
|
|
@ -1,19 +1,25 @@
|
|||
use std::borrow::Borrow;
|
||||
use std::collections::HashMap;
|
||||
use std::hash::Hash;
|
||||
use std::mem;
|
||||
|
||||
use crate::ast::Ident;
|
||||
|
||||
/// A lexical environment
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub struct Env<'ast, V>(Vec<HashMap<&'ast Ident<'ast>, V>>);
|
||||
pub struct Env<K: Eq + Hash, V>(Vec<HashMap<K, V>>);
|
||||
|
||||
impl<'ast, V> Default for Env<'ast, V> {
|
||||
impl<K, V> Default for Env<K, V>
|
||||
where
|
||||
K: Eq + Hash,
|
||||
{
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, V> Env<'ast, V> {
|
||||
impl<K, V> Env<K, V>
|
||||
where
|
||||
K: Eq + Hash,
|
||||
{
|
||||
pub fn new() -> Self {
|
||||
Self(vec![Default::default()])
|
||||
}
|
||||
|
@ -34,11 +40,15 @@ impl<'ast, V> Env<'ast, V> {
|
|||
*self = saved;
|
||||
}
|
||||
|
||||
pub fn set(&mut self, k: &'ast Ident<'ast>, v: V) {
|
||||
pub fn set(&mut self, k: K, v: V) {
|
||||
self.0.last_mut().unwrap().insert(k, v);
|
||||
}
|
||||
|
||||
pub fn resolve<'a>(&'a self, k: &'ast Ident<'ast>) -> Option<&'a V> {
|
||||
pub fn resolve<'a, Q>(&'a self, k: &Q) -> Option<&'a V>
|
||||
where
|
||||
K: Borrow<Q>,
|
||||
Q: Hash + Eq + ?Sized,
|
||||
{
|
||||
for ctx in self.0.iter().rev() {
|
||||
if let Some(res) = ctx.get(k) {
|
||||
return Some(res);
|
||||
|
|
|
@ -2,7 +2,7 @@ use std::{io, result};
|
|||
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::{codegen, interpreter, parser};
|
||||
use crate::{codegen, interpreter, parser, tc};
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum Error {
|
||||
|
@ -18,6 +18,9 @@ pub enum Error {
|
|||
#[error("Compile error: {0}")]
|
||||
CodegenError(#[from] codegen::Error),
|
||||
|
||||
#[error("Type error: {0}")]
|
||||
TypeError(#[from] tc::Error),
|
||||
|
||||
#[error("{0}")]
|
||||
Message(String),
|
||||
}
|
||||
|
@ -28,6 +31,12 @@ impl From<String> for Error {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a str> for Error {
|
||||
fn from(s: &'a str) -> Self {
|
||||
Self::Message(s.to_owned())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<nom::Err<nom::error::Error<&'a str>>> for Error {
|
||||
fn from(e: nom::Err<nom::error::Error<&'a str>>) -> Self {
|
||||
use nom::error::Error as NomError;
|
||||
|
|
|
@ -8,7 +8,7 @@ use test_strategy::Arbitrary;
|
|||
|
||||
use crate::codegen::{self, Codegen};
|
||||
use crate::common::Result;
|
||||
use crate::parser;
|
||||
use crate::{parser, tc};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Arbitrary)]
|
||||
pub enum OutputFormat {
|
||||
|
@ -55,6 +55,8 @@ pub struct CompilerOptions {
|
|||
pub fn compile_file(input: &Path, output: &Path, options: &CompilerOptions) -> Result<()> {
|
||||
let src = fs::read_to_string(input)?;
|
||||
let (_, decls) = parser::toplevel(&src)?; // TODO: statements
|
||||
let decls = tc::typecheck_toplevel(decls)?;
|
||||
|
||||
let context = codegen::Context::create();
|
||||
let mut codegen = Codegen::new(
|
||||
&context,
|
||||
|
|
|
@ -3,14 +3,13 @@ mod value;
|
|||
|
||||
pub use self::error::{Error, Result};
|
||||
pub use self::value::{Function, Value};
|
||||
use crate::ast::{
|
||||
BinaryOperator, Binding, Expr, FunctionType, Ident, Literal, Type, UnaryOperator,
|
||||
};
|
||||
use crate::ast::hir::{Binding, Expr};
|
||||
use crate::ast::{BinaryOperator, FunctionType, Ident, Literal, Type, UnaryOperator};
|
||||
use crate::common::env::Env;
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct Interpreter<'a> {
|
||||
env: Env<'a, Value<'a>>,
|
||||
env: Env<&'a Ident<'a>, Value<'a>>,
|
||||
}
|
||||
|
||||
impl<'a> Interpreter<'a> {
|
||||
|
@ -25,18 +24,19 @@ impl<'a> Interpreter<'a> {
|
|||
.ok_or_else(|| Error::UndefinedVariable(var.to_owned()))
|
||||
}
|
||||
|
||||
pub fn eval(&mut self, expr: &'a Expr<'a>) -> Result<Value<'a>> {
|
||||
match expr {
|
||||
Expr::Ident(id) => self.resolve(id),
|
||||
Expr::Literal(Literal::Int(i)) => Ok((*i).into()),
|
||||
Expr::UnaryOp { op, rhs } => {
|
||||
pub fn eval(&mut self, expr: &'a Expr<'a, Type>) -> Result<Value<'a>> {
|
||||
let res = match expr {
|
||||
Expr::Ident(id, _) => self.resolve(id),
|
||||
Expr::Literal(Literal::Int(i), _) => Ok((*i).into()),
|
||||
Expr::Literal(Literal::Bool(b), _) => Ok((*b).into()),
|
||||
Expr::UnaryOp { op, rhs, .. } => {
|
||||
let rhs = self.eval(rhs)?;
|
||||
match op {
|
||||
UnaryOperator::Neg => -rhs,
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
}
|
||||
Expr::BinaryOp { lhs, op, rhs } => {
|
||||
Expr::BinaryOp { lhs, op, rhs, .. } => {
|
||||
let lhs = self.eval(lhs)?;
|
||||
let rhs = self.eval(rhs)?;
|
||||
match op {
|
||||
|
@ -49,7 +49,7 @@ impl<'a> Interpreter<'a> {
|
|||
BinaryOperator::Neq => todo!(),
|
||||
}
|
||||
}
|
||||
Expr::Let { bindings, body } => {
|
||||
Expr::Let { bindings, body, .. } => {
|
||||
self.env.push();
|
||||
for Binding { ident, body, .. } in bindings {
|
||||
let val = self.eval(body)?;
|
||||
|
@ -63,6 +63,7 @@ impl<'a> Interpreter<'a> {
|
|||
condition,
|
||||
then,
|
||||
else_,
|
||||
..
|
||||
} => {
|
||||
let condition = self.eval(condition)?;
|
||||
if *(condition.as_type::<bool>()?) {
|
||||
|
@ -71,7 +72,7 @@ impl<'a> Interpreter<'a> {
|
|||
self.eval(else_)
|
||||
}
|
||||
}
|
||||
Expr::Call { ref fun, args } => {
|
||||
Expr::Call { ref fun, args, .. } => {
|
||||
let fun = self.eval(fun)?;
|
||||
let expected_type = FunctionType {
|
||||
args: args.iter().map(|_| Type::Int).collect(),
|
||||
|
@ -94,21 +95,26 @@ impl<'a> Interpreter<'a> {
|
|||
}
|
||||
Ok(Value::from(*interpreter.eval(body)?.as_type::<i64>()?))
|
||||
}
|
||||
Expr::Fun(fun) => Ok(Value::from(value::Function {
|
||||
// TODO
|
||||
type_: FunctionType {
|
||||
args: fun.args.iter().map(|_| Type::Int).collect(),
|
||||
ret: Box::new(Type::Int),
|
||||
},
|
||||
args: fun.args.iter().map(|arg| arg.to_owned()).collect(),
|
||||
body: fun.body.to_owned(),
|
||||
})),
|
||||
Expr::Ascription { expr, .. } => self.eval(expr),
|
||||
}
|
||||
Expr::Fun { args, body, type_ } => {
|
||||
let type_ = match type_ {
|
||||
Type::Function(ft) => ft.clone(),
|
||||
_ => unreachable!("Function expression without function type"),
|
||||
};
|
||||
|
||||
Ok(Value::from(value::Function {
|
||||
// TODO
|
||||
type_,
|
||||
args: args.iter().map(|(arg, _)| arg.to_owned()).collect(),
|
||||
body: (**body).to_owned(),
|
||||
}))
|
||||
}
|
||||
}?;
|
||||
debug_assert_eq!(&res.type_(), expr.type_());
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn eval<'a>(expr: &'a Expr<'a>) -> Result<Value> {
|
||||
pub fn eval<'a>(expr: &'a Expr<'a, Type>) -> Result<Value> {
|
||||
let mut interpreter = Interpreter::new();
|
||||
interpreter.eval(expr)
|
||||
}
|
||||
|
@ -121,17 +127,18 @@ mod tests {
|
|||
use super::*;
|
||||
use BinaryOperator::*;
|
||||
|
||||
fn int_lit(i: u64) -> Box<Expr<'static>> {
|
||||
Box::new(Expr::Literal(Literal::Int(i)))
|
||||
fn int_lit(i: u64) -> Box<Expr<'static, Type>> {
|
||||
Box::new(Expr::Literal(Literal::Int(i), Type::Int))
|
||||
}
|
||||
|
||||
fn parse_eval<T>(src: &str) -> T
|
||||
fn do_eval<T>(src: &str) -> T
|
||||
where
|
||||
for<'a> &'a T: TryFrom<&'a Val<'a>>,
|
||||
T: Clone + TypeOf,
|
||||
{
|
||||
let expr = crate::parser::expr(src).unwrap().1;
|
||||
let res = eval(&expr).unwrap();
|
||||
let hir = crate::tc::typecheck_expr(expr).unwrap();
|
||||
let res = eval(&hir).unwrap();
|
||||
res.as_type::<T>().unwrap().clone()
|
||||
}
|
||||
|
||||
|
@ -141,6 +148,7 @@ mod tests {
|
|||
lhs: int_lit(1),
|
||||
op: Mul,
|
||||
rhs: int_lit(2),
|
||||
type_: Type::Int,
|
||||
};
|
||||
let res = eval(&expr).unwrap();
|
||||
assert_eq!(*res.as_type::<i64>().unwrap(), 2);
|
||||
|
@ -148,19 +156,19 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn variable_shadowing() {
|
||||
let res = parse_eval::<i64>("let x = 1 in (let x = 2 in x) + x");
|
||||
let res = do_eval::<i64>("let x = 1 in (let x = 2 in x) + x");
|
||||
assert_eq!(res, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conditional_with_equals() {
|
||||
let res = parse_eval::<i64>("let x = 1 in if x == 1 then 2 else 4");
|
||||
let res = do_eval::<i64>("let x = 1 in if x == 1 then 2 else 4");
|
||||
assert_eq!(res, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn function_call() {
|
||||
let res = parse_eval::<i64>("let id = fn x = x in id 1");
|
||||
let res = do_eval::<i64>("let id = fn x = x in id 1");
|
||||
assert_eq!(res, 1);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,13 +6,14 @@ use std::rc::Rc;
|
|||
use derive_more::{Deref, From, TryInto};
|
||||
|
||||
use super::{Error, Result};
|
||||
use crate::ast::{Expr, FunctionType, Ident, Type};
|
||||
use crate::ast::hir::Expr;
|
||||
use crate::ast::{FunctionType, Ident, Type};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Function<'a> {
|
||||
pub type_: FunctionType,
|
||||
pub args: Vec<Ident<'a>>,
|
||||
pub body: Expr<'a>,
|
||||
pub body: Expr<'a, Type>,
|
||||
}
|
||||
|
||||
#[derive(From, TryInto)]
|
||||
|
|
|
@ -8,6 +8,7 @@ pub mod compiler;
|
|||
pub mod interpreter;
|
||||
#[macro_use]
|
||||
pub mod parser;
|
||||
pub mod tc;
|
||||
|
||||
pub use common::{Error, Result};
|
||||
|
||||
|
@ -21,6 +22,7 @@ struct Opts {
|
|||
enum Command {
|
||||
Eval(commands::Eval),
|
||||
Compile(commands::Compile),
|
||||
Check(commands::Check),
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
|
@ -28,5 +30,6 @@ fn main() -> anyhow::Result<()> {
|
|||
match opts.subcommand {
|
||||
Command::Eval(eval) => Ok(eval.run()?),
|
||||
Command::Compile(compile) => Ok(compile.run()?),
|
||||
Command::Check(check) => Ok(check.run()?),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -156,7 +156,14 @@ where
|
|||
|
||||
named!(int(&str) -> Literal, map!(flat_map!(digit1, parse_to!(u64)), Literal::Int));
|
||||
|
||||
named!(literal(&str) -> Expr, map!(alt!(int), Expr::Literal));
|
||||
named!(bool_(&str) -> Literal, alt!(
|
||||
tag!("true") => { |_| Literal::Bool(true) } |
|
||||
tag!("false") => { |_| Literal::Bool(false) }
|
||||
));
|
||||
|
||||
named!(literal(&str) -> Literal, alt!(int | bool_));
|
||||
|
||||
named!(literal_expr(&str) -> Expr, map!(literal, Expr::Literal));
|
||||
|
||||
named!(binding(&str) -> Binding, do_parse!(
|
||||
multispace0
|
||||
|
@ -262,7 +269,7 @@ named!(fun_expr(&str) -> Expr, do_parse!(
|
|||
|
||||
named!(arg(&str) -> Expr, alt!(
|
||||
ident_expr |
|
||||
literal |
|
||||
literal_expr |
|
||||
paren_expr
|
||||
));
|
||||
|
||||
|
@ -280,7 +287,7 @@ named!(simple_expr_unascripted(&str) -> Expr, alt!(
|
|||
let_ |
|
||||
if_ |
|
||||
fun_expr |
|
||||
literal |
|
||||
literal_expr |
|
||||
ident_expr
|
||||
));
|
||||
|
||||
|
@ -399,6 +406,18 @@ pub(crate) mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bools() {
|
||||
assert_eq!(
|
||||
test_parse!(expr, "true"),
|
||||
Expr::Literal(Literal::Bool(true))
|
||||
);
|
||||
assert_eq!(
|
||||
test_parse!(expr, "false"),
|
||||
Expr::Literal(Literal::Bool(false))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn let_complex() {
|
||||
let res = test_parse!(expr, "let x = 1; y = x * 7 in (x + y) * 4");
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
#[cfg(test)]
|
||||
#[macro_use]
|
||||
macro_rules! test_parse {
|
||||
($parser: ident, $src: expr) => {{
|
||||
|
|
|
@ -14,7 +14,10 @@ pub use type_::type_;
|
|||
pub type Error = nom::Err<nom::error::Error<String>>;
|
||||
|
||||
pub(crate) fn is_reserved(s: &str) -> bool {
|
||||
matches!(s, "if" | "then" | "else" | "let" | "in" | "fn")
|
||||
matches!(
|
||||
s,
|
||||
"if" | "then" | "else" | "let" | "in" | "fn" | "int" | "float" | "bool" | "true" | "false"
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn ident<'a, E>(i: &'a str) -> nom::IResult<&'a str, Ident, E>
|
||||
|
|
528
src/tc/mod.rs
Normal file
528
src/tc/mod.rs
Normal file
|
@ -0,0 +1,528 @@
|
|||
use derive_more::From;
|
||||
use itertools::Itertools;
|
||||
use std::collections::HashMap;
|
||||
use std::convert::{TryFrom, TryInto};
|
||||
use std::fmt::{self, Display};
|
||||
use std::result;
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::ast::{self, hir, BinaryOperator, Ident, Literal};
|
||||
use crate::common::env::Env;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum Error {
|
||||
#[error("Undefined variable {0}")]
|
||||
UndefinedVariable(Ident<'static>),
|
||||
|
||||
#[error("Mismatched types: expected {expected}, but got {actual}")]
|
||||
TypeMismatch { expected: Type, actual: Type },
|
||||
|
||||
#[error("Mismatched types, expected numeric type, but got {0}")]
|
||||
NonNumeric(Type),
|
||||
|
||||
#[error("Ambiguous type {0}")]
|
||||
AmbiguousType(TyVar),
|
||||
}
|
||||
|
||||
pub type Result<T> = result::Result<T, Error>;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
|
||||
pub struct TyVar(u64);
|
||||
|
||||
impl Display for TyVar {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "t{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Hash)]
|
||||
pub struct NullaryType(String);
|
||||
|
||||
impl Display for NullaryType {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str(&self.0)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
pub enum PrimType {
|
||||
Int,
|
||||
Float,
|
||||
Bool,
|
||||
}
|
||||
|
||||
impl From<PrimType> for ast::Type {
|
||||
fn from(pr: PrimType) -> Self {
|
||||
match pr {
|
||||
PrimType::Int => ast::Type::Int,
|
||||
PrimType::Float => ast::Type::Float,
|
||||
PrimType::Bool => ast::Type::Bool,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for PrimType {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
PrimType::Int => f.write_str("int"),
|
||||
PrimType::Float => f.write_str("float"),
|
||||
PrimType::Bool => f.write_str("bool"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, From)]
|
||||
pub enum Type {
|
||||
#[from(ignore)]
|
||||
Univ(TyVar),
|
||||
#[from(ignore)]
|
||||
Exist(TyVar),
|
||||
Nullary(NullaryType),
|
||||
Prim(PrimType),
|
||||
Fun {
|
||||
args: Vec<Type>,
|
||||
ret: Box<Type>,
|
||||
},
|
||||
}
|
||||
|
||||
impl PartialEq<ast::Type> for Type {
|
||||
fn eq(&self, other: &ast::Type) -> bool {
|
||||
match (self, other) {
|
||||
(Type::Univ(_), _) => todo!(),
|
||||
(Type::Exist(_), _) => false,
|
||||
(Type::Nullary(_), _) => todo!(),
|
||||
(Type::Prim(pr), ty) => ast::Type::from(*pr) == *ty,
|
||||
(Type::Fun { args, ret }, ast::Type::Function(ft)) => {
|
||||
*args == ft.args && (**ret).eq(&*ft.ret)
|
||||
}
|
||||
(Type::Fun { .. }, _) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<Type> for ast::Type {
|
||||
type Error = Type;
|
||||
|
||||
fn try_from(value: Type) -> result::Result<Self, Self::Error> {
|
||||
match value {
|
||||
Type::Univ(_) => todo!(),
|
||||
Type::Exist(_) => Err(value),
|
||||
Type::Nullary(_) => todo!(),
|
||||
Type::Prim(p) => Ok(p.into()),
|
||||
Type::Fun { ref args, ref ret } => Ok(ast::Type::Function(ast::FunctionType {
|
||||
args: args
|
||||
.clone()
|
||||
.into_iter()
|
||||
.map(Self::try_from)
|
||||
.try_collect()
|
||||
.map_err(|_| value.clone())?,
|
||||
ret: Box::new((*ret.clone()).try_into().map_err(|_| value.clone())?),
|
||||
})),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const INT: Type = Type::Prim(PrimType::Int);
|
||||
const FLOAT: Type = Type::Prim(PrimType::Float);
|
||||
const BOOL: Type = Type::Prim(PrimType::Bool);
|
||||
|
||||
impl Display for Type {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Type::Nullary(nt) => nt.fmt(f),
|
||||
Type::Prim(p) => p.fmt(f),
|
||||
Type::Univ(TyVar(n)) => write!(f, "∀{}", n),
|
||||
Type::Exist(TyVar(n)) => write!(f, "∃{}", n),
|
||||
Type::Fun { args, ret } => write!(f, "fn {} -> {}", args.iter().join(", "), ret),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ast::Type> for Type {
|
||||
fn from(type_: ast::Type) -> Self {
|
||||
match type_ {
|
||||
ast::Type::Int => INT,
|
||||
ast::Type::Float => FLOAT,
|
||||
ast::Type::Bool => BOOL,
|
||||
ast::Type::Function(ast::FunctionType { args, ret }) => Type::Fun {
|
||||
args: args.into_iter().map(Self::from).collect(),
|
||||
ret: Box::new(Self::from(*ret)),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Typechecker<'ast> {
|
||||
ty_var_counter: u64,
|
||||
ctx: HashMap<TyVar, Type>,
|
||||
env: Env<Ident<'ast>, Type>,
|
||||
}
|
||||
|
||||
impl<'ast> Typechecker<'ast> {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
ty_var_counter: 0,
|
||||
ctx: Default::default(),
|
||||
env: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn tc_expr(&mut self, expr: ast::Expr<'ast>) -> Result<hir::Expr<'ast, Type>> {
|
||||
match expr {
|
||||
ast::Expr::Ident(ident) => {
|
||||
let type_ = self
|
||||
.env
|
||||
.resolve(&ident)
|
||||
.ok_or_else(|| Error::UndefinedVariable(ident.to_owned()))?
|
||||
.clone();
|
||||
Ok(hir::Expr::Ident(ident, type_))
|
||||
}
|
||||
ast::Expr::Literal(lit) => {
|
||||
let type_ = match lit {
|
||||
Literal::Int(_) => Type::Prim(PrimType::Int),
|
||||
Literal::Bool(_) => Type::Prim(PrimType::Bool),
|
||||
};
|
||||
Ok(hir::Expr::Literal(lit, type_))
|
||||
}
|
||||
ast::Expr::UnaryOp { op, rhs } => todo!(),
|
||||
ast::Expr::BinaryOp { lhs, op, rhs } => {
|
||||
let lhs = self.tc_expr(*lhs)?;
|
||||
let rhs = self.tc_expr(*rhs)?;
|
||||
let type_ = match op {
|
||||
BinaryOperator::Equ | BinaryOperator::Neq => {
|
||||
self.unify(lhs.type_(), rhs.type_())?;
|
||||
Type::Prim(PrimType::Bool)
|
||||
}
|
||||
BinaryOperator::Add | BinaryOperator::Sub | BinaryOperator::Mul => {
|
||||
let ty = self.unify(lhs.type_(), rhs.type_())?;
|
||||
// if !matches!(ty, Type::Int | Type::Float) {
|
||||
// return Err(Error::NonNumeric(ty));
|
||||
// }
|
||||
ty
|
||||
}
|
||||
BinaryOperator::Div => todo!(),
|
||||
BinaryOperator::Pow => todo!(),
|
||||
};
|
||||
Ok(hir::Expr::BinaryOp {
|
||||
lhs: Box::new(lhs),
|
||||
op,
|
||||
rhs: Box::new(rhs),
|
||||
type_,
|
||||
})
|
||||
}
|
||||
ast::Expr::Let { bindings, body } => {
|
||||
self.env.push();
|
||||
let bindings = bindings
|
||||
.into_iter()
|
||||
.map(
|
||||
|ast::Binding { ident, type_, body }| -> Result<hir::Binding<Type>> {
|
||||
let body = self.tc_expr(body)?;
|
||||
if let Some(type_) = type_ {
|
||||
self.unify(body.type_(), &type_.into())?;
|
||||
}
|
||||
self.env.set(ident.clone(), body.type_().clone());
|
||||
Ok(hir::Binding {
|
||||
ident,
|
||||
type_: body.type_().clone(),
|
||||
body,
|
||||
})
|
||||
},
|
||||
)
|
||||
.collect::<Result<Vec<hir::Binding<Type>>>>()?;
|
||||
let body = self.tc_expr(*body)?;
|
||||
self.env.pop();
|
||||
Ok(hir::Expr::Let {
|
||||
bindings,
|
||||
type_: body.type_().clone(),
|
||||
body: Box::new(body),
|
||||
})
|
||||
}
|
||||
ast::Expr::If {
|
||||
condition,
|
||||
then,
|
||||
else_,
|
||||
} => {
|
||||
let condition = self.tc_expr(*condition)?;
|
||||
self.unify(&Type::Prim(PrimType::Bool), condition.type_())?;
|
||||
let then = self.tc_expr(*then)?;
|
||||
let else_ = self.tc_expr(*else_)?;
|
||||
let type_ = self.unify(then.type_(), else_.type_())?;
|
||||
Ok(hir::Expr::If {
|
||||
condition: Box::new(condition),
|
||||
then: Box::new(then),
|
||||
else_: Box::new(else_),
|
||||
type_,
|
||||
})
|
||||
}
|
||||
ast::Expr::Fun(f) => {
|
||||
let ast::Fun { args, body } = *f;
|
||||
self.env.push();
|
||||
let args: Vec<_> = args
|
||||
.into_iter()
|
||||
.map(|id| {
|
||||
let ty = self.fresh_ex();
|
||||
self.env.set(id.clone(), ty.clone());
|
||||
(id, ty)
|
||||
})
|
||||
.collect();
|
||||
let body = self.tc_expr(body)?;
|
||||
self.env.pop();
|
||||
Ok(hir::Expr::Fun {
|
||||
type_: Type::Fun {
|
||||
args: args.iter().map(|(_, ty)| ty.clone()).collect(),
|
||||
ret: Box::new(body.type_().clone()),
|
||||
},
|
||||
args,
|
||||
body: Box::new(body),
|
||||
})
|
||||
}
|
||||
ast::Expr::Call { fun, args } => {
|
||||
let ret_ty = self.fresh_ex();
|
||||
let arg_tys = args.iter().map(|_| self.fresh_ex()).collect::<Vec<_>>();
|
||||
let ft = Type::Fun {
|
||||
args: arg_tys.clone(),
|
||||
ret: Box::new(ret_ty.clone()),
|
||||
};
|
||||
let fun = self.tc_expr(*fun)?;
|
||||
self.unify(&ft, fun.type_())?;
|
||||
let args = args
|
||||
.into_iter()
|
||||
.zip(arg_tys)
|
||||
.map(|(arg, ty)| {
|
||||
let arg = self.tc_expr(arg)?;
|
||||
self.unify(&ty, arg.type_())?;
|
||||
Ok(arg)
|
||||
})
|
||||
.try_collect()?;
|
||||
Ok(hir::Expr::Call {
|
||||
fun: Box::new(fun),
|
||||
args,
|
||||
type_: ret_ty,
|
||||
})
|
||||
}
|
||||
ast::Expr::Ascription { expr, type_ } => {
|
||||
let expr = self.tc_expr(*expr)?;
|
||||
self.unify(expr.type_(), &type_.into())?;
|
||||
Ok(expr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn tc_decl(&mut self, decl: ast::Decl<'ast>) -> Result<hir::Decl<'ast, Type>> {
|
||||
match decl {
|
||||
ast::Decl::Fun { name, body } => {
|
||||
let body = self.tc_expr(ast::Expr::Fun(Box::new(body)))?;
|
||||
let type_ = body.type_().clone();
|
||||
self.env.set(name.clone(), type_);
|
||||
match body {
|
||||
hir::Expr::Fun { args, body, type_ } => Ok(hir::Decl::Fun {
|
||||
name,
|
||||
args,
|
||||
body,
|
||||
type_,
|
||||
}),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn fresh_tv(&mut self) -> TyVar {
|
||||
self.ty_var_counter += 1;
|
||||
TyVar(self.ty_var_counter)
|
||||
}
|
||||
|
||||
fn fresh_ex(&mut self) -> Type {
|
||||
Type::Exist(self.fresh_tv())
|
||||
}
|
||||
|
||||
fn fresh_univ(&mut self) -> Type {
|
||||
Type::Exist(self.fresh_tv())
|
||||
}
|
||||
|
||||
fn universalize<'a>(&mut self, expr: hir::Expr<'a, Type>) -> hir::Expr<'a, Type> {
|
||||
// TODO
|
||||
expr
|
||||
}
|
||||
|
||||
fn unify(&mut self, ty1: &Type, ty2: &Type) -> Result<Type> {
|
||||
match (ty1, ty2) {
|
||||
(Type::Prim(p1), Type::Prim(p2)) if p1 == p2 => Ok(ty2.clone()),
|
||||
(Type::Exist(tv), ty) | (ty, Type::Exist(tv)) => match self.resolve_tv(*tv) {
|
||||
Some(existing_ty) if *ty == existing_ty => Ok(ty.clone()),
|
||||
Some(existing_ty) => Err(Error::TypeMismatch {
|
||||
expected: ty.clone(),
|
||||
actual: existing_ty.into(),
|
||||
}),
|
||||
None => match self.ctx.insert(*tv, ty.clone()) {
|
||||
Some(existing) => self.unify(&existing, ty),
|
||||
None => Ok(ty.clone()),
|
||||
},
|
||||
},
|
||||
(Type::Univ(u1), Type::Univ(u2)) if u1 == u2 => Ok(ty2.clone()),
|
||||
(
|
||||
Type::Fun {
|
||||
args: args1,
|
||||
ret: ret1,
|
||||
},
|
||||
Type::Fun {
|
||||
args: args2,
|
||||
ret: ret2,
|
||||
},
|
||||
) => {
|
||||
let args = args1
|
||||
.iter()
|
||||
.zip(args2)
|
||||
.map(|(t1, t2)| self.unify(t1, t2))
|
||||
.try_collect()?;
|
||||
let ret = self.unify(ret1, ret2)?;
|
||||
Ok(Type::Fun {
|
||||
args,
|
||||
ret: Box::new(ret),
|
||||
})
|
||||
}
|
||||
(Type::Nullary(_), _) | (_, Type::Nullary(_)) => todo!(),
|
||||
_ => Err(Error::TypeMismatch {
|
||||
expected: ty1.clone(),
|
||||
actual: ty2.clone(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn finalize_expr(&self, expr: hir::Expr<'ast, Type>) -> Result<hir::Expr<'ast, ast::Type>> {
|
||||
expr.traverse_type(|ty| self.finalize_type(ty))
|
||||
}
|
||||
|
||||
fn finalize_decl(&self, decl: hir::Decl<'ast, Type>) -> Result<hir::Decl<'ast, ast::Type>> {
|
||||
decl.traverse_type(|ty| self.finalize_type(ty))
|
||||
}
|
||||
|
||||
fn finalize_type(&self, ty: Type) -> Result<ast::Type> {
|
||||
match ty {
|
||||
Type::Exist(tv) => self.resolve_tv(tv).ok_or(Error::AmbiguousType(tv)),
|
||||
Type::Univ(tv) => todo!(),
|
||||
Type::Nullary(_) => todo!(),
|
||||
Type::Prim(pr) => Ok(pr.into()),
|
||||
Type::Fun { args, ret } => Ok(ast::Type::Function(ast::FunctionType {
|
||||
args: args
|
||||
.into_iter()
|
||||
.map(|ty| self.finalize_type(ty))
|
||||
.try_collect()?,
|
||||
ret: Box::new(self.finalize_type(*ret)?),
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_tv(&self, tv: TyVar) -> Option<ast::Type> {
|
||||
let mut res = &Type::Exist(tv);
|
||||
loop {
|
||||
match res {
|
||||
Type::Exist(tv) => {
|
||||
res = self.ctx.get(tv)?;
|
||||
}
|
||||
Type::Univ(_) => todo!(),
|
||||
Type::Nullary(_) => todo!(),
|
||||
Type::Prim(pr) => break Some((*pr).into()),
|
||||
Type::Fun { args, ret } => todo!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn typecheck_expr(expr: ast::Expr) -> Result<hir::Expr<ast::Type>> {
|
||||
let mut typechecker = Typechecker::new();
|
||||
let typechecked = typechecker.tc_expr(expr)?;
|
||||
typechecker.finalize_expr(typechecked)
|
||||
}
|
||||
|
||||
pub fn typecheck_toplevel(decls: Vec<ast::Decl>) -> Result<Vec<hir::Decl<ast::Type>>> {
|
||||
let mut typechecker = Typechecker::new();
|
||||
decls
|
||||
.into_iter()
|
||||
.map(|decl| {
|
||||
let decl = typechecker.tc_decl(decl)?;
|
||||
typechecker.finalize_decl(decl)
|
||||
})
|
||||
.try_collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
macro_rules! assert_type {
|
||||
($expr: expr, $type: expr) => {
|
||||
use crate::parser::{expr, type_};
|
||||
let parsed_expr = test_parse!(expr, $expr);
|
||||
let parsed_type = test_parse!(type_, $type);
|
||||
let res = typecheck_expr(parsed_expr).unwrap_or_else(|e| panic!("{}", e));
|
||||
assert_eq!(res.type_(), &parsed_type);
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! assert_type_error {
|
||||
($expr: expr) => {
|
||||
use crate::parser::expr;
|
||||
let parsed_expr = test_parse!(expr, $expr);
|
||||
let res = typecheck_expr(parsed_expr);
|
||||
assert!(
|
||||
res.is_err(),
|
||||
"Expected type error, but got type: {}",
|
||||
res.unwrap().type_()
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn literal_int() {
|
||||
assert_type!("1", "int");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conditional() {
|
||||
assert_type!("if 1 == 2 then 3 else 4", "int");
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn add_bools() {
|
||||
assert_type_error!("true + false");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn call_generic_function() {
|
||||
assert_type!("(fn x = x) 1", "int");
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn generic_function() {
|
||||
assert_type!("fn x = x", "fn x, y -> x");
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn let_generalization() {
|
||||
assert_type!("let id = fn x = x in if id true then id 1 else 2", "int");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn concrete_function() {
|
||||
assert_type!("fn x = x + 1", "fn int -> int");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn call_concrete_function() {
|
||||
assert_type!("(fn x = x + 1) 2", "int");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conditional_non_bool() {
|
||||
assert_type_error!("if 3 then true else false");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn let_int() {
|
||||
assert_type!("let x = 1 in x", "int");
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue