Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Param type to Function #559

Merged
merged 4 commits into from
Apr 13, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions crates/mun_compiler/src/diagnostics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,82 +29,82 @@ mod tests {

#[test]
fn test_syntax_error() {
insta::assert_display_snapshot!(compilation_errors("\n\nfn main(\n struct Foo\n"));
insta::assert_snapshot!(compilation_errors("\n\nfn main(\n struct Foo\n"));
}

#[test]
fn test_unresolved_value_error() {
insta::assert_display_snapshot!(compilation_errors(
insta::assert_snapshot!(compilation_errors(
"\n\nfn main() {\nlet b = a;\n\nlet d = c;\n}"
));
}

#[test]
fn test_unresolved_type_error() {
insta::assert_display_snapshot!(compilation_errors(
insta::assert_snapshot!(compilation_errors(
"\n\nfn main() {\nlet a = Foo{};\n\nlet b = Bar{};\n}"
));
}

#[test]
fn test_leaked_private_type_error_function() {
insta::assert_display_snapshot!(compilation_errors(
insta::assert_snapshot!(compilation_errors(
"\n\nstruct Foo;\n pub fn Bar() -> Foo { Foo } \n fn main() {}"
));
}

#[test]
fn test_expected_function_error() {
insta::assert_display_snapshot!(compilation_errors(
insta::assert_snapshot!(compilation_errors(
"\n\nfn main() {\nlet a = Foo();\n\nlet b = Bar();\n}"
));
}

#[test]
fn test_mismatched_type_error() {
insta::assert_display_snapshot!(compilation_errors(
insta::assert_snapshot!(compilation_errors(
"\n\nfn main() {\nlet a: f64 = false;\n\nlet b: bool = 22;\n}"
));
}

#[test]
fn test_duplicate_definition_error() {
insta::assert_display_snapshot!(compilation_errors(
insta::assert_snapshot!(compilation_errors(
"\n\nfn foo(){}\n\nfn foo(){}\n\nstruct Bar;\n\nstruct Bar;\n\nfn BAZ(){}\n\nstruct BAZ;"
));
}

#[test]
fn test_possibly_uninitialized_variable_error() {
insta::assert_display_snapshot!(compilation_errors(
insta::assert_snapshot!(compilation_errors(
"\n\nfn main() {\nlet a;\nif 5>6 {\na = 5\n}\nlet b = a;\n}"
));
}

#[test]
fn test_access_unknown_field_error() {
insta::assert_display_snapshot!(compilation_errors(
insta::assert_snapshot!(compilation_errors(
"\n\nstruct Foo {\ni: bool\n}\n\nfn main() {\nlet a = Foo { i: false };\nlet b = a.t;\n}"
));
}

#[test]
fn test_free_type_alias_error() {
insta::assert_display_snapshot!(compilation_errors("\n\ntype Foo;"));
insta::assert_snapshot!(compilation_errors("\n\ntype Foo;"));
}

#[test]
fn test_type_alias_target_undeclared_error() {
insta::assert_display_snapshot!(compilation_errors("\n\ntype Foo = UnknownType;"));
insta::assert_snapshot!(compilation_errors("\n\ntype Foo = UnknownType;"));
}

#[test]
fn test_cyclic_type_alias_error() {
insta::assert_display_snapshot!(compilation_errors("\n\ntype Foo = Foo;"));
insta::assert_snapshot!(compilation_errors("\n\ntype Foo = Foo;"));
}

#[test]
fn test_expected_function() {
insta::assert_display_snapshot!(compilation_errors("\n\nfn foo() { let a = 3; a(); }"));
insta::assert_snapshot!(compilation_errors("\n\nfn foo() { let a = 3; a(); }"));
}
}
56 changes: 53 additions & 3 deletions crates/mun_hir/src/code_model/function.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{iter::once, sync::Arc};

use mun_syntax::ast::TypeAscriptionOwner;
use mun_syntax::{ast, ast::TypeAscriptionOwner};

use super::Module;
use crate::{
Expand All @@ -11,8 +11,8 @@ use crate::{
resolve::HasResolver,
type_ref::{LocalTypeRefId, TypeRefMap, TypeRefSourceMap},
visibility::RawVisibility,
Body, DefDatabase, DiagnosticSink, FileId, HasVisibility, HirDatabase, InferenceResult, Name,
Ty, Visibility,
Body, DefDatabase, DiagnosticSink, FileId, HasSource, HasVisibility, HirDatabase, InFile,
InferenceResult, Name, Ty, Visibility,
};

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)]
Expand Down Expand Up @@ -138,6 +138,20 @@ impl Function {
db.type_for_def(self.into(), Namespace::Values)
}

/// Returns the parameters of the function.
pub fn params(self, db: &dyn HirDatabase) -> Vec<Param> {
db.callable_sig(self.into())
.params()
.iter()
.enumerate()
.map(|(idx, ty)| Param {
func: self,
ty: ty.clone(),
idx,
})
.collect()
}

pub fn ret_type(self, db: &dyn HirDatabase) -> Ty {
let resolver = self.id.resolver(db.upcast());
let data = self.data(db.upcast());
Expand Down Expand Up @@ -166,6 +180,42 @@ impl Function {
}
}

#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct Param {
func: Function,
/// The index in parameter list, including self parameter.
idx: usize,
ty: Ty,
}

impl Param {
/// Returns the function to which this parameter belongs
pub fn parent_fn(&self) -> Function {
self.func
}

/// Returns the index of this parameter in the parameter list (including
/// self)
pub fn index(&self) -> usize {
self.idx
}

/// Returns the type of this parameter.
pub fn ty(&self) -> &Ty {
&self.ty
}

/// Returns the source of the parameter.
pub fn source(&self, db: &dyn HirDatabase) -> Option<InFile<ast::Param>> {
let InFile { file_id, value } = self.func.source(db.upcast());
let params = value.param_list()?;
params
.params()
.nth(self.idx)
.map(|value| InFile { file_id, value })
Wodann marked this conversation as resolved.
Show resolved Hide resolved
}
}

impl HasVisibility for Function {
fn visibility(&self, db: &dyn HirDatabase) -> Visibility {
self.data(db.upcast())
Expand Down
27 changes: 22 additions & 5 deletions crates/mun_hir/src/item_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ use std::{
sync::Arc,
};

use mun_syntax::{ast, AstNode};
use mun_syntax::ast;

use crate::{
arena::{Arena, Idx},
path::ImportAlias,
source_id::FileAstId,
source_id::{AstIdNode, FileAstId},
type_ref::{LocalTypeRefId, TypeRefMap},
visibility::RawVisibility,
DefDatabase, FileId, InFile, Name, Path,
Expand Down Expand Up @@ -112,6 +112,7 @@ impl ItemVisibilities {
struct ItemTreeData {
imports: Arena<Import>,
functions: Arena<Function>,
params: Arena<Param>,
structs: Arena<Struct>,
fields: Arena<Field>,
type_aliases: Arena<TypeAlias>,
Expand All @@ -122,7 +123,7 @@ struct ItemTreeData {

/// Trait implemented by all item nodes in the item tree.
pub trait ItemTreeNode: Clone {
type Source: AstNode + Into<ast::ModuleItem>;
type Source: AstIdNode + Into<ast::ModuleItem>;

/// Returns the AST id for this instance
fn ast_id(&self) -> FileAstId<Self::Source>;
Expand Down Expand Up @@ -244,7 +245,7 @@ macro_rules! impl_index {
};
}

impl_index!(fields: Field);
impl_index!(fields: Field, params: Param);

static VIS_PUB: RawVisibility = RawVisibility::Public;
static VIS_PRIV: RawVisibility = RawVisibility::This;
Expand Down Expand Up @@ -302,11 +303,22 @@ pub struct Function {
pub visibility: RawVisibilityId,
pub is_extern: bool,
pub types: TypeRefMap,
pub params: Box<[LocalTypeRefId]>,
pub params: IdRange<Param>,
pub ret_type: LocalTypeRefId,
pub ast_id: FileAstId<ast::FunctionDef>,
}

#[derive(Debug, Clone, Eq, PartialEq)]
pub struct Param {
pub type_ref: LocalTypeRefId,
pub ast_id: ParamAstId,
}

#[derive(Debug, Clone, Eq, PartialEq)]
pub enum ParamAstId {
Wodann marked this conversation as resolved.
Show resolved Hide resolved
Param(FileAstId<ast::Param>),
}

#[derive(Debug, Clone, Eq, PartialEq)]
pub struct Struct {
pub name: Name,
Expand Down Expand Up @@ -390,6 +402,11 @@ impl<T> IdRange<T> {
_p: PhantomData,
}
}

/// Returns true if the index range is empty
pub fn is_empty(&self) -> bool {
self.range.is_empty()
}
}

impl<T> Iterator for IdRange<T> {
Expand Down
23 changes: 18 additions & 5 deletions crates/mun_hir/src/item_tree/lower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ use smallvec::SmallVec;

use super::{
diagnostics, AssociatedItem, Field, Fields, Function, IdRange, Impl, ItemTree, ItemTreeData,
ItemTreeNode, ItemVisibilities, LocalItemTreeId, ModItem, RawVisibilityId, Struct, TypeAlias,
ItemTreeNode, ItemVisibilities, LocalItemTreeId, ModItem, Param, ParamAstId, RawVisibilityId,
Struct, TypeAlias,
};
use crate::{
arena::{Idx, RawId},
Expand Down Expand Up @@ -156,13 +157,19 @@ impl Context {
let mut types = TypeRefMap::builder();

// Lower all the params
let mut params = Vec::new();
let start_param_idx = self.next_param_idx();
if let Some(param_list) = func.param_list() {
for param in param_list.params() {
let ast_id = self.source_ast_id_map.ast_id(&param);
let type_ref = types.alloc_from_node_opt(param.ascribed_type().as_ref());
params.push(type_ref);
self.data.params.alloc(Param {
type_ref,
ast_id: ParamAstId::Param(ast_id),
});
}
}
let end_param_idx = self.next_param_idx();
let params = IdRange::new(start_param_idx..end_param_idx);

// Lowers the return type
let ret_type = match func.ret_type().and_then(|rt| rt.type_ref()) {
Expand All @@ -177,9 +184,9 @@ impl Context {
let res = Function {
name,
visibility,
types,
is_extern,
params: params.into_boxed_slice(),
types,
params,
ret_type,
ast_id,
};
Expand Down Expand Up @@ -313,6 +320,12 @@ impl Context {
let idx: u32 = self.data.fields.len().try_into().expect("too many fields");
Idx::from_raw(RawId::from(idx))
}

/// Returns the `Idx` of the next `Param`
fn next_param_idx(&self) -> Idx<Param> {
let idx: u32 = self.data.params.len().try_into().expect("too many params");
Idx::from_raw(RawId::from(idx))
}
}

/// Lowers a record field (e.g. `a:i32`)
Expand Down
10 changes: 7 additions & 3 deletions crates/mun_hir/src/item_tree/pretty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::{fmt, fmt::Write};

use crate::{
item_tree::{
Fields, Function, Impl, Import, ItemTree, LocalItemTreeId, ModItem, RawVisibilityId,
Fields, Function, Impl, Import, ItemTree, LocalItemTreeId, ModItem, Param, RawVisibilityId,
Struct, TypeAlias,
},
path::ImportAlias,
Expand Down Expand Up @@ -181,8 +181,12 @@ impl Printer<'_> {
write!(self, "(")?;
if !params.is_empty() {
self.indented(|this| {
for param in params.iter().copied() {
this.print_type_ref(param, types)?;
for param in params.clone() {
let Param {
type_ref,
ast_id: _,
} = &this.tree[param];
this.print_type_ref(*type_ref, types)?;
writeln!(this, ",")?;
}
Ok(())
Expand Down
Loading
Loading