From 77bf4a7633651d09282d2dc26c3c0cf9a5002f93 Mon Sep 17 00:00:00 2001
From: Rigidity <me@rigidnetwork.com>
Date: Wed, 17 Apr 2024 15:08:45 -0400
Subject: [PATCH] Split warnings and errors

---
 crates/rue-cli/src/main.rs         | 12 +++-
 crates/rue-compiler/src/error.rs   | 44 ++++++++------
 crates/rue-compiler/src/lib.rs     |  9 +--
 crates/rue-compiler/src/lowerer.rs | 93 ++++++++++++++----------------
 crates/rue-lsp/src/main.rs         | 15 ++---
 crates/rue-tests/src/main.rs       |  7 ++-
 6 files changed, 92 insertions(+), 88 deletions(-)

diff --git a/crates/rue-cli/src/main.rs b/crates/rue-cli/src/main.rs
index 458a4a8..d25d4b6 100644
--- a/crates/rue-cli/src/main.rs
+++ b/crates/rue-cli/src/main.rs
@@ -2,7 +2,7 @@ use std::fs;
 
 use clap::Parser;
 use clvmr::{run_program, serde::node_to_bytes, Allocator, ChiaDialect, NodePtr};
-use rue_compiler::compile;
+use rue_compiler::{compile, DiagnosticKind};
 use rue_parser::{line_col, parse, LineCol};
 
 /// The Rue language compiler and toolchain.
@@ -35,7 +35,15 @@ fn main() {
             let LineCol { line, col } = line_col(&source, error.span().start);
             let line = line + 1;
             let col = col + 1;
-            eprintln!("{} at {line}:{col}", error.info());
+
+            match error.kind() {
+                DiagnosticKind::Error(kind) => {
+                    eprintln!("Error: {} at {line}:{col}", kind)
+                }
+                DiagnosticKind::Warning(kind) => {
+                    eprintln!("Warning: {} at {line}:{col}", kind)
+                }
+            }
         }
         return;
     }
