32struct ComputePtrOffsetOfStridedBatch
35 : BatchStrideA_(BatchStrideA), BatchStrideB_(BatchStrideB), BatchStrideC_(BatchStrideC)
84template <
typename GridwiseGemm,
87 typename AGridDesc_K0_M0_M1_K1,
88 typename BGridDesc_K0_N0_N1_K1,
89 typename CGridDesc_M0_M10_M11_N0_N10_N11,
90 typename Block2CTileMap,
91 typename ComputePtrOffsetOfBatch,
92 bool HasMainKBlockLoop,
93 bool HasDoubleTailKBlockLoop>
95#if CK_USE_LAUNCH_BOUNDS
98 kernel_grouped_conv_fwd_dl(
99 const ABDataType* __restrict__ p_a_grid,
100 const ABDataType* __restrict__ p_b_grid,
101 CDataType* __restrict__ p_c_grid,
103 const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
104 const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
105 const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11,
106 const Block2CTileMap block_2_ctile_map,
107 const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
109#if(defined(__gfx906__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx12__))
111 const index_t num_blocks_per_batch =
112 __builtin_amdgcn_readfirstlane(
get_grid_size() / batch_count);
115 const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
116 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
117 const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
118 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
119 const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
120 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
122 constexpr index_t shared_block_size =
123 GridwiseGemm::GetSharedMemoryNumberOfByte() /
sizeof(ABDataType);
125 __shared__ ABDataType p_shared[shared_block_size];
127 GridwiseGemm::Run(p_a_grid + a_batch_offset,
128 p_b_grid + b_batch_offset,
129 p_c_grid + c_batch_offset,
131 a_grid_desc_k0_m0_m1_k1,
132 b_grid_desc_k0_n0_n1_k1,
133 c_grid_desc_m0_m10_m11_n0_n10_n11,
135 integral_constant<bool, HasMainKBlockLoop>{},
136 integral_constant<bool, HasDoubleTailKBlockLoop>{});
142 ignore = a_grid_desc_k0_m0_m1_k1;
143 ignore = b_grid_desc_k0_n0_n1_k1;
144 ignore = c_grid_desc_m0_m10_m11_n0_n10_n11;
145 ignore = compute_ptr_offset_of_batch;
146 ignore = block_2_ctile_map;
148 compute_ptr_offset_of_batch.GetAPtrOffset(0);
149 compute_ptr_offset_of_batch.GetBPtrOffset(0);
150 compute_ptr_offset_of_batch.GetCPtrOffset(0);
177 typename AccDataType,
181 typename AElementwiseOperation,
182 typename BElementwiseOperation,
183 typename CElementwiseOperation,
194 typename M1N1ThreadClusterM1Xs,
195 typename M1N1ThreadClusterN1Xs,
196 typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
197 typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
198 typename ABlockTransferThreadClusterArrangeOrder,
199 typename ABlockTransferSrcAccessOrder,
200 typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
201 typename ABlockTransferSrcVectorTensorContiguousDimOrder,
202 typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
203 typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
204 typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
205 typename BBlockTransferThreadClusterArrangeOrder,
206 typename BBlockTransferSrcAccessOrder,
207 typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
208 typename BBlockTransferSrcVectorTensorContiguousDimOrder,
209 typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
210 typename CThreadTransferSrcDstAccessOrder,
211 index_t CThreadTransferSrcDstVectorDim,
212 index_t CThreadTransferDstScalarPerVector,
225 AElementwiseOperation,
226 BElementwiseOperation,
227 CElementwiseOperation>
241 template <
typename ALay>
245 const auto in_gemmmraw_gemmkraw_desc =
246 conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
248 const auto in_gemmm_gemmk_desc =
251 const auto M = in_gemmm_gemmk_desc.GetLength(
I0);
252 const auto K = in_gemmm_gemmk_desc.GetLength(
I1);
254 const auto AK0 = K / K1;
263 template <
typename BLay>
267 const auto wei_gemmnraw_gemmkraw_desc =
268 conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
270 const auto wei_gemmn_gemmk_desc =
271 matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
273 const auto N = wei_gemmn_gemmk_desc.GetLength(
I0);
274 const auto K = wei_gemmn_gemmk_desc.GetLength(
I1);
276 const auto BK0 = K / K1;
279 wei_gemmn_gemmk_desc,
285 template <
typename CLay>
288 const auto out_gemmmraw_gemmnraw_desc =
289 conv_to_gemm_transformer.template MakeCDescriptor_M_N<CLay>();
291 const auto out_gemmm_gemmn_desc =
292 matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
294 return out_gemmm_gemmn_desc;
323 M1N1ThreadClusterM1Xs,
324 M1N1ThreadClusterN1Xs,
325 ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
326 ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
327 ABlockTransferThreadClusterArrangeOrder,
328 ABlockTransferSrcAccessOrder,
329 ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
330 ABlockTransferSrcVectorTensorContiguousDimOrder,
331 ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
332 BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
333 BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
334 BBlockTransferThreadClusterArrangeOrder,
335 BBlockTransferSrcAccessOrder,
336 BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
337 BBlockTransferSrcVectorTensorContiguousDimOrder,
338 BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
339 CThreadTransferSrcDstAccessOrder,
340 CThreadTransferSrcDstVectorDim,
341 CThreadTransferDstScalarPerVector>;
358 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
359 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
360 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
361 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
362 const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
363 const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides,
364 const std::array<index_t, NDimSpatial>& conv_filter_strides,
365 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
366 const std::array<index_t, NDimSpatial>& input_left_pads,
367 const std::array<index_t, NDimSpatial>& input_right_pads,
368 const AElementwiseOperation& a_element_op,
369 const BElementwiseOperation& b_element_op,
370 const CElementwiseOperation& c_element_op)
371 :
p_a_grid_{static_cast<const ADataType*>(p_a)},
372 p_b_grid_{static_cast<const BDataType*>(p_b)},
382 conv_filter_dilations,
396 a_g_n_c_wis_strides[0], b_g_k_c_xs_strides[0], c_g_n_k_wos_strides[0]},
437 std::cout <<
"num_group: " <<
num_group_ << std::endl;
504 throw std::runtime_error(
505 "wrong! DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK has invalid setting");
513 auto launch_kernel = [&](
auto has_main_k_block_loop,
514 auto has_double_tail_k_block_loop) {
515 constexpr bool has_main_loop = has_main_k_block_loop.value;
516 constexpr bool has_double_loop = has_double_tail_k_block_loop;
519 kernel_grouped_conv_fwd_dl<GridwiseGemm,
526 ComputePtrOffsetOfStridedBatch,
548 const bool has_double_tail_k_block_loop =
551 if(has_main_k_block_loop && has_double_tail_k_block_loop)
554 integral_constant<bool, true>{});
556 else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
559 integral_constant<bool, false>{});
561 else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
564 integral_constant<bool, true>{});
569 integral_constant<bool, false>{});
576 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
592 if constexpr(ConvForwardSpecialization ==
596 for(
index_t i = 0; i < NDimSpatial; ++i)
605 std::cout <<
"Filter1x1Stride1Pad0 check: i = " << i <<
" X = " << X
606 <<
" ConvStride = " << ConvStride <<
" LeftPad = " <<
LeftPad
607 <<
" RightPad = " <<
RightPad << std::endl;
612 else if constexpr(ConvForwardSpecialization ==
616 for(
index_t i = 0; i < NDimSpatial; ++i)
624 std::cout <<
"Filter1x1Stride1Pad0 check: i = " << i <<
" X = " << X
640 auto srcVectorLengths = ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1{};
641 if(srcVectorLengths[
I1] != 1 || srcVectorLengths[
I2] != 1)
645 if(K1 % srcVectorLengths[
I3] != 0 || K0PerBlock % srcVectorLengths[
I0] != 0)
652 if(C % (srcVectorLengths[
I0] * srcVectorLengths[
I3]) != 0)
671 auto srcVectorLengths = BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1{};
672 if(srcVectorLengths[
I1] != 1 || srcVectorLengths[
I2] != 1)
676 if(K1 % srcVectorLengths[
I3] != 0 || K0PerBlock % srcVectorLengths[
I0] != 0)
683 if(C % (srcVectorLengths[
I0] * srcVectorLengths[
I3]) != 0)
702 if(!(K % CThreadTransferDstScalarPerVector == 0 && CThreadTransferSrcDstVectorDim == 5))
724 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
725 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
726 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
727 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
728 const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
729 const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides,
730 const std::array<index_t, NDimSpatial>& conv_filter_strides,
731 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
732 const std::array<index_t, NDimSpatial>& input_left_pads,
733 const std::array<index_t, NDimSpatial>& input_right_pads,
734 const AElementwiseOperation& a_element_op,
735 const BElementwiseOperation& b_element_op,
736 const CElementwiseOperation& c_element_op)
748 conv_filter_dilations,
758 std::unique_ptr<BaseArgument>
762 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
763 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
764 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
765 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
766 const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
767 const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides,
768 const std::array<index_t, NDimSpatial>& conv_filter_strides,
769 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
770 const std::array<index_t, NDimSpatial>& input_left_pads,
771 const std::array<index_t, NDimSpatial>& input_right_pads,
772 const AElementwiseOperation& a_element_op,
773 const BElementwiseOperation& b_element_op,
774 const CElementwiseOperation& c_element_op)
override
776 return std::make_unique<Argument>(p_a,
786 conv_filter_dilations,
796 return std::make_unique<Invoker>(
Invoker{});
801 auto str = std::stringstream();
804 str <<
"DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK"
809 << K0PerBlock <<
", "
814 << MXdlPerWave <<
", "
815 << NXdlPerWave <<
", "
816 << ABlockTransferSrcScalarPerVector <<
", "
817 << ABlockTransferDstScalarPerVector_K1 <<
", "
818 << BBlockTransferSrcScalarPerVector <<
", "
819 << BBlockTransferDstScalarPerVector_K1 <<
", "
820 << CShuffleMXdlPerWavePerShuffle <<
", "
821 << CShuffleNXdlPerWavePerShuffle <<
", "
822 << CBlockTransferScalarPerVector_NWaveNPerXdl
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
Definition tensor_operation/gpu/device/tensor_layout.hpp:42
Definition convolution_backward_data_specialization.hpp:8
GemmSpecialization
Definition gemm_specialization.hpp:11
ConvolutionForwardSpecialization
Definition convolution_forward_specialization.hpp:15
@ Filter1x1Stride1Pad0
Definition convolution_forward_specialization.hpp:18
@ Filter1x1Pad0
Definition convolution_forward_specialization.hpp:17
std::string getConvForwardSpecializationString(const ConvolutionForwardSpecialization &s)
Definition convolution_forward_specialization.hpp:24
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__device__ index_t get_grid_size()
Definition get_id.hpp:49
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
integral_constant< index_t, N > Number
Definition number.hpp:12
std::string get_device_name()
Definition host_utility/device_prop.hpp:19
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
bool is_gfx12_supported()
Definition host_utility/device_prop.hpp:55
bool is_gfx103_supported()
Definition host_utility/device_prop.hpp:120
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
typename std::enable_if< B, T >::type enable_if_t
Definition enable_if.hpp:27
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_dl_v1r3.hpp:93
ck::GridwiseGemmDl_km_kn_mn_v1r3< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11 __host__ static __device__ constexpr auto MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_gemm_dl_v1r3.hpp:208
ck::GridwiseGemmDl_km_kn_mn_v1r3< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::CalculateHasMainKBlockLoop __host__ static __device__ constexpr bool CalculateHasMainKBlockLoop(index_t K0)
Definition gridwise_gemm_dl_v1r3.hpp:153
ck::GridwiseGemmDl_km_kn_mn_v1r3< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::CalculateGridSize __host__ static __device__ constexpr index_t CalculateGridSize(index_t M, index_t N)
Definition gridwise_gemm_dl_v1r3.hpp:146
ck::GridwiseGemmDl_km_kn_mn_v1r3< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::CheckValidity __host__ static __device__ constexpr bool CheckValidity(const AGridDesc_AK0_M_AK1 &a_grid_desc_k0_m_k1, const BGridDesc_BK0_N_BK1 &b_grid_desc_k0_n_k1, const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_gemm_dl_v1r3.hpp:129
ck::GridwiseGemmDl_km_kn_mn_v1r3< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::CalculateHasDoubleTailKBlockLoop __host__ static __device__ constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0)
Definition gridwise_gemm_dl_v1r3.hpp:160
ck::GridwiseGemmDl_km_kn_mn_v1r3< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::MakeDefaultBlock2CTileMap __host__ static __device__ constexpr auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_gemm_dl_v1r3.hpp:241
ck::GridwiseGemmDl_km_kn_mn_v1r3< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::MakeAGridDescriptor_K0_M0_M1_K1 __host__ static __device__ constexpr auto MakeAGridDescriptor_K0_M0_M1_K1(const AGridDesc_AK0_M_AK1 &a_grid_desc_k0_m_k1)
Definition gridwise_gemm_dl_v1r3.hpp:168
ck::GridwiseGemmDl_km_kn_mn_v1r3< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::MakeBGridDescriptor_K0_N0_N1_K1 __host__ static __device__ constexpr auto MakeBGridDescriptor_K0_N0_N1_K1(const BGridDesc_BK0_N_BK1 &b_grid_desc_k0_n_k1)
Definition gridwise_gemm_dl_v1r3.hpp:188
Definition multi_index_transform.hpp:196
Definition multi_index_transform.hpp:284
Definition utility/sequence.hpp:43
Definition utility/integral_constant.hpp:20
Definition device_base.hpp:197
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:354
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:457
AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:461
std::array< index_t, NDimSpatial > input_left_pads_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:485
void Print() const
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:432
const ADataType * p_a_grid_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:447
std::array< index_t, NDimSpatial+3 > b_g_k_c_xs_strides_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:480
CDataType * p_c_grid_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:449
BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:462
std::array< index_t, NDimSpatial+3 > a_g_n_c_wis_strides_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:478
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:456
DefaultBlock2CTileMap block_2_ctile_map_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:466
CElementwiseOperation c_element_op_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:474
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:469
AElementwiseOperation a_element_op_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:472
std::array< index_t, NDimSpatial+3 > a_g_n_c_wis_lengths_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:477
const BDataType * p_b_grid_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:448
std::array< index_t, NDimSpatial+3 > c_g_n_k_wos_lengths_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:481
BElementwiseOperation b_element_op_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:473
std::array< index_t, NDimSpatial+3 > b_g_k_c_xs_lengths_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:479
index_t num_group_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:452
CGridDesc_M_N c_grid_desc_m_n_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:458
CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:463
ConvToGemmFwdTransformer conv_to_gemm_transformer_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:454
Argument(const void *p_a, const void *p_b, void *p_c, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< index_t, NDimSpatial+3 > &c_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &c_g_n_k_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CElementwiseOperation &c_element_op)
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:355
std::array< index_t, NDimSpatial > input_right_pads_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:486
std::array< index_t, NDimSpatial > conv_filter_strides_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:483
std::array< index_t, NDimSpatial+3 > c_g_n_k_wos_strides_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:482
std::array< index_t, NDimSpatial > conv_filter_dilations_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:484
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:491
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:494
DeviceOp::Argument Argument
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:492
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:573
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:228
static auto MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:265
static constexpr auto I0
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:231
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:794
remove_cvref_t< decltype(MakeBGridDescriptor_BK0_N_BK1< BLayout >( dummy_conv_to_gemm_transformer))> BGridDesc_BK0_N_BK1
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:301
static constexpr auto I2
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:233
static constexpr auto I1
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:232
remove_cvref_t< decltype(MakeCGridDescriptor_M_N< CLayout >(dummy_conv_to_gemm_transformer))> CGridDesc_M_N
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:303
DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK DeviceOp
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:229
TransformConvFwdToGemm< NDimSpatial, ConvForwardSpecialization > ConvToGemmFwdTransformer
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:236
GridwiseGemmDl_km_kn_mn_v1r3< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector > GridwiseGemm
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:307
static auto MakeInvoker()
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:756
static constexpr auto I3
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:234
static auto MakeCGridDescriptor_M_N(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:286
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:716
std::string GetTypeString() const override
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:799
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< index_t, NDimSpatial+3 > &c_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &c_g_n_k_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CElementwiseOperation &c_element_op) override
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:759
static constexpr ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:298
static auto MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:243
decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_BK0_N_BK1{})) BGridDesc_K0_N0_N1_K1
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:345
decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})) CGridDesc_M0_M10_M11_N0_N10_N11
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:347
decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{})) DefaultBlock2CTileMap
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:349
remove_cvref_t< decltype(MakeAGridDescriptor_AK0_M_AK1< ALayout >( dummy_conv_to_gemm_transformer))> AGridDesc_AK0_M_AK1
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:299
static bool IsSupportedArgument(const Argument &arg)
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:580
static constexpr auto matrix_padder
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:238
static auto MakeArgument(const void *p_a, const void *p_b, void *p_c, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< index_t, NDimSpatial+3 > &c_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &c_g_n_k_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CElementwiseOperation &c_element_op)
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:721
decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_AK0_M_AK1{})) AGridDesc_K0_M0_M1_K1
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:343
Definition device_grouped_conv_fwd.hpp:31
Definition matrix_padder.hpp:180