epilogue_cshuffle_v3_wmma_base.hpp Source File

epilogue_cshuffle_v3_wmma_base.hpp Source File#

Composable Kernel: epilogue_cshuffle_v3_wmma_base.hpp Source File
epilogue_cshuffle_v3_wmma_base.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
8
9namespace ck {
10
11template <typename DsDataType,
12 typename EDataType,
13 typename AccDataType,
14 typename CShuffleDataType,
15 index_t MPerBlock,
16 index_t NPerBlock,
17 index_t MPerWmma,
18 index_t NPerWmma,
19 index_t MRepeat,
20 index_t NRepeat,
21 index_t CShuffleMRepeatPerShuffle,
22 index_t CShuffleNRepeatPerShuffle,
23 typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
24 typename CDEShuffleBlockTransferScalarPerVectors,
25 typename CDEElementwiseOperation,
26 typename ThisThreadBlock,
27 typename BlockwiseGemmPipe>
29{
30 static constexpr auto I0 = Number<0>{};
31 static constexpr auto I1 = Number<1>{};
32 static constexpr auto I2 = Number<2>{};
33 static constexpr auto I3 = Number<3>{};
34 static constexpr auto I4 = Number<4>{};
35 static constexpr auto I5 = Number<5>{};
36 static constexpr auto I6 = Number<6>{};
37
38 static constexpr index_t NumDTensor = DsDataType::Size();
40 CDEShuffleBlockTransferScalarPerVectors{}[I0];
41
45 Sequence<CShuffleMRepeatPerShuffle,
46 1,
47 1,
48 CShuffleNRepeatPerShuffle,
49 1,
50 1,
51 BlockwiseGemmPipe::MAccVgprs>>;
52
56 Sequence<1,
57 CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma,
58 1,
59 CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma>>;
60
61 // *Caution Here repeat is shuffle repeat
62 __device__ static constexpr auto
64 {
65 constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma);
66 constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
67
68 constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
72 I1,
74
75 return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
76 }
77
78 __device__ static constexpr auto GetCShuffleLDSDescriptor()
79 {
80 // C mapping in single block
81 constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp =
82 BlockwiseGemmPipe::
83 GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
84
85 constexpr auto MWave =
86 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
87 .GetLength(I1);
88 constexpr auto MSubGroup =
89 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
90 .GetLength(I2);
91 constexpr auto NWave =
92 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
93 .GetLength(I4);
94 constexpr auto NThreadPerSubGroup =
95 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
96 .GetLength(I5);
97 constexpr auto MAccVgprs =
98 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
99 .GetLength(I6);
100
105 Number<CShuffleMRepeatPerShuffle>{}, // MRepeat per shuffle repeat
106 MWave, // MWave
107 MSubGroup, // MSubGroup * MAccVgprs = MPerWmma
108 MAccVgprs)),
111 Number<CShuffleNRepeatPerShuffle>{}, // NRepeat per shuffle repeat
112 NWave, // NWave
113 NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma
116 }
117
118 __device__ static auto GetVgprToLDSEpilogueDescriptor()
119 {
120 // C mapping in single block
121 constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp =
122 BlockwiseGemmPipe::
123 GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
124
125 constexpr auto MWave =
126 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
127 .GetLength(I1);
128 constexpr auto MSubGroup =
129 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
130 .GetLength(I2);
131 constexpr auto NWave =
132 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
133 .GetLength(I4);
134 constexpr auto NThreadPerSubGroup =
135 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
136 .GetLength(I5);
137 constexpr auto MAccVgprs =
138 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
139 .GetLength(I6);
140
141 // calculate origin of thread output tensor on global memory
142 // blockwise GEMM c matrix starting index
143 const auto c_thread_mtx_on_block =
144 BlockwiseGemmPipe::CalculateCThreadOriginDataIndex(I0, I0);
145
146 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
147 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
148
149 const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor =
151 make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MSubGroup, MAccVgprs))),
154
155 const auto m_thread_data_on_block_idx =
156 m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor
157 .CalculateBottomIndex(make_multi_index(m_thread_data_on_block));
158
159 const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor =
161 make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NThreadPerSubGroup))),
164
165 const auto n_thread_data_on_block_idx =
166 n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex(
167 make_multi_index(n_thread_data_on_block));
168
170 AccDataType,
171 CShuffleDataType,
172 decltype(BlockwiseGemmPipe::
173 GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()),
174 decltype(GetCShuffleLDSDescriptor()),
176 Sequence<CShuffleMRepeatPerShuffle,
177 I1,
178 I1,
179 CShuffleNRepeatPerShuffle,
180 I1,
181 I1,
182 MAccVgprs>,
184 6,
185 1,
187 1,
190 m_thread_data_on_block_idx[I1],
191 m_thread_data_on_block_idx[I2],
192 0,
193 n_thread_data_on_block_idx[I1],
194 n_thread_data_on_block_idx[I2],
195 m_thread_data_on_block_idx[I3]),
197 }
198
199 template <InMemoryDataOperationEnum EGlobalMemoryDataOperation,
200 typename InterDataType,
201 typename CDsDescRefs,
202 typename EGridDesc>
203 __device__ static auto
204 GetLDSToVmemEpilogueDescriptor(CDsDescRefs& c_ds_desc_refs,
205 EGridDesc& e_grid_desc_mblock_mperblock_nblock_nperblock,
206 CDEElementwiseOperation& cde_element_op,
207 const index_t& block_m_id,
208 const index_t& block_n_id)
209 {
210 // tuple of starting index of C/Ds blockwise copy
211 const auto idx_c_ds_block_begin = container_concat(
212 make_tuple(make_multi_index(0, 0, 0, 0)),
213 generate_tuple([&](auto) { return make_multi_index(block_m_id, 0, block_n_id, 0); },
215
216 // blockwise copy which loads C from LDS, D from global, applies elementwise
217 // operation and stores result E to global
219 ThisThreadBlock, // ThreadGroup
220 decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
222 CDsDescRefs,
223 decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
224 CDEElementwiseOperation, // ElementwiseOperation,
225 Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // DstInMemOps,
226 Sequence<1,
227 CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma,
228 1,
229 CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves *
230 NPerWmma>, // BlockSliceLengths,
231 CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
232 Sequence<0, 1, 2, 3>, // ThreadClusterArrangeOrder,
233 Sequence<0, 1, 2, 3>, // SrcDimAccessOrder,
234 Sequence<0, 1, 2, 3>, // DstDimAccessOrder,
235 3, // SrcVectorDim,
236 3, // DstVectorDim,
237 CDEShuffleBlockTransferScalarPerVectors, // SrcScalarPerVectors
238 EShuffleBlockTransferScalarPerVector, // DstScalarPerVector
242 false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
243 Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
244 1,
245 Tuple<InterDataType>>{c_ds_desc_refs,
246 idx_c_ds_block_begin,
247 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
248 make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)),
249 cde_element_op};
250 }
251};
252
253} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition utility/sequence.hpp:928
__host__ __device__ constexpr auto container_concat(const X &x, const Ys &... ys)
Definition utility/container_helper.hpp:320
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition utility/tuple.hpp:218
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition utility/sequence.hpp:925
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
Definition epilogue_cshuffle_v3_wmma_base.hpp:29
static constexpr auto I6
Definition epilogue_cshuffle_v3_wmma_base.hpp:36
static constexpr auto I2
Definition epilogue_cshuffle_v3_wmma_base.hpp:32
static constexpr auto I4
Definition epilogue_cshuffle_v3_wmma_base.hpp:34
static constexpr auto I5
Definition epilogue_cshuffle_v3_wmma_base.hpp:35
static constexpr index_t NumDTensor
Definition epilogue_cshuffle_v3_wmma_base.hpp:38
static constexpr auto I0
Definition epilogue_cshuffle_v3_wmma_base.hpp:30
static constexpr auto I3
Definition epilogue_cshuffle_v3_wmma_base.hpp:33
static __device__ auto GetLDSToVmemEpilogueDescriptor(CDsDescRefs &c_ds_desc_refs, EGridDesc &e_grid_desc_mblock_mperblock_nblock_nperblock, CDEElementwiseOperation &cde_element_op, const index_t &block_m_id, const index_t &block_n_id)
Definition epilogue_cshuffle_v3_wmma_base.hpp:204
SpaceFillingCurve< Sequence< MRepeat, 1, 1, NRepeat, 1, 1, BlockwiseGemmPipe::MAccVgprs >, Sequence< 0, 1, 2, 3, 4, 5, 6 >, Sequence< CShuffleMRepeatPerShuffle, 1, 1, CShuffleNRepeatPerShuffle, 1, 1, BlockwiseGemmPipe::MAccVgprs > > SpaceFillingCurveVgpr
Definition epilogue_cshuffle_v3_wmma_base.hpp:42
static __device__ constexpr auto GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
Definition epilogue_cshuffle_v3_wmma_base.hpp:63
static constexpr auto I1
Definition epilogue_cshuffle_v3_wmma_base.hpp:31
static __device__ auto GetVgprToLDSEpilogueDescriptor()
Definition epilogue_cshuffle_v3_wmma_base.hpp:118
static constexpr auto EShuffleBlockTransferScalarPerVector
Definition epilogue_cshuffle_v3_wmma_base.hpp:39
SpaceFillingCurve< Sequence< 1, MPerBlock, 1, NPerBlock >, Sequence< 0, 2, 1, 3 >, Sequence< 1, CShuffleMRepeatPerShuffle *BlockwiseGemmPipe::MWaves *MPerWmma, 1, CShuffleNRepeatPerShuffle *BlockwiseGemmPipe::NWaves *NPerWmma > > SpaceFillingCurveVmem
Definition epilogue_cshuffle_v3_wmma_base.hpp:53
static __device__ constexpr auto GetCShuffleLDSDescriptor()
Definition epilogue_cshuffle_v3_wmma_base.hpp:78
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
Definition thread_group.hpp:12
Definition thread_group_tensor_slice_transfer_v7r3.hpp:48
Definition threadwise_tensor_slice_transfer.hpp:39
Definition utility/tuple.hpp:117
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340