mirror of
https://git.ryujinx.app/ryubing/ryujinx.git
synced 2025-06-07 08:57:13 +00:00
Optimize XMAD instruction sequence into a single 32-bit multiply when possible
See merge request [ryubing/ryujinx!24](https://git.ryujinx.app/ryubing/ryujinx/-/merge_requests/24)
This commit is contained in:
parent
be7285f7fc
commit
1c411082db
@ -9,6 +9,11 @@ namespace Ryujinx.Graphics.Shader.Translation.Optimizations
|
||||
{
|
||||
public static void RunPass(TransformContext context)
|
||||
{
|
||||
for (int blkIndex = 0; blkIndex < context.Blocks.Length; blkIndex++)
|
||||
{
|
||||
XmadOptimizer.RunPass(context.Blocks[blkIndex]);
|
||||
}
|
||||
|
||||
RunOptimizationPasses(context.Blocks, context.ResourceManager);
|
||||
|
||||
// TODO: Some of those are not optimizations and shouldn't be here.
|
||||
@ -355,7 +360,7 @@ namespace Ryujinx.Graphics.Shader.Translation.Optimizations
|
||||
operation.TurnIntoCopy(attrMulLhs);
|
||||
}
|
||||
|
||||
private static void RemoveNode(BasicBlock block, LinkedListNode<INode> llNode)
|
||||
public static void RemoveNode(BasicBlock block, LinkedListNode<INode> llNode)
|
||||
{
|
||||
// Remove a node from the nodes list, and also remove itself
|
||||
// from all the use lists on the operands that this node uses.
|
||||
|
@ -0,0 +1,342 @@
|
||||
using Ryujinx.Graphics.Shader.IntermediateRepresentation;
|
||||
using System.Collections.Generic;
|
||||
|
||||
namespace Ryujinx.Graphics.Shader.Translation.Optimizations
|
||||
{
|
||||
static class XmadOptimizer
|
||||
{
|
||||
public static void RunPass(BasicBlock block)
|
||||
{
|
||||
for (LinkedListNode<INode> node = block.Operations.First; node != null; node = node.Next)
|
||||
{
|
||||
if (!(node.Value is Operation operation))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if (TryMatchXmadPattern(operation, out Operand x, out Operand y, out Operand addend))
|
||||
{
|
||||
LinkedListNode<INode> nextNode;
|
||||
|
||||
if (addend != null)
|
||||
{
|
||||
Operand temp = OperandHelper.Local();
|
||||
|
||||
nextNode = block.Operations.AddAfter(node, new Operation(Instruction.Multiply, temp, x, y));
|
||||
nextNode = block.Operations.AddAfter(nextNode, new Operation(Instruction.Add, operation.Dest, temp, addend));
|
||||
}
|
||||
else
|
||||
{
|
||||
nextNode = block.Operations.AddAfter(node, new Operation(Instruction.Multiply, operation.Dest, x, y));
|
||||
}
|
||||
|
||||
Optimizer.RemoveNode(block, node);
|
||||
node = nextNode;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static bool TryMatchXmadPattern(Operation operation, out Operand x, out Operand y, out Operand addend)
|
||||
{
|
||||
return TryMatchXmad32x32Pattern(operation, out x, out y, out addend) ||
|
||||
TryMatchXmad32x16Pattern(operation, out x, out y, out addend);
|
||||
}
|
||||
|
||||
private static bool TryMatchXmad32x32Pattern(Operation operation, out Operand x, out Operand y, out Operand addend)
|
||||
{
|
||||
x = null;
|
||||
y = null;
|
||||
addend = null;
|
||||
|
||||
if (operation.Inst != Instruction.Add)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
Operand src1 = operation.GetSource(0);
|
||||
Operand src2 = operation.GetSource(1);
|
||||
|
||||
if (!(src2.AsgOp is Operation addOp) || addOp.Inst != Instruction.Add)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
Operand lowTimesLowResult = GetCopySource(addOp.GetSource(0));
|
||||
|
||||
if (!(lowTimesLowResult.AsgOp is Operation lowTimesLowOp))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!TryMatchLowTimesLow(lowTimesLowOp, out x, out y, out addend))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
Operand lowTimesHighResult = GetCopySource(GetShifted16Source(addOp.GetSource(1), Instruction.ShiftLeft));
|
||||
|
||||
if (!(lowTimesHighResult.AsgOp is Operation lowTimesHighOp))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!TryMatchLowTimesHigh(lowTimesHighOp, x, y))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!(src1.AsgOp is Operation highTimesHighOp))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!TryMatchHighTimesHigh(highTimesHighOp, x, lowTimesHighResult))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private static bool TryMatchXmad32x16Pattern(Operation operation, out Operand x, out Operand y, out Operand addend)
|
||||
{
|
||||
x = null;
|
||||
y = null;
|
||||
addend = null;
|
||||
|
||||
if (operation.Inst != Instruction.Add)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
Operand src1 = operation.GetSource(0);
|
||||
Operand src2 = operation.GetSource(1);
|
||||
|
||||
Operand lowTimesLowResult = GetCopySource(src2);
|
||||
|
||||
if (!(lowTimesLowResult.AsgOp is Operation lowTimesLowOp))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!TryMatchLowTimesLow(lowTimesLowOp, out x, out y, out addend))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
Operand highTimesLowResult = src1;
|
||||
|
||||
if (!(highTimesLowResult.AsgOp is Operation highTimesLowOp))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!TryMatchHighTimesLow(highTimesLowOp, x, y))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return y.Type == OperandType.Constant && (ushort)y.Value == y.Value;
|
||||
}
|
||||
|
||||
private static bool TryMatchLowTimesLow(Operation operation, out Operand x, out Operand y, out Operand addend)
|
||||
{
|
||||
// x = x & 0xFFFF
|
||||
// y = y & 0xFFFF
|
||||
// lowTimesLow = x * y
|
||||
|
||||
x = null;
|
||||
y = null;
|
||||
addend = null;
|
||||
|
||||
if (operation.Inst == Instruction.Add)
|
||||
{
|
||||
if (!(operation.GetSource(0).AsgOp is Operation mulOp))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
addend = operation.GetSource(1);
|
||||
operation = mulOp;
|
||||
}
|
||||
|
||||
if (operation.Inst != Instruction.Multiply)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
Operand src1 = GetMasked16Source(operation.GetSource(0));
|
||||
Operand src2 = GetMasked16Source(operation.GetSource(1));
|
||||
|
||||
if (src1 == null || src2 == null)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
x = src1;
|
||||
y = src2;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private static bool TryMatchLowTimesHigh(Operation operation, Operand x, Operand y)
|
||||
{
|
||||
// xLow = x & 0xFFFF
|
||||
// yHigh = y >> 16
|
||||
// lowTimesHigh = xLow * yHigh
|
||||
// result = (lowTimesHigh & 0xFFFF) | (y << 16)
|
||||
|
||||
if (operation.Inst != Instruction.BitwiseOr)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
Operand mulResult = GetMasked16Source(operation.GetSource(0));
|
||||
|
||||
if (mulResult == null)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
mulResult = GetCopySource(mulResult);
|
||||
|
||||
if (!(mulResult.AsgOp is Operation mulOp) || mulOp.Inst != Instruction.Multiply)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if (GetMasked16Source(mulOp.GetSource(0)) != x)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if (GetShifted16Source(mulOp.GetSource(1), Instruction.ShiftRightU32) != y)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if (GetShifted16Source(operation.GetSource(1), Instruction.ShiftLeft) != y)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private static bool TryMatchHighTimesLow(Operation operation, Operand x, Operand y)
|
||||
{
|
||||
// xHigh = x >> 16
|
||||
// yLow = y & 0xFFFF
|
||||
// highTimesLow = xHigh * yLow
|
||||
// result = highTimesLow << 16
|
||||
|
||||
if (operation.Inst != Instruction.ShiftLeft || !IsConst(operation.GetSource(1), 16))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
Operand mulResult = operation.GetSource(0);
|
||||
|
||||
if (!(mulResult.AsgOp is Operation mulOp) || mulOp.Inst != Instruction.Multiply)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if (GetShifted16Source(mulOp.GetSource(0), Instruction.ShiftRightU32) != x)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
Operand src2 = GetMasked16Source(mulOp.GetSource(1));
|
||||
|
||||
if (src2.Type != y.Type || src2.Value != y.Value)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private static bool TryMatchHighTimesHigh(Operation operation, Operand x, Operand lowTimesHighResult)
|
||||
{
|
||||
// xHigh = x >> 16
|
||||
// lowTimesHighResultHigh = lowTimesHighResult >> 16
|
||||
// highTimesHigh = xHigh * lowTimesHighResultHigh
|
||||
// result = highTimesHigh << 16
|
||||
|
||||
if (operation.Inst != Instruction.ShiftLeft || !IsConst(operation.GetSource(1), 16))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
Operand mulResult = operation.GetSource(0);
|
||||
|
||||
if (!(mulResult.AsgOp is Operation mulOp) || mulOp.Inst != Instruction.Multiply)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if (GetShifted16Source(mulOp.GetSource(0), Instruction.ShiftRightU32) != x)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if (GetCopySource(GetShifted16Source(mulOp.GetSource(1), Instruction.ShiftRightU32)) != lowTimesHighResult)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private static Operand GetMasked16Source(Operand value)
|
||||
{
|
||||
if (!(value.AsgOp is Operation maskOp))
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
if (maskOp.Inst != Instruction.BitwiseAnd || !IsConst(maskOp.GetSource(1), ushort.MaxValue))
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
return maskOp.GetSource(0);
|
||||
}
|
||||
|
||||
private static Operand GetShifted16Source(Operand value, Instruction shiftInst)
|
||||
{
|
||||
if (!(value.AsgOp is Operation shiftOp))
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
if (shiftOp.Inst != shiftInst || !IsConst(shiftOp.GetSource(1), 16))
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
return shiftOp.GetSource(0);
|
||||
}
|
||||
|
||||
private static Operand GetCopySource(Operand value)
|
||||
{
|
||||
while (value.AsgOp is Operation operation && IsCopy(operation))
|
||||
{
|
||||
value = operation.GetSource(0);
|
||||
}
|
||||
|
||||
return value;
|
||||
}
|
||||
|
||||
private static bool IsCopy(Operation operation)
|
||||
{
|
||||
return operation.Inst == Instruction.Copy || (operation.Inst == Instruction.Add && IsConst(operation.GetSource(1), 0));
|
||||
}
|
||||
|
||||
private static bool IsConst(Operand operand, int value)
|
||||
{
|
||||
return operand.Type == OperandType.Constant && operand.Value == value;
|
||||
}
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user