diff options
| -rw-r--r-- | clang/docs/ReleaseNotes.rst | 2 | ||||
| -rw-r--r-- | clang/lib/Sema/SemaExpr.cpp | 2 | ||||
| -rw-r--r-- | clang/test/SemaCXX/reinterpret-cast.cpp | 4 | ||||
| -rw-r--r-- | llvm/include/llvm/Analysis/LoopInfo.h | 5 | ||||
| -rw-r--r-- | llvm/include/llvm/Transforms/Utils/LoopUtils.h | 3 | ||||
| -rw-r--r-- | llvm/lib/Analysis/LoopInfo.cpp | 22 | ||||
| -rw-r--r-- | llvm/lib/Transforms/Scalar/LICM.cpp | 8 | ||||
| -rw-r--r-- | llvm/test/Transforms/LICM/licm-coroutine.ll | 78 | ||||
| -rw-r--r-- | mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp | 111 | ||||
| -rw-r--r-- | mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir | 77 | ||||
| -rw-r--r-- | mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir | 72 |
11 files changed, 193 insertions, 191 deletions
diff --git a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst index e1e497ccdbcc..060f3d982b85 100644 --- a/clang/docs/ReleaseNotes.rst +++ b/clang/docs/ReleaseNotes.rst @@ -283,6 +283,8 @@ Improvements to Clang's diagnostics pointers under ``-Wthread-safety-beta`` (still experimental), which reduces both false positives but also false negatives through more precise analysis. +- Clang now looks through parenthesis for ``-Wundefined-reinterpret-cast`` diagnostic. + Improvements to Clang's time-trace ---------------------------------- diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp index aba00dc8ff9b..bd62ac623418 100644 --- a/clang/lib/Sema/SemaExpr.cpp +++ b/clang/lib/Sema/SemaExpr.cpp @@ -14784,7 +14784,7 @@ static QualType CheckIndirectionOperand(Sema &S, Expr *Op, ExprValueKind &VK, QualType OpTy = Op->getType(); QualType Result; - if (isa<CXXReinterpretCastExpr>(Op)) { + if (isa<CXXReinterpretCastExpr>(Op->IgnoreParens())) { QualType OpOrigType = Op->IgnoreParenCasts()->getType(); S.CheckCompatibleReinterpretCast(OpOrigType, OpTy, /*IsDereference*/true, Op->getSourceRange()); diff --git a/clang/test/SemaCXX/reinterpret-cast.cpp b/clang/test/SemaCXX/reinterpret-cast.cpp index bfb808773b90..10b2ed183e2a 100644 --- a/clang/test/SemaCXX/reinterpret-cast.cpp +++ b/clang/test/SemaCXX/reinterpret-cast.cpp @@ -167,6 +167,10 @@ void dereference_reinterpret_cast() { (void)reinterpret_cast<float&>(d); // expected-warning {{reinterpret_cast from 'double' to 'float &' has undefined behavior}} (void)*reinterpret_cast<float*>(&d); // expected-warning {{dereference of type 'float *' that was reinterpret_cast from type 'double *' has undefined behavior}} + // Look through parens + (void)*(reinterpret_cast<double*>(&l)); // expected-warning {{dereference of type 'double *' that was reinterpret_cast from type 'long *' has undefined behavior}} + (void)*((reinterpret_cast<double*>((&l)))); // expected-warning {{dereference of type 'double *' that was reinterpret_cast from type 'long *' has undefined behavior}} + // TODO: add warning for tag types (void)reinterpret_cast<A&>(b); (void)*reinterpret_cast<A*>(&b); diff --git a/llvm/include/llvm/Analysis/LoopInfo.h b/llvm/include/llvm/Analysis/LoopInfo.h index f80744e70f7a..a7a6a2753709 100644 --- a/llvm/include/llvm/Analysis/LoopInfo.h +++ b/llvm/include/llvm/Analysis/LoopInfo.h @@ -59,12 +59,11 @@ public: }; /// Return true if the specified value is loop invariant. - bool isLoopInvariant(const Value *V, bool HasCoroSuspendInst = false) const; + bool isLoopInvariant(const Value *V) const; /// Return true if all the operands of the specified instruction are loop /// invariant. - bool hasLoopInvariantOperands(const Instruction *I, - bool HasCoroSuspendInst = false) const; + bool hasLoopInvariantOperands(const Instruction *I) const; /// If the given value is an instruction inside of the loop and it can be /// hoisted, do so to make it trivially loop-invariant. diff --git a/llvm/include/llvm/Transforms/Utils/LoopUtils.h b/llvm/include/llvm/Transforms/Utils/LoopUtils.h index 5bef67eb021c..c5dbb2bdd1dd 100644 --- a/llvm/include/llvm/Transforms/Utils/LoopUtils.h +++ b/llvm/include/llvm/Transforms/Utils/LoopUtils.h @@ -185,8 +185,7 @@ LLVM_ABI bool hoistRegion(DomTreeNode *, AAResults *, LoopInfo *, TargetLibraryInfo *, Loop *, MemorySSAUpdater &, ScalarEvolution *, ICFLoopSafetyInfo *, SinkAndHoistLICMFlags &, OptimizationRemarkEmitter *, - bool, bool AllowSpeculation, - bool HasCoroSuspendInst = false); + bool, bool AllowSpeculation); /// Return true if the induction variable \p IV in a Loop whose latch is /// \p LatchBlock would become dead if the exit test \p Cond were removed. diff --git a/llvm/lib/Analysis/LoopInfo.cpp b/llvm/lib/Analysis/LoopInfo.cpp index 6ba6073cce95..a8c3173bb179 100644 --- a/llvm/lib/Analysis/LoopInfo.cpp +++ b/llvm/lib/Analysis/LoopInfo.cpp @@ -58,26 +58,14 @@ static cl::opt<bool, true> // Loop implementation // -bool Loop::isLoopInvariant(const Value *V, bool HasCoroSuspendInst) const { - if (const Instruction *I = dyn_cast<Instruction>(V)) { - // FIXME: this is semantically inconsistent. We're tracking a proper fix in - // issue #149604. - // If V is a pointer to stack object and L contains a coro.suspend function - // call, then V may not be loop invariant because the ramp function and - // resume function have different stack frames. - if (HasCoroSuspendInst && isa<AllocaInst>(I)) - return false; - else - return !contains(I); - } +bool Loop::isLoopInvariant(const Value *V) const { + if (const Instruction *I = dyn_cast<Instruction>(V)) + return !contains(I); return true; // All non-instructions are loop invariant } -bool Loop::hasLoopInvariantOperands(const Instruction *I, - bool HasCoroSuspendInst) const { - return all_of(I->operands(), [&](Value *V) { - return isLoopInvariant(V, HasCoroSuspendInst); - }); +bool Loop::hasLoopInvariantOperands(const Instruction *I) const { + return all_of(I->operands(), [&](Value *V) { return isLoopInvariant(V); }); } bool Loop::makeLoopInvariant(Value *V, bool &Changed, Instruction *InsertPt, diff --git a/llvm/lib/Transforms/Scalar/LICM.cpp b/llvm/lib/Transforms/Scalar/LICM.cpp index e157cc921276..40104e8fb424 100644 --- a/llvm/lib/Transforms/Scalar/LICM.cpp +++ b/llvm/lib/Transforms/Scalar/LICM.cpp @@ -474,7 +474,7 @@ bool LoopInvariantCodeMotion::runOnLoop(Loop *L, AAResults *AA, LoopInfo *LI, if (Preheader) Changed |= hoistRegion(DT->getNode(L->getHeader()), AA, LI, DT, AC, TLI, L, MSSAU, SE, &SafetyInfo, Flags, ORE, LoopNestMode, - LicmAllowSpeculation, HasCoroSuspendInst); + LicmAllowSpeculation); // Now that all loop invariants have been removed from the loop, promote any // memory references to scalars that we can. @@ -892,7 +892,7 @@ bool llvm::hoistRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI, ICFLoopSafetyInfo *SafetyInfo, SinkAndHoistLICMFlags &Flags, OptimizationRemarkEmitter *ORE, bool LoopNestMode, - bool AllowSpeculation, bool HasCoroSuspendInst) { + bool AllowSpeculation) { // Verify inputs. assert(N != nullptr && AA != nullptr && LI != nullptr && DT != nullptr && CurLoop != nullptr && SafetyInfo != nullptr && @@ -925,7 +925,7 @@ bool llvm::hoistRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI, // TODO: It may be safe to hoist if we are hoisting to a conditional block // and we have accurately duplicated the control flow from the loop header // to that block. - if (CurLoop->hasLoopInvariantOperands(&I, HasCoroSuspendInst) && + if (CurLoop->hasLoopInvariantOperands(&I) && canSinkOrHoistInst(I, AA, DT, CurLoop, MSSAU, true, Flags, ORE) && isSafeToExecuteUnconditionally(I, DT, TLI, CurLoop, SafetyInfo, ORE, Preheader->getTerminator(), AC, @@ -975,7 +975,7 @@ bool llvm::hoistRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI, SafetyInfo->doesNotWriteMemoryBefore(I, CurLoop); }; if ((IsInvariantStart(I) || isGuard(&I)) && - CurLoop->hasLoopInvariantOperands(&I, HasCoroSuspendInst) && + CurLoop->hasLoopInvariantOperands(&I) && MustExecuteWithoutWritesBefore(I)) { hoist(I, DT, CurLoop, CFH.getOrCreateHoistedBlock(BB), SafetyInfo, MSSAU, SE, ORE); diff --git a/llvm/test/Transforms/LICM/licm-coroutine.ll b/llvm/test/Transforms/LICM/licm-coroutine.ll deleted file mode 100644 index a4765acfb93f..000000000000 --- a/llvm/test/Transforms/LICM/licm-coroutine.ll +++ /dev/null @@ -1,78 +0,0 @@ -; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 -; RUN: opt < %s -passes=licm -S | FileCheck %s - -; %fca.0 and %fca.1 should not be hoisted out of the loop because the ramp -; function and resume function have different stack frames, so %pointer1 and -; %pointer2 have different values before and after @llvm.coro.suspend. - -define ptr @f(i32 %n) presplitcoroutine { -; CHECK-LABEL: define ptr @f( -; CHECK-SAME: i32 [[N:%.*]]) #[[ATTR0:[0-9]+]] { -; CHECK-NEXT: [[ENTRY:.*]]: -; CHECK-NEXT: [[POINTER1:%.*]] = alloca ptr, align 8 -; CHECK-NEXT: [[POINTER2:%.*]] = alloca ptr, align 8 -; CHECK-NEXT: [[ID:%.*]] = call token @llvm.coro.id(i32 0, ptr null, ptr null, ptr null) -; CHECK-NEXT: [[SIZE:%.*]] = call i32 @llvm.coro.size.i32() -; CHECK-NEXT: [[ALLOC:%.*]] = call ptr @malloc(i32 [[SIZE]]) -; CHECK-NEXT: [[HDL:%.*]] = call noalias ptr @llvm.coro.begin(token [[ID]], ptr [[ALLOC]]) -; CHECK-NEXT: br label %[[LOOP:.*]] -; CHECK: [[LOOP]]: -; CHECK-NEXT: [[N_VAL:%.*]] = phi i32 [ [[N]], %[[ENTRY]] ], [ [[INC:%.*]], %[[RESUME:.*]] ] -; CHECK-NEXT: [[INC]] = add nsw i32 [[N_VAL]], 1 -; CHECK-NEXT: call void @print(i32 [[N_VAL]]) -; CHECK-NEXT: [[TMP0:%.*]] = call i8 @llvm.coro.suspend(token none, i1 false) -; CHECK-NEXT: switch i8 [[TMP0]], label %[[SUSPEND_LOOPEXIT:.*]] [ -; CHECK-NEXT: i8 0, label %[[RESUME]] -; CHECK-NEXT: i8 1, label %[[CLEANUP:.*]] -; CHECK-NEXT: ] -; CHECK: [[RESUME]]: -; CHECK-NEXT: [[FCA_0:%.*]] = insertvalue [2 x ptr] poison, ptr [[POINTER1]], 0 -; CHECK-NEXT: [[FCA_1:%.*]] = insertvalue [2 x ptr] [[FCA_0]], ptr [[POINTER2]], 1 -; CHECK-NEXT: call void @foo([2 x ptr] [[FCA_1]]) -; CHECK-NEXT: br label %[[LOOP]] -; CHECK: [[CLEANUP]]: -; CHECK-NEXT: [[MEM:%.*]] = call ptr @llvm.coro.free(token [[ID]], ptr [[HDL]]) -; CHECK-NEXT: call void @free(ptr [[MEM]]) -; CHECK-NEXT: br label %[[SUSPEND:.*]] -; CHECK: [[SUSPEND_LOOPEXIT]]: -; CHECK-NEXT: br label %[[SUSPEND]] -; CHECK: [[SUSPEND]]: -; CHECK-NEXT: [[UNUSED:%.*]] = call i1 @llvm.coro.end(ptr [[HDL]], i1 false, token none) -; CHECK-NEXT: ret ptr [[HDL]] -; -entry: - %pointer1 = alloca ptr - %pointer2 = alloca ptr - %id = call token @llvm.coro.id(i32 0, ptr null, ptr null, ptr null) - %size = call i32 @llvm.coro.size.i32() - %alloc = call ptr @malloc(i32 %size) - %hdl = call noalias ptr @llvm.coro.begin(token %id, ptr %alloc) - br label %loop - -loop: - %n.val = phi i32 [ %n, %entry ], [ %inc, %resume ] - %inc = add nsw i32 %n.val, 1 - call void @print(i32 %n.val) - %0 = call i8 @llvm.coro.suspend(token none, i1 false) - switch i8 %0, label %suspend [i8 0, label %resume - i8 1, label %cleanup] - -resume: - %fca.0 = insertvalue [2 x ptr] poison, ptr %pointer1, 0 - %fca.1 = insertvalue [2 x ptr] %fca.0, ptr %pointer2, 1 - call void @foo([2 x ptr] %fca.1) - br label %loop - -cleanup: - %mem = call ptr @llvm.coro.free(token %id, ptr %hdl) - call void @free(ptr %mem) - br label %suspend -suspend: - %unused = call i1 @llvm.coro.end(ptr %hdl, i1 false, token none) - ret ptr %hdl -} - -declare void @free(ptr) -declare ptr @malloc(i32) -declare void @print(i32) -declare void @foo([2 x ptr]) diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp index 819c2e5973ff..852c322cc646 100644 --- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp +++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp @@ -180,26 +180,31 @@ static void adjustStridesForPermutation(AffineMap permMap, strides = applyPermutation(strides, perms64); } -// Computes memory strides for vector transfer operations, handling both -// static and dynamic memrefs while applying permutation transformations -// for XeGPU lowering. -static SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp, - PatternRewriter &rewriter) { +// Computes memory strides and a memref offset for vector transfer operations, +// handling both static and dynamic memrefs while applying permutation +// transformations for XeGPU lowering. +static std::pair<SmallVector<Value>, Value> +computeMemrefMeta(VectorTransferOpInterface xferOp, PatternRewriter &rewriter) { SmallVector<Value> strides; Value baseMemref = xferOp.getBase(); AffineMap permMap = xferOp.getPermutationMap(); MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType()); Location loc = xferOp.getLoc(); + Value offsetVal = nullptr; if (memrefType.hasStaticShape()) { int64_t offset; SmallVector<int64_t> intStrides; if (failed(memrefType.getStridesAndOffset(intStrides, offset))) - return {}; + return {{}, offsetVal}; // Wrap static strides as MLIR values for (int64_t s : intStrides) strides.push_back(arith::ConstantIndexOp::create(rewriter, loc, s)); - } else { + if (!ShapedType::isDynamic(offset)) + offsetVal = arith::ConstantIndexOp::create(rewriter, loc, offset); + } + + if (strides.empty() || !offsetVal) { // For dynamic shape memref, use memref.extract_strided_metadata to get // stride values unsigned rank = memrefType.getRank(); @@ -220,11 +225,16 @@ static SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp, auto meta = memref::ExtractStridedMetadataOp::create( rewriter, loc, resultTypes, baseMemref); - strides.append(meta.getStrides().begin(), meta.getStrides().end()); + + if (strides.empty()) + strides.append(meta.getStrides().begin(), meta.getStrides().end()); + + if (!offsetVal) + offsetVal = meta.getOffset(); } // Adjust strides according to the permutation map (e.g., for transpose) adjustStridesForPermutation(permMap, strides); - return strides; + return {strides, offsetVal}; } // This function compute the vectors of localOffsets for scattered load/stores. @@ -254,10 +264,10 @@ static SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp, // %23 = arith.add %20, %21 // %local_offsets = arith.add %22, %23 // %orig_offset = %block_id_y * 4x2x6x32 // consider using affine map -// %offsets = orig_offset + local_offsets +// %offsets = memref_offset + orig_offset + local_offsets static Value computeOffsets(VectorTransferOpInterface xferOp, - PatternRewriter &rewriter, - ArrayRef<Value> strides) { + PatternRewriter &rewriter, ArrayRef<Value> strides, + Value baseOffset) { Location loc = xferOp.getLoc(); VectorType vectorType = xferOp.getVectorType(); SmallVector<Value> indices(xferOp.getIndices().begin(), @@ -315,51 +325,30 @@ static Value computeOffsets(VectorTransferOpInterface xferOp, arith::AddIOp::create(rewriter, loc, localOffsets, broadcasted[i]); // Compute base offset from transfer read indices - Value baseOffset = nullptr; - if (!indices.empty()) { - baseOffset = arith::ConstantIndexOp::create(rewriter, loc, 0); - for (size_t i = 0; i < indices.size(); ++i) { - Value strideVal = strides[i]; - Value offsetContrib = - arith::MulIOp::create(rewriter, loc, indices[i], strideVal); - baseOffset = - arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib); - } - // Broadcast base offset to match vector shape - Value bcastBase = vector::BroadcastOp::create( - rewriter, loc, fullIndexVectorType, baseOffset); - localOffsets = - arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets); + for (size_t i = 0; i < indices.size(); ++i) { + Value strideVal = strides[i]; + Value offsetContrib = + arith::MulIOp::create(rewriter, loc, indices[i], strideVal); + baseOffset = + arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib); } + // Broadcast base offset to match vector shape + Value bcastBase = vector::BroadcastOp::create( + rewriter, loc, fullIndexVectorType, baseOffset); + localOffsets = arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets); return localOffsets; } -// Collapse memref shape to 1D -static Value collapseMemrefTo1D(VectorTransferOpInterface xferOp, - PatternRewriter &rewriter) { +// Convert memref to i64 base pointer +static Value memrefToIndexPtr(VectorTransferOpInterface xferOp, + PatternRewriter &rewriter) { Location loc = xferOp.getLoc(); - - Value baseMemref = xferOp.getBase(); - MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType()); - Type elementType = memrefType.getElementType(); - - // Compute the total number of elements in the memref - MemRefType flatMemrefType; - if (memrefType.hasStaticShape()) { - auto totalElements = memrefType.getNumElements(); - flatMemrefType = MemRefType::get({totalElements}, elementType); - } else { - flatMemrefType = MemRefType::get({ShapedType::kDynamic}, elementType); - } - - SmallVector<ReassociationIndices> reassociation; - ReassociationIndices allDims = - llvm::to_vector(llvm::seq<int64_t>(0, memrefType.getRank())); - reassociation.push_back(allDims); - - auto collapseOp = memref::CollapseShapeOp::create( - rewriter, loc, flatMemrefType, baseMemref, reassociation); - return collapseOp; + auto indexPtr = memref::ExtractAlignedPointerAsIndexOp::create( + rewriter, loc, xferOp.getBase()) + .getResult(); + return arith::IndexCastOp::create(rewriter, loc, rewriter.getI64Type(), + indexPtr) + .getResult(); } static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp, @@ -372,13 +361,14 @@ static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp, if (!memrefType) return rewriter.notifyMatchFailure(readOp, "Expected memref source"); - SmallVector<Value> strides = computeStrides(readOp, rewriter); - if (strides.empty()) + auto meta = computeMemrefMeta(readOp, rewriter); + if (meta.first.empty()) return rewriter.notifyMatchFailure(readOp, "Failed to compute strides"); - Value localOffsets = computeOffsets(readOp, rewriter, strides); + Value localOffsets = + computeOffsets(readOp, rewriter, meta.first, meta.second); - Value flatMemref = collapseMemrefTo1D(readOp, rewriter); + Value flatMemref = memrefToIndexPtr(readOp, rewriter); Value mask = vector::ConstantMaskOp::create( rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()), @@ -405,11 +395,14 @@ static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp, if (!memrefType) return rewriter.notifyMatchFailure(writeOp, "Expected memref source"); - SmallVector<Value> strides = computeStrides(writeOp, rewriter); + auto meta = computeMemrefMeta(writeOp, rewriter); + if (meta.first.empty()) + return rewriter.notifyMatchFailure(writeOp, "Failed to compute strides"); - Value localOffsets = computeOffsets(writeOp, rewriter, strides); + Value localOffsets = + computeOffsets(writeOp, rewriter, meta.first, meta.second); - Value flatMemref = collapseMemrefTo1D(writeOp, rewriter); + Value flatMemref = memrefToIndexPtr(writeOp, rewriter); Value mask = vector::ConstantMaskOp::create( rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()), diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir index b373bdab8056..c4ca79af1bd9 100644 --- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir +++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir @@ -27,8 +27,9 @@ gpu.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vector // LOAD-GATHER-COUNT2: arith.addi {{.*}} : index // LOAD-GATHER: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex> // LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex> -// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32> -// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<4096xf32>, vector<8xindex>, vector<8xi1> -> vector<8xf32> +// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index +// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 +// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf32> } @@ -62,8 +63,9 @@ gpu.func @load_2D_vector(%source: memref<8x16x32xf32>, // LOAD-GATHER-COUNT2: arith.addi {{.*}} : index // LOAD-GATHER: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex> // LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], {{.*}}: vector<8x16xindex> -// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32> -// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<4096xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32> +// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index +// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 +// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32> } @@ -124,8 +126,9 @@ gpu.func @load_transposed(%source: memref<32x64xf32>, // LOAD-GATHER-COUNT2: arith.addi {{.*}} : index // LOAD-GATHER: %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex> // LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}}: vector<8x16xindex> -// LOAD-GATHER: %[[COLLAPSE:.*]] = memref.collapse_shape %arg0 {{\[\[}}0, 1{{\]\]}} : memref<32x64xf32> into memref<2048xf32> -// LOAD-GATHER: %[[LOAD:.*]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<2048xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32> +// LOAD-GATHER: %[[COLLAPSE:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<32x64xf32> -> index +// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 +// LOAD-GATHER: %[[LOAD:.*]] = xegpu.load %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32> } @@ -164,8 +167,9 @@ gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>, // LOAD-GATHER-COUNT2: arith.addi {{.*}} : index // LOAD-GATHER: %[[BROADIDX:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex> // LOAD-GATHER: %[[FINALIDX:.+]] = arith.addi %[[BROADIDX]], {{.*}} : vector<8x16xindex> -// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2]{{\]}} : memref<?x?x?xf32> into memref<?xf32> -// LOAD-GATHER: %[[RES:.+]] = xegpu.load %[[COLLAPSE]][%[[FINALIDX]]], %[[CST]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32> +// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<?x?x?xf32> -> index +// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 +// LOAD-GATHER: %[[RES:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[FINALIDX]]{{\]}}, %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32> // LOAD-GATHER: gpu.return %[[RES]] : vector<8x16xf32> } @@ -195,8 +199,9 @@ gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>, // LOAD-GATHER-COUNT2: arith.addi {{.*}} : index // LOAD-GATHER-DAG: %[[BCASTIDX:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex> // LOAD-GATHER-DAG: %[[OFFSETS:.+]] = arith.addi %[[BCASTIDX]], {{.*}} : vector<8x16xindex> -// LOAD-GATHER-DAG: %[[COLLAPSE:.+]] = memref.collapse_shape %arg0 {{\[}}[0, 1, 2]{{\]}} : memref<?x8x16xf32> into memref<?xf32> -// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[OFFSETS]]{{\]}}, %[[CST_0]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32> +// LOAD-GATHER-DAG: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %arg0 : memref<?x8x16xf32> -> index +// LOAD-GATHER-DAG: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 +// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[OFFSETS]]{{\]}}, %[[CST_0]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32> } @@ -224,8 +229,9 @@ gpu.func @load_dynamic_source3(%source: memref<?x?x?x?x?xf32>, // LOAD-GATHER-COUNT3: arith.addi {{.*}} : vector<2x4x8x16xindex> // LOAD-GATHER: %[[SPLAT:.+]] = vector.broadcast {{.*}} : index to vector<2x4x8x16xindex> // LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} : vector<2x4x8x16xindex> -// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2, 3, 4]{{\]}} : memref<?x?x?x?x?xf32> into memref<?xf32> -// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<?xf32>, vector<2x4x8x16xindex>, vector<2x4x8x16xi1> -> vector<2x4x8x16xf32> +// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<?x?x?x?x?xf32> -> index +// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 +// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<2x4x8x16xindex>, vector<2x4x8x16xi1> -> vector<2x4x8x16xf32> // LOAD-GATHER: return %[[VEC]] } @@ -254,8 +260,9 @@ gpu.func @load_high_dim_vector(%source: memref<16x32x64xf32>, // LOAD-GATHER-COUNT2: arith.addi {{.*}} : vector<8x16x32xindex> // LOAD-GATHER: %[[BCASTOFF:.+]] = vector.broadcast {{.*}} : index to vector<8x16x32xindex> // LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[BCASTOFF]], {{.*}} : vector<8x16x32xindex> -// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %arg0 {{\[}}[0, 1, 2]{{\]}} : memref<16x32x64xf32> into memref<32768xf32> -// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<32768xf32>, vector<8x16x32xindex>, vector<8x16x32xi1> -> vector<8x16x32xf32> +// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %arg0 : memref<16x32x64xf32> -> index +// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 +// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : i64, vector<8x16x32xindex>, vector<8x16x32xi1> -> vector<8x16x32xf32> } @@ -283,8 +290,9 @@ gpu.func @load_transpose_f16(%source: memref<32x64xf16>, // LOAD-GATHER-COUNT2: arith.addi {{.*}} : index // LOAD-GATHER: %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex> // LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}}: vector<8x16xindex> -// LOAD-GATHER: %[[COLLAPSE:.*]] = memref.collapse_shape %arg0 {{\[\[}}0, 1{{\]\]}} : memref<32x64xf16> into memref<2048xf16> -// LOAD-GATHER: %[[LOAD:.*]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<2048xf16>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf16> +// LOAD-GATHER: %[[COLLAPSE:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<32x64xf16> -> index +// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 +// LOAD-GATHER: %[[LOAD:.*]] = xegpu.load %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf16> } // ----- @@ -396,3 +404,40 @@ gpu.func @no_load_unsupported_map(%source: memref<16x32x64xf32>, // LOAD-GATHER: vector.transfer_read } +// ----- +gpu.module @xevm_module { +gpu.func @load_from_subview(%source: memref<4096x4096xf16>, %off1: index, %off2: index) -> vector<8xf16> { + %c0 = arith.constant 0.0 : f16 + %subview = memref.subview %source[%off1, %off2] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>> + %0 = vector.transfer_read %subview[%off2, %off2], %c0 + {in_bounds = [true]} : memref<256x256xf16, strided<[4096, 1], offset: ?>>, vector<8xf16> + gpu.return %0 : vector<8xf16> +} + +// LOAD-ND-LABEL: @load_from_subview( +// LOAD-ND-SAME: %[[SRC:.+]]: memref<4096x4096xf16>, +// LOAD-ND-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index +// LOAD-ND: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>> +// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc +// LOAD-ND-SAME: %[[SUBVIEW]][%[[OFF2]], %[[OFF2]]] +// LOAD-ND-SAME: memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8xf16, +// LOAD-ND-SAME: boundary_check = false +// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf16> +// LOAD-ND: return %[[VEC]] + +// LOAD-GATHER-LABEL: @load_from_subview( +// LOAD-GATHER-SAME: %[[SRC:.+]]: memref<4096x4096xf16>, +// LOAD-GATHER-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index +// LOAD-GATHER: %[[CST:.+]] = arith.constant dense<true> : vector<8xi1> +// LOAD-GATHER: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>> +// LOAD-GATHER: %[[BB:.+]], %[[OFFSET:.+]],{{.*}},{{.*}} = memref.extract_strided_metadata %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> memref<f16>, index, index, index, index, index +// LOAD-GATHER: %[[STEP:.+]] = vector.step : vector<8xindex> +// LOAD-GATHER: arith.muli {{.*}} : index +// LOAD-GATHER: arith.addi %[[OFFSET]]{{.*}} : index +// LOAD-GATHER: arith.addi {{.*}} : index +// LOAD-GATHER: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex> +// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex> +// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> index +// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 +// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf16> +} diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir index b3f761a545ee..fcfc9414da4f 100644 --- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir +++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir @@ -30,8 +30,9 @@ gpu.func @store_1D_vector(%vec: vector<8xf32>, // STORE-SCATTER-COUNT2: arith.addi {{.*}} : index // STORE-SCATTER-DAG: %[[BCAST:.+]] = vector.broadcast {{.*}} : index to vector<8xindex> // STORE-SCATTER-DAG: %[[IDX:.+]] = arith.addi %[[BCAST]], %{{.*}} : vector<8xindex> -// STORE-SCATTER-DAG: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32> -// STORE-SCATTER: xegpu.store %[[VEC]], %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8xf32>, memref<4096xf32>, vector<8xindex>, vector<8xi1> +// STORE-SCATTER-DAG: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index +// STORE-SCATTER-DAG: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 +// STORE-SCATTER: xegpu.store %[[VEC]], %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8xf32>, i64, vector<8xindex>, vector<8xi1> } // ----- @@ -64,8 +65,9 @@ gpu.func @store_2D_vector(%vec: vector<8x16xf32>, // STORE-SCATTER-COUNT2: vector.broadcast {{.*}} : vector<8x16xindex> // STORE-SCATTER-DAG: %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex> // STORE-SCATTER-DAG: %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}} : vector<8x16xindex> -// STORE-SCATTER-DAG: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32> -// STORE-SCATTER: xegpu.store %[[VEC]], %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8x16xf32>, memref<4096xf32>, vector<8x16xindex>, vector<8x16xi1> +// STORE-SCATTER-DAG: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index +// STORE-SCATTER-DAG: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 +// STORE-SCATTER: xegpu.store %[[VEC]], %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8x16xf32>, i64, vector<8x16xindex>, vector<8x16xi1> } // ----- @@ -104,8 +106,9 @@ gpu.func @store_dynamic_source(%vec: vector<8x16xf32>, // STORE-SCATTER-COUNT2: vector.broadcast {{.*}} : vector<8x16xindex> // STORE-SCATTER-DAG: %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex> // STORE-SCATTER-DAG: %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}} : vector<8x16xindex> -// STORE-SCATTER-DAG: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<?x?x?xf32> into memref<?xf32> -// STORE-SCATTER: xegpu.store %[[VEC]], %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8x16xf32>, memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> +// STORE-SCATTER-DAG: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<?x?x?xf32> -> index +// STORE-SCATTER-DAG: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 +// STORE-SCATTER: xegpu.store %[[VEC]], %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8x16xf32>, i64, vector<8x16xindex>, vector<8x16xi1> } // ----- @@ -155,8 +158,9 @@ gpu.func @no_store_transposed(%vec: vector<8x16xf32>, // STORE-SCATTER-COUNT2: vector.broadcast {{.*}} : vector<8x16xindex> // STORE-SCATTER-DAG: %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex> // STORE-SCATTER-DAG: %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}} : vector<8x16xindex> -// STORE-SCATTER-DAG: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1]{{\]}} : memref<32x64xf32> into memref<2048xf32> -// STORE-SCATTER: xegpu.store %[[VEC]], %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8x16xf32>, memref<2048xf32>, vector<8x16xindex>, vector<8x16xi1> +// STORE-SCATTER-DAG: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<32x64xf32> -> index +// STORE-SCATTER-DAG: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 +// STORE-SCATTER: xegpu.store %[[VEC]], %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8x16xf32>, i64, vector<8x16xindex>, vector<8x16xi1> } // ----- @@ -186,8 +190,9 @@ gpu.func @store_high_dim_vector(%vec: vector<8x16x32xf32>, // STORE-SCATTER-COUNT2: arith.addi {{.*}} : vector<8x16x32xindex> // STORE-SCATTER: %[[BCASTOFF:.+]] = vector.broadcast {{.*}} : index to vector<8x16x32xindex> // STORE-SCATTER: %[[IDX:.+]] = arith.addi %[[BCASTOFF]], {{.*}} : vector<8x16x32xindex> -// STORE-SCATTER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<16x32x64xf32> into memref<32768xf32> -// STORE-SCATTER: xegpu.store %[[VEC]], %[[COLLAPSE]][%[[IDX]]], %[[CST]] : vector<8x16x32xf32>, memref<32768xf32>, vector<8x16x32xindex>, vector<8x16x32xi1> +// STORE-SCATTER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<16x32x64xf32> -> index +// STORE-SCATTER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 +// STORE-SCATTER: xegpu.store %[[VEC]], %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : vector<8x16x32xf32>, i64, vector<8x16x32xindex>, vector<8x16x32xi1> } // ----- @@ -275,4 +280,49 @@ gpu.func @no_store_out_of_bounds_1D_vector(%vec: vector<8xf32>, // STORE-SCATTER-LABEL: @no_store_out_of_bounds_1D_vector( // STORE-SCATTER: vector.transfer_write -}
\ No newline at end of file +} + +// ----- +gpu.module @xevm_module { +gpu.func @store_to_subview(%vec: vector<8xf16>, + %source: memref<4096x4096xf16>, %off1: index, %off2: index) { + %subview = memref.subview %source[%off1, %off2] [256, 256] [1, 1] + : memref<4096x4096xf16> + to memref<256x256xf16, strided<[4096, 1], offset: ?>> + vector.transfer_write %vec, %subview[%off2, %off2] + {in_bounds = [true]} + : vector<8xf16>, memref<256x256xf16, strided<[4096, 1], offset: ?>> + gpu.return +} +// STORE-ND-LABEL: @store_to_subview( +// STORE-ND-SAME: %[[VEC:.+]]: vector<8xf16>, +// STORE-ND-SAME: %[[SRC:.+]]: memref<4096x4096xf16>, +// STORE-ND-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index +// STORE-ND: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] +// STORE-ND-SAME: : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>> +// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc +// STORE-ND-SAME: %[[SUBVIEW]][%[[OFF2]], %[[OFF2]]] +// STORE-ND-SAME: memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8xf16, +// STORE-ND-SAME: boundary_check = false +// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8xf16> + +// STORE-SCATTER-LABEL: @store_to_subview( +// STORE-SCATTER-SAME: %[[VEC:.+]]: vector<8xf16>, +// STORE-SCATTER-SAME: %[[SRC:.+]]: memref<4096x4096xf16>, +// STORE-SCATTER-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index +// STORE-SCATTER: %[[CST:.+]] = arith.constant dense<true> : vector<8xi1> +// STORE-SCATTER: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] +// STORE-SCATTER-SAME: : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>> +// STORE-SCATTER: %[[BB:.+]], %[[OFFSET:.+]], {{.*}}, {{.*}} = memref.extract_strided_metadata %[[SUBVIEW]] +// STORE-SCATTER-SAME: : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> memref<f16>, index, index, index, index, index +// STORE-SCATTER: %[[STEP:.+]] = vector.step : vector<8xindex> +// STORE-SCATTER: arith.muli {{.*}} : index +// STORE-SCATTER: arith.addi %[[OFFSET]]{{.*}} : index +// STORE-SCATTER: arith.addi {{.*}} : index +// STORE-SCATTER: %[[SPLAT:.+]] = vector.broadcast {{.*}} : index to vector<8xindex> +// STORE-SCATTER: %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex> +// STORE-SCATTER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SUBVIEW]] +// STORE-SCATTER-SAME: : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> index +// STORE-SCATTER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 +// STORE-SCATTER: xegpu.store %[[VEC]], %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8xf16>, i64, vector<8xindex>, vector<8xi1> +} |
