forked from rust-lang/rust
-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
250 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
use super::typetree::TypeTree; | ||
use std::str::FromStr; | ||
use rustc_data_structures::stable_hasher::{HashStable, StableHasher};//, StableOrd}; | ||
use crate::HashStableContext; | ||
|
||
#[allow(dead_code)] | ||
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug)] | ||
pub enum DiffMode { | ||
Inactive, | ||
Source, | ||
Forward, | ||
Reverse, | ||
} | ||
|
||
#[allow(dead_code)] | ||
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug)] | ||
pub enum DiffActivity { | ||
None, | ||
Active, | ||
Const, | ||
Duplicated, | ||
DuplicatedNoNeed, | ||
} | ||
fn clause_diffactivity_discriminant(value: &DiffActivity) -> usize { | ||
match value { | ||
DiffActivity::None => 0, | ||
DiffActivity::Active => 1, | ||
DiffActivity::Const => 2, | ||
DiffActivity::Duplicated => 3, | ||
DiffActivity::DuplicatedNoNeed => 4, | ||
} | ||
} | ||
fn clause_diffmode_discriminant(value: &DiffMode) -> usize { | ||
match value { | ||
DiffMode::Inactive => 0, | ||
DiffMode::Source => 1, | ||
DiffMode::Forward => 2, | ||
DiffMode::Reverse => 3, | ||
} | ||
} | ||
|
||
|
||
impl<CTX: HashStableContext> HashStable<CTX> for DiffMode { | ||
fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) { | ||
clause_diffmode_discriminant(self).hash_stable(hcx, hasher); | ||
} | ||
} | ||
|
||
impl<CTX: HashStableContext> HashStable<CTX> for DiffActivity { | ||
fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) { | ||
clause_diffactivity_discriminant(self).hash_stable(hcx, hasher); | ||
} | ||
} | ||
|
||
|
||
impl FromStr for DiffActivity { | ||
type Err = (); | ||
|
||
fn from_str(s: &str) -> Result<DiffActivity, ()> { | ||
match s { | ||
"None" => Ok(DiffActivity::None), | ||
"Active" => Ok(DiffActivity::Active), | ||
"Const" => Ok(DiffActivity::Const), | ||
"Duplicated" => Ok(DiffActivity::Duplicated), | ||
"DuplicatedNoNeed" => Ok(DiffActivity::DuplicatedNoNeed), | ||
_ => Err(()), | ||
} | ||
} | ||
} | ||
|
||
#[allow(dead_code)] | ||
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug)] | ||
pub struct AutoDiffAttrs { | ||
pub mode: DiffMode, | ||
pub ret_activity: DiffActivity, | ||
pub input_activity: Vec<DiffActivity>, | ||
} | ||
|
||
impl<CTX: HashStableContext> HashStable<CTX> for AutoDiffAttrs { | ||
fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) { | ||
self.mode.hash_stable(hcx, hasher); | ||
self.ret_activity.hash_stable(hcx, hasher); | ||
self.input_activity.hash_stable(hcx, hasher); | ||
} | ||
} | ||
|
||
impl AutoDiffAttrs { | ||
pub fn inactive() -> Self { | ||
AutoDiffAttrs { | ||
mode: DiffMode::Inactive, | ||
ret_activity: DiffActivity::None, | ||
input_activity: Vec::new(), | ||
} | ||
} | ||
|
||
pub fn is_active(&self) -> bool { | ||
match self.mode { | ||
DiffMode::Inactive => false, | ||
_ => { | ||
dbg!(&self); | ||
true | ||
}, | ||
} | ||
} | ||
|
||
pub fn is_source(&self) -> bool { | ||
dbg!(&self); | ||
match self.mode { | ||
DiffMode::Source => true, | ||
_ => false, | ||
} | ||
} | ||
pub fn apply_autodiff(&self) -> bool { | ||
match self.mode { | ||
DiffMode::Inactive => false, | ||
DiffMode::Source => false, | ||
_ => { | ||
dbg!(&self); | ||
true | ||
}, | ||
} | ||
} | ||
|
||
pub fn into_item( | ||
self, | ||
source: String, | ||
target: String, | ||
inputs: Vec<TypeTree>, | ||
output: TypeTree, | ||
) -> AutoDiffItem { | ||
dbg!(&self); | ||
AutoDiffItem { source, target, inputs, output, attrs: self } | ||
} | ||
} | ||
|
||
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] | ||
pub struct AutoDiffItem { | ||
pub source: String, | ||
pub target: String, | ||
pub attrs: AutoDiffAttrs, | ||
pub inputs: Vec<TypeTree>, | ||
pub output: TypeTree, | ||
} | ||
|
||
//impl<CTX: HashStableContext> HashStable<CTX> for AutoDiffItem { | ||
// fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) { | ||
// self.source.hash_stable(hcx, hasher); | ||
// self.target.hash_stable(hcx, hasher); | ||
// self.attrs.hash_stable(hcx, hasher); | ||
// for tt in &self.inputs { | ||
// tt.0.hash_stable(hcx, hasher); | ||
// } | ||
// //self.inputs.hash_stable(hcx, hasher); | ||
// self.output.0.hash_stable(hcx, hasher); | ||
// } | ||
//} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
use std::fmt; | ||
//use rustc_data_structures::stable_hasher::{HashStable};//, StableHasher}; | ||
//use crate::HashStableContext; | ||
|
||
|
||
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] | ||
pub enum Kind { | ||
Anything, | ||
Integer, | ||
Pointer, | ||
Half, | ||
Float, | ||
Double, | ||
Unknown, | ||
} | ||
//impl<CTX: HashStableContext> HashStable<CTX> for Kind { | ||
// fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) { | ||
// clause_kind_discriminant(self).hash_stable(hcx, hasher); | ||
// } | ||
//} | ||
//fn clause_kind_discriminant(value: &Kind) -> usize { | ||
// match value { | ||
// Kind::Anything => 0, | ||
// Kind::Integer => 1, | ||
// Kind::Pointer => 2, | ||
// Kind::Half => 3, | ||
// Kind::Float => 4, | ||
// Kind::Double => 5, | ||
// Kind::Unknown => 6, | ||
// } | ||
//} | ||
|
||
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] | ||
pub struct TypeTree(pub Vec<Type>); | ||
|
||
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] | ||
pub struct Type { | ||
pub offset: isize, | ||
pub size: usize, | ||
pub kind: Kind, | ||
pub child: TypeTree, | ||
} | ||
|
||
//impl<CTX: HashStableContext> HashStable<CTX> for Type { | ||
// fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) { | ||
// self.offset.hash_stable(hcx, hasher); | ||
// self.size.hash_stable(hcx, hasher); | ||
// self.kind.hash_stable(hcx, hasher); | ||
// self.child.0.hash_stable(hcx, hasher); | ||
// } | ||
//} | ||
|
||
impl Type { | ||
pub fn add_offset(self, add: isize) -> Self { | ||
let offset = match self.offset { | ||
-1 => add, | ||
x => add + x, | ||
}; | ||
|
||
Self { size: self.size, kind: self.kind, child: self.child, offset } | ||
} | ||
} | ||
|
||
impl fmt::Display for Type { | ||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||
<Self as fmt::Debug>::fmt(self, f) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters