Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CodeGen: Implement support for math.lerp lowering #1609

Merged
merged 2 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CodeGen/include/Luau/AssemblyBuilderX64.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ class AssemblyBuilderX64
void vmaxsd(OperandX64 dst, OperandX64 src1, OperandX64 src2);
void vminsd(OperandX64 dst, OperandX64 src1, OperandX64 src2);

void vcmpeqsd(OperandX64 dst, OperandX64 src1, OperandX64 src2);
void vcmpltsd(OperandX64 dst, OperandX64 src1, OperandX64 src2);

void vblendvpd(RegisterX64 dst, RegisterX64 src1, OperandX64 mask, RegisterX64 src3);
Expand Down
5 changes: 5 additions & 0 deletions CodeGen/include/Luau/IrData.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,11 @@ enum class IrCmd : uint8_t
// A: double
SIGN_NUM,

// Select B if C == D, otherwise select A
// A, B: double (endpoints)
// C, D: double (condition arguments)
SELECT_NUM,
zeux marked this conversation as resolved.
Show resolved Hide resolved

// Add/Sub/Mul/Div/Idiv two vectors
// A, B: TValue
ADD_VEC,
Expand Down
1 change: 1 addition & 0 deletions CodeGen/include/Luau/IrUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ inline bool hasResult(IrCmd cmd)
case IrCmd::SQRT_NUM:
case IrCmd::ABS_NUM:
case IrCmd::SIGN_NUM:
case IrCmd::SELECT_NUM:
case IrCmd::ADD_VEC:
case IrCmd::SUB_VEC:
case IrCmd::MUL_VEC:
Expand Down
5 changes: 5 additions & 0 deletions CodeGen/src/AssemblyBuilderX64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -927,6 +927,11 @@ void AssemblyBuilderX64::vminsd(OperandX64 dst, OperandX64 src1, OperandX64 src2
placeAvx("vminsd", dst, src1, src2, 0x5d, false, AVX_0F, AVX_F2);
}

void AssemblyBuilderX64::vcmpeqsd(OperandX64 dst, OperandX64 src1, OperandX64 src2)
{
placeAvx("vcmpeqsd", dst, src1, src2, 0x00, 0xc2, false, AVX_0F, AVX_F2);
}

void AssemblyBuilderX64::vcmpltsd(OperandX64 dst, OperandX64 src1, OperandX64 src2)
{
placeAvx("vcmpltsd", dst, src1, src2, 0x01, 0xc2, false, AVX_0F, AVX_F2);
Expand Down
2 changes: 2 additions & 0 deletions CodeGen/src/IrDump.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ const char* getCmdName(IrCmd cmd)
return "ABS_NUM";
case IrCmd::SIGN_NUM:
return "SIGN_NUM";
case IrCmd::SELECT_NUM:
return "SELECT_NUM";
case IrCmd::ADD_VEC:
return "ADD_VEC";
case IrCmd::SUB_VEC:
Expand Down
15 changes: 15 additions & 0 deletions CodeGen/src/IrLoweringA64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

LUAU_FASTFLAG(LuauVectorLibNativeDot)
LUAU_FASTFLAG(LuauCodeGenVectorDeadStoreElim)
LUAU_FASTFLAG(LuauCodeGenLerp)

namespace Luau
{
Expand Down Expand Up @@ -703,6 +704,20 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
build.fcsel(inst.regA64, temp1, inst.regA64, getConditionFP(IrCondition::Less));
break;
}
case IrCmd::SELECT_NUM:
{
LUAU_ASSERT(FFlag::LuauCodeGenLerp);
inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a, inst.b, inst.c, inst.d});

RegisterA64 temp1 = tempDouble(inst.a);
RegisterA64 temp2 = tempDouble(inst.b);
RegisterA64 temp3 = tempDouble(inst.c);
RegisterA64 temp4 = tempDouble(inst.d);

build.fcmp(temp3, temp4);
build.fcsel(inst.regA64, temp2, temp1, getConditionFP(IrCondition::Equal));
break;
}
case IrCmd::ADD_VEC:
{
inst.regA64 = regs.allocReuse(KindA64::q, index, {inst.a, inst.b});
Expand Down
25 changes: 25 additions & 0 deletions CodeGen/src/IrLoweringX64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

LUAU_FASTFLAG(LuauVectorLibNativeDot)
LUAU_FASTFLAG(LuauCodeGenVectorDeadStoreElim)
LUAU_FASTFLAG(LuauCodeGenLerp)

namespace Luau
{
Expand Down Expand Up @@ -622,6 +623,30 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
build.vblendvpd(inst.regX64, tmp1.reg, build.f64x2(1, 1), inst.regX64);
break;
}
case IrCmd::SELECT_NUM:
{
LUAU_ASSERT(FFlag::LuauCodeGenLerp);
inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.c, inst.d}); // can't reuse b if a is a memory operand

ScopedRegX64 tmp{regs, SizeX64::xmmword};

if (inst.c.kind == IrOpKind::Inst)
build.vcmpeqsd(tmp.reg, regOp(inst.c), memRegDoubleOp(inst.d));
else
{
build.vmovsd(tmp.reg, memRegDoubleOp(inst.c));
build.vcmpeqsd(tmp.reg, tmp.reg, memRegDoubleOp(inst.d));
}

