tile_window_linear.hpp Source File

tile_window_linear.hpp Source File#

Composable Kernel: tile_window_linear.hpp Source File
tile_window_linear.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
19
20namespace ck_tile {
21
22#define WINDOW_DISPATCH_ISSUE() \
23 if constexpr(i_access < 0) \
24 { \
25 static_for<0, NumAccess, 1>{}([&](auto ia) { issue(ia); }); \
26 } \
27 else \
28 { \
29 static_assert(i_access < NumAccess); \
30 issue(number<i_access>{}); \
31 }
32
33//
34// This version of tile window will pre-cache offset/flags based on need
35//
36// LinearBottomDims_, e.g seq<0, 1> for 2d tensor, the last one is linear dim
37// so last dim can use immediate offset to indexing, can save register
38// TODO: if using this struct, better use load_raw()/store_raw(), can control
39// the the immediate offset on the fly
40// space-filing-curve is non-snaked here!
41// This struct inherits from tile_window_with_tile_dstr_base, which is an intermediary base class
42// with the ultimate parent class being tile_window_base.
43template <typename BottomTensorView_,
44 typename WindowLengths_,
45 typename StaticTileDistribution_,
46 typename LinearBottomDims_>
48 : public tile_window_with_tile_dstr_base<tile_window_linear<BottomTensorView_,
49 WindowLengths_,
50 StaticTileDistribution_,
51 LinearBottomDims_>,
52 BottomTensorView_,
53 WindowLengths_,
54 StaticTileDistribution_>
55{
57 WindowLengths_,
58 StaticTileDistribution_,
59 LinearBottomDims_>,
60 BottomTensorView_,
61 WindowLengths_,
62 StaticTileDistribution_>;
63
65
66 static_assert(LinearBottomDims::size() == Base::BottomTensorView::get_num_of_dimension());
67
68 static constexpr auto I0 = number<0>{};
69 static constexpr auto I1 = number<1>{};
70
71 struct traits
72 {
73 private:
74 static constexpr auto get_num_non_linear_access()
75 {
76 constexpr auto sfc_access_lens = Base::Traits::SFC_Ys::access_lengths;
77 using ys_to_rhs_major =
78 typename decltype(typename Base::TileDstr{}
79 .get_static_tile_distribution_encoding())::Ys2RHsMajor;
80
81 constexpr auto non_linear = [&]() {
82 index_t cnt = 1;
83 static_for<0, Base::NDimY, 1>{}([&](auto i_dim_y) {
84 constexpr auto rhs_major = ys_to_rhs_major{}[i_dim_y];
85 constexpr auto target_h_dim = number<rhs_major - 1>{}; // no r dim here!
86 if constexpr(LinearBottomDims{}[target_h_dim] == 0)
87 {
88 cnt *= sfc_access_lens[i_dim_y];
89 }
90 });
91 return cnt;
92 }();
93
94 return non_linear;
95 }
96
97 // example:
98 // non_linear_access_map: sequence<0, 0, 0, 0, 1, 1, 1, 1> for 8 access, totally 2 register
99 // used
100 // -> histogram : sequence<4, 4>
101 // -> prefixsum : seqneuce<0, 4, 8>
102 // non_linear_access_map: sequence<0, 1, 2, 3, 4, 5, 6, 7> for 8 access, totally 8 register
103 // used, will pre-cache 8
104 // -> histogram : sequence<1, 1, 1, 1, 1, 1, 1, 1>
105 // -> prefixsum : seqneuce<0, 1, 2, 3, 4, 5, 6, 7, 8>
106 // non_linear_access_map: sequence<0, 0, 1, 1, 2, 2, 3, 3> for 8 access, totally 4 register
107 // used, will pre-cache 4
108 // -> histogram : sequence<2, 2, 2, 2>
109 // -> prefixsum : seqneuce<0, 2, 4, 6, 8>
110 static constexpr auto get_non_linear_access_map()
111 {
112 constexpr auto sfc_access_lens = Base::Traits::SFC_Ys::access_lengths;
113 using ys_to_rhs_major =
114 typename decltype(typename Base::TileDstr{}
115 .get_static_tile_distribution_encoding())::Ys2RHsMajor;
116 constexpr auto non_linear_map = [&]() {
118 index_t cumulative_len_ = 1;
119 index_t cumulative_non_linear_len_ = 1;
120 static_for<0, Base::NDimY, 1>{}([&](auto i_y) {
121 constexpr auto i_dim_y = number<Base::NDimY - i_y - 1>{}; // from right to left
122 constexpr auto rhs_major = ys_to_rhs_major{}[i_dim_y];
123 constexpr auto target_h_dim = number<rhs_major - 1>{}; // no r dim here!
124 constexpr auto is_linear_dim = LinearBottomDims{}[target_h_dim];
125
127 constexpr auto current_len_ = sfc_access_lens[i_dim_y];
128
129 // copy cumulative length as current pattern
130 for(auto i_ = 0; i_ < cumulative_len_; i_++)
131 {
132 current_m_(i_) = m_[i_];
133 }
134 for(auto j_ = 0; j_ < current_len_; j_++)
135 {
136 auto j_offset_ = is_linear_dim ? 0 : j_ * cumulative_non_linear_len_;
137 for(auto i_ = 0; i_ < cumulative_len_; i_++)
138 {
139 m_(j_ * cumulative_len_ + i_) = current_m_[i_] + j_offset_;
140 }
141 }
142 cumulative_len_ *= current_len_;
143 if(!is_linear_dim)
144 cumulative_non_linear_len_ *= current_len_;
145 });
146 return m_;
147 }();
148
149 return TO_SEQUENCE(non_linear_map, Base::Traits::NumAccess);
150 }
151
152 static constexpr auto get_non_linear_access_histogram()
153 {
154 constexpr auto m_ = get_non_linear_access_map();
155
156 constexpr auto r_ =
157 typename arithmetic_sequence_gen<0, get_num_non_linear_access() + 1, 1>::type{};
158
159 constexpr auto h_ = histogram_sorted_sequence(m_, r_);
160
161 return h_;
162 }
163
164 static constexpr auto get_non_linear_access_histogram_prefix_sum()
165 {
166 constexpr auto h_ = get_non_linear_access_histogram();
167 constexpr auto h_prefix_sum_ = prefix_sum_sequence(h_);
168 return h_prefix_sum_;
169 }
170
171 public:
172 static constexpr index_t NumAccess_NonLinear = get_num_non_linear_access();
173 using AccessMap_NonLinear = decltype(get_non_linear_access_map()); // sequence
174 using AccessHistogram_NonLinear = decltype(get_non_linear_access_histogram());
175 using AccessPrefixSum_NonLinear = decltype(get_non_linear_access_histogram_prefix_sum());
176 };
177
178 static constexpr index_t NumAccess = Base::Traits::NumAccess;
183
184 CK_TILE_DEVICE constexpr tile_window_linear() = default;
185
187 const typename Base::BottomTensorView& bottom_tensor_view,
188 const typename Base::WindowLengths& window_lengths,
189 const typename Base::BottomTensorIndex& window_origin,
190 const typename Base::TileDstr& tile_distribution)
192 {
193 this->bottom_tensor_view_ = bottom_tensor_view;
194 this->window_lengths_ = window_lengths;
195 this->window_origin_ = window_origin;
197 auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
201 generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimY>{})));
202
203 typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
204 window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
205
206 auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
207 this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
208
209 // future load/store() calls (might allocate more registers)
210 using SFC_Ys = typename Base::Traits::SFC_Ys;
211
212 static_for<0, NumAccess, 1>{}([&](auto i_access) {
213 constexpr auto non_linear_id = number<AccessMap_NonLinear{}[i_access]>{};
214 constexpr auto need_save_non_linear_coord =
215 bool_constant<AccessPrefixSum_NonLinear{}[non_linear_id] == i_access>{};
216
217 if constexpr(need_save_non_linear_coord)
218 {
219 cached_coords_(non_linear_id) = bottom_tensor_thread_coord_tmp;
220 cached_window_adaptor_coords_(non_linear_id) = window_adaptor_thread_coord_tmp;
221 }
222
223 // TODO: need pad_tensor_view to check which dim need use flag to check
224 // cached flag is independent from non-linear-coord
225 // but need be updated in move_tile, with proper dims
227 this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_coord_tmp);
228
229 if constexpr(i_access != (NumAccess - 1))
230 {
231 constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(i_access); // tuple of number
232 constexpr auto idx_diff_ps_ys = container_concat(
233 generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
234 idx_diff_ys);
235
237 window_adaptor_thread_coord_tmp,
238 bottom_tensor_thread_coord_tmp,
239 idx_diff_ps_ys);
240 }
241 });
242 }
243
244 template <index_t i_access>
246 {
247 using SFC_Ys = typename Base::Traits::SFC_Ys;
248 constexpr auto idx_ys = SFC_Ys::get_index(number<i_access>{});
249 using ys_to_rhs_major =
250 typename decltype(typename Base::TileDstr{}
251 .get_static_tile_distribution_encoding())::Ys2RHsMajor;
252
253 constexpr auto modified_idx_ys = generate_tuple(
254 [&](auto i_dim_y) {
255 constexpr auto rhs_major = ys_to_rhs_major{}[i_dim_y];
256 constexpr auto target_h_dim = number<rhs_major - 1>{}; // no r dim here!
257 if constexpr(LinearBottomDims{}[target_h_dim] == 0)
258 {
259 return number<0>{};
260 }
261 else
262 {
264 }
265 },
267
268 constexpr auto adaptor_ = typename Base::TileDstr{}.get_ps_ys_to_xs_adaptor();
269 constexpr auto idx_ =
270 container_concat(make_tuple(number<0>{}, number<0>{}), modified_idx_ys);
271
272 return adaptor_.calculate_bottom_index(idx_);
273 }
274
275 template <index_t i_access>
277 {
278 constexpr auto linear_coord = get_bottom_linear_coordinate(number<i_access>{});
279 constexpr auto is_pure_linear_tensor =
281 if constexpr(is_pure_linear_tensor)
282 {
283 // this case usually is a LDS window, everything is known at compile tile.
284 // we directly use BottomTensorView transform to compute the offset, in case padding
285 auto bottom_tensor_coord = make_tensor_coordinate(
286 typename Base::BottomTensorView{}.get_tensor_descriptor(), linear_coord);
287 return bottom_tensor_coord.get_offset();
288 }
289 else
290 {
291 // this case usually is a global window, where last dim can be linear
292 // we hack here, that use the original TileDstr to compute the linear offset
293 // ... hoping that there is no extra padding between other dims, which make sense
294 // since that would introduce runtime length (so can't use linear offset)
295 constexpr index_t linear_offset = [&]() {
296 constexpr auto x_idx_ = linear_coord;
297 constexpr auto x_len_ = typename Base::TileDstr{}.get_lengths();
298 static_assert(x_idx_.size() == x_len_.size());
299 constexpr index_t x_dims_ = x_idx_.size();
300 index_t cu_stride_ = 1;
301 index_t cu_offset_ = 0;
302 static_for<0, x_dims_, 1>{}([&](auto i_) {
303 auto r_i_ = number<x_dims_ - i_ - 1>{};
304 cu_offset_ += x_idx_[r_i_] * cu_stride_;
305 cu_stride_ *= x_len_[r_i_];
306 });
307 return cu_offset_;
308 }();
309 return linear_offset;
310 }
311 }
312
313 template <index_t i_access = -1, bool oob_conditional_check = true>
315 {
316 using vector_t = typename Base::Traits::vector_t;
317 using SFC_Ys = typename Base::Traits::SFC_Ys;
318
319 constexpr auto tile_dstr = typename Base::TileDstr{};
320
322
323 auto issue = [&](auto i_access_) {
324 constexpr auto IAccess = number<i_access_>{};
325
326 constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
327 auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
328 auto bottom_tensor_flag = cached_flags_[IAccess];
329
330 constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
331
332 // read from bottom tensor
333 const vector_t vec_value =
334 this->get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
335 bottom_tensor_thread_coord,
336 linear_offset,
337 bottom_tensor_flag,
339
340 // data index [y0, y1, ...]
341 constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess);
342 // write into distributed tensor
343 static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](auto j) {
344 constexpr auto idx_ys = generate_tuple(
345 [&](auto jj) {
346 return jj == Base::Traits::VectorDimY ? (idx_diff_ys[jj] + j)
347 : idx_diff_ys[jj];
348 },
350
351 constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
352 Base::Traits::PackedSize;
353
354 dst_tensor.get_thread_buffer().template at<d>() =
355 vec_value
356 .template get_as<typename Base::DataType>()[j / Base::Traits::PackedSize];
357 });
358 };
359
361
362 return dst_tensor;
363 }
364
365 template <typename DstTile, index_t i_access = -1, bool oob_conditional_check = true>
366 CK_TILE_DEVICE auto load(DstTile& dst_tensor,
367 number<i_access> = {},
369 {
370 using vector_t = typename Base::Traits::vector_t;
371 using SFC_Ys = typename Base::Traits::SFC_Ys;
372
373 constexpr auto tile_dstr = typename Base::TileDstr{};
374
375 // auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
376
377 auto issue = [&](auto i_access_) {
378 constexpr auto IAccess = number<i_access_>{};
379
380 constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
381 auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
382 auto bottom_tensor_flag = cached_flags_[IAccess];
383
384 constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
385
386 // read from bottom tensor
387 const vector_t vec_value =
388 this->get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
389 bottom_tensor_thread_coord,
390 linear_offset,
391 bottom_tensor_flag,
393 // data index [y0, y1, ...]
394 constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess);
395 // write into distributed tensor
396 static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](auto j) {
397 constexpr auto idx_ys = generate_tuple(
398 [&](auto jj) {
399 return jj == Base::Traits::VectorDimY ? (idx_diff_ys[jj] + j)
400 : idx_diff_ys[jj];
401 },
403
404 constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
405 Base::Traits::PackedSize;
406
407 dst_tensor.get_thread_buffer().template at<d>() =
408 vec_value
409 .template get_as<typename Base::DataType>()[j / Base::Traits::PackedSize];
410 });
411 };
412
414
415 return dst_tensor;
416 }
417
418 template <typename DstTile,
419 index_t i_access = -1,
420 bool oob_conditional_check = true,
421 bool pre_nop = false>
422 CK_TILE_DEVICE void load_raw(DstTile& dst_tensor,
423 number<i_access> = {}, // negative means loop over all num_access
425 bool_constant<pre_nop> = {}) const
426 {
427 using vector_t = typename Base::Traits::vector_t;
428 using SFC_Ys = typename Base::Traits::SFC_Ys;
429 static constexpr index_t YElementSize =
430 typename Base::TileDstr{}.get_ys_to_d_descriptor().get_element_space_size();
431 static_assert(YElementSize % (Base::Traits::PackedSize * Base::Traits::ScalarPerVector) ==
432 0);
433 using vectorized_tbuf =
434 array<vector_t,
435 YElementSize / (Base::Traits::PackedSize * Base::Traits::ScalarPerVector)>;
436
437 constexpr auto tile_dstr = typename Base::TileDstr{};
438
439 auto& dst_vec_tbuf = reinterpret_cast<vectorized_tbuf&>(dst_tensor.get_thread_buffer());
440
441 auto issue = [&](auto i_access_) {
442 constexpr auto IAccess = number<i_access_>{};
443 constexpr auto pre_nop_ = [&]() {
444 if constexpr(pre_nop && i_access_ == 0 &&
445 Base::BottomTensorView::buffer_view::get_address_space() ==
447 return bool_constant<true>{};
448 else
449 return bool_constant<false>{};
450 }();
451
452 constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
453 auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
454 constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
455 auto bottom_tensor_flag = cached_flags_[IAccess];
456
457 // data index [y0, y1, ...]
458 constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess);
459 constexpr index_t d =
460 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start) /
461 Base::Traits::PackedSize;
462 static_assert(d % Base::Traits::ScalarPerVector == 0);
463
464 this->get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
465 dst_vec_tbuf.template at<d / Base::Traits::ScalarPerVector>(),
466 bottom_tensor_thread_coord,
467 linear_offset /**/,
468 bottom_tensor_flag,
470 pre_nop_);
471#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \
472 CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
473 asm volatile(""); // this is starting from rocm-6.2, but same sympton, reuse this flag
474#endif
475 };
476
478 }
479
480 // TODO: currently async load only implemented in inline asm
481 template <typename LdsTileWindow_,
482 index_t i_access = -1,
483 bool oob_conditional_check = true,
484 bool pre_nop = false>
485 CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
486 number<i_access> = {},
488 bool_constant<pre_nop> = {}) const
489 {
490 using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
491 using LdsDataType = typename LdsTileWindow::DataType;
492
493 // currently we only support everything is non linear dim
494 // actually it's not performant if we have linear dim(e.g. fast changing)
495 static_assert(NumAccess_NonLinear == NumAccess);
496 static_assert(Base::BottomTensorView::buffer_view::get_address_space() ==
498
499 // issues * warps * lanes
500 static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
501
502 const index_t size_per_buf =
503 lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
505 sizeof(LdsDataType);
506
507 const index_t size_per_wave =
508 lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
510 sizeof(LdsDataType) -
511 size_per_buf;
512
513 const index_t size_per_issue =
514 lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
516 sizeof(LdsDataType) -
517 size_per_buf;
518
519 const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
520 m0_set_with_memory(m0_init_value); // This should be wave independent
521
522 using vector_t = typename Base::Traits::vector_t;
523
524 LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
525
526 // loop over thread tensor space [y0, y1, ...]
527 auto issue = [&](auto i_access_) {
528 constexpr auto IAccess = number<i_access_>{};
529 constexpr auto pre_nop_ = [&]() {
530 if constexpr(pre_nop && i_access_ == 0)
531 return bool_constant<true>{};
532 else
533 return bool_constant<false>{};
534 }();
535
536 constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
537 auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
538 auto bottom_tensor_flag = cached_flags_[IAccess]; // get this flag anyway
539
540 // read from bottom tensor
541 this->get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
542 smem, bottom_tensor_thread_coord, 0, bottom_tensor_flag, pre_nop_);
543
544 // move thread coordinate
545 if constexpr(i_access_ != (NumAccess - 1))
546 {
547 m0_inc_with_memory(size_per_issue);
548 }
549 };
550
552 }
553
554 template <typename LdsTileWindow_, index_t i_access = -1, bool oob_conditional_check = true>
555 CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile,
556 number<i_access> = {},
558 {
559 using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
560 using LdsDataType = typename LdsTileWindow::DataType;
561 using vector_t = typename traits::vector_t;
562
563 static_assert(NumAccess_NonLinear == NumAccess, "Unsupported configuration");
564 static_assert(Base::BottomTensorView::buffer_view::get_address_space() ==
566 "Requires global memory");
567
568 // Precompute invariant values outside the lambda
569 const auto window_origin = lds_tile.get_window_origin();
570 const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view();
571 const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor();
572 auto smem_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
573
574 auto issue = [&](auto i_access_) {
575 constexpr auto IAccess = number<i_access_>{};
576 constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
577
578 // Use precomputed values
579 auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
580 auto window_adaptor_coord = cached_window_adaptor_coords_[non_linear_id];
581 auto bottom_tensor_flag = cached_flags_[IAccess];
582
583 auto lds_bottom_tensor_thread_idx =
584 window_origin + window_adaptor_coord.get_bottom_index();
585 const auto lds_coord =
586 make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx);
587
588 CK_TILE_LDS_ADDR LdsDataType* smem = smem_base_ptr + lds_coord.get_offset();
589
590 // Read from bottom tensor
591 this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
592 smem,
593 bottom_tensor_thread_coord,
594 0,
595 bottom_tensor_flag,
597 };
598
600 }
601
602 template <typename Policy, index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
604 {
605 constexpr auto tile_dstr = typename Base::TileDstr{};
607 this->template load_transpose_linear<Policy>(
609 return dst_tensor;
610 }
611
612 template <typename Policy,
613 typename DistributedTensor,
614 index_t i_access = -1,
615 bool oob_conditional_check = true>
616 CK_TILE_DEVICE auto load_transpose_linear(DistributedTensor& dst_tensor,
617 number<i_access> = {},
619 {
620 using vector_t = typename traits::vector_t;
621 using SFC_Ys = typename traits::SFC_Ys;
622
623 constexpr auto tile_dstr = typename Base::TileDstr{};
624
625 constexpr auto group_func = Policy::group_func;
626
627 auto issue = [&](auto i_access_) {
628 constexpr auto IAccess = number<i_access_>{};
629 constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
630 auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
631 auto bottom_tensor_flag = cached_flags_[IAccess];
632
633 constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess);
634
635 // read from bottom tensor
636 const vector_t vec_value =
637 this->get_bottom_tensor_view().template get_transpose_vectorized_elements<vector_t>(
638 bottom_tensor_thread_coord, 0);
639 // write into distributed tensor
640 static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
641 constexpr auto idx_ys = generate_tuple(
642 [&](auto jj) {
643 return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
644 },
646
647 constexpr index_t linear_distributed_index =
648 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
649 dst_tensor.get_thread_buffer().template at<linear_distributed_index>() =
650 vec_value.template get_as<typename Base::DataType>()[j];
651 });
652 };
654 }
655
656 template <index_t i_access = -1, bool oob_conditional_check = true>
658 typename Base::TileDstr>& dstr_tensor,
659 number<i_access> = {},
661 {
662
663 using vector_t = typename Base::Traits::vector_t;
664 using SFC_Ys = typename Base::Traits::SFC_Ys;
665
666 constexpr auto tile_dstr = typename Base::TileDstr{};
667
668 // loop over thread tensor space [y0, y1, ...]
669 auto issue = [&](auto i_access_) {
670 constexpr auto IAccess = number<i_access_>{};
671 constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
672 auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
673 constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
674 auto bottom_tensor_flag = cached_flags_[IAccess];
675 // data index [y0, y1, ...]
676 constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess);
677
678 // read from distributed tensor
679 vector_t vec_value;
680
681 static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](auto j) {
682 constexpr auto idx_ys = generate_tuple(
683 [&](auto jj) {
684 return jj == Base::Traits::VectorDimY ? (idx_ys_start[jj] + j)
685 : idx_ys_start[jj];
686 },
688
689 constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
690 Base::Traits::PackedSize;
691
692 vec_value.template get_as<typename Base::DataType>()(j / Base::Traits::PackedSize) =
693 dstr_tensor.get_thread_buffer().template at<d>();
694 });
695
696 // write into bottom tensor
697 this->get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
698 bottom_tensor_thread_coord,
699 linear_offset,
700 bottom_tensor_flag,
701 vec_value,
703 };
704
706 }
707
708 template <index_t i_access = -1>
709 CK_TILE_DEVICE void
711 dstr_tensor,
712 number<i_access> = {}) const
713 {
714 using vector_t = typename Base::Traits::vector_t;
715 using SFC_Ys = typename Base::Traits::SFC_Ys;
716
717 constexpr auto tile_dstr = typename Base::TileDstr{};
718 static constexpr bool oob_conditional_check = true;
719
720 // loop over thread tensor space [y0, y1, ...]
721 auto issue = [&](auto i_access_) {
722 constexpr auto IAccess = number<i_access_>{};
723 constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
724 auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
725 constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
726 auto bottom_tensor_flag = cached_flags_[IAccess];
727
728 // data index [y0, y1, ...]
729 constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess);
730
731 // read from distributed tensor
732 vector_t vec_value;
733 static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](auto j) {
734 constexpr auto idx_ys = generate_tuple(
735 [&](auto jj) {
736 return jj == Base::Traits::VectorDimY ? (idx_ys_start[jj] + j)
737 : idx_ys_start[jj];
738 },
740 constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
741 Base::Traits::PackedSize;
742 vec_value.template get_as<typename Base::DataType>()(j / Base::Traits::PackedSize) =
743 dstr_tensor.get_thread_buffer().template at<d>();
744 });
745
746 // write into bottom tensor
748 .template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
749 bottom_tensor_thread_coord, linear_offset, bottom_tensor_flag, vec_value);
750 };
751
753 }
754
755 template <index_t i_access = -1, bool oob_conditional_check = true>
756 CK_TILE_DEVICE void
758 dstr_tensor,
759 number<i_access> = {},
761 {
762
763 using vector_t = typename Base::Traits::vector_t;
764 using SFC_Ys = typename Base::Traits::SFC_Ys;
765
766 constexpr auto tile_dstr = typename Base::TileDstr{};
767
768 // loop over thread tensor space [y0, y1, ...]
769 auto issue = [&](auto i_access_) {
770 constexpr auto IAccess = number<i_access_>{};
771 constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
772 auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
773 constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
774 auto bottom_tensor_flag = cached_flags_[IAccess];
775
776 // data index [y0, y1, ...]
777 constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess);
778
779 // read from distributed tensor
780 vector_t vec_value;
781
782 static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](auto j) {
783 constexpr auto idx_ys = generate_tuple(
784 [&](auto jj) {
785 return jj == Base::Traits::VectorDimY ? (idx_ys_start[jj] + j)
786 : idx_ys_start[jj];
787 },
789
790 constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
791 Base::Traits::PackedSize;
792
793 vec_value.template get_as<typename Base::DataType>()(j / Base::Traits::PackedSize) =
794 dstr_tensor.get_thread_buffer().template at<d>();
795 });
796
797 // write into bottom tensor
798 this->get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
799 bottom_tensor_thread_coord,
800 linear_offset,
801 bottom_tensor_flag,
802 vec_value,
804 };
805
807 }
808
809 template <index_t i_access = -1, bool oob_conditional_check = true, bool pre_nop = false>
810 CK_TILE_DEVICE void
812 dstr_tensor,
813 number<i_access> = {},
815 bool_constant<pre_nop> = {}) const
816 {
817
818 using vector_t = typename Base::Traits::vector_t;
819 using SFC_Ys = typename Base::Traits::SFC_Ys;
820
821 constexpr auto tile_dstr = typename Base::TileDstr{};
822
823 // loop over thread tensor space [y0, y1, ...]
824 auto issue = [&](auto i_access_) {
825 constexpr auto IAccess = number<i_access_>{};
826 constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
827 auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
828 constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
829 auto bottom_tensor_flag = cached_flags_[IAccess];
830
831 // data index [y0, y1, ...]
832 constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess);
833
834 // read from distributed tensor
835 vector_t vec_value;
836
837 static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](auto j) {
838 constexpr auto idx_ys = generate_tuple(
839 [&](auto jj) {
840 return jj == Base::Traits::VectorDimY ? (idx_ys_start[jj] + j)
841 : idx_ys_start[jj];
842 },
844
845 constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
846 Base::Traits::PackedSize;
847
848 vec_value.template get_as<typename Base::DataType>()(j / Base::Traits::PackedSize) =
849 dstr_tensor.get_thread_buffer().template at<d>();
850 });
851
852 // write into bottom tensor
853 this->get_bottom_tensor_view().template update_vectorized_elements_raw<vector_t>(
854 bottom_tensor_thread_coord,
855 linear_offset,
856 bottom_tensor_flag,
857 vec_value,
860 };
861
863 }
864 // *_extended() functions acts like a virtual function with a default implementation exisiting
865 // in the base class
867 {
868 static_for<0, NumAccess, 1>{}([&](auto i_access) {
869 constexpr auto IAccess = number<i_access>{};
870 constexpr auto non_linear_id = number<AccessMap_NonLinear{}[i_access]>{};
871 constexpr auto need_update_non_linear_coord =
872 bool_constant<AccessPrefixSum_NonLinear{}[non_linear_id] == i_access>{};
873
874 if constexpr(need_update_non_linear_coord)
875 {
876 move_tensor_coordinate(this->bottom_tensor_view_.get_tensor_descriptor(),
877 cached_coords_(non_linear_id),
878 step);
879 }
880
881 // move the current coord with linear_coords
882 auto tmp_coords = cached_coords_[non_linear_id];
883 constexpr auto linear_coord = get_bottom_linear_coordinate(IAccess);
885 this->bottom_tensor_view_.get_tensor_descriptor(), tmp_coords, linear_coord);
886
888 this->bottom_tensor_view_.get_tensor_descriptor(), tmp_coords);
889 });
890 }
891
893 {
894 auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
895 typename Base::TileDstr{}.get_ps_ys_to_xs_adaptor(),
898 generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimY>{})));
899
900 typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
901 this->window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
902
903 auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
904 this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
905
906 // future load/store() calls (might allocate more registers)
907 using SFC_Ys = typename Base::Traits::SFC_Ys;
908
909 static_for<0, NumAccess, 1>{}([&](auto i_access) {
910 constexpr auto non_linear_id = number<AccessMap_NonLinear{}[i_access]>{};
911 constexpr auto need_save_non_linear_coord =
912 bool_constant<AccessPrefixSum_NonLinear{}[non_linear_id] == i_access>{};
913
914 if constexpr(need_save_non_linear_coord)
915 {
916 cached_coords_(non_linear_id) = bottom_tensor_thread_coord_tmp;
917 cached_window_adaptor_coords_(non_linear_id) = window_adaptor_thread_coord_tmp;
918 }
919
920 if constexpr(i_access != (NumAccess - 1))
921 {
922 constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(i_access); // tuple of number
923 constexpr auto idx_diff_ps_ys = container_concat(
924 generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
925 idx_diff_ys);
926
928 window_adaptor_thread_coord_tmp,
929 bottom_tensor_thread_coord_tmp,
930 idx_diff_ps_ys);
931 }
932 });
933 }
934
935 // this contains:
940};
941
942#undef WINDOW_DISPATCH_ISSUE
943
944namespace impl {
945template <address_space_enum, index_t len_>
950
951template <index_t len_>
953{
954 // global default to seq<0,0,....1>
955 using type = typename sequence_merge<typename uniform_sequence_gen<len_ - 1, 0>::type,
957};
958
959template <index_t len_>
961{
962 // lds default to seq<1,1.....1>
964};
965} // namespace impl
966
967template <typename TensorView_>
969 typename impl::default_linear_bottom_dims_impl<TensorView_::buffer_view::get_address_space(),
970 TensorView_::get_num_of_dimension()>::type;
971
972// if using this API, will create a tile_window_linear
973// this structure can have the chance to use immediate value, save register
974// need pass in LinearBottomDims_ properly to control which dim is linear
975// so to generate a constexpr offset as linear_offset for this dim
976// (and finally pass to the immediate offset of buffer/lds instruction)
977//
978// Note: there is no internal check for which dim is OK to use linear offset
979// user must make sure by themselves
980//
981// e.g.
982// 2d global matrix, set LinearBottomDims_=seq<0, 1>, the last dim will generate
983// immediate offset if each thread has multiple issue along last dim
984//
985// 2d LDS buffer, set LinearBottomDims_=seq<1, 1>, then only one vgpr used as offset
986// everything else is just using immediate offset.
987//
988template <typename TensorView_,
989 typename WindowLengths_,
990 typename StaticTileDistribution_,
991 typename LinearBottomDims_ = default_linear_bottom_dims<TensorView_>>
992CK_TILE_DEVICE constexpr auto
994 const WindowLengths_& window_lengths,
995 const multi_index<TensorView_::get_num_of_dimension()>& origin,
996 const StaticTileDistribution_& tile_distribution,
997 LinearBottomDims_ = {})
998{
999 static_assert(LinearBottomDims_::size() == TensorView_::get_num_of_dimension());
1000 return tile_window_linear<remove_cvref_t<TensorView_>,
1004 tensor_view, window_lengths, origin, tile_distribution};
1005}
1006
1007template <
1008 typename TileWindow_,
1009 typename StaticTileDistribution_,
1011CK_TILE_DEVICE constexpr auto
1012make_tile_window_linear(const TileWindow_& tile_window,
1013 const StaticTileDistribution_& tile_distribution,
1014 LinearBottomDims_ = {})
1015{
1016 return make_tile_window_linear(tile_window.get_bottom_tensor_view(),
1017 tile_window.get_window_lengths(),
1018 tile_window.get_window_origin(),
1019 tile_distribution,
1020 LinearBottomDims_{});
1021}
1022
1023// this version must not be called under a constexpr context
1024template <typename TensorView_,
1025 typename WindowLengths_,
1026 typename StaticTileDistribution_,
1027 typename LinearBottomDims_ = default_linear_bottom_dims<TensorView_>>
1028CK_TILE_DEVICE auto
1030 const WindowLengths_& window_lengths,
1031 const multi_index<TensorView_::get_num_of_dimension()>& origin,
1032 const StaticTileDistribution_& tile_distribution,
1033 LinearBottomDims_ = {})
1034{
1035 static_assert(LinearBottomDims_::size() == TensorView_::get_num_of_dimension());
1036 auto w = tile_window_linear<remove_cvref_t<TensorView_>,
1040 tensor_view, window_lengths, origin, tile_distribution};
1041 w.init_raw();
1042 return w;
1043}
1044
1045template <
1046 typename TileWindow_,
1047 typename StaticTileDistribution_,
1049CK_TILE_DEVICE constexpr auto
1050make_tile_window_linear_raw(const TileWindow_& tile_window,
1051 const StaticTileDistribution_& tile_distribution,
1052 LinearBottomDims_ = {})
1053{
1054 return make_tile_window_linear_raw(tile_window.get_bottom_tensor_view(),
1055 tile_window.get_window_lengths(),
1056 tile_window.get_window_origin(),
1057 tile_distribution,
1058 LinearBottomDims_{});
1059}
1060
1061template <typename TensorView_,
1062 typename WindowLengths_,
1063 typename StaticTileDistribution_,
1064 typename LinearBottomDims_>
1067 window,
1068 const typename tile_window_linear<TensorView_,
1069 WindowLengths_,
1070 StaticTileDistribution_,
1071 LinearBottomDims_>::BottomTensorIndex& step)
1072{
1073 window.move(step);
1074}
1075
1084template <typename T>
1085struct is_tile_window_linear : std::false_type
1086{
1087};
1088
1100template <typename BottomTensorView_,
1101 typename WindowLengths_,
1102 typename StaticTileDistribution_,
1103 typename LinearBottomDims_>
1105 WindowLengths_,
1106 StaticTileDistribution_,
1107 LinearBottomDims_>> : std::true_type
1108{
1109};
1110
1118template <typename T>
1120
1121} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_LDS_ADDR
Definition config.hpp:58
Definition tile/core/arch/amd_buffer_addressing.hpp:110
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE index_t get_lane_id()
Definition arch.hpp:101
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
CK_TILE_DEVICE index_t get_warp_id(bool_constant< ReturnSgpr >={})
Definition arch.hpp:104
typename impl::default_linear_bottom_dims_impl< TensorView_::buffer_view::get_address_space(), TensorView_::get_num_of_dimension()>::type default_linear_bottom_dims
Definition tile_window_linear.hpp:968
CK_TILE_HOST_DEVICE constexpr auto container_concat(const X &x, const Ys &... ys)
Definition tile/core/container/container_helper.hpp:363
CK_TILE_HOST_DEVICE constexpr bool coordinate_has_valid_offset_assuming_top_index_is_valid(const TensorDesc &tensor_desc, const TensorCoord &coord)
Definition tensor_coordinate.hpp:79
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F &&f, number< N >)
Definition tile/core/container/tuple.hpp:429
array< index_t, N > multi_index
Definition tile/core/container/multi_index.hpp:17
CK_TILE_DEVICE auto make_tile_window_linear_raw(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, LinearBottomDims_={})
Definition tile_window_linear.hpp:1029
CK_TILE_HOST_DEVICE constexpr void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const Index &coord_step)
Definition tensor_coordinate.hpp:72
CK_TILE_HOST_DEVICE constexpr auto make_tensor_adaptor_coordinate(const Adaptor &adaptor, const TopIndex &idx_top)
Definition tensor_adaptor_coordinate.hpp:55
CK_TILE_DEVICE constexpr auto make_tile_window_linear(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, LinearBottomDims_={})
Definition tile_window_linear.hpp:993
constexpr bool is_tile_window_linear_v
Helper variable template to check if a type is a linear tile window.
Definition tile_window_linear.hpp:1119
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
CK_TILE_DEVICE void m0_set_with_memory(index_t v)
Definition utility.hpp:19
CK_TILE_HOST_DEVICE constexpr auto histogram_sorted_sequence(SeqSortedSamples, sequence< r, rs... >)
Definition tile/core/container/sequence.hpp:1102
address_space_enum
Definition arch.hpp:46
@ global
Definition arch.hpp:48
@ lds
Definition arch.hpp:49
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr index_t reduce_on_sequence(Seq, Reduce f, number< Init >)
Definition tile/core/container/sequence.hpp:982
CK_TILE_DEVICE void m0_inc_with_memory(index_t v)
Definition utility.hpp:25
CK_TILE_HOST_DEVICE constexpr auto make_tensor_coordinate(const TensorDesc &tensor_desc, const TopIndex &idx_top)
Definition tensor_coordinate.hpp:60
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
constexpr auto prefix_sum_sequence(Seq)
Definition tile/core/container/sequence.hpp:908
Definition tile/core/container/sequence.hpp:287
A fixed-size array container similar to std::array with additional utilities.
Definition tile/core/container/array.hpp:43
typename sequence_merge< typename uniform_sequence_gen< len_ - 1, 0 >::type, sequence< 1 > >::type type
Definition tile_window_linear.hpp:955
typename uniform_sequence_gen< len_, 1 >::type type
Definition tile_window_linear.hpp:963
Definition tile_window_linear.hpp:947
typename uniform_sequence_gen< len_, 0 >::type type
Definition tile_window_linear.hpp:948
Type trait to determine if a type is a linear tile window.
Definition tile_window_linear.hpp:1086
Definition tile/core/numeric/math.hpp:98
Definition tile/core/container/sequence.hpp:236
Definition tile/core/container/sequence.hpp:49
Definition static_distributed_tensor.hpp:21
CK_TILE_HOST_DEVICE constexpr const auto & get_thread_buffer() const
Definition static_distributed_tensor.hpp:58
Definition tile/core/utility/functional.hpp:43
Definition tensor_view.hpp:41
Definition tile_distribution.hpp:72
CK_TILE_HOST_DEVICE constexpr const auto & get_ps_ys_to_xs_adaptor() const
Definition tile_distribution.hpp:126
BottomTensorView bottom_tensor_view_
Definition tile_window_base.hpp:85
remove_cvref_t< typename BottomTensorView::DataType > DataType
Definition tile_window_base.hpp:36
CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const
Definition tile_window_base.hpp:47
BottomTensorIndex window_origin_
Definition tile_window_base.hpp:79
static CK_TILE_DEVICE constexpr index_t get_num_of_dimension()
Definition tile_window_base.hpp:48
CK_TILE_DEVICE void move(const BottomTensorIndex &step)
Definition tile_window_base.hpp:67
remove_reference_t< BottomTensorView_ > BottomTensorView
Definition tile_window_base.hpp:33
remove_cvref_t< WindowLengths_ > WindowLengths
Definition tile_window_base.hpp:34
array< index_t, NDimBottomTensor > BottomTensorIndex
Definition tile_window_base.hpp:43
WindowLengths window_lengths_
Definition tile_window_base.hpp:81
Definition tile_window_linear.hpp:72
decltype(get_non_linear_access_histogram_prefix_sum()) AccessPrefixSum_NonLinear
Definition tile_window_linear.hpp:175
decltype(get_non_linear_access_map()) AccessMap_NonLinear
Definition tile_window_linear.hpp:173
static constexpr index_t NumAccess_NonLinear
Definition tile_window_linear.hpp:172
decltype(get_non_linear_access_histogram()) AccessHistogram_NonLinear
Definition tile_window_linear.hpp:174
Definition tile_window_linear.hpp:55
static constexpr auto I0
Definition tile_window_linear.hpp:68
CK_TILE_DEVICE void set_window_origin_extended(const typename Base::BottomTensorIndex &)
Definition tile_window_linear.hpp:892
CK_TILE_DEVICE auto load(number< i_access >={}, bool_constant< oob_conditional_check >={}) const
Definition tile_window_linear.hpp:314
array< typename Base::WindowAdaptorCoord, traits::NumAccess_NonLinear > cached_window_adaptor_coords_
Definition tile_window_linear.hpp:938
CK_TILE_DEVICE constexpr tile_window_linear()=default
CK_TILE_DEVICE auto async_load(LdsTileWindow_ &&lds_tile, number< i_access >={}, bool_constant< oob_conditional_check >={}) const
Definition tile_window_linear.hpp:555
CK_TILE_DEVICE void load_raw(DstTile &dst_tensor, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition tile_window_linear.hpp:422
static CK_TILE_DEVICE constexpr auto get_bottom_linear_coordinate(number< i_access >)
Definition tile_window_linear.hpp:245
CK_TILE_DEVICE auto load_transpose() const
Definition tile_window_linear.hpp:603
typename traits::AccessHistogram_NonLinear AccessHistogram_NonLinear
Definition tile_window_linear.hpp:181
typename traits::AccessMap_NonLinear AccessMap_NonLinear
Definition tile_window_linear.hpp:180
static constexpr index_t NumAccess
Definition tile_window_linear.hpp:178
CK_TILE_DEVICE void store_raw(const static_distributed_tensor< typename Base::DataType, typename Base::TileDstr > &dstr_tensor, number< i_access >={}) const
Definition tile_window_linear.hpp:710
array< bool, Base::Traits::NumAccess > cached_flags_
Definition tile_window_linear.hpp:939
CK_TILE_DEVICE void update(const static_distributed_tensor< typename Base::DataType, typename Base::TileDstr > &dstr_tensor, number< i_access >={}, bool_constant< oob_conditional_check >={}) const
Definition tile_window_linear.hpp:757
CK_TILE_DEVICE void store(const static_distributed_tensor< typename Base::DataType, typename Base::TileDstr > &dstr_tensor, number< i_access >={}, bool_constant< oob_conditional_check >={}) const
Definition tile_window_linear.hpp:657
CK_TILE_DEVICE void update_raw(const static_distributed_tensor< typename Base::DataType, typename Base::TileDstr > &dstr_tensor, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition tile_window_linear.hpp:811
tile_window_with_tile_dstr_base< tile_window_linear< BottomTensorView_, WindowLengths_, StaticTileDistribution_, LinearBottomDims_ >, BottomTensorView_, WindowLengths_, StaticTileDistribution_ > Base
Definition tile_window_linear.hpp:56
typename traits::AccessPrefixSum_NonLinear AccessPrefixSum_NonLinear
Definition tile_window_linear.hpp:182
CK_TILE_DEVICE auto load_transpose_linear(DistributedTensor &dst_tensor, number< i_access >={}, bool_constant< oob_conditional_check >={}) const
Definition tile_window_linear.hpp:616
static constexpr index_t NumAccess_NonLinear
Definition tile_window_linear.hpp:179
CK_TILE_DEVICE auto load(DstTile &dst_tensor, number< i_access >={}, bool_constant< oob_conditional_check >={}) const
Definition tile_window_linear.hpp:366
CK_TILE_DEVICE void move_extended(const typename Base::BottomTensorIndex &step)
Definition tile_window_linear.hpp:866
array< typename Base::BottomTensorCoord, traits::NumAccess_NonLinear > cached_coords_
Definition tile_window_linear.hpp:936
static CK_TILE_DEVICE constexpr index_t get_bottom_linear_offset(number< i_access >)
Definition tile_window_linear.hpp:276
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_ &&lds_tile, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition tile_window_linear.hpp:485
CK_TILE_DEVICE constexpr tile_window_linear(const typename Base::BottomTensorView &bottom_tensor_view, const typename Base::WindowLengths &window_lengths, const typename Base::BottomTensorIndex &window_origin, const typename Base::TileDstr &tile_distribution)
Definition tile_window_linear.hpp:186
remove_cvref_t< LinearBottomDims_ > LinearBottomDims
Definition tile_window_linear.hpp:64
static constexpr auto I1
Definition tile_window_linear.hpp:69
Definition tile_window_base.hpp:94
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(WindowAdaptorCoord &window_adaptor_thread_coord, BottomTensorCoord &bottom_tensor_thread_coord, const ATopIndex &idx_diff_adaptor_top) const
Definition tile_window_base.hpp:129
Definition tile/core/container/sequence.hpp:314
typename sequence_gen< NSize, F >::type type
Definition tile/core/container/sequence.hpp:320
#define WINDOW_DISPATCH_ISSUE()
Definition tile_window_linear.hpp:22
#define TO_SEQUENCE(a, n)
Definition to_sequence.hpp:10