summaryrefslogtreecommitdiff
path: root/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp')
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp432
1 files changed, 400 insertions, 32 deletions
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index a95f393b7560..61a0bbef9089 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -248,6 +248,22 @@ static InstrSignature instrToSignature(const MachineInstr &MI,
Register DefReg;
InstrSignature Signature{MI.getOpcode()};
for (unsigned i = 0; i < MI.getNumOperands(); ++i) {
+ // The only decorations that can be applied more than once to a given <id>
+ // or structure member are UserSemantic(5635), CacheControlLoadINTEL (6442),
+ // and CacheControlStoreINTEL (6443). For all the rest of decorations, we
+ // will only add to the signature the Opcode, the id to which it applies,
+ // and the decoration id, disregarding any decoration flags. This will
+ // ensure that any subsequent decoration with the same id will be deemed as
+ // a duplicate. Then, at the call site, we will be able to handle duplicates
+ // in the best way.
+ unsigned Opcode = MI.getOpcode();
+ if ((Opcode == SPIRV::OpDecorate) && i >= 2) {
+ unsigned DecorationID = MI.getOperand(1).getImm();
+ if (DecorationID != SPIRV::Decoration::UserSemantic &&
+ DecorationID != SPIRV::Decoration::CacheControlLoadINTEL &&
+ DecorationID != SPIRV::Decoration::CacheControlStoreINTEL)
+ continue;
+ }
const MachineOperand &MO = MI.getOperand(i);
size_t h;
if (MO.isReg()) {
@@ -559,8 +575,54 @@ static void collectOtherInstr(MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI,
MAI.setSkipEmission(&MI);
InstrSignature MISign = instrToSignature(MI, MAI, true);
auto FoundMI = IS.insert(std::move(MISign));
- if (!FoundMI.second)
+ if (!FoundMI.second) {
+ if (MI.getOpcode() == SPIRV::OpDecorate) {
+ assert(MI.getNumOperands() >= 2 &&
+ "Decoration instructions must have at least 2 operands");
+ assert(MSType == SPIRV::MB_Annotations &&
+ "Only OpDecorate instructions can be duplicates");
+ // For FPFastMathMode decoration, we need to merge the flags of the
+ // duplicate decoration with the original one, so we need to find the
+ // original instruction that has the same signature. For the rest of
+ // instructions, we will simply skip the duplicate.
+ if (MI.getOperand(1).getImm() != SPIRV::Decoration::FPFastMathMode)
+ return; // Skip duplicates of other decorations.
+
+ const SPIRV::InstrList &Decorations = MAI.MS[MSType];
+ for (const MachineInstr *OrigMI : Decorations) {
+ if (instrToSignature(*OrigMI, MAI, true) == MISign) {
+ assert(OrigMI->getNumOperands() == MI.getNumOperands() &&
+ "Original instruction must have the same number of operands");
+ assert(
+ OrigMI->getNumOperands() == 3 &&
+ "FPFastMathMode decoration must have 3 operands for OpDecorate");
+ unsigned OrigFlags = OrigMI->getOperand(2).getImm();
+ unsigned NewFlags = MI.getOperand(2).getImm();
+ if (OrigFlags == NewFlags)
+ return; // No need to merge, the flags are the same.
+
+ // Emit warning about possible conflict between flags.
+ unsigned FinalFlags = OrigFlags | NewFlags;
+ llvm::errs()
+ << "Warning: Conflicting FPFastMathMode decoration flags "
+ "in instruction: "
+ << *OrigMI << "Original flags: " << OrigFlags
+ << ", new flags: " << NewFlags
+ << ". They will be merged on a best effort basis, but not "
+ "validated. Final flags: "
+ << FinalFlags << "\n";
+ MachineInstr *OrigMINonConst = const_cast<MachineInstr *>(OrigMI);
+ MachineOperand &OrigFlagsOp = OrigMINonConst->getOperand(2);
+ OrigFlagsOp =
+ MachineOperand::CreateImm(static_cast<unsigned>(FinalFlags));
+ return; // Merge done, so we found a duplicate; don't add it to MAI.MS
+ }
+ }
+ assert(false && "No original instruction found for the duplicate "
+ "OpDecorate, but we found one in IS.");
+ }
return; // insert failed, so we found a duplicate; don't add it to MAI.MS
+ }
// No duplicates, so add it.
if (Append)
MAI.MS[MSType].push_back(&MI);
@@ -934,6 +996,11 @@ static void addOpDecorateReqs(const MachineInstr &MI, unsigned DecIndex,
} else if (Dec == SPIRV::Decoration::FPMaxErrorDecorationINTEL) {
Reqs.addRequirements(SPIRV::Capability::FPMaxErrorINTEL);
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_fp_max_error);
+ } else if (Dec == SPIRV::Decoration::FPFastMathMode) {
+ if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) {
+ Reqs.addRequirements(SPIRV::Capability::FloatControls2);
+ Reqs.addExtension(SPIRV::Extension::SPV_KHR_float_controls2);
+ }
}
}
@@ -1133,6 +1200,23 @@ void addOpAccessChainReqs(const MachineInstr &Instr,
return;
}
+ bool IsNonUniform =
+ hasNonUniformDecoration(Instr.getOperand(0).getReg(), MRI);
+
+ auto FirstIndexReg = Instr.getOperand(3).getReg();
+ bool FirstIndexIsConstant =
+ Subtarget.getInstrInfo()->isConstantInstr(*MRI.getVRegDef(FirstIndexReg));
+
+ if (StorageClass == SPIRV::StorageClass::StorageClass::StorageBuffer) {
+ if (IsNonUniform)
+ Handler.addRequirements(
+ SPIRV::Capability::StorageBufferArrayNonUniformIndexingEXT);
+ else if (!FirstIndexIsConstant)
+ Handler.addRequirements(
+ SPIRV::Capability::StorageBufferArrayDynamicIndexing);
+ return;
+ }
+
Register PointeeTypeReg = ResTypeInst->getOperand(2).getReg();
MachineInstr *PointeeType = MRI.getUniqueVRegDef(PointeeTypeReg);
if (PointeeType->getOpcode() != SPIRV::OpTypeImage &&
@@ -1141,27 +1225,25 @@ void addOpAccessChainReqs(const MachineInstr &Instr,
return;
}
- bool IsNonUniform =
- hasNonUniformDecoration(Instr.getOperand(0).getReg(), MRI);
if (isUniformTexelBuffer(PointeeType)) {
if (IsNonUniform)
Handler.addRequirements(
SPIRV::Capability::UniformTexelBufferArrayNonUniformIndexingEXT);
- else
+ else if (!FirstIndexIsConstant)
Handler.addRequirements(
SPIRV::Capability::UniformTexelBufferArrayDynamicIndexingEXT);
} else if (isInputAttachment(PointeeType)) {
if (IsNonUniform)
Handler.addRequirements(
SPIRV::Capability::InputAttachmentArrayNonUniformIndexingEXT);
- else
+ else if (!FirstIndexIsConstant)
Handler.addRequirements(
SPIRV::Capability::InputAttachmentArrayDynamicIndexingEXT);
} else if (isStorageTexelBuffer(PointeeType)) {
if (IsNonUniform)
Handler.addRequirements(
SPIRV::Capability::StorageTexelBufferArrayNonUniformIndexingEXT);
- else
+ else if (!FirstIndexIsConstant)
Handler.addRequirements(
SPIRV::Capability::StorageTexelBufferArrayDynamicIndexingEXT);
} else if (isSampledImage(PointeeType) ||
@@ -1170,14 +1252,14 @@ void addOpAccessChainReqs(const MachineInstr &Instr,
if (IsNonUniform)
Handler.addRequirements(
SPIRV::Capability::SampledImageArrayNonUniformIndexingEXT);
- else
+ else if (!FirstIndexIsConstant)
Handler.addRequirements(
SPIRV::Capability::SampledImageArrayDynamicIndexing);
} else if (isStorageImage(PointeeType)) {
if (IsNonUniform)
Handler.addRequirements(
SPIRV::Capability::StorageImageArrayNonUniformIndexingEXT);
- else
+ else if (!FirstIndexIsConstant)
Handler.addRequirements(
SPIRV::Capability::StorageImageArrayDynamicIndexing);
}
@@ -1222,6 +1304,31 @@ static void AddDotProductRequirements(const MachineInstr &MI,
}
}
+void addPrintfRequirements(const MachineInstr &MI,
+ SPIRV::RequirementHandler &Reqs,
+ const SPIRVSubtarget &ST) {
+ SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
+ const SPIRVType *PtrType = GR->getSPIRVTypeForVReg(MI.getOperand(4).getReg());
+ if (PtrType) {
+ MachineOperand ASOp = PtrType->getOperand(1);
+ if (ASOp.isImm()) {
+ unsigned AddrSpace = ASOp.getImm();
+ if (AddrSpace != SPIRV::StorageClass::UniformConstant) {
+ if (!ST.canUseExtension(
+ SPIRV::Extension::
+ SPV_EXT_relaxed_printf_string_address_space)) {
+ report_fatal_error("SPV_EXT_relaxed_printf_string_address_space is "
+ "required because printf uses a format string not "
+ "in constant address space.",
+ false);
+ }
+ Reqs.addExtension(
+ SPIRV::Extension::SPV_EXT_relaxed_printf_string_address_space);
+ }
+ }
+ }
+}
+
static bool isBFloat16Type(const SPIRVType *TypeDef) {
return TypeDef && TypeDef->getNumOperands() == 3 &&
TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
@@ -1230,8 +1337,9 @@ static bool isBFloat16Type(const SPIRVType *TypeDef) {
}
void addInstrRequirements(const MachineInstr &MI,
- SPIRV::RequirementHandler &Reqs,
+ SPIRV::ModuleAnalysisInfo &MAI,
const SPIRVSubtarget &ST) {
+ SPIRV::RequirementHandler &Reqs = MAI.Reqs;
switch (MI.getOpcode()) {
case SPIRV::OpMemoryModel: {
int64_t Addr = MI.getOperand(0).getImm();
@@ -1321,6 +1429,12 @@ void addInstrRequirements(const MachineInstr &MI,
static_cast<int64_t>(
SPIRV::InstructionSet::NonSemantic_Shader_DebugInfo_100)) {
Reqs.addExtension(SPIRV::Extension::SPV_KHR_non_semantic_info);
+ break;
+ }
+ if (MI.getOperand(3).getImm() ==
+ static_cast<int64_t>(SPIRV::OpenCLExtInst::printf)) {
+ addPrintfRequirements(MI, Reqs, ST);
+ break;
}
break;
}
@@ -1781,15 +1895,45 @@ void addInstrRequirements(const MachineInstr &MI,
break;
case SPIRV::OpConvertHandleToImageINTEL:
case SPIRV::OpConvertHandleToSamplerINTEL:
- case SPIRV::OpConvertHandleToSampledImageINTEL:
+ case SPIRV::OpConvertHandleToSampledImageINTEL: {
if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bindless_images))
report_fatal_error("OpConvertHandleTo[Image/Sampler/SampledImage]INTEL "
"instructions require the following SPIR-V extension: "
"SPV_INTEL_bindless_images",
false);
+ SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
+ SPIRV::AddressingModel::AddressingModel AddrModel = MAI.Addr;
+ SPIRVType *TyDef = GR->getSPIRVTypeForVReg(MI.getOperand(1).getReg());
+ if (MI.getOpcode() == SPIRV::OpConvertHandleToImageINTEL &&
+ TyDef->getOpcode() != SPIRV::OpTypeImage) {
+ report_fatal_error("Incorrect return type for the instruction "
+ "OpConvertHandleToImageINTEL",
+ false);
+ } else if (MI.getOpcode() == SPIRV::OpConvertHandleToSamplerINTEL &&
+ TyDef->getOpcode() != SPIRV::OpTypeSampler) {
+ report_fatal_error("Incorrect return type for the instruction "
+ "OpConvertHandleToSamplerINTEL",
+ false);
+ } else if (MI.getOpcode() == SPIRV::OpConvertHandleToSampledImageINTEL &&
+ TyDef->getOpcode() != SPIRV::OpTypeSampledImage) {
+ report_fatal_error("Incorrect return type for the instruction "
+ "OpConvertHandleToSampledImageINTEL",
+ false);
+ }
+ SPIRVType *SpvTy = GR->getSPIRVTypeForVReg(MI.getOperand(2).getReg());
+ unsigned Bitwidth = GR->getScalarOrVectorBitWidth(SpvTy);
+ if (!(Bitwidth == 32 && AddrModel == SPIRV::AddressingModel::Physical32) &&
+ !(Bitwidth == 64 && AddrModel == SPIRV::AddressingModel::Physical64)) {
+ report_fatal_error(
+ "Parameter value must be a 32-bit scalar in case of "
+ "Physical32 addressing model or a 64-bit scalar in case of "
+ "Physical64 addressing model",
+ false);
+ }
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bindless_images);
Reqs.addCapability(SPIRV::Capability::BindlessImagesINTEL);
break;
+ }
case SPIRV::OpSubgroup2DBlockLoadINTEL:
case SPIRV::OpSubgroup2DBlockLoadTransposeINTEL:
case SPIRV::OpSubgroup2DBlockLoadTransformINTEL:
@@ -1906,6 +2050,17 @@ void addInstrRequirements(const MachineInstr &MI,
// TODO: Add UntypedPointersKHR when implemented.
break;
}
+ case SPIRV::OpPredicatedLoadINTEL:
+ case SPIRV::OpPredicatedStoreINTEL: {
+ if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_predicated_io))
+ report_fatal_error(
+ "OpPredicated[Load/Store]INTEL instructions require "
+ "the following SPIR-V extension: SPV_INTEL_predicated_io",
+ false);
+ Reqs.addExtension(SPIRV::Extension::SPV_INTEL_predicated_io);
+ Reqs.addCapability(SPIRV::Capability::PredicatedIOINTEL);
+ break;
+ }
default:
break;
@@ -1927,15 +2082,18 @@ static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI,
continue;
for (const MachineBasicBlock &MBB : *MF)
for (const MachineInstr &MI : MBB)
- addInstrRequirements(MI, MAI.Reqs, ST);
+ addInstrRequirements(MI, MAI, ST);
}
// Collect requirements for OpExecutionMode instructions.
auto Node = M.getNamedMetadata("spirv.ExecutionMode");
if (Node) {
- bool RequireFloatControls = false, RequireFloatControls2 = false,
+ bool RequireFloatControls = false, RequireIntelFloatControls2 = false,
+ RequireKHRFloatControls2 = false,
VerLower14 = !ST.isAtLeastSPIRVVer(VersionTuple(1, 4));
- bool HasFloatControls2 =
+ bool HasIntelFloatControls2 =
ST.canUseExtension(SPIRV::Extension::SPV_INTEL_float_controls2);
+ bool HasKHRFloatControls2 =
+ ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2);
for (unsigned i = 0; i < Node->getNumOperands(); i++) {
MDNode *MDN = cast<MDNode>(Node->getOperand(i));
const MDOperand &MDOp = MDN->getOperand(1);
@@ -1948,7 +2106,6 @@ static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI,
switch (EM) {
case SPIRV::ExecutionMode::DenormPreserve:
case SPIRV::ExecutionMode::DenormFlushToZero:
- case SPIRV::ExecutionMode::SignedZeroInfNanPreserve:
case SPIRV::ExecutionMode::RoundingModeRTE:
case SPIRV::ExecutionMode::RoundingModeRTZ:
RequireFloatControls = VerLower14;
@@ -1959,8 +2116,28 @@ static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI,
case SPIRV::ExecutionMode::RoundingModeRTNINTEL:
case SPIRV::ExecutionMode::FloatingPointModeALTINTEL:
case SPIRV::ExecutionMode::FloatingPointModeIEEEINTEL:
- if (HasFloatControls2) {
- RequireFloatControls2 = true;
+ if (HasIntelFloatControls2) {
+ RequireIntelFloatControls2 = true;
+ MAI.Reqs.getAndAddRequirements(
+ SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);
+ }
+ break;
+ case SPIRV::ExecutionMode::FPFastMathDefault: {
+ if (HasKHRFloatControls2) {
+ RequireKHRFloatControls2 = true;
+ MAI.Reqs.getAndAddRequirements(
+ SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);
+ }
+ break;
+ }
+ case SPIRV::ExecutionMode::ContractionOff:
+ case SPIRV::ExecutionMode::SignedZeroInfNanPreserve:
+ if (HasKHRFloatControls2) {
+ RequireKHRFloatControls2 = true;
+ MAI.Reqs.getAndAddRequirements(
+ SPIRV::OperandCategory::ExecutionModeOperand,
+ SPIRV::ExecutionMode::FPFastMathDefault, ST);
+ } else {
MAI.Reqs.getAndAddRequirements(
SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);
}
@@ -1975,8 +2152,10 @@ static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI,
if (RequireFloatControls &&
ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls))
MAI.Reqs.addExtension(SPIRV::Extension::SPV_KHR_float_controls);
- if (RequireFloatControls2)
+ if (RequireIntelFloatControls2)
MAI.Reqs.addExtension(SPIRV::Extension::SPV_INTEL_float_controls2);
+ if (RequireKHRFloatControls2)
+ MAI.Reqs.addExtension(SPIRV::Extension::SPV_KHR_float_controls2);
}
for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) {
const Function &F = *FI;
@@ -1991,6 +2170,9 @@ static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI,
SPIRV::OperandCategory::ExecutionModeOperand,
SPIRV::ExecutionMode::LocalSize, ST);
}
+ if (F.getFnAttribute("enable-maximal-reconvergence").getValueAsBool()) {
+ MAI.Reqs.addExtension(SPIRV::Extension::SPV_KHR_maximal_reconvergence);
+ }
if (F.getMetadata("work_group_size_hint"))
MAI.Reqs.getAndAddRequirements(
SPIRV::OperandCategory::ExecutionModeOperand,
@@ -2016,8 +2198,11 @@ static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI,
}
}
-static unsigned getFastMathFlags(const MachineInstr &I) {
+static unsigned getFastMathFlags(const MachineInstr &I,
+ const SPIRVSubtarget &ST) {
unsigned Flags = SPIRV::FPFastMathMode::None;
+ bool CanUseKHRFloatControls2 =
+ ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2);
if (I.getFlag(MachineInstr::MIFlag::FmNoNans))
Flags |= SPIRV::FPFastMathMode::NotNaN;
if (I.getFlag(MachineInstr::MIFlag::FmNoInfs))
@@ -2026,12 +2211,45 @@ static unsigned getFastMathFlags(const MachineInstr &I) {
Flags |= SPIRV::FPFastMathMode::NSZ;
if (I.getFlag(MachineInstr::MIFlag::FmArcp))
Flags |= SPIRV::FPFastMathMode::AllowRecip;
- if (I.getFlag(MachineInstr::MIFlag::FmReassoc))
- Flags |= SPIRV::FPFastMathMode::Fast;
+ if (I.getFlag(MachineInstr::MIFlag::FmContract) && CanUseKHRFloatControls2)
+ Flags |= SPIRV::FPFastMathMode::AllowContract;
+ if (I.getFlag(MachineInstr::MIFlag::FmReassoc)) {
+ if (CanUseKHRFloatControls2)
+ // LLVM reassoc maps to SPIRV transform, see
+ // https://github.com/KhronosGroup/SPIRV-Registry/issues/326 for details.
+ // Because we are enabling AllowTransform, we must enable AllowReassoc and
+ // AllowContract too, as required by SPIRV spec. Also, we used to map
+ // MIFlag::FmReassoc to FPFastMathMode::Fast, which now should instead by
+ // replaced by turning all the other bits instead. Therefore, we're
+ // enabling every bit here except None and Fast.
+ Flags |= SPIRV::FPFastMathMode::NotNaN | SPIRV::FPFastMathMode::NotInf |
+ SPIRV::FPFastMathMode::NSZ | SPIRV::FPFastMathMode::AllowRecip |
+ SPIRV::FPFastMathMode::AllowTransform |
+ SPIRV::FPFastMathMode::AllowReassoc |
+ SPIRV::FPFastMathMode::AllowContract;
+ else
+ Flags |= SPIRV::FPFastMathMode::Fast;
+ }
+
+ if (CanUseKHRFloatControls2) {
+ // Error out if SPIRV::FPFastMathMode::Fast is enabled.
+ assert(!(Flags & SPIRV::FPFastMathMode::Fast) &&
+ "SPIRV::FPFastMathMode::Fast is deprecated and should not be used "
+ "anymore.");
+
+ // Error out if AllowTransform is enabled without AllowReassoc and
+ // AllowContract.
+ assert((!(Flags & SPIRV::FPFastMathMode::AllowTransform) ||
+ ((Flags & SPIRV::FPFastMathMode::AllowReassoc &&
+ Flags & SPIRV::FPFastMathMode::AllowContract))) &&
+ "SPIRV::FPFastMathMode::AllowTransform requires AllowReassoc and "
+ "AllowContract flags to be enabled as well.");
+ }
+
return Flags;
}
-static bool isFastMathMathModeAvailable(const SPIRVSubtarget &ST) {
+static bool isFastMathModeAvailable(const SPIRVSubtarget &ST) {
if (ST.isKernel())
return true;
if (ST.getSPIRVVersion() < VersionTuple(1, 2))
@@ -2039,9 +2257,10 @@ static bool isFastMathMathModeAvailable(const SPIRVSubtarget &ST) {
return ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2);
}
-static void handleMIFlagDecoration(MachineInstr &I, const SPIRVSubtarget &ST,
- const SPIRVInstrInfo &TII,
- SPIRV::RequirementHandler &Reqs) {
+static void handleMIFlagDecoration(
+ MachineInstr &I, const SPIRVSubtarget &ST, const SPIRVInstrInfo &TII,
+ SPIRV::RequirementHandler &Reqs, const SPIRVGlobalRegistry *GR,
+ SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec) {
if (I.getFlag(MachineInstr::MIFlag::NoSWrap) && TII.canUseNSW(I) &&
getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand,
SPIRV::Decoration::NoSignedWrap, ST, Reqs)
@@ -2057,13 +2276,53 @@ static void handleMIFlagDecoration(MachineInstr &I, const SPIRVSubtarget &ST,
buildOpDecorate(I.getOperand(0).getReg(), I, TII,
SPIRV::Decoration::NoUnsignedWrap, {});
}
- if (!TII.canUseFastMathFlags(I))
- return;
- unsigned FMFlags = getFastMathFlags(I);
- if (FMFlags == SPIRV::FPFastMathMode::None)
+ if (!TII.canUseFastMathFlags(
+ I, ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)))
return;
- if (isFastMathMathModeAvailable(ST)) {
+ unsigned FMFlags = getFastMathFlags(I, ST);
+ if (FMFlags == SPIRV::FPFastMathMode::None) {
+ // We also need to check if any FPFastMathDefault info was set for the
+ // types used in this instruction.
+ if (FPFastMathDefaultInfoVec.empty())
+ return;
+
+ // There are three types of instructions that can use fast math flags:
+ // 1. Arithmetic instructions (FAdd, FMul, FSub, FDiv, FRem, etc.)
+ // 2. Relational instructions (FCmp, FOrd, FUnord, etc.)
+ // 3. Extended instructions (ExtInst)
+ // For arithmetic instructions, the floating point type can be in the
+ // result type or in the operands, but they all must be the same.
+ // For the relational and logical instructions, the floating point type
+ // can only be in the operands 1 and 2, not the result type. Also, the
+ // operands must have the same type. For the extended instructions, the
+ // floating point type can be in the result type or in the operands. It's
+ // unclear if the operands and the result type must be the same. Let's
+ // assume they must be. Therefore, for 1. and 2., we can check the first
+ // operand type, and for 3. we can check the result type.
+ assert(I.getNumOperands() >= 3 && "Expected at least 3 operands");
+ Register ResReg = I.getOpcode() == SPIRV::OpExtInst
+ ? I.getOperand(1).getReg()
+ : I.getOperand(2).getReg();
+ SPIRVType *ResType = GR->getSPIRVTypeForVReg(ResReg, I.getMF());
+ const Type *Ty = GR->getTypeForSPIRVType(ResType);
+ Ty = Ty->isVectorTy() ? cast<VectorType>(Ty)->getElementType() : Ty;
+
+ // Match instruction type with the FPFastMathDefaultInfoVec.
+ bool Emit = false;
+ for (SPIRV::FPFastMathDefaultInfo &Elem : FPFastMathDefaultInfoVec) {
+ if (Ty == Elem.Ty) {
+ FMFlags = Elem.FastMathFlags;
+ Emit = Elem.ContractionOff || Elem.SignedZeroInfNanPreserve ||
+ Elem.FPFastMathDefault;
+ break;
+ }
+ }
+
+ if (FMFlags == SPIRV::FPFastMathMode::None && !Emit)
+ return;
+ }
+ if (isFastMathModeAvailable(ST)) {
Register DstReg = I.getOperand(0).getReg();
buildOpDecorate(DstReg, I, TII, SPIRV::Decoration::FPFastMathMode,
{FMFlags});
@@ -2073,14 +2332,17 @@ static void handleMIFlagDecoration(MachineInstr &I, const SPIRVSubtarget &ST,
// Walk all functions and add decorations related to MI flags.
static void addDecorations(const Module &M, const SPIRVInstrInfo &TII,
MachineModuleInfo *MMI, const SPIRVSubtarget &ST,
- SPIRV::ModuleAnalysisInfo &MAI) {
+ SPIRV::ModuleAnalysisInfo &MAI,
+ const SPIRVGlobalRegistry *GR) {
for (auto F = M.begin(), E = M.end(); F != E; ++F) {
MachineFunction *MF = MMI->getMachineFunction(*F);
if (!MF)
continue;
+
for (auto &MBB : *MF)
for (auto &MI : MBB)
- handleMIFlagDecoration(MI, ST, TII, MAI.Reqs);
+ handleMIFlagDecoration(MI, ST, TII, MAI.Reqs, GR,
+ MAI.FPFastMathDefaultInfoMap[&(*F)]);
}
}
@@ -2126,6 +2388,111 @@ static void patchPhis(const Module &M, SPIRVGlobalRegistry *GR,
}
}
+static SPIRV::FPFastMathDefaultInfoVector &getOrCreateFPFastMathDefaultInfoVec(
+ const Module &M, SPIRV::ModuleAnalysisInfo &MAI, const Function *F) {
+ auto it = MAI.FPFastMathDefaultInfoMap.find(F);
+ if (it != MAI.FPFastMathDefaultInfoMap.end())
+ return it->second;
+
+ // If the map does not contain the entry, create a new one. Initialize it to
+ // contain all 3 elements sorted by bit width of target type: {half, float,
+ // double}.
+ SPIRV::FPFastMathDefaultInfoVector FPFastMathDefaultInfoVec;
+ FPFastMathDefaultInfoVec.emplace_back(Type::getHalfTy(M.getContext()),
+ SPIRV::FPFastMathMode::None);
+ FPFastMathDefaultInfoVec.emplace_back(Type::getFloatTy(M.getContext()),
+ SPIRV::FPFastMathMode::None);
+ FPFastMathDefaultInfoVec.emplace_back(Type::getDoubleTy(M.getContext()),
+ SPIRV::FPFastMathMode::None);
+ return MAI.FPFastMathDefaultInfoMap[F] = std::move(FPFastMathDefaultInfoVec);
+}
+
+static SPIRV::FPFastMathDefaultInfo &getFPFastMathDefaultInfo(
+ SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec,
+ const Type *Ty) {
+ size_t BitWidth = Ty->getScalarSizeInBits();
+ int Index =
+ SPIRV::FPFastMathDefaultInfoVector::computeFPFastMathDefaultInfoVecIndex(
+ BitWidth);
+ assert(Index >= 0 && Index < 3 &&
+ "Expected FPFastMathDefaultInfo for half, float, or double");
+ assert(FPFastMathDefaultInfoVec.size() == 3 &&
+ "Expected FPFastMathDefaultInfoVec to have exactly 3 elements");
+ return FPFastMathDefaultInfoVec[Index];
+}
+
+static void collectFPFastMathDefaults(const Module &M,
+ SPIRV::ModuleAnalysisInfo &MAI,
+ const SPIRVSubtarget &ST) {
+ if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2))
+ return;
+
+ // Store the FPFastMathDefaultInfo in the FPFastMathDefaultInfoMap.
+ // We need the entry point (function) as the key, and the target
+ // type and flags as the value.
+ // We also need to check ContractionOff and SignedZeroInfNanPreserve
+ // execution modes, as they are now deprecated and must be replaced
+ // with FPFastMathDefaultInfo.
+ auto Node = M.getNamedMetadata("spirv.ExecutionMode");
+ if (!Node)
+ return;
+
+ for (unsigned i = 0; i < Node->getNumOperands(); i++) {
+ MDNode *MDN = cast<MDNode>(Node->getOperand(i));
+ assert(MDN->getNumOperands() >= 2 && "Expected at least 2 operands");
+ const Function *F = cast<Function>(
+ cast<ConstantAsMetadata>(MDN->getOperand(0))->getValue());
+ const auto EM =
+ cast<ConstantInt>(
+ cast<ConstantAsMetadata>(MDN->getOperand(1))->getValue())
+ ->getZExtValue();
+ if (EM == SPIRV::ExecutionMode::FPFastMathDefault) {
+ assert(MDN->getNumOperands() == 4 &&
+ "Expected 4 operands for FPFastMathDefault");
+
+ const Type *T = cast<ValueAsMetadata>(MDN->getOperand(2))->getType();
+ unsigned Flags =
+ cast<ConstantInt>(
+ cast<ConstantAsMetadata>(MDN->getOperand(3))->getValue())
+ ->getZExtValue();
+ SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec =
+ getOrCreateFPFastMathDefaultInfoVec(M, MAI, F);
+ SPIRV::FPFastMathDefaultInfo &Info =
+ getFPFastMathDefaultInfo(FPFastMathDefaultInfoVec, T);
+ Info.FastMathFlags = Flags;
+ Info.FPFastMathDefault = true;
+ } else if (EM == SPIRV::ExecutionMode::ContractionOff) {
+ assert(MDN->getNumOperands() == 2 &&
+ "Expected no operands for ContractionOff");
+
+ // We need to save this info for every possible FP type, i.e. {half,
+ // float, double, fp128}.
+ SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec =
+ getOrCreateFPFastMathDefaultInfoVec(M, MAI, F);
+ for (SPIRV::FPFastMathDefaultInfo &Info : FPFastMathDefaultInfoVec) {
+ Info.ContractionOff = true;
+ }
+ } else if (EM == SPIRV::ExecutionMode::SignedZeroInfNanPreserve) {
+ assert(MDN->getNumOperands() == 3 &&
+ "Expected 1 operand for SignedZeroInfNanPreserve");
+ unsigned TargetWidth =
+ cast<ConstantInt>(
+ cast<ConstantAsMetadata>(MDN->getOperand(2))->getValue())
+ ->getZExtValue();
+ // We need to save this info only for the FP type with TargetWidth.
+ SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec =
+ getOrCreateFPFastMathDefaultInfoVec(M, MAI, F);
+ int Index = SPIRV::FPFastMathDefaultInfoVector::
+ computeFPFastMathDefaultInfoVecIndex(TargetWidth);
+ assert(Index >= 0 && Index < 3 &&
+ "Expected FPFastMathDefaultInfo for half, float, or double");
+ assert(FPFastMathDefaultInfoVec.size() == 3 &&
+ "Expected FPFastMathDefaultInfoVec to have exactly 3 elements");
+ FPFastMathDefaultInfoVec[Index].SignedZeroInfNanPreserve = true;
+ }
+ }
+}
+
struct SPIRV::ModuleAnalysisInfo SPIRVModuleAnalysis::MAI;
void SPIRVModuleAnalysis::getAnalysisUsage(AnalysisUsage &AU) const {
@@ -2147,7 +2514,8 @@ bool SPIRVModuleAnalysis::runOnModule(Module &M) {
patchPhis(M, GR, *TII, MMI);
addMBBNames(M, *TII, MMI, *ST, MAI);
- addDecorations(M, *TII, MMI, *ST, MAI);
+ collectFPFastMathDefaults(M, MAI, *ST);
+ addDecorations(M, *TII, MMI, *ST, MAI, GR);
collectReqs(M, MAI, MMI, *ST);