Optimize XMAD instruction sequence into a single 32-bit multiply when possible

See merge request [ryubing/ryujinx!24](https://git.ryujinx.app/ryubing/ryujinx/-/merge_requests/24)
This commit is contained in:
KeatonTheBot 2025-05-23 17:12:37 -05:00 committed by GreemDev
parent be7285f7fc
commit 1c411082db
2 changed files with 348 additions and 1 deletions

View File

@ -9,6 +9,11 @@ namespace Ryujinx.Graphics.Shader.Translation.Optimizations
{ {
public static void RunPass(TransformContext context) public static void RunPass(TransformContext context)
{ {
for (int blkIndex = 0; blkIndex < context.Blocks.Length; blkIndex++)
{
XmadOptimizer.RunPass(context.Blocks[blkIndex]);
}
RunOptimizationPasses(context.Blocks, context.ResourceManager); RunOptimizationPasses(context.Blocks, context.ResourceManager);
// TODO: Some of those are not optimizations and shouldn't be here. // TODO: Some of those are not optimizations and shouldn't be here.
@ -355,7 +360,7 @@ namespace Ryujinx.Graphics.Shader.Translation.Optimizations
operation.TurnIntoCopy(attrMulLhs); operation.TurnIntoCopy(attrMulLhs);
} }
private static void RemoveNode(BasicBlock block, LinkedListNode<INode> llNode) public static void RemoveNode(BasicBlock block, LinkedListNode<INode> llNode)
{ {
// Remove a node from the nodes list, and also remove itself // Remove a node from the nodes list, and also remove itself
// from all the use lists on the operands that this node uses. // from all the use lists on the operands that this node uses.

View File

@ -0,0 +1,342 @@
using Ryujinx.Graphics.Shader.IntermediateRepresentation;
using System.Collections.Generic;
namespace Ryujinx.Graphics.Shader.Translation.Optimizations
{
static class XmadOptimizer
{
public static void RunPass(BasicBlock block)
{
for (LinkedListNode<INode> node = block.Operations.First; node != null; node = node.Next)
{
if (!(node.Value is Operation operation))
{
continue;
}
if (TryMatchXmadPattern(operation, out Operand x, out Operand y, out Operand addend))
{
LinkedListNode<INode> nextNode;
if (addend != null)
{
Operand temp = OperandHelper.Local();
nextNode = block.Operations.AddAfter(node, new Operation(Instruction.Multiply, temp, x, y));
nextNode = block.Operations.AddAfter(nextNode, new Operation(Instruction.Add, operation.Dest, temp, addend));
}
else
{
nextNode = block.Operations.AddAfter(node, new Operation(Instruction.Multiply, operation.Dest, x, y));
}
Optimizer.RemoveNode(block, node);
node = nextNode;
}
}
}
private static bool TryMatchXmadPattern(Operation operation, out Operand x, out Operand y, out Operand addend)
{
return TryMatchXmad32x32Pattern(operation, out x, out y, out addend) ||
TryMatchXmad32x16Pattern(operation, out x, out y, out addend);
}
private static bool TryMatchXmad32x32Pattern(Operation operation, out Operand x, out Operand y, out Operand addend)
{
x = null;
y = null;
addend = null;
if (operation.Inst != Instruction.Add)
{
return false;
}
Operand src1 = operation.GetSource(0);
Operand src2 = operation.GetSource(1);
if (!(src2.AsgOp is Operation addOp) || addOp.Inst != Instruction.Add)
{
return false;
}
Operand lowTimesLowResult = GetCopySource(addOp.GetSource(0));
if (!(lowTimesLowResult.AsgOp is Operation lowTimesLowOp))
{
return false;
}
if (!TryMatchLowTimesLow(lowTimesLowOp, out x, out y, out addend))
{
return false;
}
Operand lowTimesHighResult = GetCopySource(GetShifted16Source(addOp.GetSource(1), Instruction.ShiftLeft));
if (!(lowTimesHighResult.AsgOp is Operation lowTimesHighOp))
{
return false;
}
if (!TryMatchLowTimesHigh(lowTimesHighOp, x, y))
{
return false;
}
if (!(src1.AsgOp is Operation highTimesHighOp))
{
return false;
}
if (!TryMatchHighTimesHigh(highTimesHighOp, x, lowTimesHighResult))
{
return false;
}
return true;
}
private static bool TryMatchXmad32x16Pattern(Operation operation, out Operand x, out Operand y, out Operand addend)
{
x = null;
y = null;
addend = null;
if (operation.Inst != Instruction.Add)
{
return false;
}
Operand src1 = operation.GetSource(0);
Operand src2 = operation.GetSource(1);
Operand lowTimesLowResult = GetCopySource(src2);
if (!(lowTimesLowResult.AsgOp is Operation lowTimesLowOp))
{
return false;
}
if (!TryMatchLowTimesLow(lowTimesLowOp, out x, out y, out addend))
{
return false;
}
Operand highTimesLowResult = src1;
if (!(highTimesLowResult.AsgOp is Operation highTimesLowOp))
{
return false;
}
if (!TryMatchHighTimesLow(highTimesLowOp, x, y))
{
return false;
}
return y.Type == OperandType.Constant && (ushort)y.Value == y.Value;
}
private static bool TryMatchLowTimesLow(Operation operation, out Operand x, out Operand y, out Operand addend)
{
// x = x & 0xFFFF
// y = y & 0xFFFF
// lowTimesLow = x * y
x = null;
y = null;
addend = null;
if (operation.Inst == Instruction.Add)
{
if (!(operation.GetSource(0).AsgOp is Operation mulOp))
{
return false;
}
addend = operation.GetSource(1);
operation = mulOp;
}
if (operation.Inst != Instruction.Multiply)
{
return false;
}
Operand src1 = GetMasked16Source(operation.GetSource(0));
Operand src2 = GetMasked16Source(operation.GetSource(1));
if (src1 == null || src2 == null)
{
return false;
}
x = src1;
y = src2;
return true;
}
private static bool TryMatchLowTimesHigh(Operation operation, Operand x, Operand y)
{
// xLow = x & 0xFFFF
// yHigh = y >> 16
// lowTimesHigh = xLow * yHigh
// result = (lowTimesHigh & 0xFFFF) | (y << 16)
if (operation.Inst != Instruction.BitwiseOr)
{
return false;
}
Operand mulResult = GetMasked16Source(operation.GetSource(0));
if (mulResult == null)
{
return false;
}
mulResult = GetCopySource(mulResult);
if (!(mulResult.AsgOp is Operation mulOp) || mulOp.Inst != Instruction.Multiply)
{
return false;
}
if (GetMasked16Source(mulOp.GetSource(0)) != x)
{
return false;
}
if (GetShifted16Source(mulOp.GetSource(1), Instruction.ShiftRightU32) != y)
{
return false;
}
if (GetShifted16Source(operation.GetSource(1), Instruction.ShiftLeft) != y)
{
return false;
}
return true;
}
private static bool TryMatchHighTimesLow(Operation operation, Operand x, Operand y)
{
// xHigh = x >> 16
// yLow = y & 0xFFFF
// highTimesLow = xHigh * yLow
// result = highTimesLow << 16
if (operation.Inst != Instruction.ShiftLeft || !IsConst(operation.GetSource(1), 16))
{
return false;
}
Operand mulResult = operation.GetSource(0);
if (!(mulResult.AsgOp is Operation mulOp) || mulOp.Inst != Instruction.Multiply)
{
return false;
}
if (GetShifted16Source(mulOp.GetSource(0), Instruction.ShiftRightU32) != x)
{
return false;
}
Operand src2 = GetMasked16Source(mulOp.GetSource(1));
if (src2.Type != y.Type || src2.Value != y.Value)
{
return false;
}
return true;
}
private static bool TryMatchHighTimesHigh(Operation operation, Operand x, Operand lowTimesHighResult)
{
// xHigh = x >> 16
// lowTimesHighResultHigh = lowTimesHighResult >> 16
// highTimesHigh = xHigh * lowTimesHighResultHigh
// result = highTimesHigh << 16
if (operation.Inst != Instruction.ShiftLeft || !IsConst(operation.GetSource(1), 16))
{
return false;
}
Operand mulResult = operation.GetSource(0);
if (!(mulResult.AsgOp is Operation mulOp) || mulOp.Inst != Instruction.Multiply)
{
return false;
}
if (GetShifted16Source(mulOp.GetSource(0), Instruction.ShiftRightU32) != x)
{
return false;
}
if (GetCopySource(GetShifted16Source(mulOp.GetSource(1), Instruction.ShiftRightU32)) != lowTimesHighResult)
{
return false;
}
return true;
}
private static Operand GetMasked16Source(Operand value)
{
if (!(value.AsgOp is Operation maskOp))
{
return null;
}
if (maskOp.Inst != Instruction.BitwiseAnd || !IsConst(maskOp.GetSource(1), ushort.MaxValue))
{
return null;
}
return maskOp.GetSource(0);
}
private static Operand GetShifted16Source(Operand value, Instruction shiftInst)
{
if (!(value.AsgOp is Operation shiftOp))
{
return null;
}
if (shiftOp.Inst != shiftInst || !IsConst(shiftOp.GetSource(1), 16))
{
return null;
}
return shiftOp.GetSource(0);
}
private static Operand GetCopySource(Operand value)
{
while (value.AsgOp is Operation operation && IsCopy(operation))
{
value = operation.GetSource(0);
}
return value;
}
private static bool IsCopy(Operation operation)
{
return operation.Inst == Instruction.Copy || (operation.Inst == Instruction.Add && IsConst(operation.GetSource(1), 0));
}
private static bool IsConst(Operand operand, int value)
{
return operand.Type == OperandType.Constant && operand.Value == value;
}
}
}