summaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp')
-rw-r--r--mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp45
1 files changed, 26 insertions, 19 deletions
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index d8150aeb828a..6656be830989 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -40,31 +40,35 @@ Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf,
Value max = b.create<arith::MaximumFOp>(absReal, absImag, fmf);
Value min = b.create<arith::MinimumFOp>(absReal, absImag, fmf);
- Value ratio = b.create<arith::DivFOp>(min, max, fmf);
- Value ratioSq = b.create<arith::MulFOp>(ratio, ratio, fmf);
- Value ratioSqPlusOne = b.create<arith::AddFOp>(ratioSq, one, fmf);
+
+ // The lowering below requires NaNs and infinities to work correctly.
+ arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
+ fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
+ Value ratio = b.create<arith::DivFOp>(min, max, fmfWithNaNInf);
+ Value ratioSq = b.create<arith::MulFOp>(ratio, ratio, fmfWithNaNInf);
+ Value ratioSqPlusOne = b.create<arith::AddFOp>(ratioSq, one, fmfWithNaNInf);
Value result;
if (fn == AbsFn::rsqrt) {
- ratioSqPlusOne = b.create<math::RsqrtOp>(ratioSqPlusOne, fmf);
- min = b.create<math::RsqrtOp>(min, fmf);
- max = b.create<math::RsqrtOp>(max, fmf);
+ ratioSqPlusOne = b.create<math::RsqrtOp>(ratioSqPlusOne, fmfWithNaNInf);
+ min = b.create<math::RsqrtOp>(min, fmfWithNaNInf);
+ max = b.create<math::RsqrtOp>(max, fmfWithNaNInf);
}
if (fn == AbsFn::sqrt) {
Value quarter = b.create<arith::ConstantOp>(
real.getType(), b.getFloatAttr(real.getType(), 0.25));
// sqrt(sqrt(a*b)) would avoid the pow, but will overflow more easily.
- Value sqrt = b.create<math::SqrtOp>(max, fmf);
- Value p025 = b.create<math::PowFOp>(ratioSqPlusOne, quarter, fmf);
- result = b.create<arith::MulFOp>(sqrt, p025, fmf);
+ Value sqrt = b.create<math::SqrtOp>(max, fmfWithNaNInf);
+ Value p025 = b.create<math::PowFOp>(ratioSqPlusOne, quarter, fmfWithNaNInf);
+ result = b.create<arith::MulFOp>(sqrt, p025, fmfWithNaNInf);
} else {
- Value sqrt = b.create<math::SqrtOp>(ratioSqPlusOne, fmf);
- result = b.create<arith::MulFOp>(max, sqrt, fmf);
+ Value sqrt = b.create<math::SqrtOp>(ratioSqPlusOne, fmfWithNaNInf);
+ result = b.create<arith::MulFOp>(max, sqrt, fmfWithNaNInf);
}
- Value isNaN =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, result, result, fmf);
+ Value isNaN = b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, result,
+ result, fmfWithNaNInf);
return b.create<arith::SelectOp>(isNaN, min, result);
}
@@ -595,17 +599,20 @@ struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
Value maxMinusOne = b.create<arith::SubFOp>(maxAbs, one, fmf);
Value maxAbsOfRealPlusOneAndImagMinusOne =
b.create<arith::SelectOp>(useReal, real, maxMinusOne);
- Value minMaxRatio = b.create<arith::DivFOp>(minAbs, maxAbs, fmf);
+ arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
+ fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
+ Value minMaxRatio = b.create<arith::DivFOp>(minAbs, maxAbs, fmfWithNaNInf);
Value logOfMaxAbsOfRealPlusOneAndImag =
b.create<math::Log1pOp>(maxAbsOfRealPlusOneAndImagMinusOne, fmf);
Value logOfSqrtPart = b.create<math::Log1pOp>(
- b.create<arith::MulFOp>(minMaxRatio, minMaxRatio, fmf), fmf);
+ b.create<arith::MulFOp>(minMaxRatio, minMaxRatio, fmfWithNaNInf),
+ fmfWithNaNInf);
Value r = b.create<arith::AddFOp>(
- b.create<arith::MulFOp>(half, logOfSqrtPart, fmf),
- logOfMaxAbsOfRealPlusOneAndImag, fmf);
+ b.create<arith::MulFOp>(half, logOfSqrtPart, fmfWithNaNInf),
+ logOfMaxAbsOfRealPlusOneAndImag, fmfWithNaNInf);
Value resultReal = b.create<arith::SelectOp>(
- b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, r, r, fmf), minAbs,
- r);
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, r, r, fmfWithNaNInf),
+ minAbs, r);
Value resultImag = b.create<math::Atan2Op>(imag, realPlusOne, fmf);
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
resultImag);