diff --git a/crates/rue-compiler/src/error.rs b/crates/rue-compiler/src/error.rs
index f68f81b..102646d 100644
--- a/crates/rue-compiler/src/error.rs
+++ b/crates/rue-compiler/src/error.rs
@@ -5,36 +5,48 @@ use thiserror::Error;
 #[derive(Debug)]
 pub struct Diagnostic {
     kind: DiagnosticKind,
-    info: DiagnosticInfo,
     span: Range<usize>,
 }
 
 impl Diagnostic {
-    pub fn new(kind: DiagnosticKind, info: DiagnosticInfo, span: Range<usize>) -> Self {
-        Self { kind, info, span }
+    pub fn new(kind: DiagnosticKind, span: Range<usize>) -> Self {
+        Self { kind, span }
     }
 
-    pub fn kind(&self) -> DiagnosticKind {
-        self.kind
-    }
-
-    pub fn info(&self) -> &DiagnosticInfo {
-        &self.info
+    pub fn kind(&self) -> &DiagnosticKind {
+        &self.kind
     }
 
     pub fn span(&self) -> &Range<usize> {
         &self.span
     }
+
+    pub fn is_error(&self) -> bool {
+        matches!(self.kind, DiagnosticKind::Error(_))
+    }
+
+    pub fn is_warning(&self) -> bool {
+        matches!(self.kind, DiagnosticKind::Warning(_))
+    }
 }
 
-#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
 pub enum DiagnosticKind {
-    Warning,
-    Error,
+    Warning(WarningKind),
+    Error(ErrorKind),
+}
+
+#[derive(Debug, Error, Clone, PartialEq, Eq, Hash)]
+pub enum WarningKind {
+    #[error("redundant optional type")]
+    RedundantOptional,
+
+    #[error("redundant check against same type `{0}`")]
+    RedundantTypeGuard(String),
 }
 
 #[derive(Debug, Error, Clone, PartialEq, Eq, Hash)]
-pub enum DiagnosticInfo {
+pub enum ErrorKind {
     #[error("missing `main` function")]
     MissingMain,
 
@@ -116,9 +128,6 @@ pub enum DiagnosticInfo {
     #[error("cannot check list against pair with types other than the list item type and the list itself")]
     NonListPairTypeGuard,
 
-    #[error("redundant check against same type `{0}`")]
-    RedundantTypeGuard(String),
-
     #[error("implicit return is not allowed in if statements, use an explicit return statement")]
     ImplicitReturnInIf,
 
@@ -130,9 +139,6 @@ pub enum DiagnosticInfo {
 
     #[error("cannot check equality on non-atom type `{0}`")]
     NonAtomEquality(String),
-
-    #[error("redundant optional type")]
-    RedundantOptional,
 }
 
 /// Join a list of names into a string, wrapped in backticks.
diff --git a/crates/rue-compiler/src/lib.rs b/crates/rue-compiler/src/lib.rs
index be8225c..cd1d350 100644
--- a/crates/rue-compiler/src/lib.rs
+++ b/crates/rue-compiler/src/lib.rs
@@ -53,8 +53,7 @@ pub fn compile(allocator: &mut Allocator, root: Root, parsing_succeeded: bool) -
 
     let Some(main_id) = db.scope_mut(scope_id).symbol("main") else {
         diagnostics.push(Diagnostic::new(
-            DiagnosticKind::Error,
-            DiagnosticInfo::MissingMain,
+            DiagnosticKind::Error(ErrorKind::MissingMain),
             0..0,
         ));
 
@@ -64,11 +63,7 @@ pub fn compile(allocator: &mut Allocator, root: Root, parsing_succeeded: bool) -
         };
     };
 
-    let node_ptr = if !diagnostics
-        .iter()
-        .any(|diagnostic| diagnostic.kind() == DiagnosticKind::Error)
-        && parsing_succeeded
-    {
+    let node_ptr = if !diagnostics.iter().any(Diagnostic::is_error) && parsing_succeeded {
         let mut optimizer = Optimizer::new(&mut db);
         let lir_id = optimizer.opt_main(main_id);
 
diff --git a/crates/rue-compiler/src/lowerer.rs b/crates/rue-compiler/src/lowerer.rs
index 23c382b..20c30e4 100644
--- a/crates/rue-compiler/src/lowerer.rs
+++ b/crates/rue-compiler/src/lowerer.rs
@@ -21,7 +21,7 @@ use crate::{
     scope::Scope,
     symbol::Symbol,
     ty::{EnumType, EnumVariant, FunctionType, Guard, StructType, Type, Value},
-    Diagnostic, DiagnosticInfo, DiagnosticKind,
+    Diagnostic, DiagnosticKind, ErrorKind, WarningKind,
 };
 
 /// Responsible for lowering the AST into the HIR.
@@ -226,7 +226,7 @@ impl<'a> Lowerer<'a> {
                 if i + 1 == len {
                     varargs = true;
                 } else {
-                    self.error(DiagnosticInfo::NonFinalSpread, param.syntax().text_range());
+                    self.error(ErrorKind::NonFinalSpread, param.syntax().text_range());
                 }
             }
         }
@@ -424,7 +424,7 @@ impl<'a> Lowerer<'a> {
             // If the variant is a duplicate, we don't want to overwrite the existing variant.
             if !visited_variants.insert(name.to_string()) {
                 self.error(
-                    DiagnosticInfo::DuplicateEnumVariant(name.to_string()),
+                    ErrorKind::DuplicateEnumVariant(name.to_string()),
                     name.text_range(),
                 );
                 continue;
@@ -523,7 +523,7 @@ impl<'a> Lowerer<'a> {
             // This could technically work but makes the intent of the code unclear.
             if !explicit_return {
                 self.error(
-                    DiagnosticInfo::ImplicitReturnInIf,
+                    ErrorKind::ImplicitReturnInIf,
                     then_block.syntax().text_range(),
                 );
             }
@@ -650,7 +650,7 @@ impl<'a> Lowerer<'a> {
 
         // Ensure that the block terminates.
         if !is_terminated {
-            self.error(DiagnosticInfo::EmptyBlock, block.syntax().text_range());
+            self.error(ErrorKind::EmptyBlock, block.syntax().text_range());
         }
 
         // Pop each statement in reverse order and mutate the body.
@@ -696,10 +696,7 @@ impl<'a> Lowerer<'a> {
         self.scope_stack.pop().unwrap();
 
         if explicit_return {
-            self.error(
-                DiagnosticInfo::ExplicitReturnInExpr,
-                block.syntax().text_range(),
-            );
+            self.error(ErrorKind::ExplicitReturnInExpr, block.syntax().text_range());
         }
 
         value
@@ -762,7 +759,7 @@ impl<'a> Lowerer<'a> {
             }
             Some(_) => {
                 self.error(
-                    DiagnosticInfo::UninitializableType(self.type_name(ty.unwrap())),
+                    ErrorKind::UninitializableType(self.type_name(ty.unwrap())),
                     initializer.path().unwrap().syntax().text_range(),
                 );
                 self.unknown()
@@ -804,12 +801,12 @@ impl<'a> Lowerer<'a> {
             // Insert the field if it exists and hasn't already been assigned.
             if specified_fields.contains_key(name.text()) {
                 self.error(
-                    DiagnosticInfo::DuplicateField(name.to_string()),
+                    ErrorKind::DuplicateField(name.to_string()),
                     name.text_range(),
                 );
             } else if !struct_fields.contains_key(name.text()) {
                 self.error(
-                    DiagnosticInfo::UndefinedField(name.to_string()),
+                    ErrorKind::UndefinedField(name.to_string()),
                     name.text_range(),
                 );
             } else {
@@ -825,7 +822,7 @@ impl<'a> Lowerer<'a> {
             .collect();
 
         if !missing_fields.is_empty() {
-            self.error(DiagnosticInfo::MissingFields(missing_fields), text_range);
+            self.error(ErrorKind::MissingFields(missing_fields), text_range);
         }
 
         let mut hir_id = self.nil_hir;
@@ -863,7 +860,7 @@ impl<'a> Lowerer<'a> {
                     Value::typed(self.compile_index(value.hir(), index, false), *field_type)
                 } else {
                     self.error(
-                        DiagnosticInfo::UndefinedField(field_name.to_string()),
+                        ErrorKind::UndefinedField(field_name.to_string()),
                         field_name.text_range(),
                     );
                     self.unknown()
@@ -874,7 +871,7 @@ impl<'a> Lowerer<'a> {
                 "rest" => Value::typed(self.db.alloc_hir(Hir::Rest(value.hir())), right),
                 _ => {
                     self.error(
-                        DiagnosticInfo::PairFieldAccess(field_name.to_string()),
+                        ErrorKind::PairFieldAccess(field_name.to_string()),
                         field_name.text_range(),
                     );
                     self.unknown()
@@ -886,7 +883,7 @@ impl<'a> Lowerer<'a> {
                 }
                 _ => {
                     self.error(
-                        DiagnosticInfo::BytesFieldAccess(field_name.to_string()),
+                        ErrorKind::BytesFieldAccess(field_name.to_string()),
                         field_name.text_range(),
                     );
                     self.unknown()
@@ -894,7 +891,7 @@ impl<'a> Lowerer<'a> {
             },
             _ => {
                 self.error(
-                    DiagnosticInfo::NonStructFieldAccess(self.type_name(value.ty())),
+                    ErrorKind::NonStructFieldAccess(self.type_name(value.ty())),
                     field_name.text_range(),
                 );
                 self.unknown()
@@ -917,7 +914,7 @@ impl<'a> Lowerer<'a> {
 
         let Type::List(item_type) = self.db.ty(value.ty()).clone() else {
             self.error(
-                DiagnosticInfo::IndexAccess(self.type_name(value.ty())),
+                ErrorKind::IndexAccess(self.type_name(value.ty())),
                 index_access.expr().unwrap().syntax().text_range(),
             );
             return self.unknown();
@@ -1018,7 +1015,7 @@ impl<'a> Lowerer<'a> {
                     || !self.is_atomic(rhs.ty(), &mut IndexSet::new())
                 {
                     self.error(
-                        DiagnosticInfo::NonAtomEquality(self.type_name(lhs.ty())),
+                        ErrorKind::NonAtomEquality(self.type_name(lhs.ty())),
                         text_range,
                     );
                 } else if self.is_assignable_to(lhs.ty(), self.nil_type, false, &mut HashSet::new())
@@ -1048,7 +1045,7 @@ impl<'a> Lowerer<'a> {
                     || !self.is_atomic(rhs.ty(), &mut IndexSet::new())
                 {
                     self.error(
-                        DiagnosticInfo::NonAtomEquality(self.type_name(lhs.ty())),
+                        ErrorKind::NonAtomEquality(self.type_name(lhs.ty())),
                         text_range,
                     );
                 } else if self.is_assignable_to(lhs.ty(), self.nil_type, false, &mut HashSet::new())
@@ -1181,7 +1178,7 @@ impl<'a> Lowerer<'a> {
     ) -> Option<(Guard, HirId)> {
         if self.types_equal(from, to) {
             self.warning(
-                DiagnosticInfo::RedundantTypeGuard(self.type_name(from)),
+                WarningKind::RedundantTypeGuard(self.type_name(from)),
                 text_range,
             );
             return Some((Guard::new(to, self.bool_type), hir_id));
@@ -1190,11 +1187,11 @@ impl<'a> Lowerer<'a> {
         match (self.db.ty(from).clone(), self.db.ty(to).clone()) {
             (Type::Any, Type::Pair(first, rest)) => {
                 if !self.types_equal(first, self.any_type) {
-                    self.error(DiagnosticInfo::NonAnyPairTypeGuard, text_range);
+                    self.error(ErrorKind::NonAnyPairTypeGuard, text_range);
                 }
 
                 if !self.types_equal(rest, self.any_type) {
-                    self.error(DiagnosticInfo::NonAnyPairTypeGuard, text_range);
+                    self.error(ErrorKind::NonAnyPairTypeGuard, text_range);
                 }
 
                 let hir_id = self.db.alloc_hir(Hir::IsCons(hir_id));
@@ -1208,11 +1205,11 @@ impl<'a> Lowerer<'a> {
             }
             (Type::List(inner), Type::Pair(first, rest)) => {
                 if !self.types_equal(first, inner) {
-                    self.error(DiagnosticInfo::NonListPairTypeGuard, text_range);
+                    self.error(ErrorKind::NonListPairTypeGuard, text_range);
                 }
 
                 if !self.types_equal(rest, from) {
-                    self.error(DiagnosticInfo::NonListPairTypeGuard, text_range);
+                    self.error(ErrorKind::NonListPairTypeGuard, text_range);
                 }
 
                 let hir_id = self.db.alloc_hir(Hir::IsCons(hir_id));
@@ -1246,7 +1243,7 @@ impl<'a> Lowerer<'a> {
             }
             _ => {
                 self.error(
-                    DiagnosticInfo::UnsupportedTypeGuard {
+                    ErrorKind::UnsupportedTypeGuard {
                         from: self.type_name(from),
                         to: self.type_name(to),
                     },
@@ -1325,7 +1322,7 @@ impl<'a> Lowerer<'a> {
                 if i + 1 == len {
                     nil_terminated = false;
                 } else {
-                    self.error(DiagnosticInfo::NonFinalSpread, spread.text_range());
+                    self.error(ErrorKind::NonFinalSpread, spread.text_range());
                 }
             }
 
@@ -1425,7 +1422,7 @@ impl<'a> Lowerer<'a> {
                 if i + 1 == len {
                     varargs = true;
                 } else {
-                    self.error(DiagnosticInfo::NonFinalSpread, param.syntax().text_range());
+                    self.error(ErrorKind::NonFinalSpread, param.syntax().text_range());
                 }
             }
         }
@@ -1578,7 +1575,7 @@ impl<'a> Lowerer<'a> {
         let mut idents = path.idents();
 
         if idents.len() > 1 {
-            self.error(DiagnosticInfo::PathNotAllowed, path.syntax().text_range());
+            self.error(ErrorKind::PathNotAllowed, path.syntax().text_range());
             return self.unknown();
         }
 
@@ -1591,7 +1588,7 @@ impl<'a> Lowerer<'a> {
             .find_map(|&scope_id| self.db.scope(scope_id).symbol(name.text()))
         else {
             self.error(
-                DiagnosticInfo::UndefinedReference(name.to_string()),
+                ErrorKind::UndefinedReference(name.to_string()),
                 name.text_range(),
             );
             return self.unknown();
@@ -1650,7 +1647,7 @@ impl<'a> Lowerer<'a> {
             Type::Function(function) => Some(function.clone()),
             _ => {
                 self.error(
-                    DiagnosticInfo::UncallableType(self.type_name(callee.ty())),
+                    ErrorKind::UncallableType(self.type_name(callee.ty())),
                     call.callee().unwrap().syntax().text_range(),
                 );
                 None
@@ -1685,7 +1682,7 @@ impl<'a> Lowerer<'a> {
                     spread = true;
                     continue;
                 } else {
-                    self.error(DiagnosticInfo::NonFinalSpread, arg.syntax().text_range());
+                    self.error(ErrorKind::NonFinalSpread, arg.syntax().text_range());
                 }
             }
 
@@ -1702,7 +1699,7 @@ impl<'a> Lowerer<'a> {
 
             if too_few_args || too_many_args {
                 self.error(
-                    DiagnosticInfo::ArgumentMismatch {
+                    ErrorKind::ArgumentMismatch {
                         expected: param_len,
                         found: arg_types.len(),
                     },
@@ -1713,7 +1710,7 @@ impl<'a> Lowerer<'a> {
             for (i, arg) in arg_types.into_iter().enumerate() {
                 if i + 1 == arg_len && spread && !expected.varargs() {
                     self.error(
-                        DiagnosticInfo::NonVarargSpread,
+                        ErrorKind::NonVarargSpread,
                         call.args()[i].syntax().text_range(),
                     );
                     continue;
@@ -1729,7 +1726,7 @@ impl<'a> Lowerer<'a> {
                         }
                         _ => {
                             self.error(
-                                DiagnosticInfo::NonListVararg,
+                                ErrorKind::NonListVararg,
                                 call.args()[i].syntax().text_range(),
                             );
                         }
@@ -1795,7 +1792,7 @@ impl<'a> Lowerer<'a> {
 
         let Some(mut ty) = ty else {
             self.error(
-                DiagnosticInfo::UndefinedType(name.to_string()),
+                ErrorKind::UndefinedType(name.to_string()),
                 name.text_range(),
             );
             return self.unknown_type;
@@ -1814,11 +1811,11 @@ impl<'a> Lowerer<'a> {
                 if let Some(&variant_type) = enum_type.variants().get(name) {
                     return variant_type;
                 }
-                self.error(DiagnosticInfo::UnknownEnumVariant(name.to_string()), range);
+                self.error(ErrorKind::UnknownEnumVariant(name.to_string()), range);
                 self.unknown_type
             }
             _ => {
-                self.error(DiagnosticInfo::PathIntoNonEnum(self.type_name(ty)), range);
+                self.error(ErrorKind::PathIntoNonEnum(self.type_name(ty)), range);
                 self.unknown_type
             }
         }
@@ -1865,7 +1862,7 @@ impl<'a> Lowerer<'a> {
                 if i + 1 == len {
                     vararg = true;
                 } else {
-                    self.error(DiagnosticInfo::NonFinalSpread, param.syntax().text_range());
+                    self.error(ErrorKind::NonFinalSpread, param.syntax().text_range());
                 }
             }
         }
@@ -1890,7 +1887,7 @@ impl<'a> Lowerer<'a> {
 
         if let Type::Optional(inner) = self.db.ty_raw(ty).clone() {
             self.warning(
-                DiagnosticInfo::RedundantOptional,
+                WarningKind::RedundantOptional,
                 optional.syntax().text_range(),
             );
             return inner;
@@ -1924,7 +1921,7 @@ impl<'a> Lowerer<'a> {
             Type::Function { .. } => false,
             Type::Alias(alias) => {
                 if !visited_aliases.insert(alias) {
-                    self.error(DiagnosticInfo::RecursiveTypeAlias, text_range);
+                    self.error(ErrorKind::RecursiveTypeAlias, text_range);
                     return true;
                 }
                 self.detect_cycle(alias, text_range, visited_aliases)
@@ -2028,7 +2025,7 @@ impl<'a> Lowerer<'a> {
     fn type_check(&mut self, from: TypeId, to: TypeId, range: TextRange) {
         if !self.is_assignable_to(from, to, false, &mut HashSet::new()) {
             self.error(
-                DiagnosticInfo::TypeMismatch {
+                ErrorKind::TypeMismatch {
                     expected: self.type_name(to),
                     found: self.type_name(from),
                 },
@@ -2040,7 +2037,7 @@ impl<'a> Lowerer<'a> {
     fn cast_check(&mut self, from: TypeId, to: TypeId, range: TextRange) {
         if !self.is_assignable_to(from, to, true, &mut HashSet::new()) {
             self.error(
-                DiagnosticInfo::CastMismatch {
+                ErrorKind::CastMismatch {
                     expected: self.type_name(to),
                     found: self.type_name(from),
                 },
@@ -2248,18 +2245,16 @@ impl<'a> Lowerer<'a> {
             .scope_mut(self.scope_stack.last().copied().expect("no scope found"))
     }
 
-    fn error(&mut self, info: DiagnosticInfo, range: TextRange) {
+    fn error(&mut self, info: ErrorKind, range: TextRange) {
         self.diagnostics.push(Diagnostic::new(
-            DiagnosticKind::Error,
-            info,
+            DiagnosticKind::Error(info),
             range.start().into()..range.end().into(),
         ));
     }
 
-    fn warning(&mut self, info: DiagnosticInfo, range: TextRange) {
+    fn warning(&mut self, info: WarningKind, range: TextRange) {
         self.diagnostics.push(Diagnostic::new(
-            DiagnosticKind::Warning,
-            info,
+            DiagnosticKind::Warning(info),
             range.start().into()..range.end().into(),
         ));
     }
diff --git a/crates/rue-lsp/src/main.rs b/crates/rue-lsp/src/main.rs
index 0363025..4234c52 100644
--- a/crates/rue-lsp/src/main.rs
+++ b/crates/rue-lsp/src/main.rs
@@ -76,15 +76,12 @@ impl Backend {
             let start = line_col(&text, error.span().start);
             let end = line_col(&text, error.span().end);
 
-            diagnostics.push(diagnostic(
-                start,
-                end,
-                format!("{}", error.info()),
-                match error.kind() {
-                    DiagnosticKind::Error => DiagnosticSeverity::ERROR,
-                    DiagnosticKind::Warning => DiagnosticSeverity::WARNING,
-                },
-            ));
+            let (message, severity) = match error.kind() {
+                DiagnosticKind::Error(kind) => (format!("{}", kind), DiagnosticSeverity::ERROR),
+                DiagnosticKind::Warning(kind) => (format!("{}", kind), DiagnosticSeverity::WARNING),
+            };
+
+            diagnostics.push(diagnostic(start, end, message, severity));
         }
 
         self.client
diff --git a/crates/rue-tests/src/main.rs b/crates/rue-tests/src/main.rs
index a426b1e..cd1d4f9 100644
--- a/crates/rue-tests/src/main.rs
+++ b/crates/rue-tests/src/main.rs
@@ -10,7 +10,7 @@ use clvmr::{
     Allocator, ChiaDialect,
 };
 use indexmap::{IndexMap, IndexSet};
-use rue_compiler::compile;
+use rue_compiler::{compile, DiagnosticKind};
 use rue_parser::{line_col, LineCol};
 use serde::{Deserialize, Serialize};
 use walkdir::WalkDir;
@@ -89,7 +89,10 @@ fn run_test(source: &str, input: &str) -> Result<TestOutput, TestErrors> {
             let LineCol { line, col } = line_col(source, error.span().start);
             let line = line + 1;
             let col = col + 1;
-            format!("{} at {line}:{col}", error.info())
+            match error.kind() {
+                DiagnosticKind::Error(kind) => format!("{} at {line}:{col}", kind),
+                DiagnosticKind::Warning(kind) => format!("{} at {line}:{col}", kind),
+            }
         })
         .collect();