Skip to content

Commit

Permalink
Support normalizing definitions from other modules (#370)
Browse files Browse the repository at this point in the history
* Support normalizing definitions from other modules

* Fix URI scheme change
  • Loading branch information
timsueberkrueb authored Nov 13, 2024
1 parent 44229e7 commit 2f6ce45
Show file tree
Hide file tree
Showing 23 changed files with 299 additions and 317 deletions.
26 changes: 11 additions & 15 deletions examples/boolrep.pol
Original file line number Diff line number Diff line change
@@ -1,28 +1,24 @@
data Bool { T, F }

def Bool.not: Bool {
T => F,
F => T
}
use "../std/data/bool.pol"

data BoolRep(x: Bool) {
TrueRep: BoolRep(T),
FalseRep: BoolRep(F)
FalseRep: BoolRep(F),
}

def BoolRep(x).extract(x: Bool): Bool {
TrueRep => T,
FalseRep => F
FalseRep => F,
}

data Top { Unit }

def Top.flipRep(x: Bool, rep: BoolRep(x)): BoolRep(x.not) {
Unit =>
rep.match {
TrueRep => FalseRep,
FalseRep => TrueRep
}
def Top.flipRep(x: Bool, rep: BoolRep(x)): BoolRep(x.neg) {
Unit => rep.match {
TrueRep => FalseRep,
FalseRep => TrueRep,
}
}

def Top.example: Bool { Unit => Unit.flipRep(T, TrueRep).extract(F) }
def Top.example: Bool {
Unit => Unit.flipRep(T, TrueRep).extract(F)
}
2 changes: 1 addition & 1 deletion lang/ast/src/ctx/def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ pub trait Context: Sized {

/// Interface to bind variables to anything that has a `Context`
///
/// There are two use cases for this trait.
/// There are two ways to use this trait.
///
/// Case 1: You have a type that implements `Context`.
/// Then, a blanket impl ensures that this type also implements `BindContext`.
Expand Down
23 changes: 1 addition & 22 deletions lang/ast/src/decls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,31 +231,10 @@ impl Module {
out
}

pub fn lookup_decl(&self, name: &IdBind) -> Option<&Decl> {
pub fn lookup_decl(&self, name: &IdBound) -> Option<&Decl> {
self.decls.iter().find(|decl| decl.ident() == name)
}

pub fn lookup_let(&self, name: &IdBind) -> Option<&Let> {
self.decls.iter().find_map(|decl| match decl {
Decl::Let(tl_let) if tl_let.name == *name => Some(tl_let),
_ => None,
})
}

pub fn lookup_def(&self, name: &IdBind) -> Option<&Def> {
self.decls.iter().find_map(|decl| match decl {
Decl::Def(def) if def.name == *name => Some(def),
_ => None,
})
}

pub fn lookup_codef(&self, name: &IdBind) -> Option<&Codef> {
self.decls.iter().find_map(|decl| match decl {
Decl::Codef(codef) if codef.name == *name => Some(codef),
_ => None,
})
}

pub fn find_main(&self) -> Option<Box<Exp>> {
self.decls.iter().find_map(|decl| match decl {
Decl::Let(tl_let) if tl_let.is_main() => Some(tl_let.body.clone()),
Expand Down
39 changes: 24 additions & 15 deletions lang/driver/src/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ pub struct Database {
/// The typechecked AST of a module
pub ast: Cache<Result<Arc<ast::Module>, Error>>,
/// The type info table constructed during typechecking
pub type_info_table: Cache<elaborator::ModuleTypeInfoTable>,
pub module_type_info_table: Cache<elaborator::ModuleTypeInfoTable>,
/// Hover information for spans
pub info_by_id: Cache<Lapper<u32, Info>>,
/// Spans of top-level items
Expand Down Expand Up @@ -179,8 +179,23 @@ impl Database {
//
//

pub fn type_info_table(&mut self, uri: &Url) -> Result<ModuleTypeInfoTable, Error> {
match self.type_info_table.get_unless_stale(uri) {
pub fn type_info_table(&mut self, uri: &Url) -> Result<TypeInfoTable, Error> {
let deps = self.deps(uri)?;

// Compute the type info table
let mut info_table = TypeInfoTable::default();
let mod_info_table = self.module_type_info_table(uri)?;
info_table.insert(uri.clone(), mod_info_table);
for dep_url in deps {
let mod_info_table = self.module_type_info_table(&dep_url)?;
info_table.insert(dep_url.clone(), mod_info_table);
}

Ok(info_table)
}

pub fn module_type_info_table(&mut self, uri: &Url) -> Result<ModuleTypeInfoTable, Error> {
match self.module_type_info_table.get_unless_stale(uri) {
Some(table) => {
log::debug!("Found type info table in cache: {}", uri);
Ok(table.clone())
Expand All @@ -193,7 +208,7 @@ impl Database {
log::debug!("Recomputing type info table for: {}", uri);
let ust = self.ust(uri)?;
let info_table = build_type_info_table(&ust);
self.type_info_table.insert(uri.clone(), info_table.clone());
self.module_type_info_table.insert(uri.clone(), info_table.clone());
Ok(info_table)
}

Expand All @@ -213,16 +228,9 @@ impl Database {

pub fn recompute_ast(&mut self, uri: &Url) -> Result<Arc<ast::Module>, Error> {
log::debug!("Recomputing ast for: {}", uri);
let deps = self.deps(uri)?;

// Compute the type info table
let mut info_table = TypeInfoTable::default();
let mod_info_table = self.type_info_table(uri)?;
info_table.insert(uri.clone(), mod_info_table);
for dep_url in deps {
let mod_info_table = self.type_info_table(&dep_url)?;
info_table.insert(dep_url.clone(), mod_info_table);
}
let info_table = self.type_info_table(uri)?;

// Typecheck module
let ust = self.ust(uri).map(|x| (*x).clone())?;
Expand Down Expand Up @@ -315,7 +323,7 @@ impl Database {
symbol_table: Cache::default(),
ust: Cache::default(),
ast: Cache::default(),
type_info_table: Cache::default(),
module_type_info_table: Cache::default(),
info_by_id: Cache::default(),
item_by_id: Cache::default(),
}
Expand Down Expand Up @@ -358,7 +366,7 @@ impl Database {
self.symbol_table.invalidate(uri);
self.ust.invalidate(uri);
self.ast.invalidate(uri);
self.type_info_table.invalidate(uri);
self.module_type_info_table.invalidate(uri);
self.info_by_id.invalidate(uri);
self.item_by_id.invalidate(uri);
}
Expand All @@ -367,10 +375,11 @@ impl Database {
let ast = self.ast(uri)?;

let main = ast.find_main();
let info_table = self.type_info_table(uri)?;

match main {
Some(exp) => {
let nf = exp.normalize_in_empty_env(&ast)?;
let nf = exp.normalize_in_empty_env(&Rc::new(info_table))?;
Ok(Some(nf))
}
None => Ok(None),
Expand Down
22 changes: 13 additions & 9 deletions lang/driver/src/xfunc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use transformations::LiftResult;
use transformations::Rename;

use ast::*;
use lowering::DeclMeta;
use parser::cst;
use transformations::matrix;
use transformations::result::XfuncError;
Expand All @@ -20,14 +19,16 @@ pub struct Xfunc {
}

impl Database {
pub fn all_type_names(&mut self, uri: &Url) -> Result<Vec<cst::Ident>, crate::Error> {
let symbol_table = self.symbol_table(uri)?;
Ok(symbol_table
pub fn all_declared_type_names(&mut self, uri: &Url) -> Result<Vec<cst::Ident>, crate::Error> {
let ust = self.cst(uri)?;
Ok(ust
.decls
.iter()
.filter(|(_, decl_meta)| {
matches!(decl_meta, DeclMeta::Data { .. } | DeclMeta::Codata { .. })
.filter_map(|decl| match decl {
cst::decls::Decl::Data(data) => Some(data.name.clone()),
cst::decls::Decl::Codata(codata) => Some(codata.name.clone()),
_ => None,
})
.map(|(name, _)| name.clone())
.collect())
}

Expand Down Expand Up @@ -95,7 +96,8 @@ fn generate_edits(
// Here we rewrite the entire (co)data declaration and its associated (co)definitions
let new_items = Module {
uri: module.uri.clone(),
use_decls: module.use_decls.clone(),
// Use declarations don't change, and we are only printing an excerpt of the module
use_decls: vec![],
decls: new_decls,
meta_vars: module.meta_vars.clone(),
};
Expand All @@ -106,7 +108,9 @@ fn generate_edits(
// Edits for all other declarations that have been touched
// Here we surgically rewrite only the declarations that have been changed
for name in dirty_decls {
let decl = module.lookup_decl(&name).unwrap();
let decl = module
.lookup_decl(&IdBound { span: None, id: name.id.clone(), uri: module.uri.clone() })
.unwrap();
let mut decl = decl.clone();
decl.rename();
let span = original.decl_spans[&name];
Expand Down
43 changes: 20 additions & 23 deletions lang/elaborator/src/normalizer/env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,16 @@ use crate::normalizer::val::*;
#[derive(Debug, Clone, Derivative)]
#[derivative(Eq, PartialEq, Hash)]
pub struct Env {
ctx: GenericCtx<Box<Val>>,
}

impl From<GenericCtx<Box<Val>>> for Env {
fn from(value: GenericCtx<Box<Val>>) -> Self {
Env { ctx: value }
}
/// Environment for locally bound variables
bound_vars: GenericCtx<Box<Val>>,
}

impl Context for Env {
type Elem = Box<Val>;

fn lookup<V: Into<Var>>(&self, idx: V) -> Self::Elem {
let lvl = self.ctx.var_to_lvl(idx.into());
self.ctx
let lvl = self.bound_vars.var_to_lvl(idx.into());
self.bound_vars
.bound
.get(lvl.fst)
.and_then(|ctx| ctx.get(lvl.snd))
Expand All @@ -38,15 +33,15 @@ impl Context for Env {
}

fn push_telescope(&mut self) {
self.ctx.bound.push(vec![]);
self.bound_vars.bound.push(vec![]);
}

fn pop_telescope(&mut self) {
self.ctx.bound.pop().unwrap();
self.bound_vars.bound.pop().unwrap();
}

fn push_binder(&mut self, elem: Self::Elem) {
self.ctx
self.bound_vars
.bound
.last_mut()
.expect("Cannot push without calling push_telescope first")
Expand All @@ -55,7 +50,7 @@ impl Context for Env {

fn pop_binder(&mut self, _elem: Self::Elem) {
let err = "Cannot pop from empty context";
self.ctx.bound.last_mut().expect(err).pop().expect(err);
self.bound_vars.bound.last_mut().expect(err).pop().expect(err);
}
}

Expand All @@ -66,24 +61,26 @@ impl ContextElem<Env> for &Box<Val> {
}

impl Env {
pub fn empty() -> Self {
Self { bound_vars: GenericCtx::empty() }
}

pub fn from_vec(bound: Vec<Vec<Box<Val>>>) -> Self {
Self { bound_vars: GenericCtx::from(bound) }
}

pub(super) fn for_each<F>(&mut self, f: F)
where
F: Fn(&mut Box<Val>),
{
for outer in self.ctx.bound.iter_mut() {
for outer in self.bound_vars.bound.iter_mut() {
for inner in outer {
f(inner)
}
}
}
}

impl From<Vec<Vec<Box<Val>>>> for Env {
fn from(bound: Vec<Vec<Box<Val>>>) -> Self {
Self { ctx: bound.into() }
}
}

impl Shift for Env {
fn shift_in_range<R: ShiftRange>(&mut self, range: &R, by: (isize, isize)) {
self.for_each(|val| val.shift_in_range(range, by))
Expand Down Expand Up @@ -115,7 +112,7 @@ impl ToEnv for LevelCtx {
})
.collect();

Env::from(bound)
Env::from_vec(bound)
}
}

Expand All @@ -133,7 +130,7 @@ impl ToEnv for TypeCtx {
})
.collect();

Env::from(bound)
Env::from_vec(bound)
}
}

Expand All @@ -143,7 +140,7 @@ impl Print for Env {
cfg: &printer::PrintCfg,
alloc: &'a printer::Alloc<'a>,
) -> printer::Builder<'a> {
let iter = self.ctx.iter().map(|ctx| {
let iter = self.bound_vars.iter().map(|ctx| {
alloc
.intersperse(ctx.iter().map(|typ| typ.print(cfg, alloc)), alloc.text(COMMA))
.brackets()
Expand Down
Loading

0 comments on commit 2f6ce45

Please sign in to comment.