Skip to content

Commit

Permalink
cleanup, first helper
Browse files Browse the repository at this point in the history
  • Loading branch information
ZuseZ4 committed Dec 26, 2023
1 parent 12292ca commit 175d236
Show file tree
Hide file tree
Showing 6 changed files with 250 additions and 0 deletions.
156 changes: 156 additions & 0 deletions compiler/rustc_ast/src/expand/autodiff_attrs.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
use super::typetree::TypeTree;
use std::str::FromStr;
use rustc_data_structures::stable_hasher::{HashStable, StableHasher};//, StableOrd};
use crate::HashStableContext;

#[allow(dead_code)]
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug)]
pub enum DiffMode {
Inactive,
Source,
Forward,
Reverse,
}

#[allow(dead_code)]
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug)]
pub enum DiffActivity {
None,
Active,
Const,
Duplicated,
DuplicatedNoNeed,
}
fn clause_diffactivity_discriminant(value: &DiffActivity) -> usize {
match value {
DiffActivity::None => 0,
DiffActivity::Active => 1,
DiffActivity::Const => 2,
DiffActivity::Duplicated => 3,
DiffActivity::DuplicatedNoNeed => 4,
}
}
fn clause_diffmode_discriminant(value: &DiffMode) -> usize {
match value {
DiffMode::Inactive => 0,
DiffMode::Source => 1,
DiffMode::Forward => 2,
DiffMode::Reverse => 3,
}
}


impl<CTX: HashStableContext> HashStable<CTX> for DiffMode {
fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) {
clause_diffmode_discriminant(self).hash_stable(hcx, hasher);
}
}

impl<CTX: HashStableContext> HashStable<CTX> for DiffActivity {
fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) {
clause_diffactivity_discriminant(self).hash_stable(hcx, hasher);
}
}


impl FromStr for DiffActivity {
type Err = ();

fn from_str(s: &str) -> Result<DiffActivity, ()> {
match s {
"None" => Ok(DiffActivity::None),
"Active" => Ok(DiffActivity::Active),
"Const" => Ok(DiffActivity::Const),
"Duplicated" => Ok(DiffActivity::Duplicated),
"DuplicatedNoNeed" => Ok(DiffActivity::DuplicatedNoNeed),
_ => Err(()),
}
}
}

#[allow(dead_code)]
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug)]
pub struct AutoDiffAttrs {
pub mode: DiffMode,
pub ret_activity: DiffActivity,
pub input_activity: Vec<DiffActivity>,
}

impl<CTX: HashStableContext> HashStable<CTX> for AutoDiffAttrs {
fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) {
self.mode.hash_stable(hcx, hasher);
self.ret_activity.hash_stable(hcx, hasher);
self.input_activity.hash_stable(hcx, hasher);
}
}

impl AutoDiffAttrs {
pub fn inactive() -> Self {
AutoDiffAttrs {
mode: DiffMode::Inactive,
ret_activity: DiffActivity::None,
input_activity: Vec::new(),
}
}

pub fn is_active(&self) -> bool {
match self.mode {
DiffMode::Inactive => false,
_ => {
dbg!(&self);
true
},
}
}

pub fn is_source(&self) -> bool {
dbg!(&self);
match self.mode {
DiffMode::Source => true,
_ => false,
}
}
pub fn apply_autodiff(&self) -> bool {
match self.mode {
DiffMode::Inactive => false,
DiffMode::Source => false,
_ => {
dbg!(&self);
true
},
}
}

pub fn into_item(
self,
source: String,
target: String,
inputs: Vec<TypeTree>,
output: TypeTree,
) -> AutoDiffItem {
dbg!(&self);
AutoDiffItem { source, target, inputs, output, attrs: self }
}
}

#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct AutoDiffItem {
pub source: String,
pub target: String,
pub attrs: AutoDiffAttrs,
pub inputs: Vec<TypeTree>,
pub output: TypeTree,
}

