Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tablegen: Add StaticSelect to select based on static condition #2206

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

Pangoraw
Copy link
Contributor

@Pangoraw Pangoraw commented Dec 24, 2024

cc @mofeing, we can collaborate to integrate in your fft pr.

I am open to other naming ideas.

string value = val;
}

class StaticIf<bit uses_primal, bit uses_shadow, string condition_> : Operation<uses_primal, uses_shadow> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think probably we should remove the explicit needs primal and shadow, and instead directly integrate this into the use analysis (lookup SelectIfActive in this case we wouldn’t do the check on if active but check on whatever the string is)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking into it, I could not find the custom use analysis for SelectIfActive, it seems that instead SelectIfActive has no analysis:

       SmallVector<bool> cachedArguments(Operation *op,
                                 MGradientUtilsReverse *gutils) const {
         SmallVector<bool> toret(op->getNumOperands(), false);
         for (size_t idx=0; idx<op->getNumOperands(); idx++) {
            bool used = false;
            // Rule (Op ?:$x, ?:$y)
            // Arg 0 : (SelectIfActive ?:$x, (CMul (DiffeRet), ?:$y), ?:$x)
            // Arg 1 : (DiffeRet)
            toret[idx] = used;
         }
         return toret;
       }

And it currently works because either it is used in the Forward diff which seem to rely on the analysis or the primals are used outside of the SelectIfActive node.

It seems that the assert here does not trigger due to the way enzyme-tblgen might be built?

So really, it should base its use on its arguments.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yeah we should add support there (actually for both static if and if active). The reason we never hit that branch is because use analysis is auto generated from reverse mode rules and I think we only use select if active in forward.

However here we really want to add the correct diff use analysis, since then we can for example use static if to define rules for reduce and static if it’s a known max min or add etc

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But also it should do so recursively for the operands of the select if active (depending on if active)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, the same analysis need to be used for SelectIfComplex as well (or it should be replaced with an instance of StaticIf)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean ideally we can define selectifcomplex in tablegen with selectif

@Pangoraw Pangoraw changed the title tablegen: Add StaticIf to select based on static condition tablegen: Add StaticSelect to select based on static condition Jan 5, 2025
@Pangoraw
Copy link
Contributor Author

Pangoraw commented Jan 5, 2025

I am trying to implement SelectIfActive based on StaticSelect as well but it looks like it has special handling for vectorization as well:

if (!vector && intrinsic != MLIRDerivatives) {
os << curIndent << INDENT << INDENT
<< "llvm::Value* vec_imVal = gutils->getWidth() == 1 ? imVal : "
"UndefValue::get(gutils->getShadowType(imVal"
<< "->getType()));\n";
os << curIndent << INDENT << INDENT
<< "if (gutils->getWidth() != 1)\n";
os << curIndent << INDENT << INDENT << INDENT
<< "for (size_t i=0; i<gutils->getWidth(); i++)\n";
os << curIndent << INDENT << INDENT << INDENT << INDENT
<< "vec_imVal = " << builder
<< ".CreateInsertValue(vec_imVal, imVal, "
"std::vector<unsigned>({(unsigned)i}));\n";
os << curIndent << INDENT << INDENT << "imVal = vec_imVal;\n";
}

I guess I should move this to StaticSelect as well

@wsmoses
Copy link
Member

wsmoses commented Jan 5, 2025

ah yeah that one might be hard. Yeah the return of handle is whether or not the value is batched already (e.g. for batched vector mode). The right thing to do is probably as follows.

Emit both the lhs and the rhs. If they have the same vector mode set (either both not vectored or both vectored), that's fine and return it.

If one has it set, upgrade the other one to vector mode (e.g. do that insertvalue related stuff), and return vector = true.

@wsmoses
Copy link
Member

wsmoses commented Jan 5, 2025

cc @jumerckx this will probably also need an MLIR version too at one point (which presumably calls the new broadcast op)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants