23template <
typename ThreadGroup,
24 typename ElementwiseOperation,
26 typename BlockSliceLengths,
27 typename ThreadClusterLengths,
28 typename ThreadClusterArrangeOrder,
33 typename SrcDimAccessOrder,
34 typename DstDimAccessOrder,
37 typename SrcsScalarPerVector,
38 typename DstsScalarPerVector,
39 typename SrcsScalarStrideInVector,
40 typename DstsScalarStrideInVector,
41 typename ThreadTransferSrcsResetCoordinateAfterRun,
42 typename ThreadTransferDstsResetCoordinateAfterRun,
56 const SrcDescs& src_descs,
58 const DstDescs& dst_descs,
60 const ElementwiseOperation& element_op)
61 : threadwise_transfer_(src_descs,
68 static_assert(
nDim == ThreadClusterLengths::Size() &&
69 nDim == ThreadClusterArrangeOrder::Size() &&
70 nDim == SrcDimAccessOrder::Size() &&
nDim == SrcDimAccessOrder::Size(),
71 "wrong! nDim not consistent");
73 static_for<0, nSrc, 1>{}([&](
auto src_i) {
76 "wrong! nDim not consistent");
79 static_for<0, nDst, 1>{}([&](
auto dst_i) {
82 "wrong! nDim not consistent");
87 "wrong! threads should be mapped to cover entire slicing window");
89 static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
90 "wrong! ThreadGroup::GetNumOfThread() too small");
92 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
93 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
95 const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
101 [&](
auto i) {
return src_block_slice_origins[i] + thread_data_idx_begin; },
105 [&](
auto i) {
return dst_block_slice_origins[i] + thread_data_idx_begin; },
108 threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins);
109 threadwise_transfer_.SetDstSliceOrigins(dst_descs, dst_thread_slice_origins);
113 template <
typename SrcBuffers, index_t ThreadScratchId = 0>
114 __device__
void RunRead(
const SrcDescs& src_descs,
115 const SrcBuffers& src_bufs,
118 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
119 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
121 threadwise_transfer_.
RunRead(src_descs, src_bufs, thread_scratch_id);
125 template <
typename DstBuffers, index_t ThreadScratchId = 0>
126 __device__
void RunWrite(
const DstDescs& dst_descs,
127 DstBuffers& dst_bufs,
130 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
131 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
133 threadwise_transfer_.
RunWrite(dst_descs, dst_bufs, thread_scratch_id);
137 template <
typename SrcBuffer,
typename DstBuffer, index_t ThreadScratchId>
138 __device__
void Run(
const SrcDescs& src_descs,
139 const SrcBuffer& src_bufs,
140 const DstDescs& dst_descs,
144 RunRead(src_descs, src_bufs, thread_scratch_id);
145 RunWrite(dst_descs, dst_bufs, thread_scratch_id);
150 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
151 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
153 threadwise_transfer_.MoveSrcSliceWindow(src_descs, step);
159 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
160 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
162 threadwise_transfer_.MoveDstSliceWindow(dst_descs, step);
167 static constexpr auto thread_cluster_desc_ =
170 using ThreadwiseTransfer =
172 ElementwiseOperation,
184 SrcsScalarStrideInVector,
185 DstsScalarStrideInVector,
186 ThreadTransferSrcsResetCoordinateAfterRun,
187 ThreadTransferDstsResetCoordinateAfterRun,
190 ThreadwiseTransfer threadwise_transfer_;
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
MultiIndex< nDim > Index
Definition thread_group_tensor_slice_transfer_v4r2.hpp:53
__device__ void RunWrite(const DstDescs &dst_descs, DstBuffers &dst_bufs, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition thread_group_tensor_slice_transfer_v4r2.hpp:126
static constexpr index_t nSrc
Definition thread_group_tensor_slice_transfer_v4r2.hpp:48
__device__ void MoveSrcSliceWindow(const SrcDescs &src_descs, const Index &step)
Definition thread_group_tensor_slice_transfer_v4r2.hpp:148
__device__ void MoveDstSliceWindow(const DstDescs &dst_descs, const Index &step)
Definition thread_group_tensor_slice_transfer_v4r2.hpp:157
__device__ void Run(const SrcDescs &src_descs, const SrcBuffer &src_bufs, const DstDescs &dst_descs, DstBuffer &dst_bufs, Number< ThreadScratchId > thread_scratch_id)
Definition thread_group_tensor_slice_transfer_v4r2.hpp:138
__device__ void RunRead(const SrcDescs &src_descs, const SrcBuffers &src_bufs, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition thread_group_tensor_slice_transfer_v4r2.hpp:114
__device__ constexpr ThreadGroupTensorSliceTransfer_v4r2(const SrcDescs &src_descs, const StaticallyIndexedArray< Index, nSrc > &src_block_slice_origins, const DstDescs &dst_descs, const StaticallyIndexedArray< Index, nDst > &dst_block_slice_origins, const ElementwiseOperation &element_op)
Definition thread_group_tensor_slice_transfer_v4r2.hpp:55
static constexpr index_t nDim
Definition thread_group_tensor_slice_transfer_v4r2.hpp:46
static constexpr index_t nDst
Definition thread_group_tensor_slice_transfer_v4r2.hpp:49
static constexpr auto thread_slice_lengths
Definition thread_group_tensor_slice_transfer_v4r2.hpp:51
__device__ void RunRead(const SrcDescs &src_descs, const SrcBuffers &src_bufs, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v3r2.hpp:99
__device__ void RunWrite(const DstDescs &dst_descs, DstBuffers &dst_bufs, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v3r2.hpp:305