-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* With mosaic * Extend pipeline capabilities * fixups * continuing * continuing * Update WORKSPACE * fix * bump commits * working * format * fix * fix * cleanup print * format * fixup * hash * fix * bump enzyme commit * fixup * fixup * fixup * Add dot [fwd] * fix * add abi tests
- Loading branch information
Showing
26 changed files
with
1,536 additions
and
473 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,11 @@ | ||
--- a/jaxlib/cpu/BUILD | ||
+++ a/jaxlib/cpu/BUILD | ||
@@ -79,7 +79,7 @@ cc_library( | ||
":ducc_fft_flatbuffers_cc", | ||
"@xla//xla/service:custom_call_status", | ||
"@com_github_google_flatbuffers//:flatbuffers", | ||
- "@ducc//:fft", | ||
+ "@ducc//:fft_wrapper", | ||
--- a/jaxlib/mosaic/BUILD | ||
+++ b/jaxlib/mosaic/BUILD | ||
@@ -20,7 +20,7 @@ licenses(["notice"]) | ||
package( | ||
default_applicable_licenses = [], | ||
default_visibility = [ | ||
- "//:__subpackages__", | ||
+ "//visibility:public", | ||
], | ||
) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
class InactiveOp<string dialect_, string opName_> { | ||
string dialect = dialect_; | ||
string opName = opName_; | ||
} | ||
|
||
class AllocationOp<string dialect_, string opName_> { | ||
string dialect = dialect_; | ||
string opName = opName_; | ||
} | ||
|
||
class ControlFlowOp<string dialect_, string opName_, string impl_> { | ||
string dialect = dialect_; | ||
string opName = opName_; | ||
string impl = impl_; | ||
} | ||
|
||
class MemoryIdentityOp<string dialect_, string opName_, list<int> ptrargs_, list<int> storedargs_ = []> { | ||
string dialect = dialect_; | ||
string opName = opName_; | ||
list<int> ptrargs = ptrargs_; | ||
list<int> storedargs = storedargs_; | ||
} | ||
|
||
class ReadOnlyIdentityOp<string dialect_, string opName_, list<int> ptrargs_> : MemoryIdentityOp<dialect_, opName_, ptrargs_>; | ||
|
||
class BranchOp<string dialect_, string opName_> { | ||
string dialect = dialect_; | ||
string opName = opName_; | ||
} | ||
|
||
class RegionTerminatorOp<string dialect_, string opName_> { | ||
string dialect = dialect_; | ||
string opName = opName_; | ||
} | ||
|
||
class ForwardFromSummedReverseInternal<int unused_> { | ||
int unused = unused_; | ||
} | ||
def ForwardFromSummedReverse : ForwardFromSummedReverseInternal<0>; | ||
|
||
|
||
class MLIRDerivative<string dialect_, string opName_, dag patternToMatch, list<dag> resultOps, dag forwardOps=(ForwardFromSummedReverse)> { | ||
string dialect = dialect_; | ||
string opName = opName_; | ||
dag PatternToMatch = patternToMatch; | ||
list<dag> ArgDerivatives = resultOps; | ||
dag ArgDuals = forwardOps; | ||
} | ||
|
||
class Operation<bit usesPrimal_, bit usesShadow_, bit usesCustom_=0> { | ||
bit usesPrimal = usesPrimal_; | ||
bit usesShadow = usesShadow_; | ||
bit usesCustom = usesCustom_; | ||
} | ||
|
||
class DiffeRetIndex<list<int> indices_> { | ||
list<int> indices = indices_; | ||
} | ||
def DiffeRet : DiffeRetIndex<[-1]>; | ||
|
||
def Shadow : Operation</*primal*/0, /*shadow*/1> { | ||
} | ||
|
||
class GlobalExpr<bit uses_primal, bit uses_shadow, string val> : Operation<uses_primal, uses_shadow>{ | ||
string value = val; | ||
} | ||
|
||
class Inst<string mnemonic, string dialect_> : Operation</*primal*/1, /*shadow*/0> { | ||
string name = mnemonic; | ||
string dialect = dialect_; | ||
} | ||
|
||
class ConstantFP<string val, string dialect_, string op_, string type_=""> : Operation</*primal*/0, /*shadow*/0> { | ||
string value = val; | ||
string dialect = dialect_; | ||
string opName = op_; | ||
string type = type_; | ||
} | ||
|
||
def SelectIfActive : Operation</*primal*/0, /*shadow*/0, /*custom*/1> { | ||
|
||
} | ||
|
||
def Op { | ||
} | ||
|
||
def ResultTypes : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, "op->getResultTypes()">; | ||
|
||
|
Oops, something went wrong.