device_gemm_multiple_d_xdl_cshuffle_v3.hpp Source File

device_gemm_multiple_d_xdl_cshuffle_v3.hpp Source File#

Composable Kernel: device_gemm_multiple_d_xdl_cshuffle_v3.hpp Source File
device_gemm_multiple_d_xdl_cshuffle_v3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
19
20namespace ck {
21namespace tensor_operation {
22namespace device {
23
24template <typename ALayout,
25 typename BLayout,
26 typename DsLayout,
27 typename CLayout,
28 typename ADataType,
29 typename BDataType,
30 typename DsDataType,
31 typename CDataType,
32 typename GemmAccDataType,
33 typename CShuffleDataType,
34 typename AElementwiseOperation,
35 typename BElementwiseOperation,
36 typename CElementwiseOperation,
37 GemmSpecialization GemmSpec,
38 index_t BlockSize,
39 index_t MPerBlock,
40 index_t NPerBlock,
41 index_t KPerBlock,
42 index_t AK1,
43 index_t BK1,
44 index_t MPerXDL,
45 index_t NPerXDL,
46 index_t MXdlPerWave,
47 index_t NXdlPerWave,
48 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
49 typename ABlockTransferThreadClusterArrangeOrder,
50 typename ABlockTransferSrcAccessOrder,
51 index_t ABlockTransferSrcVectorDim,
52 index_t ABlockTransferSrcScalarPerVector,
53 index_t ABlockTransferDstScalarPerVector_AK1,
54 bool ABlockLdsExtraM,
55 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
56 typename BBlockTransferThreadClusterArrangeOrder,
57 typename BBlockTransferSrcAccessOrder,
58 index_t BBlockTransferSrcVectorDim,
59 index_t BBlockTransferSrcScalarPerVector,
60 index_t BBlockTransferDstScalarPerVector_BK1,
61 bool BBlockLdsExtraN,
62 index_t CShuffleMXdlPerWavePerShuffle,
63 index_t CShuffleNXdlPerWavePerShuffle,
64 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
65 typename CDEShuffleBlockTransferScalarPerVectors,
68 typename ComputeTypeA = CDataType,
69 typename ComputeTypeB = ComputeTypeA,
70 typename LDSTypeA = ComputeTypeA,
71 typename LDSTypeB = ComputeTypeB>
73 BLayout,
74 DsLayout,
75 CLayout,
76 ADataType,
77 BDataType,
78 DsDataType,
79 CDataType,
80 AElementwiseOperation,
81 BElementwiseOperation,
82 CElementwiseOperation>
83{
84 static constexpr index_t NumDTensor = DsDataType::Size();
86 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
87 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
88
89 // GridwiseGemm
90 template <index_t NXdlPerWave_>
92 ALayout,
93 BLayout,
94 DsLayout,
95 CLayout,
96 ADataType,
97 BDataType,
98 GemmAccDataType,
99 CShuffleDataType,
100 DsDataType,
101 CDataType,
102 AElementwiseOperation,
103 BElementwiseOperation,
104 CElementwiseOperation,
105 GemmSpec,
106 BlockSize,
107 MPerBlock,
108 NPerBlock,
109 KPerBlock,
110 AK1,
111 BK1,
112 MPerXDL,
113 NPerXDL,
114 MXdlPerWave,
115 NXdlPerWave_,
116 ABlockTransferThreadClusterLengths_AK0_M_AK1,
117 ABlockTransferThreadClusterArrangeOrder,
118 ABlockTransferSrcAccessOrder,
119 ABlockTransferSrcVectorDim,
120 ABlockTransferSrcScalarPerVector,
121 ABlockTransferDstScalarPerVector_AK1,
122 false,
123 ABlockLdsExtraM,
124 BBlockTransferThreadClusterLengths_BK0_N_BK1,
125 BBlockTransferThreadClusterArrangeOrder,
126 BBlockTransferSrcAccessOrder,
127 BBlockTransferSrcVectorDim,
128 BBlockTransferSrcScalarPerVector,
129 BBlockTransferDstScalarPerVector_BK1,
130 false,
131 BBlockLdsExtraN,
132 CShuffleMXdlPerWavePerShuffle,
133 math::min(CShuffleNXdlPerWavePerShuffle, NXdlPerWave_),
134 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
135 CDEShuffleBlockTransferScalarPerVectors,
136 BlkGemmPipeSched,
137 BlkGemmPipelineVer,
138 ComputeTypeA,
139 ComputeTypeB,
140 LDSTypeA,
141 LDSTypeB>;
144
145 using Argument = typename GridwiseGemm64::Argument;
146 // Invoker
147 struct Invoker : public BaseInvoker
148 {
149 template <typename GridwiseGemm>
150 float RunImp(const typename GridwiseGemm::Argument& arg,
151 const StreamConfig& stream_config = StreamConfig{})
152 {
153 if(stream_config.log_level_ > 0)
154 {
155 arg.Print();
156 }
157
158 if(!GridwiseGemm::CheckValidity(arg))
159 {
160 throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
161 }
162
163 index_t gdx, gdy, gdz;
164 std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
165
166 float ave_time = 0;
167
168 index_t k_grain = arg.KBatch * KPerBlock;
169 index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
170
171 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
172
173 const auto Run = [&](const auto& kernel) {
174 if(stream_config.flush_cache)
175 {
176
177 std::array<std::size_t, NumDTensor> DsSize;
178
179 auto arg_ = arg;
180
181 const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
182 arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
183 const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
184 arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
185
186 auto size_a_buffer =
187 a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType);
188 auto size_b_buffer =
189 b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType);
190
191 const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
192 arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
193
194 static_for<0, NumDTensor, 1>{}([&](auto i) {
195 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
196 DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
197 });
198 ck::utility::RotatingMemWrapperMultiD<typename GridwiseGemm::Argument,
199 DsDataType>
200 rotating_mem(arg_,
201 stream_config.rotating_count,
202 size_a_buffer,
203 size_b_buffer,
204 DsSize);
205 rotating_mem.Print();
206
207 auto run_flush_cache = [&]() {
208 // flush icache
210 // rotating mem
211 rotating_mem.Next();
212 // clear c mem
213 if(arg_.KBatch > 1)
214 hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
215 0,
216 arg_.M * arg_.N * sizeof(CDataType),
217 stream_config.stream_id_));
218 };
219
221 stream_config,
222 run_flush_cache,
223 kernel,
224 dim3(gdx, gdy, gdz),
225 dim3(BlockSize),
226 0,
227 arg_);
228 }
229 else
230 {
231 if(arg.KBatch > 1)
232 hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
233 0,
234 arg.M * arg.N * sizeof(CDataType),
235 stream_config.stream_id_));
236
237 ave_time = launch_and_time_kernel(
238 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
239 }
240 };
241
242 constexpr index_t minimum_occupancy = []() {
243 if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
244 {
245 return 2;
246 }
247 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
248 {
249 return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
250 }
251 else
252 {
253 return 1;
254 }
255 }();
256
257 if(has_main_k_block_loop)
258 {
259 // Tail number always full
260 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
261 BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
262 {
263 if(arg.KBatch > 1)
264 {
265 const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
266 GridwiseGemm,
267 true,
269 minimum_occupancy>;
270 Run(kernel);
271 }
272 else
273 {
274 const auto kernel =
276 true,
278 minimum_occupancy>;
279 Run(kernel);
280 }
281 }
282 // Tail number could be One to Seven
283 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
284 {
285 if(arg.KBatch > 1)
286 {
287 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
288 {
289 const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
290 GridwiseGemm,
291 true,
293 minimum_occupancy,
295 Run(kernel);
296 }
297 else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
299 {
300 const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
301 GridwiseGemm,
302 true,
304 minimum_occupancy,
306 Run(kernel);
307 }
308
309 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
310 {
311 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
312 {
313 const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
314 GridwiseGemm,
315 true,
317 minimum_occupancy,
319 Run(kernel);
320 }
321 }
322
323 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
324 {
325 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
327 {
328 const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
329 GridwiseGemm,
330 true,
332 minimum_occupancy,
334 Run(kernel);
335 }
336 }
337
338 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
339 {
340 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
342 {
343 const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
344 GridwiseGemm,
345 true,
347 minimum_occupancy,
349 Run(kernel);
350 }
351 }
352
353 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
354 {
355 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
357 {
358 const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
359 GridwiseGemm,
360 true,
362 minimum_occupancy,
364 Run(kernel);
365 }
366 }
367
368 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
369 {
370 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
371 {
372 const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
373 GridwiseGemm,
374 true,
376 minimum_occupancy,
378 Run(kernel);
379 }
380 }
381
382 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
383 {
384 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
386 {
387 const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
388 GridwiseGemm,
389 true,
391 minimum_occupancy,
393 Run(kernel);
394 }
395 }
396 }
397 else
398 {
399 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
400 {
401 const auto kernel =
403 true,
405 minimum_occupancy,
407 Run(kernel);
408 }
409 else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
411 {
412 const auto kernel =
414 true,
416 minimum_occupancy,
418 Run(kernel);
419 }
420
421 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
422 {
423 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
424 {
425 const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
426 GridwiseGemm,
427 true,
429 minimum_occupancy,
431 Run(kernel);
432 }
433 }
434
435 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
436 {
437 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
439 {
440 const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
441 GridwiseGemm,
442 true,
444 minimum_occupancy,
446 Run(kernel);
447 }
448 }
449
450 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
451 {
452 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
454 {
455 const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
456 GridwiseGemm,
457 true,
459 minimum_occupancy,
461 Run(kernel);
462 }
463 }
464
465 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
466 {
467 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
469 {
470 const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
471 GridwiseGemm,
472 true,
474 minimum_occupancy,
476 Run(kernel);
477 }
478 }
479
480 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
481 {
482 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
483 {
484 const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
485 GridwiseGemm,
486 true,
488 minimum_occupancy,
490 Run(kernel);
491 }
492 }
493
494 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
495 {
496 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
498 {
499 const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
500 GridwiseGemm,
501 true,
503 minimum_occupancy,
505 Run(kernel);
506 }
507 }
508 }
509 }
510 // Tail number could be Odd or Even
511 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
512 {
513 if(arg.KBatch > 1)
514 {
515 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
516 {
518 GridwiseGemm,
519 true,
521 minimum_occupancy,
523 Run(kernel);
524 }
525 else
526 {
528 GridwiseGemm,
529 true,
531 minimum_occupancy,
533 Run(kernel);
534 }
535 }
536 else
537 {
538 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
539 {
541 GridwiseGemm,
542 true,
544 minimum_occupancy,
546 Run(kernel);
547 }
548 else
549 {
551 GridwiseGemm,
552 true,
554 minimum_occupancy,
556 Run(kernel);
557 }
558 }
559 }
560 else
561 {
562 if(arg.KBatch > 1)
563 {
564 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
565 {
566 const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
567 GridwiseGemm,
568 true,
570 minimum_occupancy,
572 Run(kernel);
573 }
574 else
575 {
576 const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
577 GridwiseGemm,
578 true,
580 minimum_occupancy,
582 Run(kernel);
583 }
584 }
585 else
586 {
587 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
588 {
589 const auto kernel =
591 true,
593 minimum_occupancy,
595 Run(kernel);
596 }
597 else
598 {
599 const auto kernel =
601 true,
603 minimum_occupancy,
605 Run(kernel);
606 }
607 }
608 }
609 }
610 else
611 {
612 // Tail number always 1
613 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
614 {
615 if(arg.KBatch > 1)
616 {
617 const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
618 GridwiseGemm,
619 false,
621 minimum_occupancy>;
622 Run(kernel);
623 }
624 else
625 {
626 const auto kernel =
628 false,
630 minimum_occupancy>;
631 Run(kernel);
632 }
633 }
634 }
635
636 return ave_time;
637 }
638
640
641 // polymorphic
642 float Run(const BaseArgument* p_arg,
643 const StreamConfig& stream_config = StreamConfig{}) override
644 {
645 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
646 }
647 };
648
649 static constexpr bool IsValidCompilationParameter()
650 {
651 // TODO: properly implement this check
652 return true;
653 }
654
655 static bool IsSupportedArgument(const Argument& arg)
656 {
658 {
659 return false;
660 }
661 if(is_gfx11_supported() && arg.KBatch > 1)
662 {
663 return false;
664 }
665 if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
666 {
667 return false;
668 }
669
670 if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
671 GemmSpec == GemmSpecialization::NKPadding ||
672 GemmSpec == GemmSpecialization::MNKPadding ||
673 GemmSpec == GemmSpecialization::KPadding))
674 {
675 return false;
676 }
677
678 if(get_warp_size() == 64)
679 {
680 if constexpr(NXdlPerWave64 > 0)
681 {
683 }
684 }
685 if(CDEShuffleBlockTransferScalarPerVectors{}[Number<0>{}] <= 1 && (arg.KBatch > 1))
686 {
687 return false;
688 }
689 else
690 {
691 if constexpr(NXdlPerWave32 > 0)
692 {
694 reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg));
695 }
696 }
697 return false;
698 }
699
700 // polymorphic
701 bool IsSupportedArgument(const BaseArgument* p_arg) override
702 {
703 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
704 }
705
706 static auto MakeArgument(const void* p_a,
707 const void* p_b,
708 std::array<const void*, NumDTensor> p_ds,
709 void* p_c,
710 index_t M,
711 index_t N,
712 index_t K,
713 index_t StrideA,
714 index_t StrideB,
715 std::array<index_t, NumDTensor> StrideDs,
716 index_t StrideC,
717 index_t KBatch,
718 AElementwiseOperation a_element_op,
719 BElementwiseOperation b_element_op,
720 CElementwiseOperation c_element_op)
721 {
722 return Argument{static_cast<const ADataType*>(p_a),
723 static_cast<const BDataType*>(p_b),
724 p_ds,
725 static_cast<CDataType*>(p_c),
726 M,
727 N,
728 K,
729 StrideA,
730 StrideB,
731 StrideDs,
732 StrideC,
733 KBatch,
734 a_element_op,
735 b_element_op,
736 c_element_op};
737 }
738
739 static auto MakeInvoker() { return Invoker{}; }
740
741 // polymorphic
742 std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
743 const void* p_b,
744 std::array<const void*, NumDTensor> p_ds,
745 void* p_c,
746 index_t M,
747 index_t N,
748 index_t K,
749 index_t StrideA,
750 index_t StrideB,
751 std::array<ck::index_t, NumDTensor> StrideDs,
752 index_t StrideC,
753 index_t KBatch,
754 AElementwiseOperation a_element_op,
755 BElementwiseOperation b_element_op,
756 CElementwiseOperation c_element_op) override
757 {
758 return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
759 static_cast<const BDataType*>(p_b),
760 p_ds,
761 static_cast<CDataType*>(p_c),
762 M,
763 N,
764 K,
765 StrideA,
766 StrideB,
767 StrideDs,
768 StrideC,
769 KBatch,
770 a_element_op,
771 b_element_op,
772 c_element_op);
773 }
774
775 // polymorphic
776 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
777 {
778 return std::make_unique<Invoker>(Invoker{});
779 }
780
781 // polymorphic
782 std::string GetTypeString() const override
783 {
784 auto str = std::stringstream();
785
786 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
789
790 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
796
797 // clang-format off
798 str << "DeviceGemmXdlUniversal"
799 << "<"
800 << getGemmSpecializationString(GemmSpec) << ", "
801 << std::string(ALayout::name)[0]
802 << std::string(BLayout::name)[0]
803 << std::string(CLayout::name)[0]
804 << ">"
805 << " BlkSize: "
806 << BlockSize << ", "
807 << "BlkTile: "
808 << MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
809 << "WaveTile: "
810 << MPerXDL<<"x"<<NPerXDL << ", "
811 << "WaveMap: "
812 << MXdlPerWave<<"x" << NXdlPerWave<<", "
813 << "VmemReadVec: "
814 << ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
815 << "BlkGemmPipelineScheduler: "
816 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
817 << "BlkGemmPipelineVersion: "
818 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
819 << "BlkGemmPipelinePrefetchStages: "
820 << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
821 // clang-format on
822
823 return str.str();
824 }
825};
826
827} // namespace device
828} // namespace tensor_operation
829} // namespace ck
#define INVOKER_RUN3_IMPL
Definition device_base.hpp:114
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr T min(T x)
Definition utility/math.hpp:116
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ MNKPadding
Definition gemm_specialization.hpp:20
@ NKPadding
Definition gemm_specialization.hpp:19
Definition convolution_backward_data_specialization.hpp:7
void flush_icache()
Definition flush_cache.hpp:383
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition flush_cache.hpp:398
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v5
Definition blkgemmpipe_scheduler.hpp:18
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
integral_constant< index_t, N > Number
Definition number.hpp:12
@ One
Definition blkgemmpipe_scheduler.hpp:37
@ Seven
Definition blkgemmpipe_scheduler.hpp:43
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Four
Definition blkgemmpipe_scheduler.hpp:40
@ Two
Definition blkgemmpipe_scheduler.hpp:38
@ Full
Definition blkgemmpipe_scheduler.hpp:49
@ Three
Definition blkgemmpipe_scheduler.hpp:39
@ Five
Definition blkgemmpipe_scheduler.hpp:41
@ Six
Definition blkgemmpipe_scheduler.hpp:42
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__global__ void kernel_gemm_xdl_cshuffle_v3_multi_d(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:40
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
__global__ void kernel_gemm_xdl_cshuffle_v3_multi_d_2lds(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:75
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
bool is_bf16_atomic_supported()
Definition host_utility/device_prop.hpp:108
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:157
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition device_gemm_multiple_d_xdl_cshuffle_v3.hpp:148
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_multiple_d_xdl_cshuffle_v3.hpp:642
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_multiple_d_xdl_cshuffle_v3.hpp:150
Definition device_gemm_multiple_d_xdl_cshuffle_v3.hpp:83
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_multiple_d_xdl_cshuffle_v3.hpp:776
static constexpr auto NXdlPerWave32
Definition device_gemm_multiple_d_xdl_cshuffle_v3.hpp:87
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_multiple_d_xdl_cshuffle_v3.hpp:701
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_multiple_d_xdl_cshuffle_v3.hpp:649
std::string GetTypeString() const override
Definition device_gemm_multiple_d_xdl_cshuffle_v3.hpp:782
static constexpr index_t NumDTensor
Definition device_gemm_multiple_d_xdl_cshuffle_v3.hpp:84
static auto MakeInvoker()
Definition device_gemm_multiple_d_xdl_cshuffle_v3.hpp:739
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition device_gemm_multiple_d_xdl_cshuffle_v3.hpp:742
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_gemm_multiple_d_xdl_cshuffle_v3.hpp:86
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_gemm_multiple_d_xdl_cshuffle_v3.hpp:142
typename GridwiseGemm64::Argument Argument
Definition device_gemm_multiple_d_xdl_cshuffle_v3.hpp:145
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_multiple_d_xdl_cshuffle_v3.hpp:655
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_gemm_multiple_d_xdl_cshuffle_v3.hpp:143
static auto MakeArgument(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition device_gemm_multiple_d_xdl_cshuffle_v3.hpp:706
GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, math::min(CShuffleNXdlPerWavePerShuffle, NXdlPerWave_), CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB > GridwiseGemmBase
Definition device_gemm_multiple_d_xdl_cshuffle_v3.hpp:91
Definition device_gemm_multiple_d.hpp:80
Definition flush_cache.hpp:174