Skip to content

Commit

Permalink
add packed reduce
Browse files Browse the repository at this point in the history
  • Loading branch information
zhen8838 committed Sep 30, 2024
1 parent c79b35c commit eaf1adb
Show file tree
Hide file tree
Showing 13 changed files with 423 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ protected override CSymbol VisitCall(Call expr)
IndentScope.Writer.Write($"pad<{string.Join(",", pad.Paddings)}>({Visit(args[0]).Name}, {Visit(args[1]).Name}, {args[0].CheckedDataType.ToC()} {{ {pad.PadValue} }} );\n");
break;
case TIR.CPU.Reduce reduce:
IndentScope.Writer.Write($"reduce_{reduce.ReduceOp.ToC()}<fixed_shape<{string.Join(",", reduce.Axis)}>, fixed_shape<{string.Join(",", reduce.PackedAxes)}>, fixed_shape<{string.Join(",", reduce.PadedNums)}>>({Visit(args[0]).Name}, {Visit(args[1]).Name});\n");
IndentScope.Writer.Write($"reduce_{reduce.ReduceOp.ToC()}<fixed_shape<{string.Join(",", reduce.Axes)}>, fixed_shape<{string.Join(",", reduce.PackedAxes)}>, fixed_shape<{string.Join(",", reduce.PadedNums)}>>({Visit(args[0]).Name}, {Visit(args[1]).Name});\n");
break;
case TIR.CPU.ReduceArg reduceArg:
IndentScope.Writer.Write($"reduce_arg<ops::{reduceArg.ReduceArgOp.ToC()[4..]}, {reduceArg.Axis}, {reduceArg.SelectLastIndex.ToString().ToLower(System.Globalization.CultureInfo.CurrentCulture)}, {reduceArg.KeepDims.ToString().ToLower(System.Globalization.CultureInfo.CurrentCulture)}>({Visit(args[0]).Name}, {Visit(args[1]).Name}, fixed_shape<>{{}}, fixed_shape<>{{}});\n");
Expand Down
1 change: 1 addition & 0 deletions modules/Nncase.Modules.CPU/Evaluator/CPU/CPUModule.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ public void ConfigureServices(IRegistrator registrator)
registrator.RegisterManyInterface<LoadEvaluator>(reuse: Reuse.Singleton);
registrator.RegisterManyInterface<StoreEvaluator>(reuse: Reuse.Singleton);
registrator.RegisterManyInterface<PackEvaluator>(reuse: Reuse.Singleton);
registrator.RegisterManyInterface<PackedReduceEvaluator>(reuse: Reuse.Singleton);
registrator.RegisterManyInterface<PackedSoftMaxEvaluator>(reuse: Reuse.Singleton);
registrator.RegisterManyInterface<PackedLayerNormEvaluator>(reuse: Reuse.Singleton);
registrator.RegisterManyInterface<PackedMatMulEvaluator>(reuse: Reuse.Singleton);
Expand Down
129 changes: 129 additions & 0 deletions modules/Nncase.Modules.CPU/Evaluator/CPU/PackedReduce.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using System;
using System.Linq;
using System.Numerics;
using System.Runtime.InteropServices;
using Nncase.CostModel;
using Nncase.IR;
using Nncase.IR.CPU;
using Nncase.Utilities;
using OrtKISharp;

namespace Nncase.Evaluator.IR.CPU;

