blockwise_gemm_pipeline_xdlops_mx_moe_nbs_selector.hpp Source File

blockwise_gemm_pipeline_xdlops_mx_moe_nbs_selector.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_xdlops_mx_moe_nbs_selector.hpp Source File
blockwise_gemm_pipeline_xdlops_mx_moe_nbs_selector.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
9
10namespace ck {
11template <BlockGemmPipelineVersion BlkGemmPipelineVer,
12 BlockGemmPipelineScheduler BlkGemmPipeSche,
13 index_t ThreadBlockSize,
14 index_t ScaleBlockSize,
15 typename ADataType,
16 typename AScaleDataType,
17 typename BDataType,
18 typename BScaleDataType,
19 typename ComputeDataType, // TODO: remove this as in this pipeline ADataType and BDataType
20 // must be used for compute
21 typename AccDataType,
22 typename ATileDesc,
23 typename BTileDesc,
24 typename AMmaTileDesc,
25 typename BMmaTileDesc,
26 index_t ABlockTransferSrcScalarPerVector,
27 index_t BBlockTransferSrcScalarPerVector,
28 index_t MPerBlock,
29 index_t NPerBlock,
30 index_t KPerBlock,
31 index_t MPerXDL,
32 index_t NPerXDL,
33 index_t MRepeat,
34 index_t NRepeat,
35 index_t KPack,
36 bool GUFusion = false>
38{
39
40 // Hardware MX GEMM pipeline
41 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
42 {
43 if constexpr(GUFusion)
44 {
45 return nullptr;
46 }
47 else
48 {
49 return BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v1<BlkGemmPipeSche,
50 ThreadBlockSize,
51 ScaleBlockSize,
52 ADataType,
53 AScaleDataType,
54 BDataType,
55 BScaleDataType,
56 ATileDesc,
57 BTileDesc,
58 AMmaTileDesc,
59 BMmaTileDesc,
60 ABlockTransferSrcScalarPerVector,
61 BBlockTransferSrcScalarPerVector,
62 MPerBlock,
63 NPerBlock,
64 KPerBlock,
65 MPerXDL,
66 NPerXDL,
67 MRepeat,
68 NRepeat,
69 KPack>{};
70 }
71 }
72 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
73 {
74 if constexpr(GUFusion)
75 {
77 BlkGemmPipeSche,
78 ThreadBlockSize,
79 ScaleBlockSize,
80 ADataType,
81 AScaleDataType,
82 BDataType,
83 BScaleDataType,
84 ATileDesc,
85 BTileDesc,
86 AMmaTileDesc,
87 BMmaTileDesc,
88 ABlockTransferSrcScalarPerVector,
89 BBlockTransferSrcScalarPerVector,
90 MPerBlock,
91 NPerBlock,
92 KPerBlock,
93 MPerXDL,
94 NPerXDL,
95 MRepeat,
96 NRepeat,
97 KPack>{};
98 }
99 else
100 {
101 return BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlkGemmPipeSche,
102 ThreadBlockSize,
103 ScaleBlockSize,
104 ADataType,
105 AScaleDataType,
106 BDataType,
107 BScaleDataType,
108 ATileDesc,
109 BTileDesc,
110 AMmaTileDesc,
111 BMmaTileDesc,
112 ABlockTransferSrcScalarPerVector,
113 BBlockTransferSrcScalarPerVector,
114 MPerBlock,
115 NPerBlock,
116 KPerBlock,
117 MPerXDL,
118 NPerXDL,
119 MRepeat,
120 NRepeat,
121 KPack>{};
122 }
123 }
124 else
125 {
126 std::cerr << "MX GEMM Pipeline configuration is not available" << std::endl;
127 }
128}
129
130} // namespace ck
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v1
Definition blkgemmpipe_scheduler.hpp:14
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
constexpr auto BlockGemmMXNBSPipeline_Selector()
Definition blockwise_gemm_pipeline_xdlops_mx_moe_nbs_selector.hpp:37
Definition blockwise_gemm_pipeline_xdlops_mx_moe_gufusion_v3.hpp:38
Definition blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v1.hpp:38
Definition blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp:38