feat(gs/achilles): Implement very basic monomorphization

Implement very basic monomorphization, by recording type variable
instantiations when typechecking Call nodes and then using those in a
new hir Visitor trait to copy the body of any generic decls for each
possible set of instantiation of the type variables.

Change-Id: Iab54030973e5d66e2b8bcd074b4cb6c001a90123
Reviewed-on: https://cl.tvl.fyi/c/depot/+/2617
Reviewed-by: glittershark <grfn@gws.fyi>
Tested-by: BuildkiteCI
This commit is contained in:
Griffin Smith 2021-03-20 18:14:23 -04:00 committed by glittershark
parent e7033bd8b0
commit 8d5f3029e5
12 changed files with 430 additions and 19 deletions

View file

@ -19,6 +19,7 @@ dependencies = [
"proptest",
"test-strategy",
"thiserror",
"void",
]
[[package]]
@ -761,6 +762,12 @@ version = "0.9.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b5a972e5669d67ba988ce3dc826706fb0a8b01471c088cb0b6110b805cc36aed"
[[package]]
name = "void"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a02e4885ed3bc0f2de90ea6dd45ebcbb66dacffe03547fadbb0eeae2770887d"
[[package]]
name = "wait-timeout"
version = "0.2.0"

View file

@ -19,6 +19,7 @@ pratt = "0.3.0"
proptest = "1.0.0"
test-strategy = "0.1.1"
thiserror = "1.0.24"
void = "1.0.2"
[dev-dependencies]
crate-root = "0.1.3"

View file