public sealed class PackedReduceEvaluator : IEvaluator<PackedReduce>, ITypeInferencer<PackedReduce>, ICostEvaluator<PackedReduce>
{
/// <inheritdoc/>
public IValue Visit(IEvaluateContext context, PackedReduce target)
{
var input = context.GetOrtArgumentValue(target, PackedReduce.Input);
var inshape = input.Shape.SkipLast(target.PackedAxes.Count).Select(i => (int)i).ToArray();
var inlanes = input.Shape.TakeLast(target.PackedAxes.Count).Select(i => (int)i).ToArray();
var unpackedInput = CPUEvaluatorUtility.UnpackTensor(input, target.PackedAxes, target.PadedNums, out _);
var axes = target.Axes.Select(i => (long)i).ToArray();
long keepdims = target.KeepDims ? 1 : 0;
foreach (var axis in target.PackedAxes.Reverse())
{
input = input.Unpack(axis);
}

OrtKISharp.Tensor output;
switch (target.ReduceOp)
{
case ReduceOp.Sum:
output = OrtKI.ReduceSum(unpackedInput, axes, keepdims, 0);
break;
case ReduceOp.Mean:
output = OrtKI.ReduceMean(unpackedInput, axes, keepdims);
break;
default:
throw new NotSupportedException(target.ReduceOp.ToString());
}

var (outPackAxes, outPadNums, outLanes, outShape) = PackedReduce.ComputeOutputInfo(target, inshape, inlanes);
output = CPUEvaluatorUtility.RepackTensor(output, outLanes.ToArray(), outPackAxes, outPadNums);

return Value.FromTensor(Tensor.FromBytes(outLanes.Length == 0 ? DataTypes.Float32 : new VectorType(DataTypes.Float32, outLanes.ToArray()), output.BytesBuffer.ToArray(), outShape));
}

/// <inheritdoc/>
public IRType Visit(ITypeInferenceContext context, PackedReduce target)
{
var input = context.CheckArgumentType<IRType>(target, PackedReduce.Input);

return input switch
{
DistributedType d => Visit(context, target, d),
TensorType t => Visit(context, target, t),
AnyType a => a,
_ => new InvalidType(input.GetType().ToString()),
};
}

/// <inheritdoc/>
public Cost Visit(ICostEvaluateContext context, PackedReduce target)
{
var input = context.GetArgumentType<IRType>(target, PackedReduce.Input);
var ret = context.GetReturnType<IRType>();
var inputShape = input switch
{
TensorType t => t.Shape,
DistributedType d => d.TensorType.Shape,
_ => throw new NotSupportedException(string.Empty),
};
var retShape = ret switch
{
TensorType t => t.Shape,
DistributedType d => d.TensorType.Shape,
_ => throw new NotSupportedException(string.Empty),
};
uint input_elem = inputShape.Aggregate(1U, (acc, d) => acc * (d.IsFixed ? (uint)d.FixedValue : 1U));
uint ret_elem = retShape.Aggregate(1U, (acc, d) => acc * (d.IsFixed ? (uint)d.FixedValue : 1U));
uint macPerElement = input_elem / ret_elem;
return new()
{
[CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(input),
[CostFactorNames.MemoryStore] = CostUtility.GetMemoryAccess(ret),
[CostFactorNames.CPUCycles] = CostUtility.GetCPUCycles(ret, macPerElement),
};
}

private IRType Visit(ITypeInferenceContext context, PackedReduce target, TensorType t)
{
var inshape = t.Shape.ToValueArray();
var inDtype = (VectorType)t.DType;
var inlanes = inDtype.Lanes.ToArray();
var (_, _, outLanes, outShape) = PackedReduce.ComputeOutputInfo(target, inshape, inlanes);
var outDType = outLanes.Length == 0 ? inDtype.ElemType : new VectorType(inDtype.ElemType, outLanes);
return new TensorType(outDType, outShape);
}

private IRType Visit(ITypeInferenceContext context, PackedReduce target, DistributedType input)
{
if (Visit(context, target, input.TensorType) is not TensorType tensorType)
{
throw new InvalidOperationException();
}

var axes = target.Axes.ToArray();
var invalid = new InvalidType($"{input}, not support");

var ndsbp = new SBP[input.Placement.Rank];

for (int i = 0; i < input.Placement.Rank; i++)
{
switch (input.NdSBP[i])
{
case SBPSplit { Axis: int ix } when axes.Contains(ix):
return invalid;
default:
ndsbp[i] = input.NdSBP[i];
break;
}
}

return new DistributedType(tensorType, ndsbp, input.Placement);
}
}
5 changes: 5 additions & 0 deletions modules/Nncase.Modules.CPU/IR/CPU/Functional.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ public static Expr PackedLayerNorm(Expr input, Expr scale, Expr bias, int axis,
return new Call(new PackedLayerNorm(axis, epsilon, usemean, packedAxes, padedNums), input, scale, bias);
}

public static Call PackedReduce(Expr input, ReduceOp reduceOp, IRArray<int> axes, float initValue, bool keepDims, IRArray<int> packedAxes, IRArray<int> padedNums)
{
return new Call(new PackedReduce(reduceOp, axes, initValue, keepDims, packedAxes, padedNums), input);
}

public static Expr InstacneNorm(Expr input, Expr scale, Expr bias, float epsilon, IRArray<int> packedAxes, IRArray<int> padedNums)
{
return new Call(new InstacneNorm(epsilon, packedAxes, padedNums), input, scale, bias);
Expand Down
66 changes: 66 additions & 0 deletions modules/Nncase.Modules.CPU/IR/CPU/PackedReduce.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using Nncase.PatternMatch;

namespace Nncase.IR.CPU;

[PatternFunctionalGenerator]
public sealed partial class PackedReduce : PackedOp
{
/// <summary>
/// Gets input.
/// </summary>
public static readonly ParameterInfo Input = new(typeof(PackedReduce), 0, "input", ParameterKind.Input);

public ReduceOp ReduceOp { get; }

public IRArray<int> Axes { get; }

public float InitValue { get; }

public bool KeepDims { get; }

public IRArray<int> PackedAxes { get; }

public IRArray<int> PadedNums { get; }

public static (int[] OutPackAxes, int[] OutPadNums, int[] OutLanes, int[] OutShape) ComputeOutputInfo(PackedReduce target, int[] inShape, int[] inLanes)
{
var packedAxes = target.PackedAxes.ToList();
var padedNums = target.PadedNums.ToList();
var lanes = inLanes.ToList();
var shape = inShape.ToList(); // note the inshape is packed.
var offset = 0;
foreach (var axis in target.Axes)
{
if (target.KeepDims)
{
shape[axis] = 1;
}
else
{
shape.RemoveAt(offset + axis);
offset--;
}

if (packedAxes.IndexOf(axis) is int j && j != -1)
{
packedAxes.Remove(axis);
padedNums.RemoveAt(j);
lanes.RemoveAt(j);
for (int i = 0; i < packedAxes.Count; i++)
{
if (packedAxes[i] > axis)
{
packedAxes[i]--;
}
}
}
}

return (packedAxes.ToArray(), padedNums.ToArray(), lanes.ToArray(), shape.ToArray());
}

public override string DisplayProperty() => $"ReduceOp.{ReduceOp}, Axes: {{{string.Join(",", Axes)}}}, InitValue: {InitValue}, KeepDims: {KeepDims}, PackedAxes: {{{string.Join(",", PackedAxes)}}}, PadedNums: {{{string.Join(",", PadedNums)}}}";
}
67 changes: 67 additions & 0 deletions modules/Nncase.Modules.CPU/Passes/Rules/CPU/PackRule.cs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,73 @@ void AddCandidate(int[] packedAxes, int[] lanes)
}
}