//impl<CTX: HashStableContext> HashStable<CTX> for AutoDiffItem {
// fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) {
// self.source.hash_stable(hcx, hasher);
// self.target.hash_stable(hcx, hasher);
// self.attrs.hash_stable(hcx, hasher);
// for tt in &self.inputs {
// tt.0.hash_stable(hcx, hasher);
// }
// //self.inputs.hash_stable(hcx, hasher);
// self.output.0.hash_stable(hcx, hasher);
// }
//}
68 changes: 68 additions & 0 deletions compiler/rustc_ast/src/expand/typetree.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
use std::fmt;
//use rustc_data_structures::stable_hasher::{HashStable};//, StableHasher};
//use crate::HashStableContext;


#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub enum Kind {
Anything,
Integer,
Pointer,
Half,
Float,
Double,
Unknown,
}
//impl<CTX: HashStableContext> HashStable<CTX> for Kind {
// fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) {
// clause_kind_discriminant(self).hash_stable(hcx, hasher);
// }
//}
//fn clause_kind_discriminant(value: &Kind) -> usize {
// match value {
// Kind::Anything => 0,
// Kind::Integer => 1,
// Kind::Pointer => 2,
// Kind::Half => 3,
// Kind::Float => 4,
// Kind::Double => 5,
// Kind::Unknown => 6,
// }
//}

#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct TypeTree(pub Vec<Type>);

#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct Type {
pub offset: isize,
pub size: usize,
pub kind: Kind,
pub child: TypeTree,
}

//impl<CTX: HashStableContext> HashStable<CTX> for Type {
// fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) {
// self.offset.hash_stable(hcx, hasher);
// self.size.hash_stable(hcx, hasher);
// self.kind.hash_stable(hcx, hasher);
// self.child.0.hash_stable(hcx, hasher);
// }
//}

impl Type {
pub fn add_offset(self, add: isize) -> Self {
let offset = match self.offset {
-1 => add,
x => add + x,
};

Self { size: self.size, kind: self.kind, child: self.child, offset }
}
}

impl fmt::Display for Type {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
<Self as fmt::Debug>::fmt(self, f)
}
}
1 change: 1 addition & 0 deletions compiler/rustc_builtin_macros/src/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

use crate::errors;
//use crate::util::check_builtin_macro_attribute;
//use crate::util::check_autodiff;

use rustc_ast::ptr::P;
use rustc_ast::{self as ast, FnHeader, FnSig, Generics, StmtKind};
Expand Down
4 changes: 4 additions & 0 deletions compiler/rustc_passes/messages.ftl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ passes_abi_ne =
passes_abi_of =
fn_abi_of({$fn_name}) = {$fn_abi}
passes_autodiff_attr =
`#[autodiff]` should be applied to a function
.label = not a function
passes_allow_incoherent_impl =
`rustc_allow_incoherent_impl` attribute should be applied to impl items.
.label = the only currently supported targets are inherent methods
Expand Down
13 changes: 13 additions & 0 deletions compiler/rustc_passes/src/check_attr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ impl CheckAttrVisitor<'_> {
self.check_generic_attr(hir_id, attr, target, Target::Fn);
self.check_proc_macro(hir_id, target, ProcMacroKind::Derive)
}
sym::autodiff => self.check_autodiff(hir_id, attr, span, target),
_ => {}
}

Expand Down Expand Up @@ -2382,6 +2383,18 @@ impl CheckAttrVisitor<'_> {
self.abort.set(true);
}
}

/// Checks if `#[autodiff]` is applied to an item other than a function item.
fn check_autodiff(&self, _hir_id: HirId, _attr: &Attribute, span: Span, target: Target) {
dbg!("check_autodiff");
match target {
Target::Fn => {}
_ => {
self.tcx.sess.emit_err(errors::AutoDiffAttr { attr_span: span });
self.abort.set(true);
}
}
}
}

impl<'tcx> Visitor<'tcx> for CheckAttrVisitor<'tcx> {
Expand Down
8 changes: 8 additions & 0 deletions compiler/rustc_passes/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ pub struct IncorrectDoNotRecommendLocation {
pub span: Span,
}

#[derive(Diagnostic)]
#[diag(passes_autodiff_attr)]
pub struct AutoDiffAttr {
#[primary_span]
#[label]
pub attr_span: Span,
}

#[derive(LintDiagnostic)]
#[diag(passes_outer_crate_level_attr)]
pub struct OuterCrateLevelAttr;
Expand Down

0 comments on commit 175d236

Please sign in to comment.