Skip to content

[WIP] TypeTree support in autodiff #143490

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

KMJ-007
Copy link
Contributor

@KMJ-007 KMJ-007 commented Jul 5, 2025

This PR starts the migration to full type tree support for autodiff in Rust

@rustbot rustbot added F-autodiff `#![feature(autodiff)]` T-compiler Relevant to the compiler team, which will review and decide on the PR/issue. labels Jul 5, 2025
@rust-log-analyzer

This comment has been minimized.

@rustbot rustbot added the A-LLVM Area: Code generation parts specific to LLVM. Both correctness bugs and optimization-related issues. label Jul 5, 2025
@rust-log-analyzer

This comment has been minimized.

@rust-cloud-vms rust-cloud-vms bot force-pushed the type-trees-enzyme branch from ba7cd1e to a165957 Compare July 5, 2025 13:51
@rust-log-analyzer
Copy link
Collaborator

The job tidy failed! Check out the build log: (web) (plain enhanced) (plain)

Click to see the possible cause of the failure (guessed by this bot)
 use crate::typetree::to_enzyme_typetree;
 use crate::value::Value;
-use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm, DiffTypeTree};
-use rustc_data_structures::fx::FxHashMap;
+use crate::{CodegenContext, DiffTypeTree, LlvmCodegenBackend, ModuleLlvm, attributes, llvm};
 
 fn get_params(fnc: &Value) -> Vec<&Value> {
     let param_num = llvm::LLVMCountParams(fnc) as usize;
Diff in /checkout/compiler/rustc_codegen_llvm/src/builder/autodiff.rs:521:
 
         // Use type trees from the typetrees map if available, otherwise construct from item
         let fnc_tree = if let Some(diff_tt) = typetrees.get(&item.source) {
-            Some(FncTree {
-                inputs: diff_tt.input_tt.clone(),
-                output: diff_tt.ret_tt.clone(),
-            })
+            Some(FncTree { inputs: diff_tt.input_tt.clone(), output: diff_tt.ret_tt.clone() })
         } else if !item.inputs.is_empty() || !item.output.0.is_empty() {
-            Some(FncTree {
-                inputs: item.inputs.clone(),
-                output: item.output.clone(),
-            })
+            Some(FncTree { inputs: item.inputs.clone(), output: item.output.clone() })
         } else {
             None
         };
Diff in /checkout/compiler/rustc_codegen_llvm/src/typetree.rs:1:
-use crate::llvm;
 use rustc_ast::expand::typetree::{Kind, TypeTree};
 
+use crate::llvm;
+
 pub fn to_enzyme_typetree(
     tree: TypeTree,
     llvm_data_layout: &str,
Diff in /checkout/compiler/rustc_codegen_llvm/src/typetree.rs:30:
             obj.merge(tt)
         }
     })
-} 
+}
+
Diff in /checkout/compiler/rustc_builtin_macros/src/typetree.rs:1:
 use rustc_ast as ast;
 use rustc_ast::FnRetTy;
-use rustc_ast::expand::typetree::{Type, Kind, TypeTree, FncTree};
-use rustc_middle::ty::{Ty, TyCtxt, ParamEnv, ParamEnvAnd, Adt};
-use rustc_middle::ty::layout::{FieldsShape, LayoutOf};
+use rustc_ast::expand::autodiff_attrs::DiffActivity;
+use rustc_ast::expand::typetree::{FncTree, Kind, Type, TypeTree};
 use rustc_middle::hir;
+use rustc_middle::ty::layout::{FieldsShape, LayoutOf};
+use rustc_middle::ty::{Adt, ParamEnv, ParamEnvAnd, Ty, TyCtxt};
 use rustc_span::Span;
-use rustc_ast::expand::autodiff_attrs::DiffActivity;
 
 #[cfg(llvm_enzyme)]
 pub fn typetree_from<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
Diff in /checkout/compiler/rustc_builtin_macros/src/typetree.rs:27:
 //    yet. We should add such analysis to relibably either issue an error or accept without warning.
 //    If there only were some reasearch to do that...
 #[cfg(llvm_enzyme)]
-pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec<DiffActivity>, span: Option<Span>) -> FncTree {
+pub fn fnc_typetrees<'tcx>(
+    tcx: TyCtxt<'tcx>,
+    fn_ty: Ty<'tcx>,
+    da: &mut Vec<DiffActivity>,
+    span: Option<Span>,
+) -> FncTree {
     if !fn_ty.is_fn() {
         return FncTree { args: vec![], ret: TypeTree::new() };
     }
Diff in /checkout/compiler/rustc_builtin_macros/src/typetree.rs:70:
                 let child = typetree_from_ty(inner_ty, tcx, 1, safety, &mut visited, span);
                 let tt = Type { offset: -1, kind: Kind::Pointer, size: 8, child };
                 args.push(TypeTree(vec![tt]));
-                let i64_tt = Type { offset: -1, kind: Kind::Integer, size: 8, child: TypeTree::new() };
+                let i64_tt =
+                    Type { offset: -1, kind: Kind::Integer, size: 8, child: TypeTree::new() };
                 args.push(TypeTree(vec![i64_tt]));
                 if !da.is_empty() {
                     // We are looking at a slice. The length of that slice will become an
Diff in /checkout/compiler/rustc_builtin_macros/src/typetree.rs:78:
                     // However, if the slice get's duplicated, we want to know to later check the
                     // size. So we mark the new size argument as FakeActivitySize.
                     let activity = match da[i] {
-                        DiffActivity::DualOnly | DiffActivity::Dual |
-                            DiffActivity::DuplicatedOnly | DiffActivity::Duplicated
-                            => DiffActivity::FakeActivitySize,
+                        DiffActivity::DualOnly
+                        | DiffActivity::Dual
+                        | DiffActivity::DuplicatedOnly
+                        | DiffActivity::Duplicated => DiffActivity::FakeActivitySize,
                         DiffActivity::Const => DiffActivity::Const,
                         _ => panic!("unexpected activity for ptr/ref"),
                     };
Diff in /checkout/compiler/rustc_builtin_macros/src/typetree.rs:109:
     FncTree { args, ret }
 }
 
-
 // Error type for warnings
 #[derive(Debug)]
 pub struct AutodiffUnsafeInnerConstRef {
Diff in /checkout/compiler/rustc_builtin_macros/src/typetree.rs:118:
 }
 
 #[cfg(llvm_enzyme)]
-fn typetree_from_ty<'a>(ty: Ty<'a>, tcx: TyCtxt<'a>, depth: usize, safety: bool, visited: &mut Vec<Ty<'a>>, span: Option<Span>) -> TypeTree {
+fn typetree_from_ty<'a>(
+    ty: Ty<'a>,
+    tcx: TyCtxt<'a>,
+    depth: usize,
+    safety: bool,
+    visited: &mut Vec<Ty<'a>>,
+    span: Option<Span>,
+) -> TypeTree {
     if depth > 20 {
         trace!("depth > 20 for ty: {}", &ty);
     }
Diff in /checkout/compiler/rustc_builtin_macros/src/typetree.rs:147:
             } else {
                 assert!(ty.is_box());
                 "box"
-            }.to_string();
+            }
+            .to_string();
 
             // If we have mutability, we also have a span
             assert!(span.is_some());
Diff in /checkout/compiler/rustc_builtin_macros/src/typetree.rs:154:
             let span = span.unwrap();
 
-            tcx.sess
-            .dcx()
-            .emit_warning(AutodiffUnsafeInnerConstRef{span, ty: ptr_ty});
+            tcx.sess.dcx().emit_warning(AutodiffUnsafeInnerConstRef { span, ty: ptr_ty });
         }
 
         let child = typetree_from_ty(inner_ty, tcx, depth + 1, safety, visited, span);
Diff in /checkout/compiler/rustc_builtin_macros/src/typetree.rs:201:
             let (offsets, _memory_index) = match fields {
                 // Manuel TODO:
                 FieldsShape::Arbitrary { offsets: o, memory_index: m } => (o, m),
-                FieldsShape::Array { .. } => {return TypeTree::new();}, //e.g. core::arch::x86_64::__m128i, TODO: later
-                FieldsShape::Union(_) => {return TypeTree::new();},
-                FieldsShape::Primitive => {return TypeTree::new();},
+                FieldsShape::Array { .. } => {
+                    return TypeTree::new();
+                } //e.g. core::arch::x86_64::__m128i, TODO: later
+                FieldsShape::Union(_) => {
+                    return TypeTree::new();
+                }
+                FieldsShape::Primitive => {
+                    return TypeTree::new();
+                }
             };
 
             let substs = match ty.kind() {
Diff in /checkout/compiler/rustc_builtin_macros/src/typetree.rs:224:
                         return None;
                     }
 
