summaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/SCF/Transforms/ParallelForToNestedFors.cpp
blob: c469a991fff64056bb6f143bc2dddd54c75902d9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
//===- ParallelForToNestedFors.cpp - scf.parallel to nested scf.for ops --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Transforms SCF.ParallelOp to nested scf.for ops.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Passes.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/Support/Debug.h"

namespace mlir {
#define GEN_PASS_DEF_SCFPARALLELFORTONESTEDFORS
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
} // namespace mlir

#define DEBUG_TYPE "parallel-for-to-nested-fors"
using namespace mlir;

FailureOr<scf::LoopNest>
mlir::scf::parallelForToNestedFors(RewriterBase &rewriter,
                                   scf::ParallelOp parallelOp) {

  if (!parallelOp.getResults().empty())
    return rewriter.notifyMatchFailure(
        parallelOp, "Currently scf.parallel to scf.for conversion doesn't "
                    "support scf.parallel with results.");

  rewriter.setInsertionPoint(parallelOp);

  Location loc = parallelOp.getLoc();
  SmallVector<Value> lowerBounds = parallelOp.getLowerBound();
  SmallVector<Value> upperBounds = parallelOp.getUpperBound();
  SmallVector<Value> steps = parallelOp.getStep();

  assert(lowerBounds.size() == upperBounds.size() &&
         lowerBounds.size() == steps.size() &&
         "Mismatched parallel loop bounds");

  scf::LoopNest loopNest =
      scf::buildLoopNest(rewriter, loc, lowerBounds, upperBounds, steps);

  SmallVector<Value> newInductionVars = llvm::map_to_vector(
      loopNest.loops, [](scf::ForOp forOp) { return forOp.getInductionVar(); });
  Block *linearizedBody = loopNest.loops.back().getBody();
  Block *parallelBody = parallelOp.getBody();
  rewriter.eraseOp(parallelBody->getTerminator());
  rewriter.inlineBlockBefore(parallelBody, linearizedBody->getTerminator(),
                             newInductionVars);
  rewriter.eraseOp(parallelOp);
  return loopNest;
}

namespace {
struct ParallelForToNestedFors final
    : public impl::SCFParallelForToNestedForsBase<ParallelForToNestedFors> {
  void runOnOperation() override {
    Operation *parentOp = getOperation();
    IRRewriter rewriter(parentOp->getContext());

    parentOp->walk(
        [&](scf::ParallelOp parallelOp) {
          if (failed(scf::parallelForToNestedFors(rewriter, parallelOp))) {
            LLVM_DEBUG(
                llvm::dbgs()
                << "Failed to convert scf.parallel to nested scf.for ops for:\n"
                << parallelOp << "\n");
            return WalkResult::advance();
          }
          return WalkResult::advance();
        });
  }
};
} // namespace

std::unique_ptr<Pass> mlir::createParallelForToNestedForsPass() {
  return std::make_unique<ParallelForToNestedFors>();
}