Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored and mofeing committed Jan 10, 2025
1 parent 08a05ee commit e6dd644
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/enzyme_ad/jax/Implementations/HLODerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -893,7 +893,7 @@ def FftLength : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{

def FftMultiplier : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
auto op_type = op->getResult(0).getType().cast<RankedTensorType>();
auto lengths = op.getFftLength();
auto lengths = op.getFftLengthAttr().getValues<int64_t>();
auto N = std::accumulate(lengths.begin(), lengths.end(), llvm::APInt(64, 1, true), std::multiplies{}).getSExtValue();

double value = N;
Expand All @@ -919,15 +919,15 @@ def FftMultiplier : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
auto zero_constant = builder.create<ConstantOp>(op.getLoc(), SplatElementsAttr::get(
RT, FloatAttr::get(resTy.getElementType(), 0)));
auto end_constant = builder.create<ConstantOp>(op.getLoc(), SplatElementsAttr::get(
RT, FloatAttr::get(resTy.getElementType(), lengths.back()-1)));
RT, FloatAttr::get(resTy.getElementType(), lengths[lengths.size()-1]-1)));

auto RT64 = RankedTensorType::get({1}, builder.getIntegerType(64));

Value start[] = {
builder.create<stablehlo::ConstantOp>(op.getLoc(), SplatElementsAttr::get(RT64, op.getI64IntegerAttr(0)))
builder.create<stablehlo::ConstantOp>(op.getLoc(), SplatElementsAttr::get(RT64, rewriter.getI64IntegerAttr(0)))
};
Value end[] = {
builder.create<stablehlo::ConstantOp>(op.getLoc(), SplatElementsAttr::get(RT64, op.getI64IntegerAttr(lengths.size()-1)))
builder.create<stablehlo::ConstantOp>(op.getLoc(), SplatElementsAttr::get(RT64, rewriter.getI64IntegerAttr(lengths.size()-1)))
};
ret_constant = builder.create<stablehlo::DynamicUpdateSliceOp>(op.getLoc(), resTy, ret_constant, zero_constant, start);
ret_constant = builder.create<stablehlo::DynamicUpdateSliceOp>(op.getLoc(), resTy, ret_constant, end_constant, end);
Expand Down

0 comments on commit e6dd644

Please sign in to comment.