blockwise_gemm_pipeline_xdlops_mx_moe_gufusion_v3.hpp Source File

blockwise_gemm_pipeline_xdlops_mx_moe_gufusion_v3.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_xdlops_mx_moe_gufusion_v3.hpp Source File
blockwise_gemm_pipeline_xdlops_mx_moe_gufusion_v3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9
10// Naive pipeline with lowest resource request per WGP
11// GlobalPrefetchStages: 2
12// LocalPreFillStages: 1
13// LocalPreFetchStages: 1
14// LocalSharedMemoryBuffer: 1
15
16template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17 index_t ThreadBlockSize,
18 index_t ScaleBlockSize,
19 typename ADataType,
20 typename AScaleDataType,
21 typename BDataType,
22 typename BScaleDataType,
23 typename ATileDesc,
24 typename BTileDesc,
25 typename AMmaTileDesc,
26 typename BMmaTileDesc,
27 index_t ABlockTransferSrcScalarPerVector,
28 index_t BBlockTransferSrcScalarPerVector,
29 index_t MPerBlock,
30 index_t NPerBlock,
31 index_t KPerBlock,
32 index_t MPerXDL,
33 index_t NPerXDL,
34 index_t MRepeat, // MXdlPerWave
35 index_t NRepeat, // NXdlPerWave
36 index_t KPack>
40
41template <index_t ThreadBlockSize,
42 index_t ScaleBlockSize,
43 typename ADataType,
44 typename AScaleDataType,
45 typename BDataType,
46 typename BScaleDataType,
47 typename ATileDesc,
48 typename BTileDesc,
49 typename AMmaTileDesc,
50 typename BMmaTileDesc,
51 index_t ABlockTransferSrcScalarPerVector,
52 index_t BBlockTransferSrcScalarPerVector,
53 index_t MPerBlock,
54 index_t NPerBlock,
55 index_t KPerBlock,
56 index_t MPerXDL,
57 index_t NPerXDL,
58 index_t MRepeat, // MXdlPerWave
59 index_t NRepeat, // NXdlPerWave
60 index_t KPack>
62 ThreadBlockSize,
63 ScaleBlockSize,
64 ADataType,
65 AScaleDataType,
66 BDataType,
67 BScaleDataType,
68 ATileDesc,
69 BTileDesc,
70 AMmaTileDesc,
71 BMmaTileDesc,
72 ABlockTransferSrcScalarPerVector,
73 BBlockTransferSrcScalarPerVector,
74 MPerBlock,
75 NPerBlock,
76 KPerBlock,
77 MPerXDL,
78 NPerXDL,
79 MRepeat,
80 NRepeat,
81 KPack>
83 ADataType,
84 BDataType,
85 ATileDesc,
86 BTileDesc,
87 AMmaTileDesc,
88 BMmaTileDesc,
89 ABlockTransferSrcScalarPerVector,
90 BBlockTransferSrcScalarPerVector,
91 MPerBlock,
92 NPerBlock,
93 KPerBlock,
94 MPerXDL,
95 NPerXDL,
96 MRepeat,
97 NRepeat,
98 KPack>
99
100{
101
103 ADataType,
104 BDataType,
105 ATileDesc,
106 BTileDesc,
107 AMmaTileDesc,
108 BMmaTileDesc,
109 ABlockTransferSrcScalarPerVector,
110 BBlockTransferSrcScalarPerVector,
111 MPerBlock,
112 NPerBlock,
113 KPerBlock,
114 MPerXDL,
115 NPerXDL,
116 MRepeat,
117 NRepeat,
118 KPack>;
119 using Base::I0;
120 using Base::I1;
121 using Base::KRepeat;
122 using Base::MWaves;
123 using Base::NWaves;
124 using Base::WaveSize;
125 using Base::xdlops_gemm;
126 using typename Base::HotLoopInstList;
127
136 using Base::GetWaveIdx;
139
142
143 using Base::AMmaKStride;
144 using Base::APackedSize;
145 using Base::BMmaKStride;
146 using Base::BPackedSize;
147 using Base::KThreadChunk;
148
149 using Base::KXdlPack;
150 using Base::MXdlPack;
151 using Base::NXdlPack;
152
153 using AccType = typename Base::AccType;
154 using Tuple5 = typename Base::Tuple5;
157
158 static constexpr index_t PrefetchStages = 2;
159 static constexpr index_t PrefillStages = 1;
160 static constexpr index_t GlobalBufferNum = 1;
161
162 static constexpr auto ScalesPerKBlockSize =
163 KPerBlock / ScaleBlockSize; // How many mx-vectors per K block
164
165 //> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run()
166 static constexpr auto ScalesPerXdlopsRun =
167 (APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize;
168
169 //> How many scales a thread must read to accommodate one call to xdlops_gemm.Run()
170 static constexpr auto ScalesPerXdlopsRunPerThread =
171 ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks;
172
174 static constexpr auto scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
175 static constexpr auto scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
176 static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
177 "A scale pack data type too large!");
178 static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
179 "B scale pack data type too large!");
182
183 __host__ static constexpr bool BlockHasHotloop(index_t num_loop)
184 {
185 return num_loop > PrefetchStages;
186 }
187
188 __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
189 {
190 return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
191 }
192
193 __device__ static constexpr auto HotLoopScheduler()
194 {
195 // A/B split schedule
196 // compiler is likely to use ds_read2 when instruction width smaller than 16bytes
197 constexpr auto num_ds_read_inst_a =
198 HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
201 constexpr auto num_ds_read_inst_b =
202 HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16
205
206 constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
207 constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * 2;
208
209 constexpr auto num_buffer_load_a_scale = MRepeat / MXdlPack * KRepeat / KXdlPack;
210 constexpr auto num_buffer_load_b_scale = NRepeat / NXdlPack * KRepeat / KXdlPack * 2;
211
212 constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num * APackedSize * 2;
213
214 constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
215 constexpr auto ds_read_a_issue_cycle =
216 HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
217 constexpr auto ds_read_b_issue_cycle =
218 HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
219
220 constexpr auto ds_read_a_mfma_rate =
221 (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
222 constexpr auto ds_read_b_mfma_rate =
223 (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
224
225 constexpr auto num_dsread_a_mfma =
226 (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
227 constexpr auto num_dsread_b_mfma =
228 (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
229
230 // stage 1
231 constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
232 constexpr auto num_buffer_load_total = num_buffer_load_inst_a + num_buffer_load_inst_b +
233 num_buffer_load_a_scale + num_buffer_load_b_scale;
234
235 constexpr auto mfma_perstage_more =
236 math::integer_divide_ceil(num_mfma_stage1, num_buffer_load_total);
237 constexpr auto mfma_perstage_less =
238 math::integer_divide_floor(num_mfma_stage1, num_buffer_load_total);
239
240 constexpr auto mfma_stages_more =
241 num_mfma_stage1 - mfma_perstage_less * num_buffer_load_total;
242
244 if constexpr(i < mfma_stages_more)
245 {
247 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
248 });
249 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
250 }
251 else
252 {
254 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
255 });
256 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
257 }
258 });
259
261 if constexpr((i + num_buffer_load_inst_a) < mfma_stages_more)
262 {
264 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
265 });
266 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
267 }
268 else
269 {
271 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
272 });
273 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
274 }
275 });
276
278 if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b) < mfma_stages_more)
279 {
280 static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
281 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
282 });
283 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
284 }
285 else
286 {
287 static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
288 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
289 });
290 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
291 }
292 });
293
295 if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b +
296 num_buffer_load_a_scale) < mfma_stages_more)
297 {
298 static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
299 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
300 });
301 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
302 }
303 else
304 {
305 static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
306 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
307 });
308 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
309 }
310 });
311
312 // stage 2
314 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
315 if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
316 ds_read_a_mfma_rate)
317 {
318 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
319 }
320 else
321 {
322 __builtin_amdgcn_sched_group_barrier(0x100,
323 num_ds_read_inst_a - (num_dsread_a_mfma - 1) *
324 ds_read_a_mfma_rate,
325 0); // DS read
326 }
327 });
328
330 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
331 if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
332 ds_read_b_mfma_rate)
333 {
334 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
335 }
336 else
337 {
338 __builtin_amdgcn_sched_group_barrier(0x100,
339 num_ds_read_inst_b - (num_dsread_b_mfma - 1) *
340 ds_read_b_mfma_rate,
341 0); // DS read
342 }
343 });
344 }
345
346 template <bool HasMainLoop,
347 TailNumber TailNum,
348 typename AGridDesc,
349 typename ABlockDesc,
350 typename ABlockTransfer,
351 typename AGridBuffer,
352 typename ABlockBuffer,
353 typename ABlockTransferStep,
354 typename BGridDesc,
355 typename BBlockDesc,
356 typename BBlockTransfer,
357 typename BGridBuffer,
358 typename BBlockBuffer,
359 typename BBlockTransferStep,
360 typename CThreadBuffer,
361 typename AScaleGridBuffer,
362 typename AScaleGridDesc,
363 typename AScaleThreadTransfer,
364 typename BScaleGridBuffer,
365 typename BScaleGridDesc,
366 typename BScaleThreadTransfer>
367 __device__ void Run(
368 // A
369 const AGridDesc& a_grid_desc,
370 const ABlockDesc& a_block_desc,
371 ABlockTransfer& a_blockwise_copy,
372 const AGridBuffer& a_grid_buf,
373 ABlockBuffer& a_block_bufs,
374 const ABlockTransferStep& a_block_copy_step,
375 // Gate and Up
376 const BGridDesc& b_grid_desc,
377 const BBlockDesc& b_block_desc,
378 BBlockTransfer& b_blockwise_copy,
379 BBlockTransfer& b_blockwise_copy_up,
380 const BGridBuffer& b_grid_buf,
381 const BGridBuffer& b_grid_buf_up,
382 BBlockBuffer& b_block_bufs,
383 BBlockBuffer& b_block_bufs_up,
384 const BBlockTransferStep& b_block_copy_step,
385 // C
386 CThreadBuffer& c_thread_buf,
387 CThreadBuffer& c_thread_buf_up,
388 // A scale
389 const AScaleGridDesc& a_scale_grid_desc,
390 AScaleThreadTransfer& a_scale_thread_copy,
391 const AScaleGridBuffer& a_scale_grid_buf,
392 // Gate and Up scale
393 const BScaleGridDesc& b_scale_grid_desc,
394 BScaleThreadTransfer& b_scale_thread_copy,
395 BScaleThreadTransfer& b_scale_thread_copy_up,
396 const BScaleGridBuffer& b_scale_grid_buf,
397 const BScaleGridBuffer& b_scale_grid_buf_up,
398 index_t num_loop) const
399 {
401 a_thread_desc_.GetElementSpaceSize());
403 b_thread_desc_.GetElementSpaceSize());
405 b_thread_desc_.GetElementSpaceSize());
406
408 a_scale_thread_desc.GetElementSpaceSize());
410 b_scale_thread_desc.GetElementSpaceSize());
412 b_scale_thread_desc.GetElementSpaceSize());
413
414 StaticallyIndexedArray<decltype(a_scale_thread_buf), Number<2>{}> a_scale_thread_bufs;
415 StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs;
416 StaticallyIndexedArray<decltype(b_scale_thread_buf_up), Number<2>{}> b_scale_thread_bufs_up;
417
418 // Global prefetch 1
419 a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(I0));
420 b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_bufs(I0));
421 b_blockwise_copy_up.Run(b_grid_desc, b_grid_buf_up, b_block_desc, b_block_bufs_up(I0));
422
423 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
424 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
425 b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
426
427 // Prefetch a_scales
428 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
429 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
430 a_scale_thread_copy.Run(a_scale_grid_desc,
431 a_scale_grid_buf,
433 make_tuple(m0, k0, I0),
434 a_scale_thread_bufs(I0));
435
436 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
437 make_multi_index(0, I1, 0));
438 });
439 a_scale_thread_copy.MoveSrcSliceWindow(
440 a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
441 });
442
443 // restore row id and advance to the next set of scales
444 a_scale_thread_copy.MoveSrcSliceWindow(
445 a_scale_grid_desc,
446 make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0));
447
448 // Prefetch b_scales_gate
449 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
450 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
451 b_scale_thread_copy.Run(b_scale_grid_desc,
452 b_scale_grid_buf,
454 make_tuple(n0, k0, I0),
455 b_scale_thread_bufs(I0));
456
457 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
458 make_multi_index(0, I1, 0));
459 });
460 b_scale_thread_copy.MoveSrcSliceWindow(
461 b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
462 });
463
464 // restore col id and advance to the next set of scales
465 // NWaves * NPerXDL * NRepeat == NPerBlock
466 b_scale_thread_copy.MoveSrcSliceWindow(
467 b_scale_grid_desc,
468 make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
469
470 // Prefetch b_scales_up
471 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
472 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
473 b_scale_thread_copy_up.Run(b_scale_grid_desc,
474 b_scale_grid_buf_up,
476 make_tuple(n0, k0, I0),
477 b_scale_thread_bufs_up(I0));
478
479 b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc,
480 make_multi_index(0, I1, 0));
481 });
482 b_scale_thread_copy_up.MoveSrcSliceWindow(
483 b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
484 });
485
486 // restore col id and advance to the next set of scales
487 // NWaves * NPerXDL * NRepeat == NPerBlock
488 b_scale_thread_copy_up.MoveSrcSliceWindow(
489 b_scale_grid_desc,
490 make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
491
492 // Local prefetch 1, sync the async load
493 __builtin_amdgcn_s_waitcnt(3952);
494
495 // Local prefetch 1
497 static_for<0, KRepeat, 1>{}([&](auto k) {
498 constexpr auto k_step = k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops;
499 static_for<0, MRepeat, 1>{}([&](auto m0) {
500 static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
501 [&](auto chunk) {
502 constexpr auto a_k_step_chunk =
503 k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
506 I0,
508 I0,
510 a_block_bufs(I0),
513 I0,
515 k,
517 a_thread_buf);
518 });
519 });
520 static_for<0, NRepeat, 1>{}([&](auto n0) {
521 // read block data in chunks to assemble correct thread vectors
522 static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}(
523 [&](auto chunk) {
524 constexpr auto b_k_step_chunk =
525 k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
528 I0,
530 I0,
532 b_block_bufs(I0),
535 I0,
537 k,
539 b_thread_buf);
540 });
541 });
542 static_for<0, NRepeat, 1>{}([&](auto n0) {
543 // read block data in chunks to assemble correct thread vectors
544 static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}(
545 [&](auto chunk) {
546 constexpr auto b_k_step_chunk =
547 k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
550 I0,
552 I0,
554 b_block_bufs_up(I0),
557 I0,
559 k,
561 b_thread_buf_up);
562 });
563 });
564 });
565
566 // Global prefetch 2
567 a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(I1));
568 b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_bufs(I1));
569 b_blockwise_copy_up.Run(b_grid_desc, b_grid_buf_up, b_block_desc, b_block_bufs_up(I1));
570
571 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
572 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
573 b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
574
575 // Initialize C
576 c_thread_buf.Clear();
577 c_thread_buf_up.Clear();
578 __builtin_amdgcn_sched_barrier(0);
579
580 // main body
581 if constexpr(HasMainLoop)
582 {
583 // loop over k with the step KPerBlock
584 index_t i = 0;
585 do
586 {
587 auto LoopFunc = [&](auto scale_comp_buf, auto scale_mem_buf) {
588 __builtin_amdgcn_s_waitcnt(3952);
590
591 a_blockwise_copy.Run(
592 a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(scale_comp_buf));
593 b_blockwise_copy.Run(
594 b_grid_desc, b_grid_buf, b_block_desc, b_block_bufs(scale_comp_buf));
595 b_blockwise_copy_up.Run(
596 b_grid_desc, b_grid_buf_up, b_block_desc, b_block_bufs_up(scale_comp_buf));
597
598 // Prefetch a_scales
599 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
600 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
601 a_scale_thread_copy.Run(a_scale_grid_desc,
602 a_scale_grid_buf,
604 make_tuple(m0, k0, I0),
605 a_scale_thread_bufs(scale_mem_buf));
606
607 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
608 make_multi_index(0, I1, 0));
609 });
610 a_scale_thread_copy.MoveSrcSliceWindow(
611 a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
612 });
613
614 // restore row id and advance to the next set of scales
615 a_scale_thread_copy.MoveSrcSliceWindow(
616 a_scale_grid_desc,
617 make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0));
618
619 // Prefetch b_scales
620 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
621 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
622 b_scale_thread_copy.Run(b_scale_grid_desc,
623 b_scale_grid_buf,
625 make_tuple(n0, k0, I0),
626 b_scale_thread_bufs(scale_mem_buf));
627
628 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
629 make_multi_index(0, I1, 0));
630 });
631 b_scale_thread_copy.MoveSrcSliceWindow(
632 b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
633 });
634
635 // restore col id and advance to the next set of scales
636 // NWaves * NPerXDL * NRepeat == NPerBlock
637 b_scale_thread_copy.MoveSrcSliceWindow(
638 b_scale_grid_desc,
639 make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
640
641 // Prefetch b_scales_up
642 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
643 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
644 b_scale_thread_copy_up.Run(b_scale_grid_desc,
645 b_scale_grid_buf_up,
647 make_tuple(n0, k0, I0),
648 b_scale_thread_bufs_up(scale_mem_buf));
649
650 b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc,
651 make_multi_index(0, I1, 0));
652 });
653 b_scale_thread_copy_up.MoveSrcSliceWindow(
654 b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
655 });
656
657 // restore col id and advance to the next set of scales
658 // NWaves * NPerXDL * NRepeat == NPerBlock
659 b_scale_thread_copy_up.MoveSrcSliceWindow(
660 b_scale_grid_desc,
661 make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
662
663 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
664 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
665 b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
666
667 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
668 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
669 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
670 constexpr index_t a_scale_offset =
671 a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
672 constexpr index_t b_scale_offset =
673 b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
674
675 static_assert(0 < ScalesPerXdlopsRunPerThread,
676 "Must have at least one scale per Xdlops "
677 "per Thread.");
678
680 a_scale_thread_vec;
682 b_scale_thread_vec;
684 b_scale_thread_vec_up;
685
686 // Pack scale_thread_buf into scale_thread_vec
688 a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
689 a_scale_thread_bufs(
690 scale_comp_buf)[Number<a_scale_offset + s>{}];
691 });
692
694 b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
695 b_scale_thread_bufs(
696 scale_comp_buf)[Number<b_scale_offset + s>{}];
697 });
698
700 b_scale_thread_vec_up.template AsType<BScaleDataType>()(s) =
701 b_scale_thread_bufs_up(
702 scale_comp_buf)[Number<b_scale_offset + s>{}];
703 });
704
705 static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
706 static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
707 static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
708 constexpr auto kxdl = ikxdl + k0 * KXdlPack;
709
712 vector_type<ComputeTypeB, KPack> b_thread_vec_up;
713
714 static_for<0, KPack, 1>{}([&](auto ik) {
715 a_thread_vec.template AsType<ComputeTypeA>()(
716 ik) = a_thread_buf
717 [Number<a_thread_desc_.CalculateOffset(
718 make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
719 b_thread_vec.template AsType<ComputeTypeB>()(
720 ik) = b_thread_buf
721 [Number<b_thread_desc_.CalculateOffset(
722 make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
723 b_thread_vec_up.template AsType<ComputeTypeB>()(
724 ik) = b_thread_buf_up
725 [Number<b_thread_desc_.CalculateOffset(
726 make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
727 });
728
729 using mfma_input_type_a =
730 typename vector_type<ComputeTypeA,
731 xdlops_gemm.K1PerXdlops /
732 APackedSize>::type;
733
734 using mfma_input_type_b =
735 typename vector_type<ComputeTypeB,
736 xdlops_gemm.K1PerXdlops /
737 BPackedSize>::type;
738
739 using mfma_scale_input_type_a =
740 typename vector_type<AScaleDataType,
742 using mfma_scale_input_type_b =
743 typename vector_type<BScaleDataType,
745
746 constexpr index_t c_offset =
747 c_thread_desc_.CalculateOffset(
748 make_tuple(m0, n0, imxdl, inxdl, 0));
749
750 // MFMA accumulation
751 xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
752 ikxdl * NXdlPack + inxdl>(
753 a_thread_vec.template AsType<mfma_input_type_a>(),
754 a_scale_thread_vec
755 .template AsType<mfma_scale_input_type_a>(),
756 b_thread_vec.template AsType<mfma_input_type_b>(),
757 b_scale_thread_vec
758 .template AsType<mfma_scale_input_type_b>(),
759 c_thread_buf.GetVectorTypeReference(
761
762 xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
763 ikxdl * NXdlPack + inxdl>(
764 a_thread_vec.template AsType<mfma_input_type_a>(),
765 a_scale_thread_vec
766 .template AsType<mfma_scale_input_type_a>(),
767 b_thread_vec_up
768 .template AsType<mfma_input_type_b>(),
769 b_scale_thread_vec_up
770 .template AsType<mfma_scale_input_type_b>(),
771 c_thread_buf_up.GetVectorTypeReference(
773 });
774 });
775 });
776 });
777 });
778 });
779
780 // k indexes mapping to threads for 32x32x64:
781 // t0 : |0 --> 15 32 --> 47 | 64 --> 79 96 --> 111 | etc.
782 // t32: |16 --> 31 48 --> 63 | 80 --> 95 112 --> 127 | etc.
783 // k = 0 k = 1
784
785 // k indexes mapping to threads for 16x16x128:
786 // t0 : |0 --> 15 64 --> 79 | 128 --> 143 192 --> 207| etc.
787 // t16: |16 --> 31 80 --> 95 | 144 --> 159 208 --> 223| etc.
788 // t32: |32 --> 47 96 --> 111| 160 --> 175 224 --> 239| etc.
789 // t48: |48 --> 63 112 --> 127| 176 --> 191 240 --> 255| etc.
790 // k = 0 k = 1
791 // block_sync_lds();
792 static_for<0, KRepeat, 1>{}([&](auto k) {
793 constexpr auto k_step =
794 k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops;
795 static_for<0, MRepeat, 1>{}([&](auto m0) {
796 static_for<0,
797 xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk),
798 1>{}([&](auto chunk) {
799 constexpr auto a_k_step_chunk =
800 k_step +
801 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
804 I0,
806 I0,
808 a_block_bufs(scale_mem_buf),
811 I0,
813 k,
815 a_thread_buf);
816 });
817 });
818 static_for<0, NRepeat, 1>{}([&](auto n0) {
819 // read block data in chunks to assemble correct thread vectors
820 static_for<0,
821 xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk),
822 1>{}([&](auto chunk) {
823 constexpr auto b_k_step_chunk =
824 k_step +
825 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
828 I0,
830 I0,
832 b_block_bufs(scale_mem_buf),
835 I0,
837 k,
839 b_thread_buf);
840 });
841 });
842 static_for<0, NRepeat, 1>{}([&](auto n0) {
843 // read block data in chunks to assemble correct thread vectors
844 static_for<0,
845 xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk),
846 1>{}([&](auto chunk) {
847 constexpr auto b_k_step_chunk =
848 k_step +
849 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
852 I0,
854 I0,
856 b_block_bufs_up(scale_mem_buf),
859 I0,
861 k,
863 b_thread_buf_up);
864 });
865 });
866 });
867
869 __builtin_amdgcn_sched_barrier(0);
870 };
871
872 LoopFunc(I0, I1);
873 LoopFunc(I1, I0);
874
875 i += 2;
876 } while(i < (num_loop - 2));
877 }
878
879 // tail
880 if constexpr(TailNum == TailNumber::Even)
881 {
882 // Prefetch a_scales
883 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
884 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
885 a_scale_thread_copy.Run(a_scale_grid_desc,
886 a_scale_grid_buf,
888 make_tuple(m0, k0, I0),
889 a_scale_thread_bufs(I1));
890
891 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
892 make_multi_index(0, I1, 0));
893 });
894 a_scale_thread_copy.MoveSrcSliceWindow(
895 a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
896 });
897
898 // Prefetch b_scales
899 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
900 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
901 b_scale_thread_copy.Run(b_scale_grid_desc,
902 b_scale_grid_buf,
904 make_tuple(n0, k0, I0),
905 b_scale_thread_bufs(I1));
906
907 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
908 make_multi_index(0, I1, 0));
909 });
910 b_scale_thread_copy.MoveSrcSliceWindow(
911 b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
912 });
913
914 // Prefetch b_scales_up
915 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
916 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
917 b_scale_thread_copy_up.Run(b_scale_grid_desc,
918 b_scale_grid_buf_up,
920 make_tuple(n0, k0, I0),
921 b_scale_thread_bufs_up(I1));
922
923 b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc,
924 make_multi_index(0, I1, 0));
925 });
926 b_scale_thread_copy_up.MoveSrcSliceWindow(
927 b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
928 });
929
930 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
931 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
932 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
933 constexpr index_t a_scale_offset =
934 a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
935 constexpr index_t b_scale_offset =
936 b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
937
938 static_assert(0 < ScalesPerXdlopsRunPerThread,
939 "Must have at least one scale per Xdlops "
940 "per Thread.");
941
945
946 // Pack scale_thread_buf into scale_thread_vec
948 a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
949 a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
950 });
951
953 b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
954 b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
955 });
956
958 b_scale_thread_vec_up.template AsType<BScaleDataType>()(s) =
959 b_scale_thread_bufs_up(I0)[Number<b_scale_offset + s>{}];
960 });
961
962 static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
963 static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
964 static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
965 constexpr auto kxdl = ikxdl + k0 * KXdlPack;
966
969 vector_type<ComputeTypeB, KPack> b_thread_vec_up;
970
971 static_for<0, KPack, 1>{}([&](auto ik) {
972 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
973 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
974 make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
975 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
976 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
977 make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
978 b_thread_vec_up.template AsType<ComputeTypeB>()(ik) =
979 b_thread_buf_up[Number<b_thread_desc_.CalculateOffset(
980 make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
981 });
982
983 using mfma_input_type_a =
984 typename vector_type<ComputeTypeA,
985 xdlops_gemm.K1PerXdlops /
986 APackedSize>::type;
987
988 using mfma_input_type_b =
989 typename vector_type<ComputeTypeB,
990 xdlops_gemm.K1PerXdlops /
991 BPackedSize>::type;
992
993 using mfma_scale_input_type_a =
994 typename vector_type<AScaleDataType,
996 using mfma_scale_input_type_b =
997 typename vector_type<BScaleDataType,
999
1000 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
1001 make_tuple(m0, n0, imxdl, inxdl, 0));
1002
1003 // MFMA accumulation
1004 xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
1005 ikxdl * NXdlPack + inxdl>(
1006 a_thread_vec.template AsType<mfma_input_type_a>(),
1007 a_scale_thread_vec
1008 .template AsType<mfma_scale_input_type_a>(),
1009 b_thread_vec.template AsType<mfma_input_type_b>(),
1010 b_scale_thread_vec
1011 .template AsType<mfma_scale_input_type_b>(),
1012 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
1013
1014 xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
1015 ikxdl * NXdlPack + inxdl>(
1016 a_thread_vec.template AsType<mfma_input_type_a>(),
1017 a_scale_thread_vec
1018 .template AsType<mfma_scale_input_type_a>(),
1019 b_thread_vec_up.template AsType<mfma_input_type_b>(),
1020 b_scale_thread_vec_up
1021 .template AsType<mfma_scale_input_type_b>(),
1022 c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
1023 });
1024 });
1025 });
1026 });
1027 });
1028 });
1029
1030 __builtin_amdgcn_s_waitcnt(3952);
1032
1033 static_for<0, KRepeat, 1>{}([&](auto k) {
1034 constexpr auto k_step =
1035 k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops;
1036 static_for<0, MRepeat, 1>{}([&](auto m0) {
1037 static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
1038 [&](auto chunk) {
1039 constexpr auto a_k_step_chunk =
1040 k_step +
1041 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
1044 I0,
1046 I0,
1048 a_block_bufs(I1),
1051 I0,
1053 k,
1055 a_thread_buf);
1056 });
1057 });
1058 static_for<0, NRepeat, 1>{}([&](auto n0) {
1059 // read block data in chunks to assemble correct thread vectors
1060 static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}(
1061 [&](auto chunk) {
1062 constexpr auto b_k_step_chunk =
1063 k_step +
1064 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
1067 I0,
1069 I0,
1071 b_block_bufs(I1),
1074 I0,
1076 k,
1078 b_thread_buf);
1079 });
1080 });
1081 static_for<0, NRepeat, 1>{}([&](auto n0) {
1082 // read block data in chunks to assemble correct thread vectors
1083 static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}(
1084 [&](auto chunk) {
1085 constexpr auto b_k_step_chunk =
1086 k_step +
1087 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
1090 I0,
1092 I0,
1094 b_block_bufs_up(I1),
1097 I0,
1099 k,
1101 b_thread_buf_up);
1102 });
1103 });
1104 });
1105
1106 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
1107 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
1108 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
1109 constexpr index_t a_scale_offset =
1110 a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
1111 constexpr index_t b_scale_offset =
1112 b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
1113
1114 static_assert(0 < ScalesPerXdlopsRunPerThread,
1115 "Must have at least one scale per Xdlops "
1116 "per Thread.");
1117
1121
1122 // Pack scale_thread_buf into scale_thread_vec
1124 a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
1125 a_scale_thread_bufs(I1)[Number<a_scale_offset + s>{}];
1126 });
1127
1129 b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
1130 b_scale_thread_bufs(I1)[Number<b_scale_offset + s>{}];
1131 });
1132
1134 b_scale_thread_vec_up.template AsType<BScaleDataType>()(s) =
1135 b_scale_thread_bufs_up(I1)[Number<b_scale_offset + s>{}];
1136 });
1137
1138 static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
1139 static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
1140 static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
1141 constexpr auto kxdl = ikxdl + k0 * KXdlPack;
1142
1145 vector_type<ComputeTypeB, KPack> b_thread_vec_up;
1146
1147 static_for<0, KPack, 1>{}([&](auto ik) {
1148 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
1149 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
1150 make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
1151 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
1152 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
1153 make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
1154 b_thread_vec_up.template AsType<ComputeTypeB>()(ik) =
1155 b_thread_buf_up[Number<b_thread_desc_.CalculateOffset(
1156 make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
1157 });
1158
1159 using mfma_input_type_a =
1160 typename vector_type<ComputeTypeA,
1161 xdlops_gemm.K1PerXdlops /
1162 APackedSize>::type;
1163
1164 using mfma_input_type_b =
1165 typename vector_type<ComputeTypeB,
1166 xdlops_gemm.K1PerXdlops /
1167 BPackedSize>::type;
1168
1169 using mfma_scale_input_type_a =
1170 typename vector_type<AScaleDataType,
1172 using mfma_scale_input_type_b =
1173 typename vector_type<BScaleDataType,
1175
1176 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
1177 make_tuple(m0, n0, imxdl, inxdl, 0));
1178
1179 // MFMA accumulation
1180 xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
1181 ikxdl * NXdlPack + inxdl>(
1182 a_thread_vec.template AsType<mfma_input_type_a>(),
1183 a_scale_thread_vec
1184 .template AsType<mfma_scale_input_type_a>(),
1185 b_thread_vec.template AsType<mfma_input_type_b>(),
1186 b_scale_thread_vec
1187 .template AsType<mfma_scale_input_type_b>(),
1188 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
1189
1190 xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
1191 ikxdl * NXdlPack + inxdl>(
1192 a_thread_vec.template AsType<mfma_input_type_a>(),
1193 a_scale_thread_vec
1194 .template AsType<mfma_scale_input_type_a>(),
1195 b_thread_vec_up.template AsType<mfma_input_type_b>(),
1196 b_scale_thread_vec_up
1197 .template AsType<mfma_scale_input_type_b>(),
1198 c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
1199 });
1200 });
1201 });
1202 });
1203 });
1204 });
1205 }
1206 else if constexpr(TailNum == TailNumber::Odd)
1207 {
1208 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
1209 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
1210 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
1211 constexpr index_t a_scale_offset =
1212 a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
1213 constexpr index_t b_scale_offset =
1214 b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
1215
1216 static_assert(0 < ScalesPerXdlopsRunPerThread,
1217 "Must have at least one scale per Xdlops "
1218 "per Thread.");
1219
1223
1224 // Pack scale_thread_buf into scale_thread_vec
1226 a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
1227 a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
1228 });
1229
1231 b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
1232 b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
1233 });
1234
1236 b_scale_thread_vec_up.template AsType<BScaleDataType>()(s) =
1237 b_scale_thread_bufs_up(I0)[Number<b_scale_offset + s>{}];
1238 });
1239
1240 static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
1241 static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
1242 static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
1243 constexpr auto kxdl = ikxdl + k0 * KXdlPack;
1244
1247 vector_type<ComputeTypeB, KPack> b_thread_vec_up;
1248
1249 static_for<0, KPack, 1>{}([&](auto ik) {
1250 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
1251 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
1252 make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
1253 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
1254 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
1255 make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
1256 b_thread_vec_up.template AsType<ComputeTypeB>()(ik) =
1257 b_thread_buf_up[Number<b_thread_desc_.CalculateOffset(
1258 make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
1259 });
1260
1261 using mfma_input_type_a =
1262 typename vector_type<ComputeTypeA,
1263 xdlops_gemm.K1PerXdlops /
1264 APackedSize>::type;
1265
1266 using mfma_input_type_b =
1267 typename vector_type<ComputeTypeB,
1268 xdlops_gemm.K1PerXdlops /
1269 BPackedSize>::type;
1270
1271 using mfma_scale_input_type_a =
1272 typename vector_type<AScaleDataType,
1274 using mfma_scale_input_type_b =
1275 typename vector_type<BScaleDataType,
1277
1278 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
1279 make_tuple(m0, n0, imxdl, inxdl, 0));
1280
1281 // MFMA accumulation
1282 xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
1283 ikxdl * NXdlPack + inxdl>(
1284 a_thread_vec.template AsType<mfma_input_type_a>(),
1285 a_scale_thread_vec
1286 .template AsType<mfma_scale_input_type_a>(),
1287 b_thread_vec.template AsType<mfma_input_type_b>(),
1288 b_scale_thread_vec
1289 .template AsType<mfma_scale_input_type_b>(),
1290 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
1291
1292 xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
1293 ikxdl * NXdlPack + inxdl>(
1294 a_thread_vec.template AsType<mfma_input_type_a>(),
1295 a_scale_thread_vec
1296 .template AsType<mfma_scale_input_type_a>(),
1297 b_thread_vec_up.template AsType<mfma_input_type_b>(),
1298 b_scale_thread_vec_up
1299 .template AsType<mfma_scale_input_type_b>(),
1300 c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
1301 });
1302 });
1303 });
1304 });
1305 });
1306 });
1307 }
1308 }
1309
1310 // TODO: make this field protected when a_scale_thread_copy_ is moved
1311 // here
1314 Number<KRepeat / KXdlPack>{},
1316
1317 // TODO: make this field protected when b_scale_thread_copy_ is moved
1318 // here
1321 Number<KRepeat / KXdlPack>{},
1323
1324 protected:
1325 using Base::a_thread_copy_;
1326 using Base::a_thread_desc_;
1327 using Base::b_thread_copy_;
1328 using Base::b_thread_desc_;
1329 using Base::c_thread_desc_;
1330};
1331
1332} // namespace ck
__host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
Definition utility/math.hpp:66
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ void block_sync_lds()
Definition synchronization.hpp:16
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, A_K1, B_K1, A_K1, B_K1, MRepeat, NRepeat, MPerXDL, NPerXDL, xdlops_gemm.KPerXdlops,(packed_size_v< ComputeTypeA > > 1||packed_size_v< ComputeTypeB > > 1)> HotLoopInstList
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:88
__host__ __device__ BlockwiseGemmXdlops_mx_pipeline_base(Tuple5 a_origin=CalculateAThreadOriginDataIndex(), Tuple5 b_origin=CalculateBThreadOriginDataIndex())
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:204
BlockwiseGemmXdlops_mx_pipeline_base< ThreadBlockSize, ADataType, BDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack > Base
Definition blockwise_gemm_pipeline_xdlops_mx_moe_gufusion_v3.hpp:102
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_bufs, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, BBlockTransfer &b_blockwise_copy_up, const BGridBuffer &b_grid_buf, const BGridBuffer &b_grid_buf_up, BBlockBuffer &b_block_bufs, BBlockBuffer &b_block_bufs_up, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, CThreadBuffer &c_thread_buf_up, const AScaleGridDesc &a_scale_grid_desc, AScaleThreadTransfer &a_scale_thread_copy, const AScaleGridBuffer &a_scale_grid_buf, const BScaleGridDesc &b_scale_grid_desc, BScaleThreadTransfer &b_scale_thread_copy, BScaleThreadTransfer &b_scale_thread_copy_up, const BScaleGridBuffer &b_scale_grid_buf, const BScaleGridBuffer &b_scale_grid_buf_up, index_t num_loop) const
Definition blockwise_gemm_pipeline_xdlops_mx_moe_gufusion_v3.hpp:367
Definition blockwise_gemm_pipeline_xdlops_mx_moe_gufusion_v3.hpp:38
Unsigned representation of a conventional biased Float32 exponent.
Definition utility/e8m0.hpp:26
Definition functional2.hpp:33
Definition dtype_vector.hpp:10