public sealed class PackReduce : PackRule
{
public PackReduce(int rank, int lane)
: base(rank, lane)
{
}

public override Pattern Pattern { get; } = IsReduce(
"target",
r => r.ReduceOp is ReduceOp.Mean or ReduceOp.Sum,
IsWildcard("input", e => e is not Call { Target: IR.CPU.Unpack }) with { TypePattern = IsFloat() & !IsVector() },
IsTensorConst("axes") with { TypePattern = IsIntegral() },
IsTensorConst("initValue") with { TypePattern = IsFloat() },
IsTensorConst("keepDims") with { TypePattern = IsBool() });

public override List<Expr> GetReplaceCandidates(IMatchResult result, RunPassContext context)
{
var rets = new List<Expr>();
var op = (IR.Math.Reduce)result["target"];
var input = (Expr)result["input"];
var axes = ((TensorConst)result["axes"]).Value.ToArray<int>();
if (axes.Length > 1)
{
return new();
}

var initValue = ((TensorConst)result["initValue"]).Value.ToScalar<float>();
var keepDims = ((TensorConst)result["keepDims"]).Value.ToScalar<bool>();
var inShape = input.CheckedShape.ToValueArray();

void AddCandidate(int[] packedAxes, int[] lanes)
{
var packedInput = IR.F.CPU.Pack(PackUtility.PadForPack(input, inShape, packedAxes, lanes, 0f, out var padsInput), lanes, packedAxes);

// todo support padings.
if (padsInput.Any(x => x > 0))
{
return;
}

Call reduce = IR.F.CPU.PackedReduce(packedInput, op.ReduceOp, axes, initValue, keepDims, packedAxes, padsInput);

var (outPackAxes, outPadNums, outLanes, outShape) = IR.CPU.PackedReduce.ComputeOutputInfo((IR.CPU.PackedReduce)reduce.Target, inShape, lanes);
var post = PackUtility.SliceForPack(IR.F.CPU.Unpack(reduce, outLanes, outPackAxes), outShape, outPadNums);

if (post.CheckedType is not InvalidType)
{
rets.Add(post);
}
}

for (int i = 0; i < input.CheckedShape.Count; i++)
{
AddCandidate([i], [Lane]);
for (int j = i + 1; j < input.CheckedShape.Count; j++)
{
if (Rank > 1)
{
AddCandidate([i, j], [Lane, Lane]);
}
}
}

return rets;
}
}

