feat(achilles): Implement a Unit type

Add support for a zero-sized Unit type. This requires some special at
the codegen level because LLVM (unsurprisingly) only allows Void types
in function return position - to make that a little easier to handle
there's a new pass that strips any unit-only expressions and pulls
unit-only function arguments up to new `let` bindings, so we never have
to actually pass around unit values.

Change-Id: I0fc18a516821f2d69172c42a6a5d246b23471e38
Reviewed-on: https://cl.tvl.fyi/c/depot/+/2695
Reviewed-by: glittershark <grfn@gws.fyi>
Tested-by: BuildkiteCI
This commit is contained in:
Griffin Smith 2021-03-28 13:28:49 -04:00 committed by glittershark
parent db62866d82
commit 8e13b1303a
16 changed files with 447 additions and 88 deletions

View file

@ -16,6 +16,7 @@ dependencies = [
"nom", "nom",
"nom-trace", "nom-trace",
"pratt", "pratt",
"pretty_assertions",
"proptest", "proptest",
"test-strategy", "test-strategy",
"thiserror", "thiserror",
@ -31,6 +32,15 @@ dependencies = [
"memchr", "memchr",
] ]
[[package]]
name = "ansi_term"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d52a9bb7ec0cf484c551830a7ce27bd20d67eac647e1befb56b0be4ee39a55d2"
dependencies = [
"winapi",
]
[[package]] [[package]]
name = "anyhow" name = "anyhow"
version = "1.0.38" version = "1.0.38"
@ -155,6 +165,16 @@ version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59c6fe4622b269032d2c5140a592d67a9c409031d286174fcde172fbed86f0d3" checksum = "59c6fe4622b269032d2c5140a592d67a9c409031d286174fcde172fbed86f0d3"
[[package]]
name = "ctor"
version = "0.1.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8f45d9ad417bcef4817d614a501ab55cdd96a6fdb24f49aab89a54acfd66b19"
dependencies = [
"quote",
"syn",
]
[[package]] [[package]]
name = "derive_more" name = "derive_more"
version = "0.99.11" version = "0.99.11"
@ -166,6 +186,12 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "diff"
version = "0.1.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e25ea47919b1560c4e3b7fe0aaab9becf5b84a10325ddf7db0f0ba5e1026499"
[[package]] [[package]]
name = "either" name = "either"
version = "1.6.1" version = "1.6.1"
@ -366,6 +392,15 @@ version = "2.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "afb2e1c3ee07430c2cf76151675e583e0f19985fa6efae47d6848a3e2c824f85" checksum = "afb2e1c3ee07430c2cf76151675e583e0f19985fa6efae47d6848a3e2c824f85"
[[package]]
name = "output_vt100"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "53cdc5b785b7a58c5aad8216b3dfa114df64b0b06ae6e1501cef91df2fbdf8f9"
dependencies = [
"winapi",
]
[[package]] [[package]]
name = "parking_lot" name = "parking_lot"
version = "0.11.1" version = "0.11.1"
@ -412,6 +447,18 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e31bbc12f7936a7b195790dd6d9b982b66c54f45ff6766decf25c44cac302dce" checksum = "e31bbc12f7936a7b195790dd6d9b982b66c54f45ff6766decf25c44cac302dce"
[[package]]
name = "pretty_assertions"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f297542c27a7df8d45de2b0e620308ab883ad232d06c14b76ac3e144bda50184"
dependencies = [
"ansi_term",
"ctor",
"diff",
"output_vt100",
]
[[package]] [[package]]
name = "proc-macro-error" name = "proc-macro-error"
version = "1.0.4" version = "1.0.4"

View file

@ -23,3 +23,4 @@ void = "1.0.2"
[dev-dependencies] [dev-dependencies]
crate-root = "0.1.3" crate-root = "0.1.3"
pretty_assertions = "0.7.1"

View file

@ -4,3 +4,4 @@
functions functions
simple simple
externs externs
units

View file

@ -0,0 +1,7 @@
extern puts : fn cstring -> int
ty print : fn cstring -> ()
fn print x = let _ = puts x in ()
ty main : fn -> int
fn main = let _ = print "hi" in 0

View file

@ -246,7 +246,7 @@ impl<'a, T> Expr<'a, T> {
} }
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone, PartialEq, Eq)]
pub enum Decl<'a, T> { pub enum Decl<'a, T> {
Fun { Fun {
name: Ident<'a>, name: Ident<'a>,

View file

@ -30,6 +30,7 @@ impl<'a> Ident<'a> {
Ident(Cow::Owned(self.0.clone().into_owned())) Ident(Cow::Owned(self.0.clone().into_owned()))
} }
/// Construct an identifier from a &str without checking that it's a valid identifier
pub fn from_str_unchecked(s: &'a str) -> Self { pub fn from_str_unchecked(s: &'a str) -> Self {
debug_assert!(is_valid_identifier(s)); debug_assert!(is_valid_identifier(s));
Self(Cow::Borrowed(s)) Self(Cow::Borrowed(s))
@ -109,6 +110,7 @@ pub enum UnaryOperator {
#[derive(Debug, PartialEq, Eq, Clone)] #[derive(Debug, PartialEq, Eq, Clone)]
pub enum Literal<'a> { pub enum Literal<'a> {
Unit,
Int(u64), Int(u64),
Bool(bool), Bool(bool),
String(Cow<'a, str>), String(Cow<'a, str>),
@ -120,6 +122,7 @@ impl<'a> Literal<'a> {
Literal::Int(i) => Literal::Int(*i), Literal::Int(i) => Literal::Int(*i),
Literal::Bool(b) => Literal::Bool(*b), Literal::Bool(b) => Literal::Bool(*b),
Literal::String(s) => Literal::String(Cow::Owned(s.clone().into_owned())), Literal::String(s) => Literal::String(Cow::Owned(s.clone().into_owned())),
Literal::Unit => Literal::Unit,
} }
} }
} }
@ -308,6 +311,7 @@ pub enum Type<'a> {
Float, Float,
Bool, Bool,
CString, CString,
Unit,
Var(Ident<'a>), Var(Ident<'a>),
Function(FunctionType<'a>), Function(FunctionType<'a>),
} }
@ -319,6 +323,7 @@ impl<'a> Type<'a> {
Type::Float => Type::Float, Type::Float => Type::Float,
Type::Bool => Type::Bool, Type::Bool => Type::Bool,
Type::CString => Type::CString, Type::CString => Type::CString,
Type::Unit => Type::Unit,
Type::Var(v) => Type::Var(v.to_owned()), Type::Var(v) => Type::Var(v.to_owned()),
Type::Function(f) => Type::Function(f.to_owned()), Type::Function(f) => Type::Function(f.to_owned()),
} }
@ -374,6 +379,7 @@ impl<'a> Type<'a> {
Type::Float => Type::Float, Type::Float => Type::Float,
Type::Bool => Type::Bool, Type::Bool => Type::Bool,
Type::CString => Type::CString, Type::CString => Type::CString,
Type::Unit => Type::Unit,
} }
} }
} }
@ -385,6 +391,7 @@ impl<'a> Display for Type<'a> {
Type::Float => f.write_str("float"), Type::Float => f.write_str("float"),
Type::Bool => f.write_str("bool"), Type::Bool => f.write_str("bool"),
Type::CString => f.write_str("cstring"), Type::CString => f.write_str("cstring"),
Type::Unit => f.write_str("()"),
Type::Var(v) => v.fmt(f), Type::Var(v) => v.fmt(f),
Type::Function(ft) => ft.fmt(f), Type::Function(ft) => ft.fmt(f),
} }

View file

