//===- SPIRVLegalizeImplicitBinding.cpp - Legalize implicit bindings ----*- C++ //-*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This pass legalizes the @llvm.spv.resource.handlefromimplicitbinding // intrinsic by replacing it with a call to // @llvm.spv.resource.handlefrombinding. // //===----------------------------------------------------------------------===// #include "SPIRV.h" #include "llvm/ADT/BitVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstVisitor.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/IntrinsicsSPIRV.h" #include "llvm/IR/Module.h" #include "llvm/Pass.h" #include #include using namespace llvm; namespace { class SPIRVLegalizeImplicitBinding : public ModulePass { public: static char ID; SPIRVLegalizeImplicitBinding() : ModulePass(ID) {} StringRef getPassName() const override { return "SPIRV Legalize Implicit Binding"; } bool runOnModule(Module &M) override; private: void collectBindingInfo(Module &M); uint32_t getAndReserveFirstUnusedBinding(uint32_t DescSet); void replaceImplicitBindingCalls(Module &M); void replaceResourceHandleCall(Module &M, CallInst *OldCI, uint32_t NewBinding); void replaceCounterHandleCall(Module &M, CallInst *OldCI, uint32_t NewBinding); void verifyUniqueOrderIdPerResource(SmallVectorImpl &Calls); // A map from descriptor set to a bit vector of used binding numbers. std::vector UsedBindings; // A list of all implicit binding calls, to be sorted by order ID. SmallVector ImplicitBindingCalls; }; struct BindingInfoCollector : public InstVisitor { std::vector &UsedBindings; SmallVector &ImplicitBindingCalls; BindingInfoCollector(std::vector &UsedBindings, SmallVector &ImplicitBindingCalls) : UsedBindings(UsedBindings), ImplicitBindingCalls(ImplicitBindingCalls) { } void addBinding(uint32_t DescSet, uint32_t Binding) { if (UsedBindings.size() <= DescSet) { UsedBindings.resize(DescSet + 1); UsedBindings[DescSet].resize(64); } if (UsedBindings[DescSet].size() <= Binding) { UsedBindings[DescSet].resize(2 * Binding + 1); } UsedBindings[DescSet].set(Binding); } void visitCallInst(CallInst &CI) { if (CI.getIntrinsicID() == Intrinsic::spv_resource_handlefrombinding) { const uint32_t DescSet = cast(CI.getArgOperand(0))->getZExtValue(); const uint32_t Binding = cast(CI.getArgOperand(1))->getZExtValue(); addBinding(DescSet, Binding); } else if (CI.getIntrinsicID() == Intrinsic::spv_resource_handlefromimplicitbinding) { ImplicitBindingCalls.push_back(&CI); } else if (CI.getIntrinsicID() == Intrinsic::spv_resource_counterhandlefrombinding) { const uint32_t DescSet = cast(CI.getArgOperand(2))->getZExtValue(); const uint32_t Binding = cast(CI.getArgOperand(1))->getZExtValue(); addBinding(DescSet, Binding); } else if (CI.getIntrinsicID() == Intrinsic::spv_resource_counterhandlefromimplicitbinding) { ImplicitBindingCalls.push_back(&CI); } } }; static uint32_t getOrderId(const CallInst *CI) { uint32_t OrderIdArgIdx = 0; switch (CI->getIntrinsicID()) { case Intrinsic::spv_resource_handlefromimplicitbinding: OrderIdArgIdx = 0; break; case Intrinsic::spv_resource_counterhandlefromimplicitbinding: OrderIdArgIdx = 1; break; default: llvm_unreachable("CallInst is not an implicit binding intrinsic"); } return cast(CI->getArgOperand(OrderIdArgIdx))->getZExtValue(); } static uint32_t getDescSet(const CallInst *CI) { uint32_t DescSetArgIdx; switch (CI->getIntrinsicID()) { case Intrinsic::spv_resource_handlefromimplicitbinding: case Intrinsic::spv_resource_handlefrombinding: DescSetArgIdx = 1; break; case Intrinsic::spv_resource_counterhandlefromimplicitbinding: case Intrinsic::spv_resource_counterhandlefrombinding: DescSetArgIdx = 2; break; default: llvm_unreachable("CallInst is not an implicit binding intrinsic"); } return cast(CI->getArgOperand(DescSetArgIdx))->getZExtValue(); } void SPIRVLegalizeImplicitBinding::collectBindingInfo(Module &M) { BindingInfoCollector InfoCollector(UsedBindings, ImplicitBindingCalls); InfoCollector.visit(M); // Sort the collected calls by their order ID. std::sort(ImplicitBindingCalls.begin(), ImplicitBindingCalls.end(), [](const CallInst *A, const CallInst *B) { return getOrderId(A) < getOrderId(B); }); } void SPIRVLegalizeImplicitBinding::verifyUniqueOrderIdPerResource( SmallVectorImpl &Calls) { // Check that the order Id is unique per resource. for (uint32_t i = 1; i < Calls.size(); ++i) { const uint32_t OrderA = getOrderId(Calls[i - 1]); const uint32_t OrderB = getOrderId(Calls[i]); if (OrderA == OrderB) { const uint32_t DescSetA = getDescSet(Calls[i - 1]); const uint32_t DescSetB = getDescSet(Calls[i]); if (DescSetA != DescSetB) { report_fatal_error("Implicit binding calls with the same order ID must " "have the same descriptor set"); } } } } uint32_t SPIRVLegalizeImplicitBinding::getAndReserveFirstUnusedBinding( uint32_t DescSet) { if (UsedBindings.size() <= DescSet) { UsedBindings.resize(DescSet + 1); UsedBindings[DescSet].resize(64); } int NewBinding = UsedBindings[DescSet].find_first_unset(); if (NewBinding == -1) { NewBinding = UsedBindings[DescSet].size(); UsedBindings[DescSet].resize(2 * NewBinding + 1); } UsedBindings[DescSet].set(NewBinding); return NewBinding; } void SPIRVLegalizeImplicitBinding::replaceImplicitBindingCalls(Module &M) { uint32_t lastOrderId = -1; uint32_t lastBindingNumber = -1; for (CallInst *OldCI : ImplicitBindingCalls) { const uint32_t OrderId = getOrderId(OldCI); uint32_t BindingNumber; if (OrderId == lastOrderId) { BindingNumber = lastBindingNumber; } else { const uint32_t DescSet = getDescSet(OldCI); BindingNumber = getAndReserveFirstUnusedBinding(DescSet); } if (OldCI->getIntrinsicID() == Intrinsic::spv_resource_handlefromimplicitbinding) { replaceResourceHandleCall(M, OldCI, BindingNumber); } else { assert(OldCI->getIntrinsicID() == Intrinsic::spv_resource_counterhandlefromimplicitbinding && "Unexpected implicit binding intrinsic"); replaceCounterHandleCall(M, OldCI, BindingNumber); } lastOrderId = OrderId; lastBindingNumber = BindingNumber; } } bool SPIRVLegalizeImplicitBinding::runOnModule(Module &M) { collectBindingInfo(M); if (ImplicitBindingCalls.empty()) { return false; } verifyUniqueOrderIdPerResource(ImplicitBindingCalls); replaceImplicitBindingCalls(M); return true; } } // namespace char SPIRVLegalizeImplicitBinding::ID = 0; INITIALIZE_PASS(SPIRVLegalizeImplicitBinding, "legalize-spirv-implicit-binding", "Legalize SPIR-V implicit bindings", false, false) ModulePass *llvm::createSPIRVLegalizeImplicitBindingPass() { return new SPIRVLegalizeImplicitBinding(); } void SPIRVLegalizeImplicitBinding::replaceResourceHandleCall( Module &M, CallInst *OldCI, uint32_t NewBinding) { IRBuilder<> Builder(OldCI); const uint32_t DescSet = cast(OldCI->getArgOperand(1))->getZExtValue(); SmallVector Args; Args.push_back(Builder.getInt32(DescSet)); Args.push_back(Builder.getInt32(NewBinding)); // Copy the remaining arguments from the old call. for (uint32_t i = 2; i < OldCI->arg_size(); ++i) { Args.push_back(OldCI->getArgOperand(i)); } Function *NewFunc = Intrinsic::getOrInsertDeclaration( &M, Intrinsic::spv_resource_handlefrombinding, OldCI->getType()); CallInst *NewCI = Builder.CreateCall(NewFunc, Args); NewCI->setCallingConv(OldCI->getCallingConv()); OldCI->replaceAllUsesWith(NewCI); OldCI->eraseFromParent(); } void SPIRVLegalizeImplicitBinding::replaceCounterHandleCall( Module &M, CallInst *OldCI, uint32_t NewBinding) { IRBuilder<> Builder(OldCI); const uint32_t DescSet = cast(OldCI->getArgOperand(2))->getZExtValue(); SmallVector Args; Args.push_back(OldCI->getArgOperand(0)); Args.push_back(Builder.getInt32(NewBinding)); Args.push_back(Builder.getInt32(DescSet)); Type *Tys[] = {OldCI->getType(), OldCI->getArgOperand(0)->getType()}; Function *NewFunc = Intrinsic::getOrInsertDeclaration( &M, Intrinsic::spv_resource_counterhandlefrombinding, Tys); CallInst *NewCI = Builder.CreateCall(NewFunc, Args); NewCI->setCallingConv(OldCI->getCallingConv()); OldCI->replaceAllUsesWith(NewCI); OldCI->eraseFromParent(); }