block_norm_reduce.hpp Source File

block_norm_reduce.hpp Source File#

Composable Kernel: block_norm_reduce.hpp Source File
block_norm_reduce.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8
9namespace ck_tile {
10
11template <typename Problem_, typename Policy_ = void>
13{
15 using XDataType = typename Problem::XDataType;
16 using ComputeDataType = typename Problem::ComputeDataType;
17 static constexpr bool kFastFDiv = Problem::kFastFDiv;
18 static constexpr bool kWelford = Problem::kWelford;
19
21
22 // [CAUSION] - max_count_ is to deal with the padding problem
23 // max_count_ is depend on caller, eg: naive and splitN norm_reduce will have different
24 // calculation of max_count_
25 // -> use block_welford_calculate_max_count to compute
26 template <typename XDistributedTensor_,
27 typename MeanDistributedTensor_,
28 typename VarDistributedTensor_>
29 CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor,
30 MeanDistributedTensor_& mean_tensor,
31 VarDistributedTensor_& var_tensor,
32 int& cur_count_, // -> prefer init as zero
33 const int& max_count_)
34 {
35 constexpr auto I0 = number<0>{};
36 constexpr auto I1 = number<1>{};
37
38 constexpr auto spans = XDistributedTensor_::get_distributed_spans();
39
40 sweep_tile_span(spans[I1], [&](auto dstr_idx_i1) {
41 if(cur_count_ < max_count_)
42 {
43 ++cur_count_;
44 sweep_tile_span(spans[I0], [&](auto dstr_idx_i0) {
45 constexpr auto in_dstr_idx = make_tuple(dstr_idx_i0, dstr_idx_i1);
46 constexpr auto out_dstr_idx = make_tuple(dstr_idx_i0);
47
48 auto x = ck_tile::type_convert<ComputeDataType>(x_tensor[in_dstr_idx]);
49 if(kWelford)
50 {
51 welford_update(mean_tensor(out_dstr_idx),
52 var_tensor(out_dstr_idx),
53 x,
54 cur_count_,
56 }
57 else
58 {
59 mean_tensor(out_dstr_idx) += x;
60 var_tensor(out_dstr_idx) += x * x;
61 }
62 });
63 }
64 });
65 }
66
67 template <typename XDistributedTensor_>
69 {
70 static_assert(std::is_same_v<XDataType, typename XDistributedTensor_::DataType>, "wrong!");
71
72 constexpr auto reduce_dims = sequence<1>{};
73
74 constexpr auto dstr =
76 XDistributedTensor_::get_tile_distribution()
77 .get_static_tile_distribution_encoding(),
78 reduce_dims));
79
81
82 return tensor;
83 }
84
85 template <typename XDistributedTensor_>
87 operator()(const XDistributedTensor_& x_tensor, int& cur_count_, const int& max_count_)
88 {
91 clear_tile(mean_tensor);
92 clear_tile(var_tensor);
93
94 (*this)(x_tensor, mean_tensor, var_tensor, cur_count_, max_count_);
95
96 return ck_tile::make_tuple(mean_tensor, var_tensor);
97 }
98};
99
100template <typename Problem_, typename Policy_ = void>
102{
104 static constexpr bool kFastFDiv = Problem::kFastFDiv;
105 static constexpr bool kWelford = Problem::kWelford;
106
107 template <typename MeanDistributedTensor_, typename VarDistributedTensor_>
108 CK_TILE_DEVICE void
109 operator()(MeanDistributedTensor_& mean_tensor, VarDistributedTensor_& var_tensor, int& count)
110 {
111 using Dstr = typename MeanDistributedTensor_::StaticTileDistribution;
112 using DstrEncode = typename Dstr::DstrEncode;
113 using DstrEncodeDetail = typename DstrEncode::detail;
114
115 static_assert(std::is_same_v<Dstr, typename VarDistributedTensor_::StaticTileDistribution>,
116 "wrong!");
117
118 constexpr index_t NDimP = Dstr::get_num_of_dimension_p();
119 constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
120
121 constexpr index_t idim_p_lane = NDimP - 1;
122
123 // const auto ps_idx = make_array<index_t>(get_warp_id(), get_lane_id());
124 // const auto rs_idx =
125 // mean_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx);
126
127 constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size();
128 static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size());
129
130 const int original_count = count;
131
132 // loop over thread data
134 auto v_local_mean = mean_tensor.get_thread_buffer()[i];
135 auto v_local_var = var_tensor.get_thread_buffer()[i];
136 auto v_local_count = original_count;
137
138 // cross-lane reduce for replication
139 // only reduce on R dimension correspond to lane
140 // (lane id maps to this R dimension)
141 static_for<0, NDimR, 1>{}([&](auto idim_r) {
142 // FIXME: nasty to use does_p_own_r_
143 if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
144 {
145 constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
146
147 constexpr index_t lid_over_rid_derivative =
148 DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
149
150 static_assert(is_power_of_two_integer(r_length),
151 "wrong! only support power of 2 reduction");
152
153 constexpr index_t nstage = integer_log2_floor(r_length);
154
155 // reduction sweep forward
156 static_for<0, nstage, 1>{}([&](auto istage) {
157 // xor
158 index_t src_lane =
159 (__lane_id()) ^
160 (number<lid_over_rid_derivative << istage.value>{}.value);
161
162 // pull data from remote lane
163 const auto v_remote_mean = warp_shuffle(v_local_mean, src_lane);
164 const auto v_remote_var = warp_shuffle(v_local_var, src_lane);
165 if(kWelford)
166 {
167 const auto v_remote_count = warp_shuffle(v_local_count, src_lane);
168
169 // norm_reduce merge
170 welford_merge(v_local_mean,
171 v_local_var,
172 v_local_count,
173 v_remote_mean,
174 v_remote_var,
175 v_remote_count,
177 }
178 else
179 {
180 v_local_mean += v_remote_mean;
181 v_local_var += v_remote_var;
182 }
183 });
184 }
185 });
186
187 mean_tensor.get_thread_buffer()(i) = v_local_mean;
188 var_tensor.get_thread_buffer()(i) = v_local_var;
189 if(kWelford)
190 {
191 count = v_local_count;
192 }
193 });
194 }
195};
196
197template <typename Problem_, typename Policy_ = void>
199{
201 using BlockShape = typename Problem::BlockShape;
202 static constexpr bool kFastFDiv = Problem::kFastFDiv;
203 static constexpr bool kWelford = Problem::kWelford;
204 using smem_dtype = std::conditional_t<kWelford, fp32x4_t, fp32x2_t>;
205
206 template <typename MeanDistributedTensor_>
208 {
209 constexpr index_t num_reduce_warps = [&]() {
210 using Dstr = typename MeanDistributedTensor_::StaticTileDistribution;
211 using DstrEncode = typename Dstr::DstrEncode;
212 using DstrEncodeDetail = typename DstrEncode::detail;
213
214 constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
215
216 constexpr index_t idim_p_warp = 0;
217
218 index_t len_ = 1;
219 static_for<0, NDimR, 1>{}([&](auto idim_r) {
220 if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_warp][idim_r])
221 {
222 constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
223 len_ *= r_length;
224 }
225 });
226 return len_;
227 }();
228 return num_reduce_warps;
229 }
230
231 // return in byte
232 template <typename MeanDistributedTensor_>
234 {
235 // constexpr auto num_reduce_warps = GetReduceWarps<MeanDistributedTensor_>();
236
237 // data need to exchange is very small, we just pack mean+var+count -> 4dword
238 constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size();
239
240 // we need to store all data from every wave into smem
241 // e.g. 2x2 reduce along N
242 // -------------> reduce N
243 // | w0 | w1 | ___> | w01 |
244 // | w2 | w3 | | w23 |
245 //
246 // -> store data from every wave into LDS
247 //
248 //
249 // -------------> reduce N
250 // | w0 | w1 | w2 | w3 | -----> | w0123 |
251 //
252 // -> also store data from every wave into LDS
253 constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
254 return num_warps * 4 * thread_buf_size * sizeof(float);
255 }
256
257 template <typename MeanDistributedTensor_, typename VarDistributedTensor_>
258 CK_TILE_DEVICE void operator()(MeanDistributedTensor_& mean_tensor,
259 VarDistributedTensor_& var_tensor,
260 int& count,
261 void* smem)
262 {
263 using DataType = typename MeanDistributedTensor_::DataType;
264 using Dstr = typename MeanDistributedTensor_::StaticTileDistribution;
265 // using DstrEncode = typename Dstr::DstrEncode;
266 // using DstrEncodeDetail = typename DstrEncode::detail;
267
268 static_assert(std::is_same_v<Dstr, typename VarDistributedTensor_::StaticTileDistribution>,
269 "wrong!");
270
271 constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size();
272 static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size());
273
274 // Note: we always pack everything into fp32x4
275 smem_dtype* smem_ptr = reinterpret_cast<smem_dtype*>(smem);
276 const index_t lane_id = get_lane_id();
277 const index_t warp_id = get_warp_id();
278 constexpr auto num_reduce_warps = GetReduceWarps<MeanDistributedTensor_>();
279 constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
280 const index_t smem_offset = warp_id;
281
282 // skip if nonthing to do
283 if constexpr(num_reduce_warps == 1)
284 return;
285
286 // store into smem only for lane-0 within one warp
287 if(lane_id == 0)
288 {
290 smem_dtype local_scratch_;
291 local_scratch_[0] = bit_cast<float>(mean_tensor.get_thread_buffer()[i]);
292 local_scratch_[1] = bit_cast<float>(var_tensor.get_thread_buffer()[i]);
293 if(kWelford)
294 {
295 local_scratch_[2] = bit_cast<float>(count);
296 }
297 smem_ptr[smem_offset + i * num_warps] = local_scratch_;
298 });
299 }
301
302 // load from smem. here we let everythread to do compute :)
303 index_t local_warp_id = warp_id / num_reduce_warps;
304 index_t local_smem_os = local_warp_id * num_reduce_warps;
305 smem_dtype all_scratch[thread_buf_size * num_reduce_warps];
306 static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
307 static_for<0, num_reduce_warps, 1>{}([&](auto i_1) {
308 all_scratch[i_0 * num_reduce_warps + i_1] =
309 smem_ptr[i_0 * num_warps + local_smem_os + i_1];
310 });
311 });
312 block_sync_lds(); // TODO: we don't need sync here
313
314 // const int original_count = count;
315
316 static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
317 // TODO: use descriptor for this
318 auto v_local = all_scratch[i_0 * num_reduce_warps];
319 auto v_local_mean = bit_cast<DataType>(v_local[0]);
320 auto v_local_var = bit_cast<DataType>(v_local[1]);
321 int v_local_count = kWelford ? bit_cast<int>(v_local[2]) : 0;
322
323 // further reduce mean/var
324 static_for<0, num_reduce_warps - 1, 1>{}([&](auto i_1_n1) {
325 constexpr auto i_1 = number<i_1_n1 + 1>{};
326 const smem_dtype v_remote = all_scratch[i_0 * num_reduce_warps + i_1];
327 const auto v_remote_mean = bit_cast<DataType>(v_remote[0]);
328 const auto v_remote_var = bit_cast<DataType>(v_remote[1]);
329 if(kWelford)
330 {
331 const auto v_remote_count = bit_cast<int>(v_remote[2]);
332
333 welford_merge(v_local_mean,
334 v_local_var,
335 v_local_count,
336 v_remote_mean,
337 v_remote_var,
338 v_remote_count,
340 }
341 else
342 {
343 v_local_mean += v_remote_mean;
344 v_local_var += v_remote_var;
345 }
346 });
347
348 mean_tensor.get_thread_buffer()(i_0) = v_local_mean;
349 var_tensor.get_thread_buffer()(i_0) = v_local_var;
350 if(kWelford)
351 count = v_local_count;
352 });
353 }
354};
355
356// compute the max count for a last dim reduce
357// everything may have vector/repeat, so the max count could be uneven
358// TODO: specify which dim to compute and proper set the problem
359// TODO: BlockShape we reuse layernorm_fwd_shape :)
360template <typename BlockShape>
362{
363#if 0
364 using S = BlockShape;
365 index_t LastloopN = row_size % S::Block_N == 0 ? S::Block_N : row_size % S::Block_N;
366 constexpr index_t NThread = S::WarpPerBlock_N * S::ThreadPerWarp_N;
367 index_t iNLane = get_thread_id() % NThread;
368 index_t iN0 = LastloopN / (S::Vector_N * S::ThreadPerWarp_N);
369 index_t iN1 = (LastloopN % (S::Vector_N * S::ThreadPerWarp_N)) / S::Vector_N;
370 index_t N2 = (LastloopN % (S::Vector_N * S::ThreadPerWarp_N)) % S::Vector_N;
371 index_t iN3 = iNLane < iN1 ? S::Vector_N : iNLane == iN1 ? N2 : 0;
372 return iN0 * S::Vector_N + iN3;
373#endif
374 using S_ = BlockShape;
375 constexpr index_t ThreadsPerBlock_N = S_::WarpPerBlock_N * S_::ThreadPerWarp_N;
376
377 // TODO: we always check vector size, need be evenly devidable by vector-n
378 const index_t element_per_row = row_size / S_::Vector_N;
379 index_t lane_id_n = get_thread_id() % ThreadsPerBlock_N;
380
381 index_t cnt = 0;
382 // TODO: Repeat_N can not be too long, otherwise this is not good
384 index_t _a = lane_id_n < element_per_row ? 1 : 0;
385 cnt += _a;
386 lane_id_n += ThreadsPerBlock_N;
387 });
388 return cnt * S_::Vector_N;
389}
390
391// Note: this function must be called after all the computation
392template <typename VarDistributedTensor_, bool FastFdiv_ = false>
393CK_TILE_DEVICE constexpr void block_tile_welford_post_scale_var(VarDistributedTensor_& var_tensor,
394 int count,
396{
397 using DataType = typename VarDistributedTensor_::DataType;
399 [&count](auto& x) {
400 if(FastFdiv_ && std::is_same_v<DataType, float>)
401 {
402 x = x * __builtin_amdgcn_rcpf(type_convert<DataType>(count));
403 }
404 else
405 {
406 x = x / type_convert<DataType>(count);
407 }
408 },
409 var_tensor);
410}
411} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
CK_TILE_HOST_DEVICE constexpr auto make_reduce_tile_distribution_encoding(InDstr, sequence< InReduceDimXs... > reduce_dim_xs_in)
Definition tile_distribution_encoding.hpp:762
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_DEVICE constexpr index_t block_tile_welford_calculate_max_count(int row_size)
Definition block_norm_reduce.hpp:361
CK_TILE_DEVICE constexpr void block_tile_welford_post_scale_var(VarDistributedTensor_ &var_tensor, int count, bool_constant< FastFdiv_ >={})
Definition block_norm_reduce.hpp:393
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr bool is_power_of_two_integer(int32_t x)
Definition tile/core/numeric/math.hpp:462
CK_TILE_DEVICE index_t get_lane_id()
Definition arch.hpp:101
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
CK_TILE_DEVICE T warp_shuffle(const T &v_local, uint32_t src_lane)
Definition utility.hpp:78
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
CK_TILE_DEVICE index_t get_warp_id(bool_constant< ReturnSgpr >={})
Definition arch.hpp:104
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition tile_elementwise.hpp:23
CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X &x)
Definition bit_cast.hpp:11
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition sweep_tile.hpp:20
CK_TILE_HOST_DEVICE constexpr int32_t integer_log2_floor(int32_t x)
Definition tile/core/numeric/math.hpp:455
CK_TILE_DEVICE index_t get_thread_id()
Definition arch.hpp:117
CK_TILE_DEVICE void welford_update(T &mean, T &var, T x, int count, bool_constant< kFastFDiv >={})
Definition thread_welford.hpp:11
int32_t index_t
Definition integer.hpp:9
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition tile_elementwise.hpp:177
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition block_norm_reduce.hpp:199
static constexpr bool kFastFDiv
Definition block_norm_reduce.hpp:202
static constexpr bool kWelford
Definition block_norm_reduce.hpp:203
typename Problem::BlockShape BlockShape
Definition block_norm_reduce.hpp:201
remove_cvref_t< Problem_ > Problem
Definition block_norm_reduce.hpp:200
static CK_TILE_DEVICE constexpr index_t GetReduceWarps()
Definition block_norm_reduce.hpp:207
std::conditional_t< kWelford, fp32x4_t, fp32x2_t > smem_dtype
Definition block_norm_reduce.hpp:204
CK_TILE_DEVICE void operator()(MeanDistributedTensor_ &mean_tensor, VarDistributedTensor_ &var_tensor, int &count, void *smem)
Definition block_norm_reduce.hpp:258
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition block_norm_reduce.hpp:233
typename Problem::ComputeDataType ComputeDataType
Definition block_norm_reduce.hpp:16
static CK_TILE_DEVICE auto MakeMeanVarBlockTile()
Definition block_norm_reduce.hpp:68
remove_cvref_t< Problem_ > Problem
Definition block_norm_reduce.hpp:14
static constexpr bool kWelford
Definition block_norm_reduce.hpp:18
CK_TILE_DEVICE void operator()(const XDistributedTensor_ &x_tensor, MeanDistributedTensor_ &mean_tensor, VarDistributedTensor_ &var_tensor, int &cur_count_, const int &max_count_)
Definition block_norm_reduce.hpp:29
CK_TILE_DEVICE auto operator()(const XDistributedTensor_ &x_tensor, int &cur_count_, const int &max_count_)
Definition block_norm_reduce.hpp:87
CK_TILE_DEVICE constexpr BlockNormReduce()
Definition block_norm_reduce.hpp:20
typename Problem::XDataType XDataType
Definition block_norm_reduce.hpp:15
static constexpr bool kFastFDiv
Definition block_norm_reduce.hpp:17
Definition block_norm_reduce.hpp:102
CK_TILE_DEVICE void operator()(MeanDistributedTensor_ &mean_tensor, VarDistributedTensor_ &var_tensor, int &count)
Definition block_norm_reduce.hpp:109
static constexpr bool kWelford
Definition block_norm_reduce.hpp:105
remove_cvref_t< Problem_ > Problem
Definition block_norm_reduce.hpp:103
static constexpr bool kFastFDiv
Definition block_norm_reduce.hpp:104
Definition tile/core/numeric/integral_constant.hpp:13
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43