if (inst.a.kind == IrOpKind::Inst)
build.vblendvpd(inst.regX64, regOp(inst.a), memRegDoubleOp(inst.b), tmp.reg);
else
{
build.vmovsd(inst.regX64, memRegDoubleOp(inst.a));
build.vblendvpd(inst.regX64, inst.regX64, memRegDoubleOp(inst.b), tmp.reg);
}
break;
}
case IrCmd::ADD_VEC:
{
inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b});
Expand Down
39 changes: 39 additions & 0 deletions CodeGen/src/IrTranslateBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ static const int kBit32BinaryOpUnrolledParams = 5;

LUAU_FASTFLAGVARIABLE(LuauVectorLibNativeCodegen);
LUAU_FASTFLAGVARIABLE(LuauVectorLibNativeDot);
LUAU_FASTFLAGVARIABLE(LuauCodeGenLerp);

namespace Luau
{
Expand Down Expand Up @@ -284,6 +285,42 @@ static BuiltinImplResult translateBuiltinMathClamp(
return {BuiltinImplType::UsesFallback, 1};
}

static BuiltinImplResult translateBuiltinMathLerp(
IrBuilder& build,
int nparams,
int ra,
int arg,
IrOp args,
IrOp arg3,
int nresults,
IrOp fallback,
int pcpos
)
{
LUAU_ASSERT(FFlag::LuauCodeGenLerp);

if (nparams < 3 || nresults > 1)
return {BuiltinImplType::None, -1};

builtinCheckDouble(build, build.vmReg(arg), pcpos);
builtinCheckDouble(build, args, pcpos);
builtinCheckDouble(build, arg3, pcpos);

IrOp a = builtinLoadDouble(build, build.vmReg(arg));
IrOp b = builtinLoadDouble(build, args);
IrOp t = builtinLoadDouble(build, arg3);

IrOp l = build.inst(IrCmd::ADD_NUM, a, build.inst(IrCmd::MUL_NUM, build.inst(IrCmd::SUB_NUM, b, a), t));
IrOp r = build.inst(IrCmd::SELECT_NUM, l, b, t, build.constDouble(1.0)); // select on t==1.0

build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), r);

if (ra != arg)
build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER));

return {BuiltinImplType::Full, 1};
}

static BuiltinImplResult translateBuiltinMathUnary(IrBuilder& build, IrCmd cmd, int nparams, int ra, int arg, int nresults, int pcpos)
{
if (nparams < 1 || nresults > 1)
Expand Down Expand Up @@ -1387,6 +1424,8 @@ BuiltinImplResult translateBuiltin(
case LBF_VECTOR_MAX:
return FFlag::LuauVectorLibNativeCodegen ? translateBuiltinVectorMap2(build, IrCmd::MAX_NUM, nparams, ra, arg, args, arg3, nresults, pcpos)
: noneResult;
case LBF_MATH_LERP:
return FFlag::LuauCodeGenLerp ? translateBuiltinMathLerp(build, nparams, ra, arg, args, arg3, nresults, fallback, pcpos) : noneResult;
default:
return {BuiltinImplType::None, -1};
}
Expand Down
12 changes: 12 additions & 0 deletions CodeGen/src/IrUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <math.h>

LUAU_FASTFLAG(LuauVectorLibNativeDot);
LUAU_FASTFLAG(LuauCodeGenLerp);

namespace Luau
{
Expand Down Expand Up @@ -70,6 +71,7 @@ IrValueKind getCmdValueKind(IrCmd cmd)
case IrCmd::SQRT_NUM:
case IrCmd::ABS_NUM:
case IrCmd::SIGN_NUM:
case IrCmd::SELECT_NUM:
return IrValueKind::Double;
case IrCmd::ADD_VEC:
case IrCmd::SUB_VEC:
Expand Down Expand Up @@ -656,6 +658,16 @@ void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint3
substitute(function, inst, build.constDouble(v > 0.0 ? 1.0 : v < 0.0 ? -1.0 : 0.0));
}
break;
case IrCmd::SELECT_NUM:
LUAU_ASSERT(FFlag::LuauCodeGenLerp);
if (inst.c.kind == IrOpKind::Constant && inst.d.kind == IrOpKind::Constant)
{
double c = function.doubleOp(inst.c);
double d = function.doubleOp(inst.d);

substitute(function, inst, c == d ? inst.b : inst.a);
}
break;
case IrCmd::NOT_ANY:
if (inst.a.kind == IrOpKind::Constant)
{
Expand Down
1 change: 1 addition & 0 deletions CodeGen/src/OptimizeConstProp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1382,6 +1382,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction&
case IrCmd::SQRT_NUM:
case IrCmd::ABS_NUM:
case IrCmd::SIGN_NUM:
case IrCmd::SELECT_NUM:
case IrCmd::NOT_ANY:
state.substituteOrRecord(inst, index);
break;
Expand Down
1 change: 1 addition & 0 deletions tests/AssemblyBuilderX64.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,7 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXBinaryInstructionForms")
SINGLE_COMPARE(vmaxsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0x5f, 0xc6);
SINGLE_COMPARE(vminsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0x5d, 0xc6);

SINGLE_COMPARE(vcmpeqsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0xc2, 0xc6, 0x00);
SINGLE_COMPARE(vcmpltsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0xc2, 0xc6, 0x01);
}

Expand Down
1 change: 1 addition & 0 deletions tests/conformance/math.lua
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ assert(math.lerp(1, 5, 1) == 5)
assert(math.lerp(1, 5, 0.5) == 3)
assert(math.lerp(1, 5, 1.5) == 7)
assert(math.lerp(1, 5, -0.5) == -1)
assert(math.lerp(1, 5, noinline(0.5)) == 3)

-- lerp properties
local sq2, sq3 = math.sqrt(2), math.sqrt(3)
Expand Down
Loading