diff options
Diffstat (limited to 'mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp')
| -rw-r--r-- | mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 66 |
1 files changed, 63 insertions, 3 deletions
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 92168cfa3614..8baa31a23595 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -618,6 +618,66 @@ struct VectorInterleaveOpConvert final } }; +struct VectorDeinterleaveOpConvert final + : public OpConversionPattern<vector::DeinterleaveOp> { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + // Check the result vector type. + VectorType oldResultType = deinterleaveOp.getResultVectorType(); + Type newResultType = getTypeConverter()->convertType(oldResultType); + if (!newResultType) + return rewriter.notifyMatchFailure(deinterleaveOp, + "unsupported result vector type"); + + Location loc = deinterleaveOp->getLoc(); + + // Deinterleave the indices. + Value sourceVector = adaptor.getSource(); + VectorType sourceType = deinterleaveOp.getSourceVectorType(); + int n = sourceType.getNumElements(); + + // Output vectors of size 1 are converted to scalars by the type converter. + // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to + // use `spirv::CompositeExtractOp`. + if (n == 2) { + auto elem0 = rewriter.create<spirv::CompositeExtractOp>( + loc, newResultType, sourceVector, rewriter.getI32ArrayAttr({0})); + + auto elem1 = rewriter.create<spirv::CompositeExtractOp>( + loc, newResultType, sourceVector, rewriter.getI32ArrayAttr({1})); + + rewriter.replaceOp(deinterleaveOp, {elem0, elem1}); + return success(); + } + + // Indices for `shuffleEven` (result 0). + auto seqEven = llvm::seq<int64_t>(n / 2); + auto indicesEven = + llvm::map_to_vector(seqEven, [](int i) { return i * 2; }); + + // Indices for `shuffleOdd` (result 1). + auto seqOdd = llvm::seq<int64_t>(n / 2); + auto indicesOdd = + llvm::map_to_vector(seqOdd, [](int i) { return i * 2 + 1; }); + + // Create two SPIR-V shuffles. + auto shuffleEven = rewriter.create<spirv::VectorShuffleOp>( + loc, newResultType, sourceVector, sourceVector, + rewriter.getI32ArrayAttr(indicesEven)); + + auto shuffleOdd = rewriter.create<spirv::VectorShuffleOp>( + loc, newResultType, sourceVector, sourceVector, + rewriter.getI32ArrayAttr(indicesOdd)); + + rewriter.replaceOp(deinterleaveOp, {shuffleEven, shuffleOdd}); + return success(); + } +}; + struct VectorLoadOpConverter final : public OpConversionPattern<vector::LoadOp> { using OpConversionPattern::OpConversionPattern; @@ -862,9 +922,9 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter, VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>, VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast, VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert, - VectorInterleaveOpConvert, VectorSplatPattern, VectorLoadOpConverter, - VectorStoreOpConverter>(typeConverter, patterns.getContext(), - PatternBenefit(1)); + VectorInterleaveOpConvert, VectorDeinterleaveOpConvert, + VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>( + typeConverter, patterns.getContext(), PatternBenefit(1)); // Make sure that the more specialized dot product pattern has higher benefit // than the generic one that extracts all elements. |
