blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v1.hpp Source File

blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v1.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v1.hpp Source File
blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v1.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9
10// Naive pipeline with lowest resource request per WGP
11// GlobalPrefetchStages: 1
12// LocalPreFillStages: 1
13// LocalPreFetchStages: 0
14// LocalSharedMemoryBuffer: 1
15
16template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17 index_t ThreadBlockSize,
18 index_t ScaleBlockSize,
19 typename ADataType,
20 typename AScaleDataType,
21 typename BDataType,
22 typename BScaleDataType,
23 typename ATileDesc,
24 typename BTileDesc,
25 typename AMmaTileDesc,
26 typename BMmaTileDesc,
27 index_t ABlockTransferSrcScalarPerVector,
28 index_t BBlockTransferSrcScalarPerVector,
29 index_t MPerBlock,
30 index_t NPerBlock,
31 index_t KPerBlock,
32 index_t MPerXDL,
33 index_t NPerXDL,
34 index_t MRepeat, // MXdlPerWave
35 index_t NRepeat, // NXdlPerWave
36 index_t KPack>
40
41template <index_t ThreadBlockSize,
42 index_t ScaleBlockSize,
43 typename ADataType,
44 typename AScaleDataType,
45 typename BDataType,
46 typename BScaleDataType,
47 typename ATileDesc,
48 typename BTileDesc,
49 typename AMmaTileDesc,
50 typename BMmaTileDesc,
51 index_t ABlockTransferSrcScalarPerVector,
52 index_t BBlockTransferSrcScalarPerVector,
53 index_t MPerBlock,
54 index_t NPerBlock,
55 index_t KPerBlock,
56 index_t MPerXDL,
57 index_t NPerXDL,
58 index_t MRepeat, // MXdlPerWave
59 index_t NRepeat, // NXdlPerWave
60 index_t KPack>
62 ThreadBlockSize,
63 ScaleBlockSize,
64 ADataType,
65 AScaleDataType,
66 BDataType,
67 BScaleDataType,
68 ATileDesc,
69 BTileDesc,
70 AMmaTileDesc,
71 BMmaTileDesc,
72 ABlockTransferSrcScalarPerVector,
73 BBlockTransferSrcScalarPerVector,
74 MPerBlock,
75 NPerBlock,
76 KPerBlock,
77 MPerXDL,
78 NPerXDL,
79 MRepeat,
80 NRepeat,
81 KPack>
83 ADataType,
84 BDataType,
85 ATileDesc,
86 BTileDesc,
87 AMmaTileDesc,
88 BMmaTileDesc,
89 ABlockTransferSrcScalarPerVector,
90 BBlockTransferSrcScalarPerVector,
91 MPerBlock,
92 NPerBlock,
93 KPerBlock,
94 MPerXDL,
95 NPerXDL,
96 MRepeat,
97 NRepeat,
98 KPack>
99
100{
101
103 ADataType,
104 BDataType,
105 ATileDesc,
106 BTileDesc,
107 AMmaTileDesc,
108 BMmaTileDesc,
109 ABlockTransferSrcScalarPerVector,
110 BBlockTransferSrcScalarPerVector,
111 MPerBlock,
112 NPerBlock,
113 KPerBlock,
114 MPerXDL,
115 NPerXDL,
116 MRepeat,
117 NRepeat,
118 KPack>;
119 using Base::I0;
120 using Base::I1;
121 using Base::KRepeat;
122 using Base::MWaves;
123 using Base::NWaves;
124 using Base::WaveSize;
125 using Base::xdlops_gemm;
126 using typename Base::HotLoopInstList;
127
136 using Base::GetWaveIdx;
139
142
143 using Base::AMmaKStride;
144 using Base::APackedSize;
145 using Base::BMmaKStride;
146 using Base::BPackedSize;
147 using Base::KThreadChunk;
148
149 using Base::KXdlPack;
150 using Base::MXdlPack;
151 using Base::NXdlPack;
152
153 using AccType = typename Base::AccType;
154 using Tuple5 = typename Base::Tuple5;
157
158 static constexpr index_t PrefetchStages = 1;
159 static constexpr index_t PrefillStages = 1;
160 static constexpr index_t GlobalBufferNum = 1;
161
162 static constexpr auto ScalesPerKBlockSize =
163 KPerBlock / ScaleBlockSize; // How many mx-vectors per K block
164
165 //> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run()
166 static constexpr auto ScalesPerXdlopsRun =
167 (APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize;
168
169 //> How many scales a thread must read to accommodate one call to xdlops_gemm.Run()
170 static constexpr auto ScalesPerXdlopsRunPerThread =
171 ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks;
172
174 static constexpr auto scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
175 static constexpr auto scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
176 static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
177 "A scale pack data type too large!");
178 static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
179 "B scale pack data type too large!");
182
183 __host__ static constexpr bool BlockHasHotloop(index_t num_loop)
184 {
185 return num_loop > PrefetchStages;
186 }
187
188 __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
189 {
190 return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
191 }
192
193 template <bool HasMainLoop,
194 TailNumber TailNum,
195 typename AGridDesc,
196 typename ABlockDesc,
197 typename ABlockTransfer,
198 typename AGridBuffer,
199 typename ABlockBuffer,
200 typename ABlockTransferStep,
201 typename BGridDesc,
202 typename BBlockDesc,
203 typename BBlockTransfer,
204 typename BGridBuffer,
205 typename BBlockBuffer,
206 typename BBlockTransferStep,
207 typename CThreadBuffer,
208 typename AScaleGridBuffer,
209 typename AScaleGridDesc,
210 typename AScaleThreadTransfer,
211 typename BScaleGridBuffer,
212 typename BScaleGridDesc,
213 typename BScaleThreadTransfer>
214 __device__ void Run(
215 // ABlockCopy
216 const AGridDesc& a_grid_desc,
217 const ABlockDesc& a_block_desc,
218 ABlockTransfer& a_blockwise_copy,
219 const AGridBuffer& a_grid_buf,
220 ABlockBuffer& a_block_buf,
221 const ABlockTransferStep& a_block_copy_step,
222 // BBlockCopy
223 const BGridDesc& b_grid_desc,
224 const BBlockDesc& b_block_desc,
225 BBlockTransfer& b_blockwise_copy,
226 const BGridBuffer& b_grid_buf,
227 BBlockBuffer& b_block_buf,
228 const BBlockTransferStep& b_block_copy_step,
229 // CThread
230 CThreadBuffer& c_thread_buf,
231 // A and B scales
232 const AScaleGridDesc& a_scale_grid_desc,
233 AScaleThreadTransfer& a_scale_thread_copy,
234 const AScaleGridBuffer& a_scale_grid_buf,
235 const BScaleGridDesc& b_scale_grid_desc,
236 BScaleThreadTransfer& b_scale_thread_copy,
237 const BScaleGridBuffer& b_scale_grid_buf,
238 index_t num_loop) const
239 {
241 a_thread_desc_.GetElementSpaceSize());
243 b_thread_desc_.GetElementSpaceSize());
244
246 a_scale_thread_desc.GetElementSpaceSize());
247
249 b_scale_thread_desc.GetElementSpaceSize());
250
251 // Global prefetch 1
252 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
253 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
254
255 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
256 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
257
258 // Prefetch a_scales
259 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
260 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
261 a_scale_thread_copy.Run(a_scale_grid_desc,
262 a_scale_grid_buf,
264 make_tuple(m0, k0, I0),
265 a_scale_thread_buf);
266
267 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
268 make_multi_index(0, I1, 0));
269 });
270 a_scale_thread_copy.MoveSrcSliceWindow(
271 a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
272 });
273
274 // restore row id and advance to the next set of scales
275 a_scale_thread_copy.MoveSrcSliceWindow(
276 a_scale_grid_desc,
277 make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0));
278
279 // Prefetch b_scales
280 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
281 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
282 b_scale_thread_copy.Run(b_scale_grid_desc,
283 b_scale_grid_buf,
285 make_tuple(n0, k0, I0),
286 b_scale_thread_buf);
287
288 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
289 make_multi_index(0, I1, 0));
290 });
291 b_scale_thread_copy.MoveSrcSliceWindow(
292 b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
293 });
294
295 // restore col id and advance to the next set of scales
296 // NWaves * NPerXDL * NRepeat == NPerBlock
297 b_scale_thread_copy.MoveSrcSliceWindow(
298 b_scale_grid_desc,
299 make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
300
301 // Local prefill 1
302 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
303 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
304
305 // Initialize C
306 c_thread_buf.Clear();
307
308 // main body
309 if constexpr(HasMainLoop)
310 {
311 // loop over k with the step KPerBlock
312 index_t i = 0;
313 do
314 {
315 // -------------------------------------------------------------------------------------------
316 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
317 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
318
319 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
320 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
321
323
324 static_for<0, KRepeat, 1>{}([&](auto k) {
325 constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
326 (APackedSize * KPack / xdlops_gemm.K1PerXdlops);
327 static_for<0, MRepeat, 1>{}([&](auto m0) {
328 static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
329 [&](auto chunk) {
330 constexpr auto a_k_step_chunk =
331 k_step +
332 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
335 I0,
337 I0,
339 a_block_buf,
342 I0,
344 k,
346 a_thread_buf);
347 });
348 });
349 static_for<0, NRepeat, 1>{}([&](auto n0) {
350 // read block data in chunks to assemble correct thread vectors
351 static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}(
352 [&](auto chunk) {
353 constexpr auto b_k_step_chunk =
354 k_step +
355 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
358 I0,
360 I0,
362 b_block_buf,
365 I0,
367 k,
369 b_thread_buf);
370 });
371 });
372 });
373
374 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
375 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
376 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
377 constexpr index_t a_scale_offset =
378 a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
379 constexpr index_t b_scale_offset =
380 b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
381
382 static_assert(0 < ScalesPerXdlopsRunPerThread,
383 "Must have at least one scale per Xdlops "
384 "per Thread.");
385
388
389 // Pack scale_thread_buf into scale_thread_vec
391 a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
392 a_scale_thread_buf[Number<a_scale_offset + s>{}];
393 });
394
396 b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
397 b_scale_thread_buf[Number<b_scale_offset + s>{}];
398 });
399
400 static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
401 static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
402 static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
403 constexpr auto kxdl = ikxdl + k0 * KXdlPack;
404
407
408 static_for<0, KPack, 1>{}([&](auto ik) {
409 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
410 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
411 make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
412 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
413 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
414 make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
415 });
416
417 using mfma_input_type_a =
418 typename vector_type<ComputeTypeA,
419 xdlops_gemm.K1PerXdlops /
420 APackedSize>::type;
421
422 using mfma_input_type_b =
423 typename vector_type<ComputeTypeB,
424 xdlops_gemm.K1PerXdlops /
425 BPackedSize>::type;
426
427 using mfma_scale_input_type_a =
428 typename vector_type<AScaleDataType,
430 using mfma_scale_input_type_b =
431 typename vector_type<BScaleDataType,
433
434 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
435 make_tuple(m0, n0, imxdl, inxdl, 0));
436
437 // MFMA accumulation
438 xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
439 ikxdl * NXdlPack + inxdl>(
440 a_thread_vec.template AsType<mfma_input_type_a>(),
441 a_scale_thread_vec
442 .template AsType<mfma_scale_input_type_a>(),
443 b_thread_vec.template AsType<mfma_input_type_b>(),
444 b_scale_thread_vec
445 .template AsType<mfma_scale_input_type_b>(),
446 c_thread_buf.GetVectorTypeReference(
448 });
449 });
450 });
451 });
452 });
453 });
454
455 // Prefetch a_scales
456 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
457 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
458 a_scale_thread_copy.Run(a_scale_grid_desc,
459 a_scale_grid_buf,
461 make_tuple(m0, k0, I0),
462 a_scale_thread_buf);
463
464 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
465 make_multi_index(0, I1, 0));
466 });
467 a_scale_thread_copy.MoveSrcSliceWindow(
468 a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
469 });
470
471 // restore row id and advance to the next set of scales
472 a_scale_thread_copy.MoveSrcSliceWindow(
473 a_scale_grid_desc,
474 make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0));
475
476 // Prefetch b_scales
477 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
478 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
479 b_scale_thread_copy.Run(b_scale_grid_desc,
480 b_scale_grid_buf,
482 make_tuple(n0, k0, I0),
483 b_scale_thread_buf);
484
485 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
486 make_multi_index(0, I1, 0));
487 });
488 b_scale_thread_copy.MoveSrcSliceWindow(
489 b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
490 });
491
492 // restore col id and advance to the next set of scales
493 // NWaves * NPerXDL * NRepeat == NPerBlock
494 b_scale_thread_copy.MoveSrcSliceWindow(
495 b_scale_grid_desc,
496 make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
497
499 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
500 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
501
502 i += 1;
503 } while(i < (num_loop - 1));
504 }
505
506 // tail
507 if constexpr(TailNum == TailNumber::Full)
508 {
510 static_for<0, KRepeat, 1>{}([&](auto k) {
511 constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
512 (APackedSize * KPack / xdlops_gemm.K1PerXdlops);
513 static_for<0, MRepeat, 1>{}([&](auto m0) {
514 static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
515 [&](auto chunk) {
516 constexpr auto a_k_step_chunk =
517 k_step +
518 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
521 I0,
523 I0,
525 a_block_buf,
528 I0,
530 k,
532 a_thread_buf);
533 });
534 });
535 static_for<0, NRepeat, 1>{}([&](auto n0) {
536 // read block data in chunks to assemble correct thread vectors
537 static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}(
538 [&](auto chunk) {
539 constexpr auto b_k_step_chunk =
540 k_step +
541 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
544 I0,
546 I0,
548 b_block_buf,
551 I0,
553 k,
555 b_thread_buf);
556 });
557 });
558 });
559
560 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
561 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
562 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
563 constexpr index_t a_scale_offset =
564 a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
565 constexpr index_t b_scale_offset =
566 b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
567
568 static_assert(0 < ScalesPerXdlopsRunPerThread,
569 "Must have at least one scale per Xdlops "
570 "per Thread.");
571
574
575 // Pack scale_thread_buf into scale_thread_vec
577 a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
578 a_scale_thread_buf[Number<a_scale_offset + s>{}];
579 });
580
582 b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
583 b_scale_thread_buf[Number<b_scale_offset + s>{}];
584 });
585
586 static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
587 static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
588 static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
589 constexpr auto kxdl = ikxdl + k0 * KXdlPack;
590
593
594 static_for<0, KPack, 1>{}([&](auto ik) {
595 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
596 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
597 make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
598 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
599 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
600 make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
601 });
602
603 using mfma_input_type_a =
604 typename vector_type<ComputeTypeA,
605 xdlops_gemm.K1PerXdlops /
606 APackedSize>::type;
607
608 using mfma_input_type_b =
609 typename vector_type<ComputeTypeB,
610 xdlops_gemm.K1PerXdlops /
611 BPackedSize>::type;
612
613 using mfma_scale_input_type_a =
614 typename vector_type<AScaleDataType,
616 using mfma_scale_input_type_b =
617 typename vector_type<BScaleDataType,
619
620 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
621 make_tuple(m0, n0, imxdl, inxdl, 0));
622
623 // MFMA accumulation
624 xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
625 ikxdl * NXdlPack + inxdl>(
626 a_thread_vec.template AsType<mfma_input_type_a>(),
627 a_scale_thread_vec
628 .template AsType<mfma_scale_input_type_a>(),
629 b_thread_vec.template AsType<mfma_input_type_b>(),
630 b_scale_thread_vec
631 .template AsType<mfma_scale_input_type_b>(),
632 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
633 });
634 });
635 });
636 });
637 });
638 });
639 }
640 }
641
642 // TODO: make this field protected when a_scale_thread_copy_ is moved
643 // here
646 Number<KRepeat / KXdlPack>{},
648
649 // TODO: make this field protected when b_scale_thread_copy_ is moved
650 // here
653 Number<KRepeat / KXdlPack>{},
655
656 protected:
657 using Base::a_thread_copy_;
658 using Base::a_thread_desc_;
659 using Base::b_thread_copy_;
660 using Base::b_thread_desc_;
661 using Base::c_thread_desc_;
662};
663
664} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Full
Definition blkgemmpipe_scheduler.hpp:49
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ void block_sync_lds()
Definition synchronization.hpp:16
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, A_K1, B_K1, A_K1, B_K1, MRepeat, NRepeat, MPerXDL, NPerXDL, xdlops_gemm.KPerXdlops,(packed_size_v< ComputeTypeA > > 1||packed_size_v< ComputeTypeB > > 1)> HotLoopInstList
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:88
__host__ __device__ BlockwiseGemmXdlops_mx_pipeline_base(Tuple5 a_origin=CalculateAThreadOriginDataIndex(), Tuple5 b_origin=CalculateBThreadOriginDataIndex())
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:204
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, const AScaleGridDesc &a_scale_grid_desc, AScaleThreadTransfer &a_scale_thread_copy, const AScaleGridBuffer &a_scale_grid_buf, const BScaleGridDesc &b_scale_grid_desc, BScaleThreadTransfer &b_scale_thread_copy, const BScaleGridBuffer &b_scale_grid_buf, index_t num_loop) const
Definition blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v1.hpp:214
BlockwiseGemmXdlops_mx_pipeline_base< ThreadBlockSize, ADataType, BDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack > Base
Definition blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v1.hpp:102
Definition blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v1.hpp:38
Unsigned representation of a conventional biased Float32 exponent.
Definition utility/e8m0.hpp:26
Definition functional2.hpp:33
Definition dtype_vector.hpp:10