summaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp')
-rw-r--r--mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp28
1 files changed, 22 insertions, 6 deletions
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 2dc42f0a85e6..054827d40f0f 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -25,8 +25,8 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Pass/Pass.h"
-#include "mlir/Support/MathExtras.h"
#include "llvm/ADT/SmallBitVector.h"
+#include "llvm/Support/MathExtras.h"
#include <optional>
namespace mlir {
@@ -971,8 +971,8 @@ struct MemorySpaceCastOpLowering
resultUnderlyingDesc, resultElemPtrType);
int64_t bytesToSkip =
- 2 *
- ceilDiv(getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8);
+ 2 * llvm::divideCeil(
+ getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8);
Value bytesToSkipConst = rewriter.create<LLVM::ConstantOp>(
loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip));
Value copySize = rewriter.create<LLVM::SubOp>(
@@ -1590,10 +1590,26 @@ public:
matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- MemRefDescriptor desc(adaptor.getSource());
+ BaseMemRefType sourceTy = extractOp.getSource().getType();
+
+ Value alignedPtr;
+ if (sourceTy.hasRank()) {
+ MemRefDescriptor desc(adaptor.getSource());
+ alignedPtr = desc.alignedPtr(rewriter, extractOp->getLoc());
+ } else {
+ auto elementPtrTy = LLVM::LLVMPointerType::get(
+ rewriter.getContext(), sourceTy.getMemorySpaceAsInt());
+
+ UnrankedMemRefDescriptor desc(adaptor.getSource());
+ Value descPtr = desc.memRefDescPtr(rewriter, extractOp->getLoc());
+
+ alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
+ rewriter, extractOp->getLoc(), *getTypeConverter(), descPtr,
+ elementPtrTy);
+ }
+
rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(
- extractOp, getTypeConverter()->getIndexType(),
- desc.alignedPtr(rewriter, extractOp->getLoc()));
+ extractOp, getTypeConverter()->getIndexType(), alignedPtr);
return success();
}
};