thread_group_tensor_slice_transfer_gather_direct_load.hpp Source File

thread_group_tensor_slice_transfer_gather_direct_load.hpp Source File#

Composable Kernel: thread_group_tensor_slice_transfer_gather_direct_load.hpp Source File
thread_group_tensor_slice_transfer_gather_direct_load.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
11
12namespace ck {
13
42template <typename ThreadGroup,
43 typename BlockSliceLengths,
44 typename ThreadClusterLengths,
45 typename ThreadClusterArrangeOrder,
46 typename SrcData,
47 typename DstData,
48 typename SrcDesc,
49 typename DstDesc,
50 typename SrcDimAccessOrder,
51 index_t SrcVectorDim,
52 index_t DstVectorDim,
53 index_t ScalarPerVector,
54 typename IndexType,
55 index_t GatherDim = 1>
57{
60
61 using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
62 using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
63
64 using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
65 using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
66
67 static constexpr auto I0 = Number<0>{};
68 static constexpr auto I1 = Number<1>{};
69
70 static constexpr auto block_slice_lengths = BlockSliceLengths{};
71 static constexpr auto thread_cluster_lengths = ThreadClusterLengths{};
72
75 // After a load, each thread moves by `thread_steps` instead of loading the next elements.
76 // It makes the whole wavefront load contiguous memory, what is required for direct loads.
80
81 static __device__ constexpr bool AreThreadClusterLengthsValid()
82 {
83 // Make sure that ThreadClusterLengths are set in a way that allows for contiguous writes to
84 // LDS by the threads from a single wavefront.
85 // Examples (assuming 64 threads in a wavefront, 128 in a thread block):
86 // 1. BlockSliceLengths = [K0PerBlock, MPerBlock, K1PerBlock] = [4, 128, 8],
87 // data type = fp32 -> ScalarPerVector = 1
88 // INVALID: ThreadClusterLengths = [4, 4, 8] since in the first iteration, threads 0-31
89 // write [0, 0, 0] - [0, 3, 7] and thread 32 writes [1, 0, 0] instead of
90 // [0, 4, 0].
91 // VALID: ThreadClusterLengths = [2, 8, 8] or [1, 16, 8] since in the first iteration,
92 // threads 0-63 write [0, 0, 0] - [0, 7, 7] -> 64 consecutive elements (DWORDs).
93 // 2. BlockSliceLengths = [K0PerBlock, MPerBlock, K1PerBlock] = [4, 128, 8],
94 // data type = fp16 -> ScalarPerVector = 2
95 // NOTE: ThreadClusterLengths must take into account that each thread writes two
96 // elements (single DWORD) along the contiguous dimension.
97 // INVALID: ThreadClusterLengths = [4, 4, 8] since each 8 threads would try to write
98 // 8 * 2 elements of K1PerBlock and there are only 8;
99 // ThreadClusterLengths = [4, 8, 4] since in the first iteration, threads 0-31
100 // write [0, 0, 0] - [0, 7, 7] (7 since each writes 2 elements) and thread 32
101 // writes [1, 0, 0] instead of [0, 8, 0].
102 // VALID: ThreadClusterLengths = [4, 16, 4] or [2, 32, 4] or [1, 64, 4] since in the
103 // first iteration, threads 0-63 write [0, 0, 0] - [0, 15, 7] -> 128 consecutive
104 // elements = 64 consecutive DWORDs.
105#if defined(__gfx950__)
106 int num_contiguous_dwords = 4;
107#else
108 int num_contiguous_dwords = 1;
109#endif
110 bool is_contiguous = true;
111 static_for<0, nDim, 1>{}([&](auto i) {
112 if(is_contiguous)
113 {
114 num_contiguous_dwords *= thread_cluster_lengths[nDim - i - 1];
115 }
116 if(thread_slice_lengths[nDim - i - 1] > 1)
117 {
118 is_contiguous = false;
119 }
120 });
121 constexpr index_t wavefront_size = get_warp_size();
122 const bool wave_contiguous = num_contiguous_dwords % wavefront_size == 0;
123
124 bool thread_slice_lengths_correct = true;
125 static_for<0, nDim, 1>{}([&](auto i) {
126 if(thread_slice_lengths[i] <= 0)
127 {
128 thread_slice_lengths_correct = false;
129 }
130 });
131
132 return wave_contiguous && thread_slice_lengths_correct;
133 }
134
136 const SrcDesc& src_desc,
137 const Index& src_block_slice_origin,
138 const DstDesc& dst_desc,
139 const Index& dst_block_slice_origin,
141 : gather_offsets_(gather_offsets)
142 {
144 "Direct load transfer does not support datatypes conversion. Source and "
145 "destination data types must be the same.");
146
147 static_assert(
148 DstVectorDim == nDim - 1,
149 "Direct load transfer requires the destination vector dimension to be the last one.");
150
151 static_assert(ScalarPerVector == 1 || SrcVectorDim == DstVectorDim,
152 "When loading more than one element per thread at once, the contiguous "
153 "dimension must be the same between source and destination.");
154
155 // constexpr auto dword_bytes = 4;
156 // constexpr auto bytes_per_thread_load = ScalarPerVector * sizeof(SrcData);
157 // static_assert(bytes_per_thread_load == dword_bytes,
158 // "Direct load transfer requires each thread to load exactly a single "
159 // "DWORD of data.");
160
163 nDim == ThreadClusterLengths::Size(),
164 "Inconsistent number of dimensions across lengths and descriptors.");
165
166 static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
167 "The number of threads cannot be less than the number of elements in "
168 "thread cluster lengths.");
169
170 // static_assert(
171 // AreThreadClusterLengthsValid(),
172 // "Thread cluster lengths are incorrect. They must be set in a way that allows a single
173 // " "wavefront to write contiguous DWORDs into LDS memory. ");
174
175 const auto thread_cluster_idx =
176 thread_cluster_desc_.CalculateBottomIndex(make_multi_index(ThreadGroup::GetThreadId()));
177
178 constexpr auto wave_cluster_lengths = generate_sequence_v2(
179 [&](auto i) {
180 if constexpr(ThreadClusterArrangeOrder{}.At(i) == (nDim - 3))
181 {
182 return Number<ThreadGroup::GetNumOfThread() / 64>{};
183 }
184 else
185 {
186 return I1;
187 }
188 },
189 Number<nDim>{});
190
191 constexpr auto wave_thread_cluster_lengths = ThreadClusterLengths{} / wave_cluster_lengths;
192 constexpr auto wave_single_load_size =
193 wave_thread_cluster_lengths * thread_single_load_size;
194 constexpr auto wave_cluster_desc_ =
195 make_cluster_descriptor(wave_cluster_lengths, ThreadClusterArrangeOrder{});
196
197 const auto wave_cluster_idx = wave_cluster_desc_.CalculateBottomIndex(
198 make_multi_index(ThreadGroup::GetThreadId() / 64));
199
200 const auto thread_data_idx_begin = thread_cluster_idx * thread_single_load_size;
201 const auto wave_data_idx_begin = wave_cluster_idx * wave_single_load_size;
202
203 SetSrcSliceOrigin(src_desc, src_block_slice_origin + thread_data_idx_begin);
204 // We don't need threadwise offset for lds since it was calculate by HW
205 // We still need input the wavewise offset.
206 SetDstSliceOrigin(dst_desc, dst_block_slice_origin + wave_data_idx_begin);
207 }
208
209 __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
210 {
211 auto adjusted_src_origin_idx = [&]() {
212 Index idx;
213 static_for<0, nDim, 1>{}([&](auto i) {
214 idx(i) = i.value == GatherDim ? 0 : src_slice_origin_idx[Number<i>{}];
215 });
216 return idx;
217 }();
218
219 // CK_PRINT<decltype(adjusted_src_origin_idx)>();
220 // CK_PRINT<decltype(src_slice_origin_idx)>();
221
222 src_coord_ = make_tensor_coordinate(src_desc, adjusted_src_origin_idx);
223 src_slice_origin_ = adjusted_src_origin_idx;
224 }
225
226 __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
227 {
228 dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
229 dst_slice_origin_ = dst_slice_origin_idx;
230 }
231
232 __device__ void ResetDstSliceWindow(const DstDesc& dst_desc)
233 {
234 dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_);
235 }
236
237 template <typename SrcBuffer, typename DstBuffer>
238 __device__ void Run(const SrcDesc& src_desc,
239 const SrcBuffer& src_buf,
240 const DstDesc& dst_desc,
241 DstBuffer& dst_buf)
242 {
243 static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global,
244 "Source data must come from a global memory buffer.");
245 static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
246 "Destination data must be stored in an LDS memory buffer.");
247
248 static_assert(
250 "SrcBuffer and SrcData data types must be consistent.");
251 static_assert(
253 "DstBuffer and DstData data types must be consistent.");
254
255 constexpr auto dst_access_lengths = thread_slice_lengths;
256
257 const auto dst_forward_steps = generate_steps(dst_desc, 1);
258 const auto dst_backward_steps = generate_steps(dst_desc, -1);
259 const auto src_forward_steps = generate_steps(src_desc, 1);
260 const auto src_backward_steps = generate_steps(src_desc, -1);
261
262 // Loop over the destination block and copy data.
263 static_ford<decltype(dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
264 IndexType gather_offset = gather_offsets_[ordered_dst_access_idx[Number<GatherDim>{}]];
265 // src_coord_xor_ = src_coord_;
266 // src_coord_xor_.GetIndex().At(I0) =
267 // src_coord_.GetIndex().At(I0) ^ ((threadIdx.x % 64) / 8);
268 Index new_index = src_coord_.GetIndex();
269 new_index(I0) = src_coord_.GetIndex().At(I0) ^ ((threadIdx.x % 64) / 8);
270 src_coord_xor_ = make_tensor_coordinate(src_desc, new_index);
271
272 const IndexType src_offset = src_coord_xor_.GetOffset() + gather_offset;
273 const IndexType dst_offset = __builtin_amdgcn_readfirstlane(dst_coord_.GetOffset());
274
275 // Check if src data is not in the logic padding area.
276 // Leave the HW for oob checking
277 // const bool is_src_valid =
278 // coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc,
279 // src_coord_);
280
281 src_buf.template DirectCopyToLds<remove_cvref_t<decltype(dst_buf)>, ScalarPerVector>(
282 dst_buf, src_offset, dst_offset, true);
283
284 constexpr auto move_src_on_dim = [&]() constexpr {
286
287 static_for<0, nDim, 1>{}([&](auto i) {
288 move_on_dim_(i) = ordered_dst_access_idx[i] < dst_access_lengths[i] - 1;
289
290 static_for<i + 1, nDim, 1>{}([&](auto j) {
291 move_on_dim_(i) &= ordered_dst_access_idx[j] == dst_access_lengths[j] - 1;
292 });
293 move_on_dim_(i) &= i.value != GatherDim;
294 });
295
296 return move_on_dim_;
297 }();
298
299 constexpr auto move_dst_on_dim = [&]() constexpr {
301
302 static_for<0, nDim, 1>{}([&](auto i) {
303 move_on_dim_(i) = ordered_dst_access_idx[i] < dst_access_lengths[i] - 1;
304
305 static_for<i + 1, nDim, 1>{}([&](auto j) {
306 move_on_dim_(i) &= ordered_dst_access_idx[j] == dst_access_lengths[j] - 1;
307 });
308 });
309
310 return move_on_dim_;
311 }();
312
313 // Decide whether to move forward or backward.
314 constexpr auto forward_sweep = [&]() {
316
317 forward_sweep_(I0) = true;
318
319 static_for<1, nDim, 1>{}([&](auto i) {
320 index_t tmp = ordered_dst_access_idx[I0];
321
322 static_for<1, i, 1>{}([&](auto j) {
323 tmp = tmp * dst_access_lengths[j] + ordered_dst_access_idx[j];
324 });
325
326 forward_sweep_(i) = tmp % 2 == 0;
327 });
328
329 return forward_sweep_;
330 }();
331
332 static_for<0, nDim, 1>{}([&](auto i) {
333 // Move the source coordinate.
334 if constexpr(move_src_on_dim[i])
335 {
336 if constexpr(forward_sweep[i])
337 {
338 move_tensor_coordinate(src_desc, src_coord_, src_forward_steps[i]);
339 }
340 else
341 {
342 move_tensor_coordinate(src_desc, src_coord_, src_backward_steps[i]);
343 }
344 }
345
346 // Move the destination coordinate.
347 if constexpr(move_dst_on_dim[i])
348 {
349 if constexpr(forward_sweep[i])
350 {
351 move_tensor_coordinate(dst_desc, dst_coord_, dst_forward_steps[i]);
352 }
353 else
354 {
355 move_tensor_coordinate(dst_desc, dst_coord_, dst_backward_steps[i]);
356 }
357 }
358 });
359 });
360
361 // Reset the destination slice since the entire buffer has been already filled.
362 ResetDstSliceWindow(dst_desc);
363 }
364
365 __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
366 {
367 src_slice_origin_ = src_slice_origin_ + step;
368 src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_);
369 }
370
371 template <typename DescType>
372 __device__ auto generate_steps(const DescType& desc, int sign)
373 {
374 return generate_tuple(
375 [&](auto i) {
376 Index step_idx;
377
378 static_for<0, nDim, 1>{}([&](auto j) {
379 step_idx(j) = (i.value == j.value) ? sign * thread_steps[i] : 0;
380 });
381
382 return make_tensor_coordinate_step(desc, step_idx);
383 },
384 Number<nDim>{});
385 }
386
387 private:
388 static constexpr auto thread_cluster_desc_ =
389 make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
390
391 SrcCoord src_coord_;
392 SrcCoord src_coord_xor_;
393 DstCoord dst_coord_;
394 Index src_slice_origin_;
395 Index dst_slice_origin_;
397 // static constexpr auto a_grid_xor_desc = make_naive_tensor_descriptor_packed(
398 // make_tuple(Number<AK0 ^ ((threadIdx / AK0) % AK0)>{}, Number<M>{}, Number<AK1>{}));
399};
400
401} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc &, const VisibleIndex &idx_diff_visible, UpdateLowerIndexHack)
Definition tensor_description/tensor_descriptor.hpp:444
__host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const TensorCoordStep &coord_step)
Definition tensor_description/tensor_descriptor.hpp:508
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Lds
Definition amd_address_space.hpp:18
@ Global
Definition amd_address_space.hpp:17
__host__ __device__ constexpr auto generate_sequence_v2(F &&f, Number< N >)
Definition sequence_helper.hpp:25
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
__host__ __device__ constexpr auto generate_sequence(F, Number< N >)
Definition sequence_helper.hpp:18
__host__ __device__ constexpr auto make_tensor_coordinate(const TensorDesc &tensor_desc, const VisibleIndex &idx_visible)
Definition tensor_description/tensor_descriptor.hpp:407
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
__host__ __device__ constexpr const TData & At(index_t i) const
Definition utility/array.hpp:22
__device__ void ResetDstSliceWindow(const DstDesc &dst_desc)
Definition thread_group_tensor_slice_transfer_gather_direct_load.hpp:232
static constexpr auto I0
Definition thread_group_tensor_slice_transfer_gather_direct_load.hpp:67
__device__ void SetDstSliceOrigin(const DstDesc &dst_desc, const Index &dst_slice_origin_idx)
Definition thread_group_tensor_slice_transfer_gather_direct_load.hpp:226
decltype(make_tensor_coordinate(SrcDesc{}, Index{})) SrcCoord
Definition thread_group_tensor_slice_transfer_gather_direct_load.hpp:61
decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})) SrcCoordStep
Definition thread_group_tensor_slice_transfer_gather_direct_load.hpp:64
static constexpr index_t gather_num
Definition thread_group_tensor_slice_transfer_gather_direct_load.hpp:79
__device__ constexpr ThreadGroupTensorSliceTransfer_Gather_DirectLoad(const SrcDesc &src_desc, const Index &src_block_slice_origin, const DstDesc &dst_desc, const Index &dst_block_slice_origin, const StaticallyIndexedArray< IndexType, gather_num > &gather_offsets)
Definition thread_group_tensor_slice_transfer_gather_direct_load.hpp:135
static constexpr auto thread_cluster_lengths
Definition thread_group_tensor_slice_transfer_gather_direct_load.hpp:71
decltype(make_tensor_coordinate(DstDesc{}, Index{})) DstCoord
Definition thread_group_tensor_slice_transfer_gather_direct_load.hpp:62
static constexpr auto thread_steps
Definition thread_group_tensor_slice_transfer_gather_direct_load.hpp:77
static constexpr auto thread_single_load_size
Definition thread_group_tensor_slice_transfer_gather_direct_load.hpp:73
static __device__ constexpr bool AreThreadClusterLengthsValid()
Definition thread_group_tensor_slice_transfer_gather_direct_load.hpp:81
__device__ void SetSrcSliceOrigin(const SrcDesc &src_desc, const Index &src_slice_origin_idx)
Definition thread_group_tensor_slice_transfer_gather_direct_load.hpp:209
decltype(make_tensor_coordinate_step(DstDesc{}, Index{})) DstCoordStep
Definition thread_group_tensor_slice_transfer_gather_direct_load.hpp:65
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition thread_group_tensor_slice_transfer_gather_direct_load.hpp:238
__device__ auto generate_steps(const DescType &desc, int sign)
Definition thread_group_tensor_slice_transfer_gather_direct_load.hpp:372
MultiIndex< nDim > Index
Definition thread_group_tensor_slice_transfer_gather_direct_load.hpp:59
static constexpr index_t nDim
Definition thread_group_tensor_slice_transfer_gather_direct_load.hpp:58
static constexpr auto block_slice_lengths
Definition thread_group_tensor_slice_transfer_gather_direct_load.hpp:70
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &step)
Definition thread_group_tensor_slice_transfer_gather_direct_load.hpp:365
static constexpr auto I1
Definition thread_group_tensor_slice_transfer_gather_direct_load.hpp:68
static constexpr auto thread_slice_lengths
Definition thread_group_tensor_slice_transfer_gather_direct_load.hpp:78
Definition threadwise_tensor_slice_transfer_util.hpp:20
Definition functional2.hpp:33
Definition functional3.hpp:97