-                    let mut child = typetree_from_ty(field_ty, tcx, depth + 1, safety, visited, span).0;
+                    let mut child =
+                        typetree_from_ty(field_ty, tcx, depth + 1, safety, visited, span).0;
 
                     for c in &mut child {
                         if c.offset == -1 {
Diff in /checkout/compiler/rustc_builtin_macros/src/typetree.rs:317:
 pub fn construct_typetree_from_fnsig(sig: &ast::FnSig) -> (Vec<TypeTree>, TypeTree) {
     // For now, return empty type trees
     // This will be replaced with proper layout-based construction
-    let inputs: Vec<TypeTree> = sig.decl.inputs.iter()
-        .map(|_| TypeTree::new())
-        .collect();
-    
+    let inputs: Vec<TypeTree> = sig.decl.inputs.iter().map(|_| TypeTree::new()).collect();
+
     let output = match &sig.decl.output {
         FnRetTy::Default(_) => TypeTree::new(),
         FnRetTy::Ty(_) => TypeTree::new(),
Diff in /checkout/compiler/rustc_builtin_macros/src/typetree.rs:327:
     };
-    
+
     (inputs, output)
 }
 
Diff in /checkout/compiler/rustc_builtin_macros/src/autodiff.rs:11:
         AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ret_activity,
         valid_ty_for_activity,
     };
-    use rustc_ast::expand::typetree::{TypeTree, Type, Kind};
+    use rustc_ast::expand::typetree::{Kind, Type, TypeTree};
     use rustc_ast::ptr::P;
-    use crate::typetree::construct_typetree_from_fnsig;
     use rustc_ast::token::{Lit, LitKind, Token, TokenKind};
     use rustc_ast::tokenstream::*;
     use rustc_ast::visit::AssocCtxt::*;
Diff in /checkout/compiler/rustc_builtin_macros/src/autodiff.rs:27:
     use tracing::{debug, trace};
 
     use crate::errors;
+    use crate::typetree::construct_typetree_from_fnsig;
 
     pub(crate) fn outer_normal_attr(
         kind: &P<rustc_ast::NormalAttr>,
Diff in /checkout/compiler/rustc_builtin_macros/src/autodiff.rs:328:
 
         // Construct real type trees from function signature
         let (inputs, output) = construct_typetree_from_fnsig(&sig);
-        
+
         // Use the new into_item method to construct the AutoDiffItem
         let autodiff_item = x.clone().into_item(
             primal.to_string(),
Diff in /checkout/compiler/rustc_ast/src/expand/autodiff_attrs.rs:6:
 use std::fmt::{self, Display, Formatter};
 use std::str::FromStr;
 
+use crate::expand::typetree::TypeTree;
 use crate::expand::{Decodable, Encodable, HashStable_Generic};
 use crate::ptr::P;
 use crate::{Ty, TyKind};
Diff in /checkout/compiler/rustc_ast/src/expand/autodiff_attrs.rs:12:
-use crate::expand::typetree::TypeTree;
 
 /// Forward and Reverse Mode are well known names for automatic differentiation implementations.
 /// Enzyme does support both, but with different semantics, see DiffActivity. The First variants
Diff in /checkout/compiler/rustc_ast/src/expand/autodiff_attrs.rs:117:
         matches!(self.ret_activity, DiffActivity::Active | DiffActivity::Dual)
     }
     /// New constructor for type tree support
-    pub fn into_item(self, source: String, target: String, inputs: Vec<TypeTree>, output: TypeTree) -> AutoDiffItem {
+    pub fn into_item(
+        self,
+        source: String,
+        target: String,
+        inputs: Vec<TypeTree>,
+        output: TypeTree,
+    ) -> AutoDiffItem {
         AutoDiffItem { source, target, attrs: self, inputs, output }
     }
 }
fmt: checked 6148 files
Build completed unsuccessfully in 0:00:46
  local time: Sat Jul  5 13:57:17 UTC 2025

@KMJ-007
Copy link
Contributor Author

KMJ-007 commented Jul 5, 2025

r? @ZuseZ4

@ZuseZ4
Copy link
Member

ZuseZ4 commented Jul 5, 2025

thanks! It looks like you haven't ported everything over yet, but let me know if you get stuck somewhere.
I would recommend to aim to get an MVP to work where you add typetrees in one location, like e.g. memcpy, to prove that the code works. Then we can talk with the others to see where to go from here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
A-LLVM Area: Code generation parts specific to LLVM. Both correctness bugs and optimization-related issues. F-autodiff `#![feature(autodiff)]` T-compiler Relevant to the compiler team, which will review and decide on the PR/issue.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants