diff options
Diffstat (limited to 'llvm/lib/Target/AMDGPU/AMDGPULowerBufferFatPointers.cpp')
| -rw-r--r-- | llvm/lib/Target/AMDGPU/AMDGPULowerBufferFatPointers.cpp | 294 |
1 files changed, 73 insertions, 221 deletions
diff --git a/llvm/lib/Target/AMDGPU/AMDGPULowerBufferFatPointers.cpp b/llvm/lib/Target/AMDGPU/AMDGPULowerBufferFatPointers.cpp index f878bd9465d3..a8f6ad09fe28 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPULowerBufferFatPointers.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPULowerBufferFatPointers.cpp @@ -200,6 +200,7 @@ #include "llvm/ADT/SetOperations.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/Utils/Local.h" #include "llvm/CodeGen/TargetPassConfig.h" #include "llvm/IR/AttributeMask.h" #include "llvm/IR/Constants.h" @@ -214,6 +215,7 @@ #include "llvm/IR/Metadata.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/ReplaceConstant.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Support/AtomicOrdering.h" @@ -578,18 +580,14 @@ bool StoreFatPtrsAsIntsVisitor::visitStoreInst(StoreInst &SI) { /// buffer fat pointer constant. static std::pair<Constant *, Constant *> splitLoweredFatBufferConst(Constant *C) { - if (auto *AZ = dyn_cast<ConstantAggregateZero>(C)) - return std::make_pair(AZ->getStructElement(0), AZ->getStructElement(1)); - if (auto *SC = dyn_cast<ConstantStruct>(C)) - return std::make_pair(SC->getOperand(0), SC->getOperand(1)); - llvm_unreachable("Conversion should've created a {p8, i32} struct"); + assert(isSplitFatPtr(C->getType()) && "Not a split fat buffer pointer"); + return std::make_pair(C->getAggregateElement(0u), C->getAggregateElement(1u)); } namespace { /// Handle the remapping of ptr addrspace(7) constants. class FatPtrConstMaterializer final : public ValueMaterializer { BufferFatPtrToStructTypeMap *TypeMap; - BufferFatPtrToIntTypeMap *IntTypeMap; // An internal mapper that is used to recurse into the arguments of constants. // While the documentation for `ValueMapper` specifies not to use it // recursively, examination of the logic in mapValue() shows that it can @@ -599,16 +597,12 @@ class FatPtrConstMaterializer final : public ValueMaterializer { Constant *materializeBufferFatPtrConst(Constant *C); - const DataLayout &DL; - public: // UnderlyingMap is the value map this materializer will be filling. FatPtrConstMaterializer(BufferFatPtrToStructTypeMap *TypeMap, - ValueToValueMapTy &UnderlyingMap, - BufferFatPtrToIntTypeMap *IntTypeMap, - const DataLayout &DL) - : TypeMap(TypeMap), IntTypeMap(IntTypeMap), - InternalMapper(UnderlyingMap, RF_None, TypeMap, this), DL(DL) {} + ValueToValueMapTy &UnderlyingMap) + : TypeMap(TypeMap), + InternalMapper(UnderlyingMap, RF_None, TypeMap, this) {} virtual ~FatPtrConstMaterializer() = default; Value *materialize(Value *V) override; @@ -631,10 +625,6 @@ Constant *FatPtrConstMaterializer::materializeBufferFatPtrConst(Constant *C) { UndefValue::get(NewTy->getElementType(1))}); } - if (isa<GlobalValue>(C)) - report_fatal_error("Global values containing ptr addrspace(7) (buffer " - "fat pointer) values are not supported"); - if (auto *VC = dyn_cast<ConstantVector>(C)) { if (Constant *S = VC->getSplatValue()) { Constant *NewS = InternalMapper.mapConstant(*S); @@ -660,127 +650,14 @@ Constant *FatPtrConstMaterializer::materializeBufferFatPtrConst(Constant *C) { return ConstantStruct::get(NewTy, {RsrcVec, OffVec}); } - // Constant expressions. This code mirrors how we fix up the equivalent - // instructions later. - auto *CE = dyn_cast<ConstantExpr>(C); - if (!CE) - return nullptr; - if (auto *GEPO = dyn_cast<GEPOperator>(C)) { - Constant *RemappedPtr = - InternalMapper.mapConstant(*cast<Constant>(GEPO->getPointerOperand())); - auto [Rsrc, Off] = splitLoweredFatBufferConst(RemappedPtr); - Type *OffTy = Off->getType(); - bool InBounds = GEPO->isInBounds(); - - MapVector<Value *, APInt> VariableOffs; - APInt NewConstOffVal = APInt::getZero(BufferOffsetWidth); - if (!GEPO->collectOffset(DL, BufferOffsetWidth, VariableOffs, - NewConstOffVal)) - report_fatal_error( - "Scalable vector or unsized struct in fat pointer GEP"); - Constant *OffAccum = nullptr; - // Accumulate offsets together before adding to the base in order to - // preserve as many of the inbounds properties as possible. - for (auto [Arg, Multiple] : VariableOffs) { - Constant *NewArg = InternalMapper.mapConstant(*cast<Constant>(Arg)); - NewArg = ConstantFoldIntegerCast(NewArg, OffTy, /*IsSigned=*/true, DL); - if (!Multiple.isOne()) { - if (Multiple.isPowerOf2()) { - NewArg = ConstantExpr::getShl( - NewArg, - CE->getIntegerValue( - OffTy, APInt(BufferOffsetWidth, Multiple.logBase2())), - /*hasNUW=*/InBounds, /*HasNSW=*/InBounds); - } else { - NewArg = - ConstantExpr::getMul(NewArg, CE->getIntegerValue(OffTy, Multiple), - /*hasNUW=*/InBounds, /*hasNSW=*/InBounds); - } - } - if (OffAccum) { - OffAccum = ConstantExpr::getAdd(OffAccum, NewArg, /*hasNUW=*/InBounds, - /*hasNSW=*/InBounds); - } else { - OffAccum = NewArg; - } - } - Constant *NewConstOff = CE->getIntegerValue(OffTy, NewConstOffVal); - if (OffAccum) - OffAccum = ConstantExpr::getAdd(OffAccum, NewConstOff, - /*hasNUW=*/InBounds, /*hasNSW=*/InBounds); - else - OffAccum = NewConstOff; - bool HasNonNegativeOff = false; - if (auto *CI = dyn_cast<ConstantInt>(OffAccum)) { - HasNonNegativeOff = !CI->isNegative(); - } - Constant *NewOff = ConstantExpr::getAdd( - Off, OffAccum, /*hasNUW=*/InBounds && HasNonNegativeOff, - /*hasNSW=*/false); - return ConstantStruct::get(NewTy, {Rsrc, NewOff}); - } - - if (auto *PI = dyn_cast<PtrToIntOperator>(CE)) { - Constant *Parts = - InternalMapper.mapConstant(*cast<Constant>(PI->getPointerOperand())); - auto [Rsrc, Off] = splitLoweredFatBufferConst(Parts); - // Here, we take advantage of the fact that ptrtoint has a built-in - // zero-extension behavior. - unsigned FatPtrWidth = - DL.getPointerSizeInBits(AMDGPUAS::BUFFER_FAT_POINTER); - Constant *RsrcInt = CE->getPtrToInt(Rsrc, SrcTy); - unsigned Width = SrcTy->getScalarSizeInBits(); - Constant *Shift = - CE->getIntegerValue(SrcTy, APInt(Width, BufferOffsetWidth)); - Constant *OffCast = - ConstantFoldIntegerCast(Off, SrcTy, /*IsSigned=*/false, DL); - Constant *RsrcHi = ConstantExpr::getShl( - RsrcInt, Shift, Width >= FatPtrWidth, Width > FatPtrWidth); - // This should be an or, but those got recently removed. - Constant *Result = ConstantExpr::getAdd(RsrcHi, OffCast, true, true); - return Result; - } + if (isa<GlobalValue>(C)) + report_fatal_error("Global values containing ptr addrspace(7) (buffer " + "fat pointer) values are not supported"); - if (CE->getOpcode() == Instruction::IntToPtr) { - auto *Arg = cast<Constant>(CE->getOperand(0)); - unsigned FatPtrWidth = - DL.getPointerSizeInBits(AMDGPUAS::BUFFER_FAT_POINTER); - unsigned RsrcPtrWidth = DL.getPointerSizeInBits(AMDGPUAS::BUFFER_RESOURCE); - auto *WantedTy = Arg->getType()->getWithNewBitWidth(FatPtrWidth); - Arg = ConstantFoldIntegerCast(Arg, WantedTy, /*IsSigned=*/false, DL); - - Constant *Shift = - CE->getIntegerValue(WantedTy, APInt(FatPtrWidth, BufferOffsetWidth)); - Type *RsrcIntType = WantedTy->getWithNewBitWidth(RsrcPtrWidth); - Type *RsrcTy = NewTy->getElementType(0); - Type *OffTy = WantedTy->getWithNewBitWidth(BufferOffsetWidth); - Constant *RsrcInt = CE->getTrunc( - ConstantFoldBinaryOpOperands(Instruction::LShr, Arg, Shift, DL), - RsrcIntType); - Constant *Rsrc = CE->getIntToPtr(RsrcInt, RsrcTy); - Constant *Off = ConstantFoldIntegerCast(Arg, OffTy, /*isSigned=*/false, DL); - - return ConstantStruct::get(NewTy, {Rsrc, Off}); - } + if (isa<ConstantExpr>(C)) + report_fatal_error("Constant exprs containing ptr addrspace(7) (buffer " + "fat pointer) values should have been expanded earlier"); - if (auto *AC = dyn_cast<AddrSpaceCastOperator>(CE)) { - unsigned SrcAS = AC->getSrcAddressSpace(); - unsigned DstAS = AC->getDestAddressSpace(); - auto *Arg = cast<Constant>(AC->getPointerOperand()); - auto *NewArg = InternalMapper.mapConstant(*Arg); - if (!NewArg) - return nullptr; - if (SrcAS == AMDGPUAS::BUFFER_FAT_POINTER && - DstAS == AMDGPUAS::BUFFER_FAT_POINTER) - return NewArg; - if (SrcAS == AMDGPUAS::BUFFER_RESOURCE && - DstAS == AMDGPUAS::BUFFER_FAT_POINTER) { - auto *NullOff = CE->getNullValue(NewTy->getElementType(1)); - return ConstantStruct::get(NewTy, {NewArg, NullOff}); - } - report_fatal_error( - "Unsupported address space cast for a buffer fat pointer"); - } return nullptr; } @@ -788,26 +665,6 @@ Value *FatPtrConstMaterializer::materialize(Value *V) { Constant *C = dyn_cast<Constant>(V); if (!C) return nullptr; - if (auto *GEPO = dyn_cast<GEPOperator>(C)) { - // As a special case, adjust GEP constants that have a ptr addrspace(7) in - // their source types here, since the earlier local changes didn't handle - // htis. - Type *SrcTy = GEPO->getSourceElementType(); - Type *NewSrcTy = IntTypeMap->remapType(SrcTy); - if (SrcTy != NewSrcTy) { - SmallVector<Constant *> Ops; - Ops.reserve(GEPO->getNumOperands()); - for (const Use &U : GEPO->operands()) - Ops.push_back(cast<Constant>(U.get())); - auto *NewGEP = ConstantExpr::getGetElementPtr( - NewSrcTy, Ops[0], ArrayRef<Constant *>(Ops).slice(1), - GEPO->getNoWrapFlags(), GEPO->getInRange()); - LLVM_DEBUG(dbgs() << "p7-getting GEP: " << *GEPO << " becomes " << *NewGEP - << "\n"); - Value *FurtherMap = materialize(NewGEP); - return FurtherMap ? FurtherMap : NewGEP; - } - } // Structs and other types that happen to contain fat pointers get remapped // by the mapValue() logic. if (!isBufferFatPtrConst(C)) @@ -1387,57 +1244,25 @@ PtrParts SplitPtrStructs::visitAtomicCmpXchgInst(AtomicCmpXchgInst &AI) { } PtrParts SplitPtrStructs::visitGetElementPtrInst(GetElementPtrInst &GEP) { + using namespace llvm::PatternMatch; Value *Ptr = GEP.getPointerOperand(); if (!isSplitFatPtr(Ptr->getType())) return {nullptr, nullptr}; IRB.SetInsertPoint(&GEP); auto [Rsrc, Off] = getPtrParts(Ptr); - Type *OffTy = Off->getType(); const DataLayout &DL = GEP.getModule()->getDataLayout(); bool InBounds = GEP.isInBounds(); - // In order to call collectOffset() and thus not have to reimplement it, - // we need the GEP's pointer operand to have ptr addrspace(7) type - GEP.setOperand(GEP.getPointerOperandIndex(), - PoisonValue::get(IRB.getPtrTy(AMDGPUAS::BUFFER_FAT_POINTER))); - MapVector<Value *, APInt> VariableOffs; - APInt ConstOffVal = APInt::getZero(BufferOffsetWidth); - if (!GEP.collectOffset(DL, BufferOffsetWidth, VariableOffs, ConstOffVal)) - report_fatal_error("Scalable vector or unsized struct in fat pointer GEP"); - GEP.setOperand(GEP.getPointerOperandIndex(), Ptr); - Value *OffAccum = nullptr; - // Accumulate offsets together before adding to the base in order to preserve - // as many of the inbounds properties as possible. - for (auto [Arg, Multiple] : VariableOffs) { - if (auto *OffVecTy = dyn_cast<VectorType>(OffTy)) - if (!Arg->getType()->isVectorTy()) - Arg = IRB.CreateVectorSplat(OffVecTy->getElementCount(), Arg); - Arg = IRB.CreateIntCast(Arg, OffTy, /*isSigned=*/true); - if (!Multiple.isOne()) { - if (Multiple.isPowerOf2()) - Arg = IRB.CreateShl(Arg, Multiple.logBase2(), "", /*hasNUW=*/InBounds, - /*HasNSW=*/InBounds); - else - Arg = IRB.CreateMul(Arg, ConstantExpr::getIntegerValue(OffTy, Multiple), - "", /*hasNUW=*/InBounds, /*hasNSW=*/InBounds); - } - if (OffAccum) - OffAccum = IRB.CreateAdd(OffAccum, Arg, "", /*hasNUW=*/InBounds, - /*hasNSW=*/InBounds); - else - OffAccum = Arg; - } - if (!ConstOffVal.isZero()) { - Constant *ConstOff = ConstantExpr::getIntegerValue(OffTy, ConstOffVal); - if (OffAccum) - OffAccum = IRB.CreateAdd(OffAccum, ConstOff, "", /*hasNUW=*/InBounds, - /*hasNSW=*/InBounds); - else - OffAccum = ConstOff; - } - - if (!OffAccum) { // Constant-zero offset + // In order to call emitGEPOffset() and thus not have to reimplement it, + // we need the GEP result to have ptr addrspace(7) type. + Type *FatPtrTy = IRB.getPtrTy(AMDGPUAS::BUFFER_FAT_POINTER); + if (auto *VT = dyn_cast<VectorType>(Off->getType())) + FatPtrTy = VectorType::get(FatPtrTy, VT->getElementCount()); + GEP.mutateType(FatPtrTy); + Value *OffAccum = emitGEPOffset(&IRB, DL, &GEP); + GEP.mutateType(Ptr->getType()); + if (match(OffAccum, m_Zero())) { // Constant-zero offset SplitUsers.insert(&GEP); return {Rsrc, Off}; } @@ -1447,7 +1272,7 @@ PtrParts SplitPtrStructs::visitGetElementPtrInst(GetElementPtrInst &GEP) { HasNonNegativeOff = !CI->isNegative(); } Value *NewOff; - if (PatternMatch::match(Off, PatternMatch::is_zero())) { + if (match(Off, m_Zero())) { NewOff = OffAccum; } else { NewOff = IRB.CreateAdd(Off, OffAccum, "", @@ -1473,20 +1298,22 @@ PtrParts SplitPtrStructs::visitPtrToIntInst(PtrToIntInst &PI) { const DataLayout &DL = PI.getModule()->getDataLayout(); unsigned FatPtrWidth = DL.getPointerSizeInBits(AMDGPUAS::BUFFER_FAT_POINTER); - Value *RsrcInt; - if (Width <= BufferOffsetWidth) - RsrcInt = ConstantExpr::getIntegerValue(ResTy, APInt::getZero(Width)); - else - RsrcInt = IRB.CreatePtrToInt(Rsrc, ResTy, PI.getName() + ".rsrc"); - copyMetadata(RsrcInt, &PI); - - Value *Shl = IRB.CreateShl( - RsrcInt, - ConstantExpr::getIntegerValue(ResTy, APInt(Width, BufferOffsetWidth)), "", - Width >= FatPtrWidth, Width > FatPtrWidth); - Value *OffCast = - IRB.CreateIntCast(Off, ResTy, /*isSigned=*/false, PI.getName() + ".off"); - Value *Res = IRB.CreateOr(Shl, OffCast); + Value *Res; + if (Width <= BufferOffsetWidth) { + Res = IRB.CreateIntCast(Off, ResTy, /*isSigned=*/false, + PI.getName() + ".off"); + } else { + Value *RsrcInt = IRB.CreatePtrToInt(Rsrc, ResTy, PI.getName() + ".rsrc"); + Value *Shl = IRB.CreateShl( + RsrcInt, + ConstantExpr::getIntegerValue(ResTy, APInt(Width, BufferOffsetWidth)), + "", Width >= FatPtrWidth, Width > FatPtrWidth); + Value *OffCast = IRB.CreateIntCast(Off, ResTy, /*isSigned=*/false, + PI.getName() + ".off"); + Res = IRB.CreateOr(Shl, OffCast); + } + + copyMetadata(Res, &PI); Res->takeName(&PI); SplitUsers.insert(&PI); PI.replaceAllUsesWith(Res); @@ -1818,14 +1645,9 @@ public: static bool containsBufferFatPointers(const Function &F, BufferFatPtrToStructTypeMap *TypeMap) { bool HasFatPointers = false; - for (const BasicBlock &BB : F) { - for (const Instruction &I : BB) { + for (const BasicBlock &BB : F) + for (const Instruction &I : BB) HasFatPointers |= (I.getType() != TypeMap->remapType(I.getType())); - for (const Use &U : I.operands()) - if (auto *C = dyn_cast<Constant>(U.get())) - HasFatPointers |= isBufferFatPtrConst(C); - } - } return HasFatPointers; } @@ -1924,6 +1746,36 @@ bool AMDGPULowerBufferFatPointers::run(Module &M, const TargetMachine &TM) { "buffer resource pointers (address space 8) instead."); } + { + // Collect all constant exprs and aggregates referenced by any function. + SmallVector<Constant *, 8> Worklist; + for (Function &F : M.functions()) + for (Instruction &I : instructions(F)) + for (Value *Op : I.operands()) + if (isa<ConstantExpr>(Op) || isa<ConstantAggregate>(Op)) + Worklist.push_back(cast<Constant>(Op)); + + // Recursively look for any referenced buffer pointer constants. + SmallPtrSet<Constant *, 8> Visited; + SetVector<Constant *> BufferFatPtrConsts; + while (!Worklist.empty()) { + Constant *C = Worklist.pop_back_val(); + if (!Visited.insert(C).second) + continue; + if (isBufferFatPtrOrVector(C->getType())) + BufferFatPtrConsts.insert(C); + for (Value *Op : C->operands()) + if (isa<ConstantExpr>(Op) || isa<ConstantAggregate>(Op)) + Worklist.push_back(cast<Constant>(Op)); + } + + // Expand all constant expressions using fat buffer pointers to + // instructions. + Changed |= convertUsersOfConstantsToInstructions( + BufferFatPtrConsts.getArrayRef(), /*RestrictToFunc=*/nullptr, + /*RemoveDeadConstants=*/false, /*IncludeSelf=*/true); + } + StoreFatPtrsAsIntsVisitor MemOpsRewrite(&IntTM, M.getContext()); for (Function &F : M.functions()) { bool InterfaceChange = hasFatPointerInterface(F, &StructTM); @@ -1939,7 +1791,7 @@ bool AMDGPULowerBufferFatPointers::run(Module &M, const TargetMachine &TM) { SmallVector<Function *> Intrinsics; // Keep one big map so as to memoize constants across functions. ValueToValueMapTy CloneMap; - FatPtrConstMaterializer Materializer(&StructTM, CloneMap, &IntTM, DL); + FatPtrConstMaterializer Materializer(&StructTM, CloneMap); ValueMapper LowerInFuncs(CloneMap, RF_None, &StructTM, &Materializer); for (auto [F, InterfaceChange] : NeedsRemap) { |