@ -68,8 +68,12 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
self.function_stack.last().unwrap() self.function_stack.last().unwrap()
} }
pub fn finish_function(&mut self, res: &BasicValueEnum<'ctx>) -> FunctionValue<'ctx> { pub fn finish_function(&mut self, res: Option<&BasicValueEnum<'ctx>>) -> FunctionValue<'ctx> {
self.builder.build_return(Some(res)); self.builder.build_return(match res {
// lol
Some(val) => Some(val),
None => None,
});
self.function_stack.pop().unwrap() self.function_stack.pop().unwrap()
} }
@ -78,79 +82,92 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
.append_basic_block(*self.function_stack.last().unwrap(), name) .append_basic_block(*self.function_stack.last().unwrap(), name)
} }
pub fn codegen_expr(&mut self, expr: &'ast Expr<'ast, Type>) -> Result<AnyValueEnum<'ctx>> { pub fn codegen_expr(
&mut self,
expr: &'ast Expr<'ast, Type>,
) -> Result<Option<AnyValueEnum<'ctx>>> {
match expr { match expr {
Expr::Ident(id, _) => self Expr::Ident(id, _) => self
.env .env
.resolve(id) .resolve(id)
.cloned() .cloned()
.ok_or_else(|| Error::UndefinedVariable(id.to_owned())), .ok_or_else(|| Error::UndefinedVariable(id.to_owned()))
.map(Some),
Expr::Literal(lit, ty) => { Expr::Literal(lit, ty) => {
let ty = self.codegen_int_type(ty); let ty = self.codegen_int_type(ty);
match lit { match lit {
Literal::Int(i) => Ok(AnyValueEnum::IntValue(ty.const_int(*i, false))), Literal::Int(i) => Ok(Some(AnyValueEnum::IntValue(ty.const_int(*i, false)))),
Literal::Bool(b) => Ok(AnyValueEnum::IntValue( Literal::Bool(b) => Ok(Some(AnyValueEnum::IntValue(
ty.const_int(if *b { 1 } else { 0 }, false), ty.const_int(if *b { 1 } else { 0 }, false),
))),
Literal::String(s) => Ok(Some(
self.builder
.build_global_string_ptr(s, "s")
.as_pointer_value()
.into(),
)), )),
Literal::String(s) => Ok(self Literal::Unit => Ok(None),
.builder
.build_global_string_ptr(s, "s")
.as_pointer_value()
.into()),
} }
} }
Expr::UnaryOp { op, rhs, .. } => { Expr::UnaryOp { op, rhs, .. } => {
let rhs = self.codegen_expr(rhs)?; let rhs = self.codegen_expr(rhs)?.unwrap();
match op { match op {
UnaryOperator::Not => unimplemented!(), UnaryOperator::Not => unimplemented!(),
UnaryOperator::Neg => Ok(AnyValueEnum::IntValue( UnaryOperator::Neg => Ok(Some(AnyValueEnum::IntValue(
self.builder.build_int_neg(rhs.into_int_value(), "neg"), self.builder.build_int_neg(rhs.into_int_value(), "neg"),
)), ))),
} }
} }
Expr::BinaryOp { lhs, op, rhs, .. } => { Expr::BinaryOp { lhs, op, rhs, .. } => {
let lhs = self.codegen_expr(lhs)?; let lhs = self.codegen_expr(lhs)?.unwrap();
let rhs = self.codegen_expr(rhs)?; let rhs = self.codegen_expr(rhs)?.unwrap();
match op { match op {
BinaryOperator::Add => Ok(AnyValueEnum::IntValue(self.builder.build_int_add( BinaryOperator::Add => {
lhs.into_int_value(), Ok(Some(AnyValueEnum::IntValue(self.builder.build_int_add(
rhs.into_int_value(),
"add",
))),
BinaryOperator::Sub => Ok(AnyValueEnum::IntValue(self.builder.build_int_sub(
lhs.into_int_value(),
rhs.into_int_value(),
"add",
))),
BinaryOperator::Mul => Ok(AnyValueEnum::IntValue(self.builder.build_int_sub(
lhs.into_int_value(),
rhs.into_int_value(),
"add",
))),
BinaryOperator::Div => {
Ok(AnyValueEnum::IntValue(self.builder.build_int_signed_div(
lhs.into_int_value(), lhs.into_int_value(),
rhs.into_int_value(), rhs.into_int_value(),
"add", "add",
))) ))))
} }
BinaryOperator::Sub => {
Ok(Some(AnyValueEnum::IntValue(self.builder.build_int_sub(
lhs.into_int_value(),
rhs.into_int_value(),
"add",
))))
}
BinaryOperator::Mul => {
Ok(Some(AnyValueEnum::IntValue(self.builder.build_int_sub(
lhs.into_int_value(),
rhs.into_int_value(),
"add",
))))
}
BinaryOperator::Div => Ok(Some(AnyValueEnum::IntValue(
self.builder.build_int_signed_div(
lhs.into_int_value(),
rhs.into_int_value(),
"add",
),
))),
BinaryOperator::Pow => unimplemented!(), BinaryOperator::Pow => unimplemented!(),
BinaryOperator::Equ => { BinaryOperator::Equ => Ok(Some(AnyValueEnum::IntValue(
Ok(AnyValueEnum::IntValue(self.builder.build_int_compare( self.builder.build_int_compare(
IntPredicate::EQ, IntPredicate::EQ,
lhs.into_int_value(), lhs.into_int_value(),
rhs.into_int_value(), rhs.into_int_value(),
"eq", "eq",
))) ),
} ))),
BinaryOperator::Neq => todo!(), BinaryOperator::Neq => todo!(),
} }
} }
Expr::Let { bindings, body, .. } => { Expr::Let { bindings, body, .. } => {
self.env.push(); self.env.push();
for Binding { ident, body, .. } in bindings { for Binding { ident, body, .. } in bindings {
let val = self.codegen_expr(body)?; if let Some(val) = self.codegen_expr(body)? {
self.env.set(ident, val); self.env.set(ident, val);
}
} }
let res = self.codegen_expr(body); let res = self.codegen_expr(body);
self.env.pop(); self.env.pop();
@ -165,7 +182,7 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
let then_block = self.append_basic_block("then"); let then_block = self.append_basic_block("then");
let else_block = self.append_basic_block("else"); let else_block = self.append_basic_block("else");
let join_block = self.append_basic_block("join"); let join_block = self.append_basic_block("join");
let condition = self.codegen_expr(condition)?; let condition = self.codegen_expr(condition)?.unwrap();
self.builder.build_conditional_branch( self.builder.build_conditional_branch(
condition.into_int_value(), condition.into_int_value(),
then_block, then_block,
@ -180,12 +197,22 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
self.builder.build_unconditional_branch(join_block); self.builder.build_unconditional_branch(join_block);
self.builder.position_at_end(join_block); self.builder.position_at_end(join_block);
let phi = self.builder.build_phi(self.codegen_type(type_), "join"); if let Some(phi_type) = self.codegen_type(type_) {
phi.add_incoming(&[ let phi = self.builder.build_phi(phi_type, "join");
(&BasicValueEnum::try_from(then_res).unwrap(), then_block), phi.add_incoming(&[
(&BasicValueEnum::try_from(else_res).unwrap(), else_block), (
]); &BasicValueEnum::try_from(then_res.unwrap()).unwrap(),
Ok(phi.as_basic_value().into()) then_block,
),
(
&BasicValueEnum::try_from(else_res.unwrap()).unwrap(),
else_block,
),
]);
Ok(Some(phi.as_basic_value().into()))
} else {
Ok(None)
}
} }
Expr::Call { fun, args, .. } => { Expr::Call { fun, args, .. } => {
if let Expr::Ident(id, _) = &**fun { if let Expr::Ident(id, _) = &**fun {
@ -196,15 +223,14 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
.ok_or_else(|| Error::UndefinedVariable(id.to_owned()))?; .ok_or_else(|| Error::UndefinedVariable(id.to_owned()))?;
let args = args let args = args
.iter() .iter()
.map(|arg| Ok(self.codegen_expr(arg)?.try_into().unwrap())) .map(|arg| Ok(self.codegen_expr(arg)?.unwrap().try_into().unwrap()))
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
Ok(self Ok(self
.builder .builder
.build_call(function, &args, "call") .build_call(function, &args, "call")
.try_as_basic_value() .try_as_basic_value()
.left() .left()
.unwrap() .map(|val| val.into()))
.into())
} else { } else {
todo!() todo!()
} }
@ -216,7 +242,7 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
let function = self.codegen_function(&fname, args, body)?; let function = self.codegen_function(&fname, args, body)?;
self.builder.position_at_end(cur_block); self.builder.position_at_end(cur_block);
self.env.restore(env); self.env.restore(env);
Ok(function.into()) Ok(Some(function.into()))
} }
} }
} }
@ -227,15 +253,17 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
args: &'ast [(Ident<'ast>, Type)], args: &'ast [(Ident<'ast>, Type)],
body: &'ast Expr<'ast, Type>, body: &'ast Expr<'ast, Type>,
) -> Result<FunctionValue<'ctx>> { ) -> Result<FunctionValue<'ctx>> {
let arg_types = args
.iter()
.filter_map(|(_, at)| self.codegen_type(at))
.collect::<Vec<_>>();
self.new_function( self.new_function(
name, name,
self.codegen_type(body.type_()).fn_type( match self.codegen_type(body.type_()) {
args.iter() Some(ret_ty) => ret_ty.fn_type(&arg_types, false),
.map(|(_, at)| self.codegen_type(at)) None => self.context.void_type().fn_type(&arg_types, false),
.collect::<Vec<_>>() },
.as_slice(),
false,
),
); );
self.env.push(); self.env.push();
for (i, (arg, _)) in args.iter().enumerate() { for (i, (arg, _)) in args.iter().enumerate() {
@ -244,9 +272,9 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
self.cur_function().get_nth_param(i as u32).unwrap().into(), self.cur_function().get_nth_param(i as u32).unwrap().into(),
); );
} }
let res = self.codegen_expr(body)?.try_into().unwrap(); let res = self.codegen_expr(body)?;
self.env.pop(); self.env.pop();
Ok(self.finish_function(&res)) Ok(self.finish_function(res.map(|av| av.try_into().unwrap()).as_ref()))
} }
pub fn codegen_extern( pub fn codegen_extern(
@ -255,15 +283,16 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
args: &'ast [Type], args: &'ast [Type],
ret: &'ast Type, ret: &'ast Type,
) -> Result<()> { ) -> Result<()> {
let arg_types = args
.iter()
.map(|t| self.codegen_type(t).unwrap())
.collect::<Vec<_>>();
self.module.add_function( self.module.add_function(
name, name,
self.codegen_type(ret).fn_type( match self.codegen_type(ret) {
&args Some(ret_ty) => ret_ty.fn_type(&arg_types, false),
.iter() None => self.context.void_type().fn_type(&arg_types, false),
.map(|t| self.codegen_type(t)) },
.collect::<Vec<_>>(),
false,
),
None, None,
); );
Ok(()) Ok(())
@ -287,29 +316,31 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
pub fn codegen_main(&mut self, expr: &'ast Expr<'ast, Type>) -> Result<()> { pub fn codegen_main(&mut self, expr: &'ast Expr<'ast, Type>) -> Result<()> {
self.new_function("main", self.context.i64_type().fn_type(&[], false)); self.new_function("main", self.context.i64_type().fn_type(&[], false));
let res = self.codegen_expr(expr)?.try_into().unwrap(); let res = self.codegen_expr(expr)?;
if *expr.type_() != Type::Int { if *expr.type_() != Type::Int {
self.builder self.builder
.build_return(Some(&self.context.i64_type().const_int(0, false))); .build_return(Some(&self.context.i64_type().const_int(0, false)));
} else { } else {
self.finish_function(&res); self.finish_function(res.map(|r| r.try_into().unwrap()).as_ref());
} }
Ok(()) Ok(())
} }
fn codegen_type(&self, type_: &'ast Type) -> BasicTypeEnum<'ctx> { fn codegen_type(&self, type_: &'ast Type) -> Option<BasicTypeEnum<'ctx>> {
// TODO // TODO
match type_ { match type_ {
Type::Int => self.context.i64_type().into(), Type::Int => Some(self.context.i64_type().into()),
Type::Float => self.context.f64_type().into(), Type::Float => Some(self.context.f64_type().into()),
Type::Bool => self.context.bool_type().into(), Type::Bool => Some(self.context.bool_type().into()),
Type::CString => self Type::CString => Some(
.context self.context
.i8_type() .i8_type()
.ptr_type(AddressSpace::Generic) .ptr_type(AddressSpace::Generic)
.into(), .into(),
),
Type::Function(_) => todo!(), Type::Function(_) => todo!(),
Type::Var(_) => unreachable!(), Type::Var(_) => unreachable!(),
Type::Unit => None,
} }
} }

View file

@ -8,7 +8,7 @@ use test_strategy::Arbitrary;
use crate::codegen::{self, Codegen}; use crate::codegen::{self, Codegen};
use crate::common::Result; use crate::common::Result;
use crate::passes::hir::monomorphize; use crate::passes::hir::{monomorphize, strip_positive_units};
use crate::{parser, tc}; use crate::{parser, tc};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Arbitrary)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Arbitrary)]
@ -55,9 +55,10 @@ pub struct CompilerOptions {
pub fn compile_file(input: &Path, output: &Path, options: &CompilerOptions) -> Result<()> { pub fn compile_file(input: &Path, output: &Path, options: &CompilerOptions) -> Result<()> {
let src = fs::read_to_string(input)?; let src = fs::read_to_string(input)?;
let (_, decls) = parser::toplevel(&src)?; // TODO: statements let (_, decls) = parser::toplevel(&src)?;
let mut decls = tc::typecheck_toplevel(decls)?; let mut decls = tc::typecheck_toplevel(decls)?;
monomorphize::run_toplevel(&mut decls); monomorphize::run_toplevel(&mut decls);
strip_positive_units::run_toplevel(&mut decls);
let context = codegen::Context::create(); let context = codegen::Context::create();
let mut codegen = Codegen::new( let mut codegen = Codegen::new(

View file

@ -30,6 +30,7 @@ impl<'a> Interpreter<'a> {
Expr::Literal(Literal::Int(i), _) => Ok((*i).into()), Expr::Literal(Literal::Int(i), _) => Ok((*i).into()),
Expr::Literal(Literal::Bool(b), _) => Ok((*b).into()), Expr::Literal(Literal::Bool(b), _) => Ok((*b).into()),
Expr::Literal(Literal::String(s), _) => Ok(s.clone().into()), Expr::Literal(Literal::String(s), _) => Ok(s.clone().into()),
Expr::Literal(Literal::Unit, _) => unreachable!(),
Expr::UnaryOp { op, rhs, .. } => { Expr::UnaryOp { op, rhs, .. } => {
let rhs = self.eval(rhs)?; let rhs = self.eval(rhs)?;
match op { match op {

View file

@ -186,7 +186,9 @@ named!(string(&str) -> Literal, preceded!(
) )
)); ));
named!(literal(&str) -> Literal, alt!(int | bool_ | string)); named!(unit(&str) -> Literal, map!(complete!(tag!("()")), |_| Literal::Unit));
named!(literal(&str) -> Literal, alt!(int | bool_ | string | unit));
named!(literal_expr(&str) -> Expr, map!(literal, Expr::Literal)); named!(literal_expr(&str) -> Expr, map!(literal, Expr::Literal));
@ -270,7 +272,6 @@ named!(funcref(&str) -> Expr, alt!(
named!(no_arg_call(&str) -> Expr, do_parse!( named!(no_arg_call(&str) -> Expr, do_parse!(
fun: funcref fun: funcref
>> multispace0
>> complete!(tag!("()")) >> complete!(tag!("()"))
>> (Expr::Call { >> (Expr::Call {
fun: Box::new(fun), fun: Box::new(fun),
@ -431,6 +432,11 @@ pub(crate) mod tests {
} }
} }
#[test]
fn unit() {
assert_eq!(test_parse!(expr, "()"), Expr::Literal(Literal::Unit));
}
#[test] #[test]
fn bools() { fn bools() {
assert_eq!( assert_eq!(
@ -515,6 +521,18 @@ pub(crate) mod tests {
); );
} }
#[test]
fn unit_call() {
let res = test_parse!(expr, "f ()");
assert_eq!(
res,
Expr::Call {
fun: ident_expr("f"),
args: vec![Expr::Literal(Literal::Unit)]
}
)
}
#[test] #[test]
fn call_with_args() { fn call_with_args() {
let res = test_parse!(expr, "f x 1"); let res = test_parse!(expr, "f x 1");

View file

@ -1,9 +1,9 @@
use nom::character::complete::{multispace0, multispace1}; use nom::character::complete::{multispace0, multispace1};
use nom::error::{ErrorKind, ParseError}; use nom::error::{ErrorKind, ParseError};
use nom::{alt, char, complete, do_parse, many0, named, separated_list0, tag, terminated}; use nom::{alt, char, complete, do_parse, eof, many0, named, separated_list0, tag, terminated};
#[macro_use] #[macro_use]
mod macros; pub(crate) mod macros;
mod expr; mod expr;
mod type_; mod type_;
@ -136,7 +136,11 @@ named!(pub decl(&str) -> Decl, alt!(
extern_decl extern_decl
)); ));
named!(pub toplevel(&str) -> Vec<Decl>, terminated!(many0!(decl), multispace0)); named!(pub toplevel(&str) -> Vec<Decl>, do_parse!(
decls: many0!(decl)
>> multispace0
>> eof!()
>> (decls)));
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
@ -215,4 +219,21 @@ mod tests {
}] }]
) )
} }
#[test]
fn return_unit() {
assert_eq!(
test_parse!(decl, "fn g _ = ()"),
Decl::Fun {
name: "g".try_into().unwrap(),
body: Fun {
args: vec![Arg {
ident: "_".try_into().unwrap(),
type_: None,
}],
body: Expr::Literal(Literal::Unit),
},
}
)
}
} }

View file

@ -29,6 +29,7 @@ named!(pub type_(&str) -> Type, alt!(
tag!("float") => { |_| Type::Float } | tag!("float") => { |_| Type::Float } |
tag!("bool") => { |_| Type::Bool } | tag!("bool") => { |_| Type::Bool } |
tag!("cstring") => { |_| Type::CString } | tag!("cstring") => { |_| Type::CString } |
tag!("()") => { |_| Type::Unit } |
function_type => { |ft| Type::Function(ft) }| function_type => { |ft| Type::Function(ft) }|
ident => { |id| Type::Var(id) } | ident => { |id| Type::Var(id) } |
delimited!( delimited!(
@ -51,6 +52,7 @@ mod tests {
assert_eq!(test_parse!(type_, "float"), Type::Float); assert_eq!(test_parse!(type_, "float"), Type::Float);
assert_eq!(test_parse!(type_, "bool"), Type::Bool); assert_eq!(test_parse!(type_, "bool"), Type::Bool);
assert_eq!(test_parse!(type_, "cstring"), Type::CString); assert_eq!(test_parse!(type_, "cstring"), Type::CString);
assert_eq!(test_parse!(type_, "()"), Type::Unit);
} }
#[test] #[test]

View file

@ -4,6 +4,7 @@ use crate::ast::hir::{Binding, Decl, Expr};
use crate::ast::{BinaryOperator, Ident, Literal, UnaryOperator}; use crate::ast::{BinaryOperator, Ident, Literal, UnaryOperator};
pub(crate) mod monomorphize; pub(crate) mod monomorphize;
pub(crate) mod strip_positive_units;
pub(crate) trait Visitor<'a, 'ast, T: 'ast>: Sized + 'a { pub(crate) trait Visitor<'a, 'ast, T: 'ast>: Sized + 'a {
type Error; type Error;
@ -53,7 +54,12 @@ pub(crate) trait Visitor<'a, 'ast, T: 'ast>: Sized + 'a {
Ok(()) Ok(())
} }
fn pre_visit_expr(&mut self, _expr: &mut Expr<'ast, T>) -> Result<(), Self::Error> {
Ok(())
}
fn visit_expr(&mut self, expr: &mut Expr<'ast, T>) -> Result<(), Self::Error> { fn visit_expr(&mut self, expr: &mut Expr<'ast, T>) -> Result<(), Self::Error> {
self.pre_visit_expr(expr)?;
match expr { match expr {
Expr::Ident(id, t) => { Expr::Ident(id, t) => {
self.visit_ident(id)?; self.visit_ident(id)?;
@ -140,6 +146,17 @@ pub(crate) trait Visitor<'a, 'ast, T: 'ast>: Sized + 'a {
Ok(()) Ok(())
} }
fn post_visit_fun_decl(
&mut self,
_name: &mut Ident<'ast>,
_type_args: &mut Vec<Ident>,
_args: &mut Vec<(Ident, T)>,
_body: &mut Box<Expr<T>>,
_type_: &mut T,
) -> Result<(), Self::Error> {
Ok(())
}
fn visit_decl(&mut self, decl: &'a mut Decl<'ast, T>) -> Result<(), Self::Error> { fn visit_decl(&mut self, decl: &'a mut Decl<'ast, T>) -> Result<(), Self::Error> {
match decl { match decl {
Decl::Fun { Decl::Fun {
@ -150,15 +167,16 @@ pub(crate) trait Visitor<'a, 'ast, T: 'ast>: Sized + 'a {
type_, type_,
} => { } => {
self.visit_ident(name)?; self.visit_ident(name)?;
for type_arg in type_args { for type_arg in type_args.iter_mut() {
self.visit_ident(type_arg)?; self.visit_ident(type_arg)?;
} }
for (arg, t) in args { for (arg, t) in args.iter_mut() {
self.visit_ident(arg)?; self.visit_ident(arg)?;
self.visit_type(t)?; self.visit_type(t)?;
} }
self.visit_expr(body)?; self.visit_expr(body)?;
self.visit_type(type_)?; self.visit_type(type_)?;
self.post_visit_fun_decl(name, type_args, args, body, type_)?;
} }
Decl::Extern { Decl::Extern {
name, name,

View file

@ -0,0 +1,189 @@
use std::collections::HashMap;
use std::mem;
use ast::hir::Binding;
use ast::Literal;
use void::{ResultVoidExt, Void};
use crate::ast::hir::{Decl, Expr};
use crate::ast::{self, Ident};
use super::Visitor;
/// Strip all values with a unit type in positive (non-return) position
pub(crate) struct StripPositiveUnits {}
impl<'a, 'ast> Visitor<'a, 'ast, ast::Type<'ast>> for StripPositiveUnits {
type Error = Void;
fn pre_visit_expr(
&mut self,
expr: &mut Expr<'ast, ast::Type<'ast>>,
) -> Result<(), Self::Error> {
let mut extracted = vec![];
if let Expr::Call { args, .. } = expr {
// TODO(grfn): replace with drain_filter once it's stabilized
let mut i = 0;
while i != args.len() {
if args[i].type_() == &ast::Type::Unit {
let expr = args.remove(i);
if !matches!(expr, Expr::Literal(Literal::Unit, _)) {
extracted.push(expr)
};
} else {
i += 1
}
}
}
if !extracted.is_empty() {
let body = mem::replace(expr, Expr::Literal(Literal::Unit, ast::Type::Unit));
*expr = Expr::Let {
bindings: extracted
.into_iter()
.map(|expr| Binding {
ident: Ident::from_str_unchecked("___discarded"),
type_: expr.type_().clone(),
body: expr,
})
.collect(),
type_: body.type_().clone(),
body: Box::new(body),
};
}
Ok(())
}
fn post_visit_call(
&mut self,
_fun: &mut Expr<'ast, ast::Type<'ast>>,
_type_args: &mut HashMap<Ident<'ast>, ast::Type<'ast>>,
args: &mut Vec<Expr<'ast, ast::Type<'ast>>>,
) -> Result<(), Self::Error> {
args.retain(|arg| arg.type_() != &ast::Type::Unit);
Ok(())
}
fn visit_type(&mut self, type_: &mut ast::Type<'ast>) -> Result<(), Self::Error> {
if let ast::Type::Function(ft) = type_ {
ft.args.retain(|a| a != &ast::Type::Unit);
}
Ok(())
}
fn post_visit_fun_decl(
&mut self,
_name: &mut Ident<'ast>,
_type_args: &mut Vec<Ident>,
args: &mut Vec<(Ident, ast::Type<'ast>)>,
_body: &mut Box<Expr<ast::Type<'ast>>>,
_type_: &mut ast::Type<'ast>,
) -> Result<(), Self::Error> {
args.retain(|(_, ty)| ty != &ast::Type::Unit);
Ok(())
}
}
pub(crate) fn run_toplevel<'a>(toplevel: &mut Vec<Decl<'a, ast::Type<'a>>>) {
let mut pass = StripPositiveUnits {};
for decl in toplevel.iter_mut() {
pass.visit_decl(decl).void_unwrap();
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parser::toplevel;
use crate::tc::typecheck_toplevel;
use pretty_assertions::assert_eq;
#[test]
fn unit_only_arg() {
let (_, program) = toplevel(
"ty f : fn () -> int
fn f _ = 1
ty main : fn -> int
fn main = f ()",
)
.unwrap();
let (_, expected) = toplevel(
"ty f : fn -> int
fn f = 1
ty main : fn -> int
fn main = f()",
)
.unwrap();
let expected = typecheck_toplevel(expected).unwrap();
let mut program = typecheck_toplevel(program).unwrap();
run_toplevel(&mut program);
assert_eq!(program, expected);
}
#[test]
fn unit_and_other_arg() {
let (_, program) = toplevel(
"ty f : fn (), int -> int
fn f _ x = x
ty main : fn -> int
fn main = f () 1",
)
.unwrap();
let (_, expected) = toplevel(
"ty f : fn int -> int
fn f x = x
ty main : fn -> int
fn main = f 1",
)
.unwrap();
let expected = typecheck_toplevel(expected).unwrap();
let mut program = typecheck_toplevel(program).unwrap();
run_toplevel(&mut program);
assert_eq!(program, expected);
}
#[test]
fn unit_expr_and_other_arg() {
let (_, program) = toplevel(
"ty f : fn (), int -> int
fn f _ x = x
ty g : fn int -> ()
fn g _ = ()
ty main : fn -> int
fn main = f (g 2) 1",
)
.unwrap();
let (_, expected) = toplevel(
"ty f : fn int -> int
fn f x = x
ty g : fn int -> ()
fn g _ = ()
ty main : fn -> int
fn main = let ___discarded = g 2 in f 1",
)
.unwrap();
assert_eq!(expected.len(), 6);
let expected = typecheck_toplevel(expected).unwrap();
let mut program = typecheck_toplevel(program).unwrap();
run_toplevel(&mut program);
assert_eq!(program, expected);
}
}

View file

@ -85,6 +85,7 @@ pub enum Type {
Exist(TyVar), Exist(TyVar),
Nullary(NullaryType), Nullary(NullaryType),
Prim(PrimType), Prim(PrimType),
Unit,
Fun { Fun {
args: Vec<Type>, args: Vec<Type>,
ret: Box<Type>, ret: Box<Type>,
@ -96,6 +97,7 @@ impl<'a> TryFrom<Type> for ast::Type<'a> {
fn try_from(value: Type) -> result::Result<Self, Self::Error> { fn try_from(value: Type) -> result::Result<Self, Self::Error> {
match value { match value {
Type::Unit => Ok(ast::Type::Unit),
Type::Univ(_) => todo!(), Type::Univ(_) => todo!(),
Type::Exist(_) => Err(value), Type::Exist(_) => Err(value),
Type::Nullary(_) => todo!(), Type::Nullary(_) => todo!(),
@ -126,6 +128,7 @@ impl Display for Type {
Type::Univ(TyVar(n)) => write!(f, "∀{}", n), Type::Univ(TyVar(n)) => write!(f, "∀{}", n),
Type::Exist(TyVar(n)) => write!(f, "∃{}", n), Type::Exist(TyVar(n)) => write!(f, "∃{}", n),
Type::Fun { args, ret } => write!(f, "fn {} -> {}", args.iter().join(", "), ret), Type::Fun { args, ret } => write!(f, "fn {} -> {}", args.iter().join(", "), ret),
Type::Unit => write!(f, "()"),
} }
} }
} }
@ -171,6 +174,7 @@ impl<'ast> Typechecker<'ast> {
Literal::Int(_) => Type::Prim(PrimType::Int), Literal::Int(_) => Type::Prim(PrimType::Int),
Literal::Bool(_) => Type::Prim(PrimType::Bool), Literal::Bool(_) => Type::Prim(PrimType::Bool),
Literal::String(_) => Type::Prim(PrimType::CString), Literal::String(_) => Type::Prim(PrimType::CString),
Literal::Unit => Type::Unit,
}; };
Ok(hir::Expr::Literal(lit.to_owned(), type_)) Ok(hir::Expr::Literal(lit.to_owned(), type_))
} }
@ -377,6 +381,7 @@ impl<'ast> Typechecker<'ast> {
fn unify(&mut self, ty1: &Type, ty2: &Type) -> Result<Type> { fn unify(&mut self, ty1: &Type, ty2: &Type) -> Result<Type> {
match (ty1, ty2) { match (ty1, ty2) {
(Type::Unit, Type::Unit) => Ok(Type::Unit),
(Type::Exist(tv), ty) | (ty, Type::Exist(tv)) => match self.resolve_tv(*tv) { (Type::Exist(tv), ty) | (ty, Type::Exist(tv)) => match self.resolve_tv(*tv) {
Some(existing_ty) if self.types_match(ty, &existing_ty) => Ok(ty.clone()), Some(existing_ty) if self.types_match(ty, &existing_ty) => Ok(ty.clone()),
Some(var @ ast::Type::Var(_)) => { Some(var @ ast::Type::Var(_)) => {
@ -466,6 +471,7 @@ impl<'ast> Typechecker<'ast> {
let ret = match ty { let ret = match ty {
Type::Exist(tv) => self.resolve_tv(tv).ok_or(Error::AmbiguousType(tv)), Type::Exist(tv) => self.resolve_tv(tv).ok_or(Error::AmbiguousType(tv)),
Type::Univ(tv) => Ok(ast::Type::Var(self.name_univ(tv))), Type::Univ(tv) => Ok(ast::Type::Var(self.name_univ(tv))),
Type::Unit => Ok(ast::Type::Unit),
Type::Nullary(_) => todo!(), Type::Nullary(_) => todo!(),
Type::Prim(pr) => Ok(pr.into()), Type::Prim(pr) => Ok(pr.into()),
Type::Fun { args, ret } => Ok(ast::Type::Function(ast::FunctionType { Type::Fun { args, ret } => Ok(ast::Type::Function(ast::FunctionType {
@ -496,6 +502,7 @@ impl<'ast> Typechecker<'ast> {
} }
Type::Nullary(_) => todo!(), Type::Nullary(_) => todo!(),
Type::Prim(pr) => break Some((*pr).into()), Type::Prim(pr) => break Some((*pr).into()),
Type::Unit => break Some(ast::Type::Unit),
Type::Fun { args, ret } => todo!(), Type::Fun { args, ret } => todo!(),
} }
} }
@ -503,6 +510,7 @@ impl<'ast> Typechecker<'ast> {
fn type_from_ast_type(&mut self, ast_type: ast::Type<'ast>) -> Type { fn type_from_ast_type(&mut self, ast_type: ast::Type<'ast>) -> Type {
match ast_type { match ast_type {
ast::Type::Unit => Type::Unit,
ast::Type::Int => INT, ast::Type::Int => INT,
ast::Type::Float => FLOAT, ast::Type::Float => FLOAT,
ast::Type::Bool => BOOL, ast::Type::Bool => BOOL,
@ -570,6 +578,8 @@ impl<'ast> Typechecker<'ast> {
} }
(Type::Univ(_), _) => false, (Type::Univ(_), _) => false,
(Type::Exist(_), _) => false, (Type::Exist(_), _) => false,
(Type::Unit, ast::Type::Unit) => true,
(Type::Unit, _) => false,
(Type::Nullary(_), _) => todo!(), (Type::Nullary(_), _) => todo!(),
(Type::Prim(pr), ty) => ast::Type::from(*pr) == *ty, (Type::Prim(pr), ty) => ast::Type::from(*pr) == *ty,
(Type::Fun { args, ret }, ast::Type::Function(ft)) => { (Type::Fun { args, ret }, ast::Type::Function(ft)) => {

View file

@ -24,6 +24,11 @@ const FIXTURES: &[Fixture] = &[
exit_code: 0, exit_code: 0,
expected_output: "foobar\n", expected_output: "foobar\n",
}, },
Fixture {
name: "units",
exit_code: 0,
expected_output: "hi\n",
},
]; ];
#[test] #[test]