@ -1,3 +1,5 @@
use std::collections::HashMap;
use itertools::Itertools;
use super::{BinaryOperator, Ident, Literal, UnaryOperator};
@ -55,6 +57,7 @@ pub enum Expr<'a, T> {
},
Fun {
type_args: Vec<Ident<'a>>,
args: Vec<(Ident<'a>, T)>,
body: Box<Expr<'a, T>>,
type_: T,
@ -62,6 +65,7 @@ pub enum Expr<'a, T> {
Call {
fun: Box<Expr<'a, T>>,
type_args: HashMap<Ident<'a>, T>,
args: Vec<Expr<'a, T>>,
type_: T,
},
@ -133,16 +137,31 @@ impl<'a, T> Expr<'a, T> {
else_: Box::new(else_.traverse_type(f.clone())?),
type_: f(type_)?,
}),
Expr::Fun { args, body, type_ } => Ok(Expr::Fun {
Expr::Fun {
args,
type_args,
body,
type_,
} => Ok(Expr::Fun {
args: args
.into_iter()
.map(|(id, t)| Ok((id, f.clone()(t)?)))
.collect::<Result<Vec<_>, E>>()?,
type_args,
body: Box::new(body.traverse_type(f.clone())?),
type_: f(type_)?,
}),
Expr::Call { fun, args, type_ } => Ok(Expr::Call {
Expr::Call {
fun,
type_args,
args,
type_,
} => Ok(Expr::Call {
fun: Box::new(fun.traverse_type(f.clone())?),
type_args: type_args
.into_iter()
.map(|(id, ty)| Ok((id, f.clone()(ty)?)))
.collect::<Result<HashMap<_, _>, E>>()?,
args: args
.into_iter()
.map(|e| e.traverse_type(f.clone()))
@ -180,7 +199,7 @@ impl<'a, T> Expr<'a, T> {
body,
type_,
} => Expr::Let {
bindings: bindings.into_iter().map(|b| b.to_owned()).collect(),
bindings: bindings.iter().map(|b| b.to_owned()).collect(),
body: Box::new((**body).to_owned()),
type_: type_.clone(),
},
@ -195,26 +214,43 @@ impl<'a, T> Expr<'a, T> {
else_: Box::new((**else_).to_owned()),
type_: type_.clone(),
},
Expr::Fun { args, body, type_ } => Expr::Fun {
Expr::Fun {
args,
type_args,
body,
type_,
} => Expr::Fun {
args: args
.into_iter()
.iter()
.map(|(id, t)| (id.to_owned(), t.clone()))
.collect(),
type_args: type_args.iter().map(|arg| arg.to_owned()).collect(),
body: Box::new((**body).to_owned()),
type_: type_.clone(),
},
Expr::Call { fun, args, type_ } => Expr::Call {
Expr::Call {
fun,
type_args,
args,
type_,
} => Expr::Call {
fun: Box::new((**fun).to_owned()),
args: args.into_iter().map(|e| e.to_owned()).collect(),
type_args: type_args
.iter()
.map(|(id, t)| (id.to_owned(), t.clone()))
.collect(),
args: args.iter().map(|e| e.to_owned()).collect(),
type_: type_.clone(),
},
}
}
}
#[derive(Debug, Clone)]
pub enum Decl<'a, T> {
Fun {
name: Ident<'a>,
type_args: Vec<Ident<'a>>,
args: Vec<(Ident<'a>, T)>,
body: Box<Expr<'a, T>>,
type_: T,
@ -235,6 +271,13 @@ impl<'a, T> Decl<'a, T> {
}
}
pub fn set_name(&mut self, new_name: Ident<'a>) {
match self {
Decl::Fun { name, .. } => *name = new_name,
Decl::Extern { name, .. } => *name = new_name,
}
}
pub fn type_(&self) -> Option<&T> {
match self {
Decl::Fun { type_, .. } => Some(type_),
@ -249,11 +292,13 @@ impl<'a, T> Decl<'a, T> {
match self {
Decl::Fun {
name,
type_args,
args,
body,
type_,
} => Ok(Decl::Fun {
name,
type_args,
args: args
.into_iter()
.map(|(id, t)| Ok((id, f(t)?)))

View file

@ -356,6 +356,26 @@ impl<'a> Type<'a> {
let mut substs = HashMap::new();
do_alpha_equiv(&mut substs, self, other)
}
pub fn traverse_type_vars<'b, F>(self, mut f: F) -> Type<'b>
where
F: FnMut(Ident<'a>) -> Type<'b> + Clone,
{
match self {
Type::Var(tv) => f(tv),
Type::Function(FunctionType { args, ret }) => Type::Function(FunctionType {
args: args
.into_iter()
.map(|t| t.traverse_type_vars(f.clone()))
.collect(),
ret: Box::new(ret.traverse_type_vars(f)),
}),
Type::Int => Type::Int,
Type::Float => Type::Float,
Type::Bool => Type::Bool,
Type::CString => Type::CString,
}
}
}
impl<'a> Display for Type<'a> {

View file

@ -8,6 +8,7 @@ use test_strategy::Arbitrary;
use crate::codegen::{self, Codegen};
use crate::common::Result;
use crate::passes::hir::monomorphize;
use crate::{parser, tc};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Arbitrary)]
@ -55,7 +56,8 @@ pub struct CompilerOptions {
pub fn compile_file(input: &Path, output: &Path, options: &CompilerOptions) -> Result<()> {
let src = fs::read_to_string(input)?;
let (_, decls) = parser::toplevel(&src)?; // TODO: statements
let decls = tc::typecheck_toplevel(decls)?;
let mut decls = tc::typecheck_toplevel(decls)?;
monomorphize::run_toplevel(&mut decls);
let context = codegen::Context::create();
let mut codegen = Codegen::new(

View file

@ -96,7 +96,12 @@ impl<'a> Interpreter<'a> {
}
Ok(Value::from(*interpreter.eval(body)?.as_type::<i64>()?))
}
Expr::Fun { args, body, type_ } => {
Expr::Fun {
type_args: _,
args,
body,
type_,
} => {
let type_ = match type_ {
Type::Function(ft) => ft.clone(),
_ => unreachable!("Function expression without function type"),

View file

@ -6,6 +6,7 @@ pub(crate) mod commands;
pub(crate) mod common;
pub mod compiler;
pub mod interpreter;
pub(crate) mod passes;
#[macro_use]
pub mod parser;
pub mod tc;

View file

@ -0,0 +1,179 @@
use std::collections::HashMap;
use crate::ast::hir::{Binding, Decl, Expr};
use crate::ast::{BinaryOperator, Ident, Literal, UnaryOperator};
pub(crate) mod monomorphize;
pub(crate) trait Visitor<'a, 'ast, T: 'ast>: Sized + 'a {
type Error;
fn visit_type(&mut self, _type: &mut T) -> Result<(), Self::Error> {
Ok(())
}
fn visit_ident(&mut self, _ident: &mut Ident<'ast>) -> Result<(), Self::Error> {
Ok(())
}
fn visit_literal(&mut self, _literal: &mut Literal<'ast>) -> Result<(), Self::Error> {
Ok(())
}
fn visit_unary_operator(&mut self, _op: &mut UnaryOperator) -> Result<(), Self::Error> {
Ok(())
}
fn visit_binary_operator(&mut self, _op: &mut BinaryOperator) -> Result<(), Self::Error> {
Ok(())
}
fn visit_binding(&mut self, binding: &mut Binding<'ast, T>) -> Result<(), Self::Error> {
self.visit_ident(&mut binding.ident)?;
self.visit_type(&mut binding.type_)?;
self.visit_expr(&mut binding.body)?;
Ok(())
}
fn post_visit_call(
&mut self,
_fun: &mut Expr<'ast, T>,
_type_args: &mut HashMap<Ident<'ast>, T>,
_args: &mut Vec<Expr<'ast, T>>,
) -> Result<(), Self::Error> {
Ok(())
}
fn pre_visit_call(
&mut self,
_fun: &mut Expr<'ast, T>,
_type_args: &mut HashMap<Ident<'ast>, T>,
_args: &mut Vec<Expr<'ast, T>>,
) -> Result<(), Self::Error> {
Ok(())
}
fn visit_expr(&mut self, expr: &mut Expr<'ast, T>) -> Result<(), Self::Error> {
match expr {
Expr::Ident(id, t) => {
self.visit_ident(id)?;
self.visit_type(t)?;
}
Expr::Literal(lit, t) => {
self.visit_literal(lit)?;
self.visit_type(t)?;
}
Expr::UnaryOp { op, rhs, type_ } => {
self.visit_unary_operator(op)?;
self.visit_expr(rhs)?;
self.visit_type(type_)?;
}
Expr::BinaryOp {
lhs,
op,
rhs,
type_,
} => {
self.visit_expr(lhs)?;
self.visit_binary_operator(op)?;
self.visit_expr(rhs)?;
self.visit_type(type_)?;
}
Expr::Let {
bindings,
body,
type_,
} => {
for binding in bindings.iter_mut() {
self.visit_binding(binding)?;
}
self.visit_expr(body)?;
self.visit_type(type_)?;
}
Expr::If {
condition,
then,
else_,
type_,
} => {
self.visit_expr(condition)?;
self.visit_expr(then)?;
self.visit_expr(else_)?;
self.visit_type(type_)?;
}
Expr::Fun {
args,
body,
type_args,
type_,
} => {
for (ident, t) in args {
self.visit_ident(ident)?;
self.visit_type(t)?;
}
for ta in type_args {
self.visit_ident(ta)?;
}
self.visit_expr(body)?;
self.visit_type(type_)?;
}
Expr::Call {
fun,
args,
type_args,
type_,
} => {
self.pre_visit_call(fun, type_args, args)?;
self.visit_expr(fun)?;
for arg in args.iter_mut() {
self.visit_expr(arg)?;
}
self.visit_type(type_)?;
self.post_visit_call(fun, type_args, args)?;
}
}
Ok(())
}
fn post_visit_decl(&mut self, decl: &'a Decl<'ast, T>) -> Result<(), Self::Error> {
Ok(())
}
fn visit_decl(&mut self, decl: &'a mut Decl<'ast, T>) -> Result<(), Self::Error> {
match decl {
Decl::Fun {
name,
type_args,
args,
body,
type_,
} => {
self.visit_ident(name)?;
for type_arg in type_args {
self.visit_ident(type_arg)?;
}
for (arg, t) in args {
self.visit_ident(arg)?;
self.visit_type(t)?;
}
self.visit_expr(body)?;
self.visit_type(type_)?;
}
Decl::Extern {
name,
arg_types,
ret_type,
} => {
self.visit_ident(name)?;
for arg_t in arg_types {
self.visit_type(arg_t)?;
}
self.visit_type(ret_type)?;
}
}
self.post_visit_decl(decl)?;
Ok(())
}
}

View file

@ -0,0 +1,139 @@
use std::cell::RefCell;
use std::collections::{HashMap, HashSet};
use std::convert::TryInto;
use std::mem;
use void::{ResultVoidExt, Void};
use crate::ast::hir::{Decl, Expr};
use crate::ast::{self, Ident};
use super::Visitor;
#[derive(Default)]
pub(crate) struct Monomorphize<'a, 'ast> {
decls: HashMap<&'a Ident<'ast>, &'a Decl<'ast, ast::Type<'ast>>>,
extra_decls: Vec<Decl<'ast, ast::Type<'ast>>>,
remove_decls: HashSet<Ident<'ast>>,
}
impl<'a, 'ast> Monomorphize<'a, 'ast> {
pub(crate) fn new() -> Self {
Default::default()
}
}
impl<'a, 'ast> Visitor<'a, 'ast, ast::Type<'ast>> for Monomorphize<'a, 'ast> {
type Error = Void;
fn post_visit_call(
&mut self,
fun: &mut Expr<'ast, ast::Type<'ast>>,
type_args: &mut HashMap<Ident<'ast>, ast::Type<'ast>>,
args: &mut Vec<Expr<'ast, ast::Type<'ast>>>,
) -> Result<(), Self::Error> {
let new_fun = match fun {
Expr::Ident(id, _) => {
let decl: Decl<_> = (**self.decls.get(id).unwrap()).clone();
let name = RefCell::new(id.to_string());
let type_args = mem::take(type_args);
let mut monomorphized = decl
.traverse_type(|ty| -> Result<_, Void> {
Ok(ty.clone().traverse_type_vars(|v| {
let concrete = type_args.get(&v).unwrap();
name.borrow_mut().push_str(&concrete.to_string());
concrete.clone()
}))
})
.void_unwrap();
let name: Ident = name.into_inner().try_into().unwrap();
if name != *id {
self.remove_decls.insert(id.clone());
monomorphized.set_name(name.clone());
let type_ = monomorphized.type_().unwrap().clone();
self.extra_decls.push(monomorphized);
Some(Expr::Ident(name, type_))
} else {
None
}
}
_ => todo!(),
};
if let Some(new_fun) = new_fun {
*fun = new_fun;
}
Ok(())
}
fn post_visit_decl(
&mut self,
decl: &'a Decl<'ast, ast::Type<'ast>>,
) -> Result<(), Self::Error> {
self.decls.insert(decl.name(), decl);
Ok(())
}
}
pub(crate) fn run_toplevel<'a>(toplevel: &mut Vec<Decl<'a, ast::Type<'a>>>) {
let mut pass = Monomorphize::new();
for decl in toplevel.iter_mut() {
pass.visit_decl(decl).void_unwrap();
}
let remove_decls = mem::take(&mut pass.remove_decls);
let mut extra_decls = mem::take(&mut pass.extra_decls);
toplevel.retain(|decl| !remove_decls.contains(decl.name()));
extra_decls.append(toplevel);
*toplevel = extra_decls;
}
#[cfg(test)]
mod tests {
use std::convert::TryFrom;
use super::*;
use crate::parser::toplevel;
use crate::tc::typecheck_toplevel;
#[test]
fn call_id_decl() {
let (_, program) = toplevel(
"ty id : fn a -> a
fn id x = x
ty main : fn -> int
fn main = id 0",
)
.unwrap();
let mut program = typecheck_toplevel(program).unwrap();
run_toplevel(&mut program);
let find_decl = |ident: &str| {
program.iter().find(|decl| {
matches!(decl, Decl::Fun {name, ..} if name == &Ident::try_from(ident).unwrap())
}).unwrap()
};
let main = find_decl("main");
let body = match main {
Decl::Fun { body, .. } => body,
_ => unreachable!(),
};
let expected_type = ast::Type::Function(ast::FunctionType {
args: vec![ast::Type::Int],
ret: Box::new(ast::Type::Int),
});
match &**body {
Expr::Call { fun, .. } => {
let fun = match &**fun {
Expr::Ident(fun, _) => fun,
_ => unreachable!(),
};
let called_decl = find_decl(fun.into());
assert_eq!(called_decl.type_().unwrap(), &expected_type);
}
_ => unreachable!(),
}
}
}

View file

@ -0,0 +1 @@
pub(crate) mod hir;

View file

@ -266,6 +266,7 @@ impl<'ast> Typechecker<'ast> {
args: args.iter().map(|(_, ty)| ty.clone()).collect(),
ret: Box::new(body.type_().clone()),
},
type_args: vec![], // TODO fill in once we do let generalization
args,
body: Box::new(body),
})
@ -289,9 +290,10 @@ impl<'ast> Typechecker<'ast> {
Ok(arg)
})
.try_collect()?;
self.commit_instantiations();
let type_args = self.commit_instantiations();
Ok(hir::Expr::Call {
fun: Box::new(fun),
type_args,
args,
type_: ret_ty,
})
@ -325,8 +327,14 @@ impl<'ast> Typechecker<'ast> {
self.env.set(name.clone(), type_);
self.env.pop();
match body {
hir::Expr::Fun { args, body, type_ } => Ok(Some(hir::Decl::Fun {
hir::Expr::Fun {
type_args,
args,
body,
type_,
} => Ok(Some(hir::Decl::Fun {
name,
type_args,
args,
body,
type_,
@ -538,17 +546,21 @@ impl<'ast> Typechecker<'ast> {
})
}
fn commit_instantiations(&mut self) {
fn commit_instantiations(&mut self) -> HashMap<Ident<'ast>, Type> {
let mut res = HashMap::new();
let mut ctx = mem::take(&mut self.ctx);
for (_, v) in ctx.iter_mut() {
if let Type::Univ(tv) = v {
if let Some(concrete) = self.instantiations.resolve(&self.name_univ(*tv)) {
let tv_name = self.name_univ(*tv);
if let Some(concrete) = self.instantiations.resolve(&tv_name) {
res.insert(tv_name, concrete.clone());
*v = concrete.clone();
}
}
}
self.ctx = ctx;
self.instantiations.pop();
res
}
fn types_match(&self, type_: &Type, ast_type: &ast::Type<'ast>) -> bool {

View file

@ -14,12 +14,11 @@ const FIXTURES: &[Fixture] = &[
exit_code: 5,
expected_output: "",
},
// TODO(grfn): needs monomorphization
// Fixture {
// name: "functions",
// exit_code: 9,
// expected_output: "",
// },
Fixture {
name: "functions",
exit_code: 9,
expected_output: "",
},
Fixture {
name: "externs",
exit_code: 0,