diff options
Diffstat (limited to 'mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp')
| -rw-r--r-- | mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp | 45 |
1 files changed, 26 insertions, 19 deletions
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp index d8150aeb828a..6656be830989 100644 --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -40,31 +40,35 @@ Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf, Value max = b.create<arith::MaximumFOp>(absReal, absImag, fmf); Value min = b.create<arith::MinimumFOp>(absReal, absImag, fmf); - Value ratio = b.create<arith::DivFOp>(min, max, fmf); - Value ratioSq = b.create<arith::MulFOp>(ratio, ratio, fmf); - Value ratioSqPlusOne = b.create<arith::AddFOp>(ratioSq, one, fmf); + + // The lowering below requires NaNs and infinities to work correctly. + arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear( + fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf); + Value ratio = b.create<arith::DivFOp>(min, max, fmfWithNaNInf); + Value ratioSq = b.create<arith::MulFOp>(ratio, ratio, fmfWithNaNInf); + Value ratioSqPlusOne = b.create<arith::AddFOp>(ratioSq, one, fmfWithNaNInf); Value result; if (fn == AbsFn::rsqrt) { - ratioSqPlusOne = b.create<math::RsqrtOp>(ratioSqPlusOne, fmf); - min = b.create<math::RsqrtOp>(min, fmf); - max = b.create<math::RsqrtOp>(max, fmf); + ratioSqPlusOne = b.create<math::RsqrtOp>(ratioSqPlusOne, fmfWithNaNInf); + min = b.create<math::RsqrtOp>(min, fmfWithNaNInf); + max = b.create<math::RsqrtOp>(max, fmfWithNaNInf); } if (fn == AbsFn::sqrt) { Value quarter = b.create<arith::ConstantOp>( real.getType(), b.getFloatAttr(real.getType(), 0.25)); // sqrt(sqrt(a*b)) would avoid the pow, but will overflow more easily. - Value sqrt = b.create<math::SqrtOp>(max, fmf); - Value p025 = b.create<math::PowFOp>(ratioSqPlusOne, quarter, fmf); - result = b.create<arith::MulFOp>(sqrt, p025, fmf); + Value sqrt = b.create<math::SqrtOp>(max, fmfWithNaNInf); + Value p025 = b.create<math::PowFOp>(ratioSqPlusOne, quarter, fmfWithNaNInf); + result = b.create<arith::MulFOp>(sqrt, p025, fmfWithNaNInf); } else { - Value sqrt = b.create<math::SqrtOp>(ratioSqPlusOne, fmf); - result = b.create<arith::MulFOp>(max, sqrt, fmf); + Value sqrt = b.create<math::SqrtOp>(ratioSqPlusOne, fmfWithNaNInf); + result = b.create<arith::MulFOp>(max, sqrt, fmfWithNaNInf); } - Value isNaN = - b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, result, result, fmf); + Value isNaN = b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, result, + result, fmfWithNaNInf); return b.create<arith::SelectOp>(isNaN, min, result); } @@ -595,17 +599,20 @@ struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> { Value maxMinusOne = b.create<arith::SubFOp>(maxAbs, one, fmf); Value maxAbsOfRealPlusOneAndImagMinusOne = b.create<arith::SelectOp>(useReal, real, maxMinusOne); - Value minMaxRatio = b.create<arith::DivFOp>(minAbs, maxAbs, fmf); + arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear( + fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf); + Value minMaxRatio = b.create<arith::DivFOp>(minAbs, maxAbs, fmfWithNaNInf); Value logOfMaxAbsOfRealPlusOneAndImag = b.create<math::Log1pOp>(maxAbsOfRealPlusOneAndImagMinusOne, fmf); Value logOfSqrtPart = b.create<math::Log1pOp>( - b.create<arith::MulFOp>(minMaxRatio, minMaxRatio, fmf), fmf); + b.create<arith::MulFOp>(minMaxRatio, minMaxRatio, fmfWithNaNInf), + fmfWithNaNInf); Value r = b.create<arith::AddFOp>( - b.create<arith::MulFOp>(half, logOfSqrtPart, fmf), - logOfMaxAbsOfRealPlusOneAndImag, fmf); + b.create<arith::MulFOp>(half, logOfSqrtPart, fmfWithNaNInf), + logOfMaxAbsOfRealPlusOneAndImag, fmfWithNaNInf); Value resultReal = b.create<arith::SelectOp>( - b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, r, r, fmf), minAbs, - r); + b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, r, r, fmfWithNaNInf), + minAbs, r); Value resultImag = b.create<math::Atan2Op>(imag, realPlusOne, fmf); rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, resultImag); |