public sealed class PackInstanceNorm : PackRule
{
public PackInstanceNorm(int rank, int lane)
Expand Down
3 changes: 3 additions & 0 deletions modules/Nncase.Modules.CPU/Passes/Tile/KernelToTIRVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,9 @@ protected override Unit VisitLeafCall(Call expr)
case IR.NN.Erf erf:
_mainBody.Add(TIR.F.CPU.Erf(arguments[0], ret));
break;
case IR.CPU.PackedReduce pr:
_mainBody.Add(TIR.F.CPU.Reduce(arguments[0], ret, pr.PackedAxes.ToArray(), pr.PadedNums.ToArray(), pr.Axes, pr.KeepDims, pr.ReduceOp));
break;
case IR.Tensors.GetItem:
break;
default:
Expand Down
2 changes: 1 addition & 1 deletion modules/Nncase.Modules.CPU/TIR/CPU/Reduce.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public sealed partial class Reduce : CPUKernelOp

public IRArray<int> PadedNums { get; }

public Nncase.IR.IRArray<int> Axis { get; }
public IRArray<int> Axes { get; }

public bool KeepDims { get; }

Expand Down
1 change: 1 addition & 0 deletions modules/Nncase.Modules.CPU/Targets/CPUTarget.cs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ public void RegisterTargetDependentAfterQuantPass(IPassManager passManager, Comp
// todo config it in the target options.
var rank = 1;
var lane = System.Runtime.Intrinsics.Vector256.IsHardwareAccelerated ? 8 : 4;
p.Add<Passes.Rules.CPU.PackReduce>(rank, lane);
p.Add<Passes.Rules.CPU.PackSwish>(rank, lane);
p.Add<Passes.Rules.CPU.PackResizeImage>(rank, lane);
p.Add<Passes.Rules.CPU.PackMatMul>(rank, lane);
Expand Down
11 changes: 11 additions & 0 deletions src/Native/include/nncase/ntt/arch/aarch64/tensor_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,17 @@ template <> struct tload_scalar<ntt::vector<float, 4>> {
}
};

template <> struct tload_scalar<ntt::vector<float, 4, 4>> {
ntt::vector<float, 4, 4> operator()(const float &v) const noexcept {
ntt::vector<float, 4, 4> ret;
ret(0) = vdupq_n_f32(v);
ret(1) = vdupq_n_f32(v);
ret(2) = vdupq_n_f32(v);
ret(3) = vdupq_n_f32(v);
return ret;
}
};

template <> struct tload_scalar<ntt::vector<float, 8>> {
ntt::vector<float, 8> operator()(const float &v) const noexcept {
return float32x4x2_t{vdupq_n_f32(v), vdupq_n_f32(v)};
Expand Down
Loading

0 comments on commit eaf1adb

Please sign in to comment.