diff options
Diffstat (limited to 'mlir/lib/Analysis/Presburger/IntegerRelation.cpp')
| -rw-r--r-- | mlir/lib/Analysis/Presburger/IntegerRelation.cpp | 38 |
1 files changed, 38 insertions, 0 deletions
diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp index 17e48e0d069b..5c4d4d13580a 100644 --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -2481,6 +2481,44 @@ void IntegerRelation::applyDomain(const IntegerRelation &rel) { void IntegerRelation::applyRange(const IntegerRelation &rel) { compose(rel); } +IntegerRelation IntegerRelation::rangeProduct(const IntegerRelation &rel) { + /// R1: (i, j) -> k : f(i, j, k) = 0 + /// R2: (i, j) -> l : g(i, j, l) = 0 + /// R1.rangeProduct(R2): (i, j) -> (k, l) : f(i, j, k) = 0 and g(i, j, l) = 0 + assert(getNumDomainVars() == rel.getNumDomainVars() && + "Range product is only defined for relations with equal domains"); + + // explicit copy of `this` + IntegerRelation result = *this; + unsigned relRangeVarStart = rel.getVarKindOffset(VarKind::Range); + unsigned numThisRangeVars = getNumRangeVars(); + unsigned numNewSymbolVars = result.getNumSymbolVars() - getNumSymbolVars(); + + result.appendVar(VarKind::Range, rel.getNumRangeVars()); + + // Copy each equality from `rel` and update the copy to account for range + // variables from `this`. The `rel` equality is a list of coefficients of the + // variables from `rel`, and so the range variables need to be shifted right + // by the number of `this` range variables and symbols. + for (unsigned i = 0; i < rel.getNumEqualities(); ++i) { + SmallVector<DynamicAPInt> copy = + SmallVector<DynamicAPInt>(rel.getEquality(i)); + copy.insert(copy.begin() + relRangeVarStart, + numThisRangeVars + numNewSymbolVars, DynamicAPInt(0)); + result.addEquality(copy); + } + + for (unsigned i = 0; i < rel.getNumInequalities(); ++i) { + SmallVector<DynamicAPInt> copy = + SmallVector<DynamicAPInt>(rel.getInequality(i)); + copy.insert(copy.begin() + relRangeVarStart, + numThisRangeVars + numNewSymbolVars, DynamicAPInt(0)); + result.addInequality(copy); + } + + return result; +} + void IntegerRelation::printSpace(raw_ostream &os) const { space.print(os); os << getNumConstraints() << " constraints\n"; |
