diff options
Diffstat (limited to 'llvm/lib/Analysis/HashRecognize.cpp')
| -rw-r--r-- | llvm/lib/Analysis/HashRecognize.cpp | 364 |
1 files changed, 102 insertions, 262 deletions
diff --git a/llvm/lib/Analysis/HashRecognize.cpp b/llvm/lib/Analysis/HashRecognize.cpp index 92c9e37dbb48..5d7ee1fe8eb1 100644 --- a/llvm/lib/Analysis/HashRecognize.cpp +++ b/llvm/lib/Analysis/HashRecognize.cpp @@ -8,8 +8,10 @@ // // The HashRecognize analysis recognizes unoptimized polynomial hash functions // with operations over a Galois field of characteristic 2, also called binary -// fields, or GF(2^n): this class of hash functions can be optimized using a -// lookup-table-driven implementation, or with target-specific instructions. +// fields, or GF(2^n). 2^n is termed the order of the Galois field. This class +// of hash functions can be optimized using a lookup-table-driven +// implementation, or with target-specific instructions. +// // Examples: // // 1. Cyclic redundancy check (CRC), which is a polynomial division in GF(2). @@ -24,12 +26,10 @@ // // c_m * x^m + c_(m-1) * x^(m-1) + ... + c_0 * x^0 // -// where each coefficient c is can take values in GF(2^n), where 2^n is termed -// the order of the Galois field. For GF(2), each coefficient can take values -// either 0 or 1, and the polynomial is simply represented by m+1 bits, -// corresponding to the coefficients. The different variants of CRC are named by -// degree of generating polynomial used: so CRC-32 would use a polynomial of -// degree 32. +// where each coefficient c is can take values 0 or 1. The polynomial is simply +// represented by m+1 bits, corresponding to the coefficients. The different +// variants of CRC are named by degree of generating polynomial used: so CRC-32 +// would use a polynomial of degree 32. // // The reason algorithms on GF(2^n) can be optimized with a lookup-table is the // following: in such fields, polynomial addition and subtraction are identical @@ -73,202 +73,31 @@ using namespace SCEVPatternMatch; #define DEBUG_TYPE "hash-recognize" -// KnownBits for a PHI node. There are at most two PHI nodes, corresponding to -// the Simple Recurrence and Conditional Recurrence. The IndVar PHI is not -// relevant. -using KnownPhiMap = SmallDenseMap<const PHINode *, KnownBits, 2>; - -// A pair of a PHI node along with its incoming value from within a loop. -using PhiStepPair = std::pair<const PHINode *, const Instruction *>; - -/// A much simpler version of ValueTracking, in that it computes KnownBits of -/// values, except that it computes the evolution of KnownBits in a loop with a -/// given trip count, and predication is specialized for a significant-bit -/// check. -class ValueEvolution { - const unsigned TripCount; - const bool ByteOrderSwapped; - APInt GenPoly; - StringRef ErrStr; - - // Compute the KnownBits of a BinaryOperator. - KnownBits computeBinOp(const BinaryOperator *I); - - // Compute the KnownBits of an Instruction. - KnownBits computeInstr(const Instruction *I); - - // Compute the KnownBits of a Value. - KnownBits compute(const Value *V); - -public: - // ValueEvolution is meant to be constructed with the TripCount of the loop, - // and a boolean indicating whether the polynomial algorithm is big-endian - // (for the significant-bit check). - ValueEvolution(unsigned TripCount, bool ByteOrderSwapped); - - // Given a list of PHI nodes along with their incoming value from within the - // loop, computeEvolutions computes the KnownBits of each of the PHI nodes on - // the final iteration. Returns true on success and false on error. - bool computeEvolutions(ArrayRef<PhiStepPair> PhiEvolutions); - - // In case ValueEvolution encounters an error, this is meant to be used for a - // precise error message. - StringRef getError() const { return ErrStr; } - - // A set of Instructions visited by ValueEvolution. The only unvisited - // instructions will be ones not on the use-def chain of the PHIs' evolutions. +/// Checks if there's a stray instruction in the loop \p L outside of the +/// use-def chains from \p Roots, or if we escape the loop during the use-def +/// walk. +static bool containsUnreachable(const Loop &L, + ArrayRef<const Instruction *> Roots) { SmallPtrSet<const Instruction *, 16> Visited; + BasicBlock *Latch = L.getLoopLatch(); - // The computed KnownBits for each PHI node, which is populated after - // computeEvolutions is called. - KnownPhiMap KnownPhis; -}; - -ValueEvolution::ValueEvolution(unsigned TripCount, bool ByteOrderSwapped) - : TripCount(TripCount), ByteOrderSwapped(ByteOrderSwapped) {} - -KnownBits ValueEvolution::computeBinOp(const BinaryOperator *I) { - KnownBits KnownL(compute(I->getOperand(0))); - KnownBits KnownR(compute(I->getOperand(1))); - - switch (I->getOpcode()) { - case Instruction::BinaryOps::And: - return KnownL & KnownR; - case Instruction::BinaryOps::Or: - return KnownL | KnownR; - case Instruction::BinaryOps::Xor: - return KnownL ^ KnownR; - case Instruction::BinaryOps::Shl: { - auto *OBO = cast<OverflowingBinaryOperator>(I); - return KnownBits::shl(KnownL, KnownR, OBO->hasNoUnsignedWrap(), - OBO->hasNoSignedWrap()); - } - case Instruction::BinaryOps::LShr: - return KnownBits::lshr(KnownL, KnownR); - case Instruction::BinaryOps::AShr: - return KnownBits::ashr(KnownL, KnownR); - case Instruction::BinaryOps::Add: { - auto *OBO = cast<OverflowingBinaryOperator>(I); - return KnownBits::add(KnownL, KnownR, OBO->hasNoUnsignedWrap(), - OBO->hasNoSignedWrap()); - } - case Instruction::BinaryOps::Sub: { - auto *OBO = cast<OverflowingBinaryOperator>(I); - return KnownBits::sub(KnownL, KnownR, OBO->hasNoUnsignedWrap(), - OBO->hasNoSignedWrap()); - } - case Instruction::BinaryOps::Mul: { - Value *Op0 = I->getOperand(0); - Value *Op1 = I->getOperand(1); - bool SelfMultiply = Op0 == Op1 && isGuaranteedNotToBeUndef(Op0); - return KnownBits::mul(KnownL, KnownR, SelfMultiply); - } - case Instruction::BinaryOps::UDiv: - return KnownBits::udiv(KnownL, KnownR); - case Instruction::BinaryOps::SDiv: - return KnownBits::sdiv(KnownL, KnownR); - case Instruction::BinaryOps::URem: - return KnownBits::urem(KnownL, KnownR); - case Instruction::BinaryOps::SRem: - return KnownBits::srem(KnownL, KnownR); - default: - ErrStr = "Unknown BinaryOperator"; - unsigned BitWidth = I->getType()->getScalarSizeInBits(); - return {BitWidth}; - } -} - -KnownBits ValueEvolution::computeInstr(const Instruction *I) { - unsigned BitWidth = I->getType()->getScalarSizeInBits(); - - // computeInstr is the only entry-point that needs to update the Visited set. - Visited.insert(I); + SmallVector<const Instruction *, 16> Worklist(Roots); + while (!Worklist.empty()) { + const Instruction *I = Worklist.pop_back_val(); + Visited.insert(I); - // We look up in the map that contains the KnownBits of the PHI from the - // previous iteration. - if (const PHINode *P = dyn_cast<PHINode>(I)) - return KnownPhis.lookup_or(P, BitWidth); + if (isa<PHINode>(I)) + continue; - // Compute the KnownBits for a Select(Cmp()), forcing it to take the branch - // that is predicated on the (least|most)-significant-bit check. - CmpPredicate Pred; - Value *L, *R; - Instruction *TV, *FV; - if (match(I, m_Select(m_ICmp(Pred, m_Value(L), m_Value(R)), m_Instruction(TV), - m_Instruction(FV)))) { - Visited.insert(cast<Instruction>(I->getOperand(0))); - - // We need to check LCR against [0, 2) in the little-endian case, because - // the RCR check is insufficient: it is simply [0, 1). - if (!ByteOrderSwapped) { - KnownBits KnownL = compute(L); - unsigned ICmpBW = KnownL.getBitWidth(); - auto LCR = ConstantRange::fromKnownBits(KnownL, false); - auto CheckLCR = ConstantRange(APInt::getZero(ICmpBW), APInt(ICmpBW, 2)); - if (LCR != CheckLCR) { - ErrStr = "Bad LHS of significant-bit-check"; - return {BitWidth}; + for (const Use &U : I->operands()) { + if (auto *UI = dyn_cast<Instruction>(U)) { + if (!L.contains(UI)) + return true; + Worklist.push_back(UI); } } - - // Check that the predication is on (most|least) significant bit. - KnownBits KnownR = compute(R); - unsigned ICmpBW = KnownR.getBitWidth(); - auto RCR = ConstantRange::fromKnownBits(KnownR, false); - auto AllowedR = ConstantRange::makeAllowedICmpRegion(Pred, RCR); - ConstantRange CheckRCR(APInt::getZero(ICmpBW), - ByteOrderSwapped ? APInt::getSignedMinValue(ICmpBW) - : APInt(ICmpBW, 1)); - - // We only compute KnownBits of either TV or FV, as the other value would - // just be a bit-shift as checked by isBigEndianBitShift. - if (AllowedR == CheckRCR) { - Visited.insert(FV); - return compute(TV); - } - if (AllowedR.inverse() == CheckRCR) { - Visited.insert(TV); - return compute(FV); - } - - ErrStr = "Bad RHS of significant-bit-check"; - return {BitWidth}; - } - - if (auto *BO = dyn_cast<BinaryOperator>(I)) - return computeBinOp(BO); - - switch (I->getOpcode()) { - case Instruction::CastOps::Trunc: - return compute(I->getOperand(0)).trunc(BitWidth); - case Instruction::CastOps::ZExt: - return compute(I->getOperand(0)).zext(BitWidth); - case Instruction::CastOps::SExt: - return compute(I->getOperand(0)).sext(BitWidth); - default: - ErrStr = "Unknown Instruction"; - return {BitWidth}; } -} - -KnownBits ValueEvolution::compute(const Value *V) { - if (auto *CI = dyn_cast<ConstantInt>(V)) - return KnownBits::makeConstant(CI->getValue()); - - if (auto *I = dyn_cast<Instruction>(V)) - return computeInstr(I); - - ErrStr = "Unknown Value"; - unsigned BitWidth = V->getType()->getScalarSizeInBits(); - return {BitWidth}; -} - -bool ValueEvolution::computeEvolutions(ArrayRef<PhiStepPair> PhiEvolutions) { - for (unsigned I = 0; I < TripCount; ++I) - for (auto [Phi, Step] : PhiEvolutions) - KnownPhis.emplace_or_assign(Phi, computeInstr(Step)); - - return ErrStr.empty(); + return std::distance(Latch->begin(), Latch->end()) != Visited.size(); } /// A structure that can hold either a Simple Recurrence or a Conditional @@ -320,6 +149,62 @@ private: Instruction::BinaryOps BOWithConstOpToMatch = Instruction::BinaryOpsEnd); }; +/// Check the well-formedness of the (most|least) significant bit check given \p +/// ConditionalRecurrence, \p SimpleRecurrence, depending on \p +/// ByteOrderSwapped. We check that ConditionalRecurrence.Step is a +/// Select(Cmp()) where the compare is `>= 0` in the big-endian case, and `== 0` +/// in the little-endian case (or the inverse, in which case the branches of the +/// compare are swapped). We check that the LHS is (ConditionalRecurrence.Phi +/// [xor SimpleRecurrence.Phi]) in the big-endian case, and additionally check +/// for an AND with one in the little-endian case. We then check AllowedByR +/// against CheckAllowedByR, which is [0, smin) in the big-endian case, and is +/// [0, 1) in the little-endian case. CheckAllowedByR checks for +/// significant-bit-clear, and we match the corresponding arms of the select +/// against bit-shift and bit-shift-and-xor-gen-poly. +static bool +isSignificantBitCheckWellFormed(const RecurrenceInfo &ConditionalRecurrence, + const RecurrenceInfo &SimpleRecurrence, + bool ByteOrderSwapped) { + auto *SI = cast<SelectInst>(ConditionalRecurrence.Step); + CmpPredicate Pred; + const Value *L; + const APInt *R; + Instruction *TV, *FV; + if (!match(SI, m_Select(m_ICmp(Pred, m_Value(L), m_APInt(R)), + m_Instruction(TV), m_Instruction(FV)))) + return false; + + // Match predicate with or without a SimpleRecurrence (the corresponding data + // is LHSAux). + auto MatchPred = m_CombineOr( + m_Specific(ConditionalRecurrence.Phi), + m_c_Xor(m_ZExtOrTruncOrSelf(m_Specific(ConditionalRecurrence.Phi)), + m_ZExtOrTruncOrSelf(m_Specific(SimpleRecurrence.Phi)))); + bool LWellFormed = ByteOrderSwapped ? match(L, MatchPred) + : match(L, m_c_And(MatchPred, m_One())); + if (!LWellFormed) + return false; + + KnownBits KnownR = KnownBits::makeConstant(*R); + unsigned BW = KnownR.getBitWidth(); + auto RCR = ConstantRange::fromKnownBits(KnownR, false); + auto AllowedByR = ConstantRange::makeAllowedICmpRegion(Pred, RCR); + ConstantRange CheckAllowedByR(APInt::getZero(BW), + ByteOrderSwapped ? APInt::getSignedMinValue(BW) + : APInt(BW, 1)); + + BinaryOperator *BitShift = ConditionalRecurrence.BO; + if (AllowedByR == CheckAllowedByR) + return TV == BitShift && + match(FV, m_c_Xor(m_Specific(BitShift), + m_SpecificInt(*ConditionalRecurrence.ExtraConst))); + if (AllowedByR.inverse() == CheckAllowedByR) + return FV == BitShift && + match(TV, m_c_Xor(m_Specific(BitShift), + m_SpecificInt(*ConditionalRecurrence.ExtraConst))); + return false; +} + /// Wraps llvm::matchSimpleRecurrence. Match a simple first order recurrence /// cycle of the form: /// @@ -336,8 +221,11 @@ private: /// %BO = binop %step, %rec /// bool RecurrenceInfo::matchSimpleRecurrence(const PHINode *P) { - Phi = P; - return llvm::matchSimpleRecurrence(Phi, BO, Start, Step); + if (llvm::matchSimpleRecurrence(P, BO, Start, Step)) { + Phi = P; + return true; + } + return false; } /// Digs for a recurrence starting with \p V hitting the PHI node in a use-def @@ -459,26 +347,6 @@ PolynomialInfo::PolynomialInfo(unsigned TripCount, Value *LHS, const APInt &RHS, : TripCount(TripCount), LHS(LHS), RHS(RHS), ComputedValue(ComputedValue), ByteOrderSwapped(ByteOrderSwapped), LHSAux(LHSAux) {} -/// In the big-endian case, checks the bottom N bits against CheckFn, and that -/// the rest are unknown. In the little-endian case, checks the top N bits -/// against CheckFn, and that the rest are unknown. Callers usually call this -/// function with N = TripCount, and CheckFn checking that the remainder bits of -/// the CRC polynomial division are zero. -static bool checkExtractBits(const KnownBits &Known, unsigned N, - function_ref<bool(const KnownBits &)> CheckFn, - bool ByteOrderSwapped) { - // Check that the entire thing is a constant. - if (N == Known.getBitWidth()) - return CheckFn(Known.extractBits(N, 0)); - - // Check that the {top, bottom} N bits are not unknown and that the {bottom, - // top} N bits are known. - unsigned BitPos = ByteOrderSwapped ? 0 : Known.getBitWidth() - N; - unsigned SwappedBitPos = ByteOrderSwapped ? N : 0; - return CheckFn(Known.extractBits(N, BitPos)) && - Known.extractBits(Known.getBitWidth() - N, SwappedBitPos).isUnknown(); -} - /// Generate a lookup table of 256 entries by interleaving the generating /// polynomial. The optimization technique of table-lookup for CRC is also /// called the Sarwate algorithm. @@ -511,8 +379,6 @@ CRCTable HashRecognize::genSarwateTable(const APInt &GenPoly, /// Checks that \p P1 and \p P2 are used together in an XOR in the use-def chain /// of \p SI's condition, ignoring any casts. The purpose of this function is to /// ensure that LHSAux from the SimpleRecurrence is used correctly in the CRC -/// computation. We cannot check the correctness of casts at this point, and -/// rely on the KnownBits propagation to check correctness of the CRC /// computation. /// /// In other words, it checks for the following pattern: @@ -540,8 +406,8 @@ static bool isConditionalOnXorOfPHIs(const SelectInst *SI, const PHINode *P1, continue; // If we match an XOR of the two PHIs ignoring casts, we're done. - if (match(I, m_c_Xor(m_CastOrSelf(m_Specific(P1)), - m_CastOrSelf(m_Specific(P2))))) + if (match(I, m_c_Xor(m_ZExtOrTruncOrSelf(m_Specific(P1)), + m_ZExtOrTruncOrSelf(m_Specific(P2))))) return true; // Continue along the use-def chain. @@ -570,10 +436,8 @@ static std::optional<bool> isBigEndianBitShift(Value *V, ScalarEvolution &SE) { } /// The main entry point for analyzing a loop and recognizing the CRC algorithm. -/// Returns a PolynomialInfo on success, and either an ErrBits or a StringRef on -/// failure. -std::variant<PolynomialInfo, ErrBits, StringRef> -HashRecognize::recognizeCRC() const { +/// Returns a PolynomialInfo on success, and a StringRef on failure. +std::variant<PolynomialInfo, StringRef> HashRecognize::recognizeCRC() const { if (!L.isInnermost()) return "Loop is not innermost"; BasicBlock *Latch = L.getLoopLatch(); @@ -582,7 +446,7 @@ HashRecognize::recognizeCRC() const { if (!Latch || !Exit || !IndVar || L.getNumBlocks() != 1) return "Loop not in canonical form"; unsigned TC = SE.getSmallConstantTripCount(&L); - if (!TC || TC > 256 || TC % 8) + if (!TC || TC % 8) return "Unable to find a small constant byte-multiple trip count"; auto R = getRecurrences(Latch, IndVar, L); @@ -637,36 +501,19 @@ HashRecognize::recognizeCRC() const { "Expected ExtraConst in conditional recurrence"); const APInt &GenPoly = *ConditionalRecurrence.ExtraConst; - // PhiEvolutions are pairs of PHINodes along with their incoming value from - // within the loop, which we term as their step. Note that in the case of a - // Simple Recurrence, Step is an operand of the BO, while in a Conditional - // Recurrence, it is a SelectInst. - SmallVector<PhiStepPair, 2> PhiEvolutions; - PhiEvolutions.emplace_back(ConditionalRecurrence.Phi, ComputedValue); + if (!isSignificantBitCheckWellFormed(ConditionalRecurrence, SimpleRecurrence, + *ByteOrderSwapped)) + return "Malformed significant-bit check"; + + SmallVector<const Instruction *> Roots( + {ComputedValue, + cast<Instruction>(IndVar->getIncomingValueForBlock(Latch)), + L.getLatchCmpInst(), Latch->getTerminator()}); if (SimpleRecurrence) - PhiEvolutions.emplace_back(SimpleRecurrence.Phi, SimpleRecurrence.BO); - - ValueEvolution VE(TC, *ByteOrderSwapped); - if (!VE.computeEvolutions(PhiEvolutions)) - return VE.getError(); - KnownBits ResultBits = VE.KnownPhis.at(ConditionalRecurrence.Phi); - - // There must be exactly four unvisited instructions, corresponding to the - // IndVar PHI. Any other unvisited instructions from the KnownBits propagation - // can complicate the optimization, which replaces the entire loop with the - // table-lookup version of the hash algorithm. - std::initializer_list<const Instruction *> AugmentVisited = { - IndVar, Latch->getTerminator(), L.getLatchCmpInst(), - cast<Instruction>(IndVar->getIncomingValueForBlock(Latch))}; - VE.Visited.insert_range(AugmentVisited); - if (std::distance(Latch->begin(), Latch->end()) != VE.Visited.size()) + Roots.push_back(SimpleRecurrence.BO); + if (containsUnreachable(L, Roots)) return "Found stray unvisited instructions"; - unsigned N = std::min(TC, ResultBits.getBitWidth()); - auto IsZero = [](const KnownBits &K) { return K.isZero(); }; - if (!checkExtractBits(ResultBits, N, IsZero, *ByteOrderSwapped)) - return ErrBits(ResultBits, TC, *ByteOrderSwapped); - return PolynomialInfo(TC, LHS, GenPoly, ComputedValue, *ByteOrderSwapped, LHSAux); } @@ -693,13 +540,6 @@ void HashRecognize::print(raw_ostream &OS) const { OS << "Did not find a hash algorithm\n"; if (std::holds_alternative<StringRef>(Ret)) OS << "Reason: " << std::get<StringRef>(Ret) << "\n"; - if (std::holds_alternative<ErrBits>(Ret)) { - auto [Actual, Iter, ByteOrderSwapped] = std::get<ErrBits>(Ret); - OS << "Reason: Expected " << (ByteOrderSwapped ? "bottom " : "top ") - << Iter << " bits zero ("; - Actual.print(OS); - OS << ")\n"; - } return; } |
