summaryrefslogtreecommitdiff
path: root/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'flang/lib/Optimizer/Transforms/CUFOpConversion.cpp')
-rw-r--r--flang/lib/Optimizer/Transforms/CUFOpConversion.cpp38
1 files changed, 5 insertions, 33 deletions
diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index 9834b0499b93..609a1fc9fb02 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -557,8 +557,8 @@ static mlir::Value emboxSrc(mlir::PatternRewriter &rewriter,
mlir::Value src = op.getSrc();
if (srcTy.isInteger(1)) {
// i1 is not a supported type in the descriptor and it is actually coming
- // from a LOGICAL constant. Store it as a fir.logical.
- srcTy = fir::LogicalType::get(rewriter.getContext(), 4);
+ // from a LOGICAL constant. Use the destination type to avoid mismatch.
+ srcTy = dstEleTy;
src = createConvertOp(rewriter, loc, srcTy, src);
addr = builder.createTemporary(loc, srcTy);
fir::StoreOp::create(builder, loc, src, addr);
@@ -650,7 +650,7 @@ struct CUFDataTransferOpConversion
if (fir::isa_trivial(srcTy) && !fir::isa_trivial(dstTy)) {
// Initialization of an array from a scalar value should be implemented
- // via a kernel launch. Use the flan runtime via the Assign function
+ // via a kernel launch. Use the flang runtime via the Assign function
// until we have more infrastructure.
mlir::Value src = emboxSrc(rewriter, op, symtab);
mlir::Value dst = emboxDst(rewriter, op, symtab);
@@ -928,34 +928,6 @@ struct CUFSyncDescriptorOpConversion
}
};
-struct CUFSetAllocatorIndexOpConversion
- : public mlir::OpRewritePattern<cuf::SetAllocatorIndexOp> {
- using OpRewritePattern::OpRewritePattern;
-
- mlir::LogicalResult
- matchAndRewrite(cuf::SetAllocatorIndexOp op,
- mlir::PatternRewriter &rewriter) const override {
- auto mod = op->getParentOfType<mlir::ModuleOp>();
- fir::FirOpBuilder builder(rewriter, mod);
- mlir::Location loc = op.getLoc();
- int idx = kDefaultAllocator;
- if (op.getDataAttr() == cuf::DataAttribute::Device) {
- idx = kDeviceAllocatorPos;
- } else if (op.getDataAttr() == cuf::DataAttribute::Managed) {
- idx = kManagedAllocatorPos;
- } else if (op.getDataAttr() == cuf::DataAttribute::Unified) {
- idx = kUnifiedAllocatorPos;
- } else if (op.getDataAttr() == cuf::DataAttribute::Pinned) {
- idx = kPinnedAllocatorPos;
- }
- mlir::Value index =
- builder.createIntegerConstant(loc, builder.getI32Type(), idx);
- fir::runtime::cuda::genSetAllocatorIndex(builder, loc, op.getBox(), index);
- op.erase();
- return mlir::success();
- }
-};
-
class CUFOpConversion : public fir::impl::CUFOpConversionBase<CUFOpConversion> {
public:
void runOnOperation() override {
@@ -1017,8 +989,8 @@ void cuf::populateCUFToFIRConversionPatterns(
const mlir::SymbolTable &symtab, mlir::RewritePatternSet &patterns) {
patterns.insert<CUFAllocOpConversion>(patterns.getContext(), &dl, &converter);
patterns.insert<CUFAllocateOpConversion, CUFDeallocateOpConversion,
- CUFFreeOpConversion, CUFSyncDescriptorOpConversion,
- CUFSetAllocatorIndexOpConversion>(patterns.getContext());
+ CUFFreeOpConversion, CUFSyncDescriptorOpConversion>(
+ patterns.getContext());
patterns.insert<CUFDataTransferOpConversion>(patterns.getContext(), symtab,
&dl, &converter);
patterns.insert<CUFLaunchOpConversion, CUFDeviceAddressOpConversion>(