31#ifdef CK_EXPERIMENTAL_BUILDER
32#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp"
66template <
typename GridwiseGemm,
67 typename ComputePtrOffset,
68 typename AGridDesc_AK0_M_K1,
69 typename BGridDesc_BK0_N_K1,
70 typename DsGridDesc_M_N,
71 typename EGridDesc_M_N,
72 bool HasMainKBlockLoop,
77#if CK_USE_LAUNCH_BOUNDS
80 kernel_grouped_conv_fwd_xdl_cshuffle_v3(
typename GridwiseGemm::Argument karg,
81 const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
82 const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
83 const DsGridDesc_M_N ds_grid_desc_m_n,
84 const EGridDesc_M_N c_grid_desc_m_n,
85 const ComputePtrOffset compute_ptr_offset_of_groups,
86 const ComputePtrOffset compute_ptr_offset_of_n)
88#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
89 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
92 const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
93 const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
95 const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx);
96 const auto& ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx);
98 static constexpr index_t NumDTensor = GridwiseGemm::NumDTensor;
99 using DsGridPointer =
typename GridwiseGemm::DsGridPointer;
100 DsGridPointer p_ds_grid_grp{};
102 static_for<0, NumDTensor, 1>{}([&](
auto i) {
103 p_ds_grid_grp(i) = karg.p_ds_grid[i] + ds_n_offset[i] + ds_group_offset[i];
118 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
120 using Block2CTileMap =
typename GridwiseGemm::Block2CTileMapDefault;
121 const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4};
123 if constexpr(GridwiseGemm::DirectLoadEnabled)
125#if defined(__gfx950__)
126 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
127 karg.p_a_grid + a_group_offset + a_n_offset,
128 karg.p_b_grid + b_group_offset,
130 karg.p_c_grid + e_group_offset + e_n_offset,
137 GridwiseGemm::template TransformGrid<
decltype(a_grid_desc_ak0_m_ak1),
138 GridwiseGemm::AK0Number,
139 GridwiseGemm::AK1Number>(
140 a_grid_desc_ak0_m_ak1),
141 GridwiseGemm::template TransformGrid<
decltype(b_grid_desc_bk0_n_bk1),
142 GridwiseGemm::BK0Number,
143 GridwiseGemm::BK1Number>(
144 b_grid_desc_bk0_n_bk1),
151 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
152 karg.p_a_grid + a_group_offset + a_n_offset,
153 karg.p_b_grid + b_group_offset,
155 karg.p_c_grid + e_group_offset + e_n_offset,
162 GridwiseGemm::template TransformGrid<
decltype(a_grid_desc_ak0_m_ak1),
163 GridwiseGemm::AK0Number,
164 GridwiseGemm::AK1Number>(
165 a_grid_desc_ak0_m_ak1),
166 GridwiseGemm::template TransformGrid<
decltype(b_grid_desc_bk0_n_bk1),
167 GridwiseGemm::BK0Number,
168 GridwiseGemm::BK1Number>(
169 b_grid_desc_bk0_n_bk1),
176 ignore = a_grid_desc_ak0_m_ak1;
177 ignore = b_grid_desc_bk0_n_bk1;
178 ignore = ds_grid_desc_m_n;
180 ignore = compute_ptr_offset_of_groups;
181 ignore = compute_ptr_offset_of_n;
185template <
typename GridwiseGemm,
186 typename ComputePtrOffset,
187 typename AGridDesc_AK0_M_K1,
188 typename BGridDesc_BK0_N_K1,
189 typename DsGridDesc_M_N,
190 typename EGridDesc_M_N,
191 bool HasMainKBlockLoop,
196#if CK_USE_LAUNCH_BOUNDS
199 kernel_grouped_conv_fwd_xdl_cshuffle_v3_2lds(
200 typename GridwiseGemm::Argument karg,
201 const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
202 const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
203 const DsGridDesc_M_N ds_grid_desc_m_n,
204 const EGridDesc_M_N c_grid_desc_m_n,
205 const ComputePtrOffset compute_ptr_offset_of_groups,
206 const ComputePtrOffset compute_ptr_offset_of_n)
208#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
209 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
212 const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
213 const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
215 const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx);
216 const auto& ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx);
218 static constexpr index_t NumDTensor = GridwiseGemm::NumDTensor;
219 using DsGridPointer =
typename GridwiseGemm::DsGridPointer;
220 DsGridPointer p_ds_grid_grp{};
222 static_for<0, NumDTensor, 1>{}([&](
auto i) {
223 p_ds_grid_grp(i) = karg.p_ds_grid[i] + ds_n_offset[i] + ds_group_offset[i];
240 __shared__
char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
241 __shared__
char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
243 using Block2CTileMap =
typename GridwiseGemm::Block2CTileMapDefault;
244 const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4};
246 if constexpr(GridwiseGemm::DirectLoadEnabled)
248#if defined(__gfx950__)
249 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
250 karg.p_a_grid + a_group_offset + a_n_offset,
251 karg.p_b_grid + b_group_offset,
253 karg.p_c_grid + e_group_offset + e_n_offset,
261 GridwiseGemm::template TransformGrid<
decltype(a_grid_desc_ak0_m_ak1),
262 GridwiseGemm::AK0Number,
263 GridwiseGemm::AK1Number>(
264 a_grid_desc_ak0_m_ak1),
265 GridwiseGemm::template TransformGrid<
decltype(b_grid_desc_bk0_n_bk1),
266 GridwiseGemm::BK0Number,
267 GridwiseGemm::BK1Number>(
268 b_grid_desc_bk0_n_bk1),
275 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
276 karg.p_a_grid + a_group_offset + a_n_offset,
277 karg.p_b_grid + b_group_offset,
279 karg.p_c_grid + e_group_offset + e_n_offset,
287 GridwiseGemm::template TransformGrid<
decltype(a_grid_desc_ak0_m_ak1),
288 GridwiseGemm::AK0Number,
289 GridwiseGemm::AK1Number>(
290 a_grid_desc_ak0_m_ak1),
291 GridwiseGemm::template TransformGrid<
decltype(b_grid_desc_bk0_n_bk1),
292 GridwiseGemm::BK0Number,
293 GridwiseGemm::BK1Number>(
294 b_grid_desc_bk0_n_bk1),
301 ignore = a_grid_desc_ak0_m_ak1;
302 ignore = b_grid_desc_bk0_n_bk1;
303 ignore = ds_grid_desc_m_n;
305 ignore = compute_ptr_offset_of_groups;
306 ignore = compute_ptr_offset_of_n;
313using is_tuple =
decltype(std::declval<T&>().IsTuple());
338 typename AccDataType,
339 typename CShuffleDataType,
342 typename AElementwiseOperation,
343 typename BElementwiseOperation,
344 typename CDEElementwiseOperation,
357 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
358 typename ABlockTransferThreadClusterArrangeOrder,
359 typename ABlockTransferSrcAccessOrder,
360 index_t ABlockTransferSrcVectorDim,
361 index_t ABlockTransferSrcScalarPerVector,
362 index_t ABlockTransferDstScalarPerVector_AK1,
364 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
365 typename BBlockTransferThreadClusterArrangeOrder,
366 typename BBlockTransferSrcAccessOrder,
367 index_t BBlockTransferSrcVectorDim,
368 index_t BBlockTransferSrcScalarPerVector,
369 index_t BBlockTransferDstScalarPerVector_BK1,
371 index_t CShuffleMXdlPerWavePerShuffle,
372 index_t CShuffleNXdlPerWavePerShuffle,
373 typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
374 index_t CDEBlockTransferScalarPerVector_NPerBlock,
377 typename AComputeDataType =
378 decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
383 typename BComputeDataType = AComputeDataType,
384 bool DirectLoad =
false>
395 AElementwiseOperation,
396 BElementwiseOperation,
397 CDEElementwiseOperation,
408 static constexpr bool isMultiD = DsDataType::Size() > 0;
429 CDEBlockTransferScalarPerVector_NPerBlock>::type;
432 ConvForwardSpecialization,
443 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
453 template <
typename ALay>
459 using Layout = std::conditional_t<
462 std::conditional_t<is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>(),
466 const auto in_gemmmraw_gemmkraw_desc =
467 conv_to_gemm_transformer.template MakeADescriptor_M_K<Layout>();
469 const auto in_gemmm_gemmk_desc =
472 const auto M = in_gemmm_gemmk_desc.GetLength(
I0);
473 const auto K = in_gemmm_gemmk_desc.GetLength(
I1);
475 const auto AK0 = K / AK1;
484 template <
typename BLay>
489 using Layout = std::conditional_t<
492 std::conditional_t<is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>(),
496 const auto wei_gemmnraw_gemmkraw_desc =
497 conv_to_gemm_transformer.template MakeBDescriptor_N_K<Layout>();
499 const auto wei_gemmn_gemmk_desc =
500 matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
502 const auto N = wei_gemmn_gemmk_desc.GetLength(
I0);
503 const auto K = wei_gemmn_gemmk_desc.GetLength(
I1);
505 const auto BK0 = K / BK1;
514 template <
typename ELay>
519 using Layout = std::conditional_t<
522 std::conditional_t<is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>(),
526 const auto out_gemmmraw_gemmnraw_desc =
527 conv_to_gemm_transformer.template MakeCDescriptor_M_N<Layout>();
529 const auto out_gemmm_gemmn_desc =
530 matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
532 return out_gemmm_gemmn_desc;
556 ABlockTransferSrcScalarPerVector *
sizeof(ADataType) == 8
557 ? 4 /
sizeof(ADataType)
558 : ABlockTransferSrcScalarPerVector;
560 BBlockTransferSrcScalarPerVector *
sizeof(BDataType) == 8
561 ? 4 /
sizeof(BDataType)
562 : BBlockTransferSrcScalarPerVector;
565 template <index_t NXdlPerWave_>
577 AElementwiseOperation,
578 BElementwiseOperation,
579 CDEElementwiseOperation,
591 ABlockTransferThreadClusterLengths_AK0_M_AK1,
592 ABlockTransferThreadClusterArrangeOrder,
593 ABlockTransferSrcAccessOrder,
594 ABlockTransferSrcVectorDim,
596 ABlockTransferDstScalarPerVector_AK1,
599 BBlockTransferThreadClusterLengths_BK0_N_BK1,
600 BBlockTransferThreadClusterArrangeOrder,
601 BBlockTransferSrcAccessOrder,
602 BBlockTransferSrcVectorDim,
604 BBlockTransferDstScalarPerVector_BK1,
607 CShuffleMXdlPerWavePerShuffle,
608 CShuffleNXdlPerWavePerShuffle,
609 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
628 .template MakeNGCHWTransposeDesc<NDimSpatial>({}, {}))>;
631 .template MakeNHWGCTransposeDesc<NDimSpatial>({}, {}))>;
635 .template MakeGKCYXTransposeDesc<NDimSpatial>({}, {}))>;
638 .template MakeGKYXCTransposeDesc<NDimSpatial>({}, {}))>;
707 const std::array<const void*, NumDTensor>& p_ds,
709 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
710 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
711 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
712 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
713 const std::array<std::array<index_t, NDimSpatial + 3>,
NumDTensor>&
714 ds_g_n_k_wos_lengths,
715 const std::array<std::array<index_t, NDimSpatial + 3>,
NumDTensor>&
716 ds_g_n_k_wos_strides,
717 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
718 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
719 const std::array<index_t, NDimSpatial>& conv_filter_strides,
720 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
721 const std::array<index_t, NDimSpatial>& input_left_pads,
722 const std::array<index_t, NDimSpatial>& input_right_pads,
723 const AElementwiseOperation& a_element_op,
724 const BElementwiseOperation& b_element_op,
725 const CDEElementwiseOperation& cde_element_op)
732 a_g_n_c_wis_lengths, a_g_n_c_wis_strides)},
735 b_g_k_c_xs_lengths, b_g_k_c_xs_strides)},
740 e_g_n_k_wos_lengths, e_g_n_k_wos_strides)},
776 p_a_grid_ =
static_cast<const ADataType*
>(p_as);
777 p_b_grid_ =
static_cast<const BDataType*
>(p_bs);
812 a_g_n_c_wis_lengths, a_g_n_c_wis_strides);
815 a_g_n_c_wis_lengths, a_g_n_c_wis_strides);
819 b_g_k_c_xs_lengths, b_g_k_c_xs_strides);
822 b_g_k_c_xs_lengths, b_g_k_c_xs_strides);
826 e_g_n_k_wos_lengths, e_g_n_k_wos_strides);
829 e_g_n_k_wos_lengths, e_g_n_k_wos_strides);
879 return sizeof(EDataType) * e_accum;
898 [&](
auto i) { std::cout <<
"Ds[M, N]: " <<
ds_grid_desc_m_n_[i] << std::endl; });
960 template <
typename Gr
idwiseGemm>
963 if(stream_config.log_level_ > 0)
970 constexpr index_t minimum_occupancy =
978 const index_t num_workgroups_per_Conv_N =
983 std::tie(gdx, gdy, gdz) =
984 GridwiseGemm::CalculateGridSize(GemmM, GemmN, I1 );
987 gdz = num_workgroups_per_Conv_N;
989 index_t K_split = (GemmK + KPerBlock - 1) / KPerBlock * KPerBlock;
990 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
992 const ADataType* p_a_grid = arg.
p_a_grid_;
993 const BDataType* p_b_grid = arg.
p_b_grid_;
1008 typename GridwiseGemm::Argument gemm_arg{
1026 const auto Run = [&](
const auto& kernel) {
1027 if(stream_config.flush_cache)
1029 typename GridwiseGemm::Argument gemm_arg_ = gemm_arg;
1032 stream_config.rotating_count,
1033 gemm_arg_.M * gemm_arg_.K *
sizeof(ADataType),
1034 gemm_arg_.K * gemm_arg_.N *
sizeof(BDataType));
1035 rotating_mem.Print();
1037 auto run_flush_cache = [&]() {
1041 rotating_mem.Next();
1048 dim3(gdx, gdy, gdz),
1063 dim3(gdx, gdy, gdz),
1076 if(has_main_k_block_loop)
1079 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
1080 BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
1083 kernel_grouped_conv_fwd_xdl_cshuffle_v3<GridwiseGemm,
1085 DeviceOp::AGridDesc_AK0_M_AK1,
1086 DeviceOp::BGridDesc_BK0_N_BK1,
1087 DeviceOp::DsGridDesc_M_N,
1088 DeviceOp::EGridDesc_M_N,
1090 InMemoryDataOperationEnum::Set,
1095 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
1097 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
1100 kernel_grouped_conv_fwd_xdl_cshuffle_v3<GridwiseGemm,
1102 DeviceOp::AGridDesc_AK0_M_AK1,
1103 DeviceOp::BGridDesc_BK0_N_BK1,
1104 DeviceOp::DsGridDesc_M_N,
1105 DeviceOp::EGridDesc_M_N,
1107 InMemoryDataOperationEnum::Set,
1112 else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full)
1115 kernel_grouped_conv_fwd_xdl_cshuffle_v3<GridwiseGemm,
1117 DeviceOp::AGridDesc_AK0_M_AK1,
1118 DeviceOp::BGridDesc_BK0_N_BK1,
1119 DeviceOp::DsGridDesc_M_N,
1120 DeviceOp::EGridDesc_M_N,
1122 InMemoryDataOperationEnum::Set,
1128 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
1130 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
1132 const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3<
1135 DeviceOp::AGridDesc_AK0_M_AK1,
1136 DeviceOp::BGridDesc_BK0_N_BK1,
1137 DeviceOp::DsGridDesc_M_N,
1138 DeviceOp::EGridDesc_M_N,
1140 InMemoryDataOperationEnum::Set,
1147 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
1149 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Three)
1151 const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3<
1154 DeviceOp::AGridDesc_AK0_M_AK1,
1155 DeviceOp::BGridDesc_BK0_N_BK1,
1156 DeviceOp::DsGridDesc_M_N,
1157 DeviceOp::EGridDesc_M_N,
1159 InMemoryDataOperationEnum::Set,
1166 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
1168 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Four)
1170 const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3<
1173 DeviceOp::AGridDesc_AK0_M_AK1,
1174 DeviceOp::BGridDesc_BK0_N_BK1,
1175 DeviceOp::DsGridDesc_M_N,
1176 DeviceOp::EGridDesc_M_N,
1178 InMemoryDataOperationEnum::Set,
1185 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
1187 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Five)
1189 const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3<
1192 DeviceOp::AGridDesc_AK0_M_AK1,
1193 DeviceOp::BGridDesc_BK0_N_BK1,
1194 DeviceOp::DsGridDesc_M_N,
1195 DeviceOp::EGridDesc_M_N,
1197 InMemoryDataOperationEnum::Set,
1204 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
1206 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
1208 const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3<
1211 DeviceOp::AGridDesc_AK0_M_AK1,
1212 DeviceOp::BGridDesc_BK0_N_BK1,
1213 DeviceOp::DsGridDesc_M_N,
1214 DeviceOp::EGridDesc_M_N,
1216 InMemoryDataOperationEnum::Set,
1223 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
1225 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Seven)
1227 const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3<
1230 DeviceOp::AGridDesc_AK0_M_AK1,
1231 DeviceOp::BGridDesc_BK0_N_BK1,
1232 DeviceOp::DsGridDesc_M_N,
1233 DeviceOp::EGridDesc_M_N,
1235 InMemoryDataOperationEnum::Set,
1243 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
1245 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
1247 const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3_2lds<
1250 DeviceOp::AGridDesc_AK0_M_AK1,
1251 DeviceOp::BGridDesc_BK0_N_BK1,
1252 DeviceOp::DsGridDesc_M_N,
1253 DeviceOp::EGridDesc_M_N,
1255 InMemoryDataOperationEnum::Set,
1262 const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3_2lds<
1265 DeviceOp::AGridDesc_AK0_M_AK1,
1266 DeviceOp::BGridDesc_BK0_N_BK1,
1267 DeviceOp::DsGridDesc_M_N,
1268 DeviceOp::EGridDesc_M_N,
1270 InMemoryDataOperationEnum::Set,
1278 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
1281 kernel_grouped_conv_fwd_xdl_cshuffle_v3<GridwiseGemm,
1283 DeviceOp::AGridDesc_AK0_M_AK1,
1284 DeviceOp::BGridDesc_BK0_N_BK1,
1285 DeviceOp::DsGridDesc_M_N,
1286 DeviceOp::EGridDesc_M_N,
1288 InMemoryDataOperationEnum::Set,
1296 kernel_grouped_conv_fwd_xdl_cshuffle_v3<GridwiseGemm,
1298 DeviceOp::AGridDesc_AK0_M_AK1,
1299 DeviceOp::BGridDesc_BK0_N_BK1,
1300 DeviceOp::DsGridDesc_M_N,
1301 DeviceOp::EGridDesc_M_N,
1303 InMemoryDataOperationEnum::Set,
1314 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
1317 kernel_grouped_conv_fwd_xdl_cshuffle_v3<GridwiseGemm,
1319 DeviceOp::AGridDesc_AK0_M_AK1,
1320 DeviceOp::BGridDesc_BK0_N_BK1,
1321 DeviceOp::DsGridDesc_M_N,
1322 DeviceOp::EGridDesc_M_N,
1324 InMemoryDataOperationEnum::Set,
1333 template <
typename Gr
idwiseGemm>
1336 float avg_time = 0.f;
1351 BDataType* p_b_out_grid =
1355 auto kernel_transpose =
1373 dim3(a_grid_size + b_grid_size),
1390 avg_time += RunGemm<GridwiseGemm>(arg, stream_config);
1400 const EDataType* p_e_in_grid =
1405 EDataType* p_e_out_grid = arg.
p_e_grid_;
1412 Block2TileMapElementwise,
1419 dim3(ElementwiseBlocksize),
1437 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
1455 std::cout <<
"The MultiABD is not supported!" <<
" In " << __FILE__ <<
":"
1456 << __LINE__ <<
", in function: " << __func__ << std::endl;
1461 if constexpr(DirectLoad)
1477 <<
"On gfx908 the accumulation data type must be one of fp32 or int32!"
1478 <<
" In " << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1489 std::cout <<
"Current device does not support xdl instructions!" <<
" In "
1490 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1507 std::cout <<
"ComputeDataType for A and B should be same while using TF32"
1515 if constexpr(ConvForwardSpecialization ==
1519 for(
index_t i = 0; i < NDimSpatial; ++i)
1526 if(!(SpatialDim == 1 && ConvStride == 1 &&
LeftPad == 0 &&
RightPad == 0))
1530 std::cout <<
"The input paramters do not align with specialization "
1531 "Filter1x1Stride1Pad0!"
1532 <<
" In " << __FILE__ <<
":" << __LINE__
1533 <<
", in function: " << __func__ << std::endl;
1539 else if constexpr(ConvForwardSpecialization ==
1543 for(
index_t i = 0; i < NDimSpatial; ++i)
1554 <<
"The input paramters do not align with specialization Filter1x1Pad0!"
1555 <<
" In " << __FILE__ <<
":" << __LINE__
1556 <<
", in function: " << __func__ << std::endl;
1572 if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0))
1576 std::cout <<
"[A Layout] The number of input channels is not a multiple of "
1577 "ABlockTransferSrcScalarPerVector!"
1578 <<
" In " << __FILE__ <<
":" << __LINE__
1579 <<
", in function: " << __func__ << std::endl;
1588 std::cout <<
"Unsupported A Layout!" <<
" In " << __FILE__ <<
":" << __LINE__
1589 <<
", in function: " << __func__ << std::endl;
1604 if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0))
1608 std::cout <<
"[B Layout] The number of input channels is not a multiple of "
1609 "BBlockTransferSrcScalarPerVector!"
1610 <<
" In " << __FILE__ <<
":" << __LINE__
1611 <<
", in function: " << __func__ << std::endl;
1620 std::cout <<
"Unsupported A Layout!" <<
" In " << __FILE__ <<
":" << __LINE__
1621 <<
", in function: " << __func__ << std::endl;
1629 if((G * C) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
1633 std::cout <<
"[NGCHW Layout] The G * C is not a multiple of "
1634 "CDEBlockTransferScalarPerVector_NPerBlock"
1635 <<
" In " << __FILE__ <<
":" << __LINE__
1636 <<
", in function: " << __func__ << std::endl;
1641 if((G * K) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
1645 std::cout <<
"[NGCHW Layout] The G * K is not a multiple of "
1646 "CDEBlockTransferScalarPerVector_NPerBlock"
1647 <<
" In " << __FILE__ <<
":" << __LINE__
1648 <<
", in function: " << __func__ << std::endl;
1658 if(input_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0)
1662 std::cout <<
"[NGCHW Layout] The input_spatial_acum is not a multiple of "
1663 "CDEBlockTransferScalarPerVector_NPerBlock"
1664 <<
" In " << __FILE__ <<
":" << __LINE__
1665 <<
", in function: " << __func__ << std::endl;
1670 if(output_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0)
1674 std::cout <<
"[NGCHW Layout] The output_spatial_acum is not a multiple of "
1675 "CDEBlockTransferScalarPerVector_NPerBlock"
1676 <<
" In " << __FILE__ <<
":" << __LINE__
1677 <<
", in function: " << __func__ << std::endl;
1686 std::cout <<
"Warning: Workspace for "
1687 "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3::Argument is not "
1688 "allocated, use SetWorkSpacePointer."
1700 std::cout <<
"[NGCHW Layout] One of the transposed vectors is exceeding 2GB "
1702 <<
" In " << __FILE__ <<
":" << __LINE__
1703 <<
", in function: " << __func__ << std::endl;
1717 if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
1721 std::cout <<
"[E Layout] The K is not a multiple of "
1722 "CDEBlockTransferScalarPerVector_NPerBlock"
1723 <<
" In " << __FILE__ <<
":" << __LINE__
1724 <<
", in function: " << __func__ << std::endl;
1733 std::cout <<
"Unsupported E Layout!" <<
" In " << __FILE__ <<
":" << __LINE__
1734 <<
", in function: " << __func__ << std::endl;
1745 <<
"[conv_to_gemm_transformer_] One of the descriptors is bigger than 2GB!"
1746 <<
" In " << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1762 typename GridwiseGemm64::Argument gemm_arg{
nullptr,
1784 typename GridwiseGemm32::Argument gemm_arg{
nullptr,
1814 const std::array<const void*, NumDTensor>& p_ds,
1816 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
1817 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
1818 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
1819 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
1820 const std::array<std::array<index_t, NDimSpatial + 3>,
NumDTensor>& ds_g_n_k_wos_lengths,
1821 const std::array<std::array<index_t, NDimSpatial + 3>,
NumDTensor>& ds_g_n_k_wos_strides,
1822 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
1823 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
1824 const std::array<index_t, NDimSpatial>& conv_filter_strides,
1825 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
1826 const std::array<index_t, NDimSpatial>& input_left_pads,
1827 const std::array<index_t, NDimSpatial>& input_right_pads,
1828 const AElementwiseOperation& a_element_op,
1829 const BElementwiseOperation& b_element_op,
1830 const CDEElementwiseOperation& cde_element_op)
1836 a_g_n_c_wis_lengths,
1837 a_g_n_c_wis_strides,
1840 ds_g_n_k_wos_lengths,
1841 ds_g_n_k_wos_strides,
1842 e_g_n_k_wos_lengths,
1843 e_g_n_k_wos_strides,
1844 conv_filter_strides,
1845 conv_filter_dilations,
1856 const std::array<const void*, NumDTensor>& p_ds,
1858 const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
1859 const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
1860 const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
1861 const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
1862 const std::array<std::array<long_index_t, NDimSpatial + 3>,
NumDTensor>&
1863 ds_g_n_k_wos_lengths,
1864 const std::array<std::array<long_index_t, NDimSpatial + 3>,
NumDTensor>&
1865 ds_g_n_k_wos_strides,
1866 const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
1867 const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
1868 const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
1869 const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
1870 const std::array<long_index_t, NDimSpatial>& input_left_pads,
1871 const std::array<long_index_t, NDimSpatial>& input_right_pads,
1872 const AElementwiseOperation& a_element_op,
1873 const BElementwiseOperation& b_element_op,
1874 const CDEElementwiseOperation& cde_element_op)
1876 std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
1877 std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
1878 std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
1879 std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
1880 std::array<std::array<index_t, NDimSpatial + 3>,
NumDTensor> ds_g_n_k_wos_lengths_i32;
1881 std::array<std::array<index_t, NDimSpatial + 3>,
NumDTensor> ds_g_n_k_wos_strides_i32;
1882 std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
1883 std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
1884 std::array<index_t, NDimSpatial> conv_filter_strides_i32;
1885 std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
1886 std::array<index_t, NDimSpatial> input_left_pads_i32;
1887 std::array<index_t, NDimSpatial> input_right_pads_i32;
1889 array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
1890 array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
1895 array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
1896 array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
1898 array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
1899 array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
1900 array_convert(conv_filter_strides_i32, conv_filter_strides);
1901 array_convert(conv_filter_dilations_i32, conv_filter_dilations);
1909 a_g_n_c_wis_lengths_i32,
1910 a_g_n_c_wis_strides_i32,
1911 b_g_k_c_xs_lengths_i32,
1912 b_g_k_c_xs_strides_i32,
1913 ds_g_n_k_wos_lengths_i32,
1914 ds_g_n_k_wos_strides_i32,
1915 e_g_n_k_wos_lengths_i32,
1916 e_g_n_k_wos_strides_i32,
1917 conv_filter_strides_i32,
1918 conv_filter_dilations_i32,
1919 input_left_pads_i32,
1920 input_right_pads_i32,
1931 const std::array<const void*, NumDTensor>& p_ds,
1933 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
1934 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
1935 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
1936 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
1937 const std::array<std::array<index_t, NDimSpatial + 3>,
NumDTensor>& ds_g_n_k_wos_lengths,
1938 const std::array<std::array<index_t, NDimSpatial + 3>,
NumDTensor>& ds_g_n_k_wos_strides,
1939 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
1940 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
1941 const std::array<index_t, NDimSpatial>& conv_filter_strides,
1942 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
1943 const std::array<index_t, NDimSpatial>& input_left_pads,
1944 const std::array<index_t, NDimSpatial>& input_right_pads,
1945 const AElementwiseOperation& a_element_op,
1946 const BElementwiseOperation& b_element_op,
1947 const CDEElementwiseOperation& cde_element_op)
override
1949 return std::make_unique<Argument>(p_a,
1953 a_g_n_c_wis_lengths,
1954 a_g_n_c_wis_strides,
1957 ds_g_n_k_wos_lengths,
1958 ds_g_n_k_wos_strides,
1959 e_g_n_k_wos_lengths,
1960 e_g_n_k_wos_strides,
1961 conv_filter_strides,
1962 conv_filter_dilations,
1970 std::unique_ptr<BaseArgument>
1973 const std::array<const void*, NumDTensor>& p_ds,
1975 const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
1976 const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
1977 const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
1978 const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
1979 const std::array<std::array<long_index_t, NDimSpatial + 3>,
NumDTensor>&
1980 ds_g_n_k_wos_lengths,
1981 const std::array<std::array<long_index_t, NDimSpatial + 3>,
NumDTensor>&
1982 ds_g_n_k_wos_strides,
1983 const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
1984 const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
1985 const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
1986 const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
1987 const std::array<long_index_t, NDimSpatial>& input_left_pads,
1988 const std::array<long_index_t, NDimSpatial>& input_right_pads,
1989 const AElementwiseOperation& a_element_op,
1990 const BElementwiseOperation& b_element_op,
1991 const CDEElementwiseOperation& cde_element_op)
override
1993 std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
1994 std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
1995 std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
1996 std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
1997 std::array<std::array<index_t, NDimSpatial + 3>,
NumDTensor> ds_g_n_k_wos_lengths_i32;
1998 std::array<std::array<index_t, NDimSpatial + 3>,
NumDTensor> ds_g_n_k_wos_strides_i32;
1999 std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
2000 std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
2001 std::array<index_t, NDimSpatial> conv_filter_strides_i32;
2002 std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
2003 std::array<index_t, NDimSpatial> input_left_pads_i32;
2004 std::array<index_t, NDimSpatial> input_right_pads_i32;
2006 array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
2007 array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
2012 array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
2013 array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
2015 array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
2016 array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
2017 array_convert(conv_filter_strides_i32, conv_filter_strides);
2018 array_convert(conv_filter_dilations_i32, conv_filter_dilations);
2022 return std::make_unique<Argument>(p_a,
2026 a_g_n_c_wis_lengths_i32,
2027 a_g_n_c_wis_strides_i32,
2028 b_g_k_c_xs_lengths_i32,
2029 b_g_k_c_xs_strides_i32,
2030 ds_g_n_k_wos_lengths_i32,
2031 ds_g_n_k_wos_strides_i32,
2032 e_g_n_k_wos_lengths_i32,
2033 e_g_n_k_wos_strides_i32,
2034 conv_filter_strides_i32,
2035 conv_filter_dilations_i32,
2036 input_left_pads_i32,
2037 input_right_pads_i32,
2045 return std::make_unique<Invoker>(
Invoker{});
2050 auto str = std::stringstream();
2052 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
2056 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
2064 str <<
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3";
2066 if constexpr(DirectLoad) {
2067 str <<
"_DirectLoad";
2071 << BlockSize <<
", "
2072 << MPerBlock <<
", "
2073 << NPerBlock <<
", "
2074 << KPerBlock <<
", "
2078 << MXdlPerWave <<
", "
2079 << NXdlPerWave <<
", "
2080 << ABlockTransferSrcScalarPerVector <<
", "
2081 << BBlockTransferSrcScalarPerVector <<
", "
2082 << CDEBlockTransferScalarPerVector_NPerBlock <<
", "
2083 << CShuffleMXdlPerWavePerShuffle <<
", "
2084 << CShuffleNXdlPerWavePerShuffle <<
", "
2085 <<
"BlkGemmPipelineScheduler: "
2086 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] <<
", "
2087 <<
"BlkGemmPipelineVersion: "
2088 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer]
2095#ifdef CK_EXPERIMENTAL_BUILDER
2096 std::string GetInstanceString()
const override
2098 static_assert(ck_tile::reflect::HasInstanceTraits<DeviceOp>,
2099 "Specialization of instance_traits not found. Please check that a "
2100 "specialization exists in file "
2101 "ck_tile/builder/reflect/"
2102 "instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp "
2103 "for the given template parameters.");
2104 return ck_tile::reflect::instance_string<DeviceOp>();
2110 auto arg =
dynamic_cast<const Argument*
>(p_arg);
2116 throw std::runtime_error(
2117 "The argument pointer is not an object of "
2118 "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument structure!");
2125 auto p_arg_ =
dynamic_cast<Argument*
>(p_arg);
2128 p_arg_->p_workspace_ = p_workspace;
2131 throw std::runtime_error(
2132 "The argument pointer is not an object of "
2133 "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument structure!");
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition tensor_operation/gpu/device/tensor_layout.hpp:42
Definition convolution_backward_data_specialization.hpp:8
GemmSpecialization
Definition gemm_specialization.hpp:11
constexpr bool is_NGCDHW_NGKDHW()
Definition device_grouped_conv_utils.hpp:112
constexpr bool is_NGCHW_GKCYX_NGKHW()
Definition device_grouped_conv_utils.hpp:64
decltype(std::declval< T & >().IsTuple()) is_tuple
Definition device_grouped_conv_fwd_multiple_abd.hpp:23
ConvolutionForwardSpecialization
Definition convolution_forward_specialization.hpp:15
@ Filter1x1Stride1Pad0
Definition convolution_forward_specialization.hpp:18
@ Filter1x1Pad0
Definition convolution_forward_specialization.hpp:17
std::string getConvForwardSpecializationString(const ConvolutionForwardSpecialization &s)
Definition convolution_forward_specialization.hpp:24
constexpr bool is_NGCDHW_GKCZYX_NGKDHW()
Definition device_grouped_conv_utils.hpp:104
constexpr bool is_NGCHW_NGKHW()
Definition device_grouped_conv_utils.hpp:72
Definition convolution_backward_data_specialization.hpp:7
void flush_icache()
Definition flush_cache.hpp:383
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition flush_cache.hpp:398
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v5
Definition blkgemmpipe_scheduler.hpp:18
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition is_detected.hpp:34
integral_constant< index_t, N > Number
Definition number.hpp:12
std::string get_device_name()
Definition host_utility/device_prop.hpp:19
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Full
Definition blkgemmpipe_scheduler.hpp:49
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition amd_wave_read_first_lane.hpp:100
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
bool is_tf32_supported()
Definition host_utility/device_prop.hpp:132
__host__ __device__ void array_convert(std::array< Y, NumElems > &y, const std::array< X, NumElems > &x)
Definition utility/type_convert.hpp:2466
constexpr bool is_same_v
Definition type.hpp:283
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
auto accumulate_n(ForwardIterator first, Size count, T init, BinaryOperation op) -> decltype(std::accumulate(first, std::next(first, count), init, op))
Definition library/utility/numeric.hpp:11
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__global__ void kernel_elementwise_dual(const InAGridDescTuple in_grid_desc_tuple_a, const InBGridDescTuple in_grid_desc_tuple_b, const OutAGridDescTuple out_grid_desc_tuple_a, const OutBGridDescTuple out_grid_desc_tuple_b, const InADataTypePointerTuple p_in_global_tuple_a, const InBDataTypePointerTuple p_in_global_tuple_b, const OutADataTypePointerTuple p_out_global_tuple_a, const OutBDataTypePointerTuple p_out_global_tuple_b, const Block2TileMapA block_2_tile_map_a, const Block2TileMapB block_2_tile_map_b, const ElementwiseOperation elementwise_op, const index_t a_grid_size)
Definition gridwise_elementwise_2d.hpp:61
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__global__ void kernel_elementwise(const InGridDescTuple in_grid_desc_tuple, const OutGridDescTuple out_grid_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const Block2TileMap block_2_tile_map, const ElementwiseOperation elementwise_op)
Definition gridwise_elementwise_2d.hpp:29
Layout wrapper that performs the tensor descriptor logic.
Definition layout.hpp:24
Definition ck/stream_config.hpp:10
Definition block_to_ctile_map.hpp:261
Definition gridwise_elementwise_2d.hpp:278
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:157
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< tensor_layout::gemm::RowMajor, tensor_layout::gemm::ColumnMajor, DsLayout, tensor_layout::gemm::RowMajor, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, DirectLoad ? ABlockTransferSrcScalarPerVectorAligned :ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, DirectLoad ? BBlockTransferSrcScalarPerVectorAligned :BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, AComputeDataType, BComputeDataType, ADataType, BDataType, DoElementwiseBeforeCShuffle, DirectLoad >::CheckValidity static __host__ constexpr bool CheckValidity(const Argument &karg)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:1186
Definition multi_index_transform.hpp:196
Definition multi_index_transform.hpp:284
Definition utility/sequence.hpp:43
Definition utility/tuple.hpp:117
Definition functional2.hpp:33
Definition tensor_operation/gpu/device/tensor_layout.hpp:31
Definition tensor_operation/gpu/device/tensor_layout.hpp:26
Definition device_base.hpp:197
void * p_workspace_
Definition device_base.hpp:204
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:704
AElementwiseOperation a_element_op_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:941
Argument(const void *p_as, const void *p_bs, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op)
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:705
ComputePtrOffset compute_ptr_offset_of_n_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:938
std::size_t GetWorkspaceSizeBytes() const
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:887
DsGridDesc_M_N ds_grid_desc_m_n_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:930
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_b_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:947
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:933
NGCHWTransposeDescType a_in_transpose_desc_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:949
const std::array< const void *, NumDTensor > p_ds_grid_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:906
std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > ds_g_n_k_wos_lengths_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:914
std::array< index_t, NDimSpatial > input_left_pads_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:920
BElementwiseOperation b_element_op_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:942
ComputePtrOffset compute_ptr_offset_of_groups_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:937
const BDataType * p_b_grid_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:905
NHWGCTransposeDescType e_in_transpose_desc_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:950
GKCYXTransposeDescType b_in_transpose_desc_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:951
const ADataType * p_a_grid_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:904
std::array< index_t, NDimSpatial+3 > e_g_n_k_wos_lengths_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:916
std::array< index_t, NDimSpatial > conv_filter_strides_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:918
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_e_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:947
CDEElementwiseOperation cde_element_op_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:943
std::array< index_t, NDimSpatial+3 > b_g_k_c_xs_strides_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:913
std::array< index_t, NDimSpatial > input_right_pads_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:921
std::size_t GetWorkspaceBTensorSizeBytes() const
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:856
std::array< index_t, NDimSpatial+3 > b_g_k_c_xs_lengths_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:912
std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > ds_g_n_k_wos_strides_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:915
NGCHWTransposeDescType e_out_transpose_desc_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:949
GKYXCTransposeDescType b_out_transpose_desc_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:952
std::array< index_t, NDimSpatial+3 > a_g_n_c_wis_lengths_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:910
EDataType * p_e_grid_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:907
NHWGCTransposeDescType a_out_transpose_desc_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:950
index_t num_group_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:924
std::array< index_t, NDimSpatial > conv_filter_dilations_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:919
ConvToGemmFwdTransformer conv_to_gemm_transformer_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:926
void Print() const
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:893
std::array< index_t, NDimSpatial+3 > e_g_n_k_wos_strides_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:917
std::size_t GetWorkspaceATensorSizeBytes() const
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:840
EGridDesc_M_N e_grid_desc_m_n_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:931
std::array< index_t, NDimSpatial+3 > a_g_n_c_wis_strides_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:911
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:934
std::size_t GetWorkspaceETensorSizeBytes() const
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:872
index_t conv_N_per_block_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:927
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:946
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:957
float RunGemm(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:961
DeviceOp::Argument Argument
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:958
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:1334
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:1434
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:400
GridwiseElementwise< Tuple< NGCHWTransposeDescType >, Tuple< NHWGCTransposeDescType >, Tuple< const ADataType * >, Tuple< ADataType * >, Block2TileMapElementwise, element_wise::PassThrough, ElementwiseBlocksize, NPerBlock, NPerBlock, NPerBlock/ClusterLengthNPerBlock, NPerBlock/ClusterLengthNPerBlock, Sequence< 1, 0 >, Sequence< CDEBlockTransferScalarPerVector_NPerBlock >, Sequence< CDEBlockTransferScalarPerVector_NPerBlock >, I1, I0 > GridwiseElementwiseInputTranspose
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:642
remove_cvref_t< decltype(conv_ngchw_to_nhwgc_transformer .template MakeGKYXCTransposeDesc< NDimSpatial >({}, {}))> GKYXCTransposeDescType
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:636
static constexpr auto I3
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:422
static auto MakeArgument(const void *p_as, const void *p_bs, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op)
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:1811
static bool IsSupportedArgument(const Argument &arg)
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:1441
static constexpr bool isMultiABD
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:409
size_t GetWorkSpaceSize(const BaseArgument *p_arg) const override
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:2108
static constexpr bool isMultiB
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:407
remove_cvref_t< decltype(MakeAGridDescriptor_AK0_M_AK1< ALayout >( dummy_conv_to_gemm_transformer))> AGridDesc_AK0_M_AK1
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:697
static constexpr index_t NumBTensor
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:416
static constexpr auto conv_ngchw_to_nhwgc_transformer
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:445
static constexpr index_t BBlockTransferSrcScalarPerVectorAligned
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:559
static auto MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:486
remove_cvref_t< decltype(MakeEGridDescriptor_M_N< ELayout >(dummy_conv_to_gemm_transformer))> EGridDesc_M_N
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:550
static constexpr index_t ABlockTransferSrcScalarPerVectorAligned
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:555
GridwiseElementwise< Tuple< NHWGCTransposeDescType >, Tuple< NGCHWTransposeDescType >, Tuple< const EDataType * >, Tuple< EDataType * >, Block2TileMapElementwise, element_wise::PassThrough, ElementwiseBlocksize, NPerBlock, NPerBlock, NPerBlock/ClusterLengthNPerBlock, NPerBlock/ClusterLengthNPerBlock, Sequence< 1, 0 >, Sequence< CDEBlockTransferScalarPerVector_NPerBlock >, Sequence< CDEBlockTransferScalarPerVector_NPerBlock >, I0, I1 > GridwiseElementwiseOutputTranspose
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:678
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:619
static constexpr auto matrix_padder
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:439
static auto MakeArgument(const void *p_as, const void *p_bs, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< long_index_t, NDimSpatial > &conv_filter_strides, const std::array< long_index_t, NDimSpatial > &conv_filter_dilations, const std::array< long_index_t, NDimSpatial > &input_left_pads, const std::array< long_index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op)
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:1854
static constexpr index_t ElementwiseBlocksize
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:640
BlockToCTileMap_M00_N0_M01Adapt< NPerBlock, NPerBlock > Block2TileMapElementwise
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:624
remove_cvref_t< decltype(MakeBGridDescriptor_BK0_N_BK1< BLayout >( dummy_conv_to_gemm_transformer))> BGridDesc_BK0_N_BK1
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:699
static constexpr ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:549
GridwiseElementwise< Tuple< GKCYXTransposeDescType >, Tuple< GKYXCTransposeDescType >, Tuple< const BDataType * >, Tuple< BDataType * >, Block2TileMapElementwise, element_wise::PassThrough, ElementwiseBlocksize, NPerBlock, NPerBlock, NPerBlock/ClusterLengthNPerBlock, NPerBlock/ClusterLengthNPerBlock, Sequence< 1, 0 >, Sequence< 1 >, Sequence< CDEBlockTransferScalarPerVector_NPerBlock >, I0, I1 > GridwiseElementwiseWeightTranspose
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:660
static constexpr auto I5
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:424
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:2043
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N(dummy_conv_to_gemm_transformer))> DsGridDesc_M_N
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:552
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:620
static constexpr index_t NumATensor
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:415
GridwiseGemmMultiD_xdl_cshuffle_v3< tensor_layout::gemm::RowMajor, tensor_layout::gemm::ColumnMajor, DsLayout, tensor_layout::gemm::RowMajor, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, DirectLoad ? ABlockTransferSrcScalarPerVectorAligned :ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, DirectLoad ? BBlockTransferSrcScalarPerVectorAligned :BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, AComputeDataType, BComputeDataType, ADataType, BDataType, DoElementwiseBeforeCShuffle, DirectLoad > GridwiseGemmBase
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:566
static constexpr auto I1
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:420
remove_cvref_t< decltype(conv_ngchw_to_nhwgc_transformer .template MakeNHWGCTransposeDesc< NDimSpatial >({}, {}))> NHWGCTransposeDescType
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:629
static auto MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:455
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op) override
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:1928
TransformConvFwdToGemm< NDimSpatial, ConvForwardSpecialization, true, ADataType, EDataType > ConvToGemmFwdTransformer
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:431
std::string GetTypeString() const override
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:2048
static constexpr bool DoElementwiseBeforeCShuffle
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:411
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 DeviceOp
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:401
static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:537
typename uniform_sequence_gen< NumDTensor+1, CDEBlockTransferScalarPerVector_NPerBlock >::type CDEBlockTransferScalarPerVectors
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:427
ComputePtrOffsetOfStridedBatch< I1, I1, NumDTensor > ComputePtrOffset
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:437
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< long_index_t, NDimSpatial > &conv_filter_strides, const std::array< long_index_t, NDimSpatial > &conv_filter_dilations, const std::array< long_index_t, NDimSpatial > &input_left_pads, const std::array< long_index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op) override
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:1971
remove_cvref_t< decltype(conv_ngchw_to_nhwgc_transformer .template MakeNGCHWTransposeDesc< NDimSpatial >({}, {}))> NGCHWTransposeDescType
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:626
static auto MakeInvoker()
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:1926
static constexpr index_t ClusterLengthNPerBlock
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:442
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:515
void SetWorkSpacePointer(BaseArgument *p_arg, void *p_workspace, const StreamConfig &=StreamConfig{}) const override
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:2121
static constexpr bool isMultiD
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:408
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:403
static constexpr index_t NumDTensor
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:417
static constexpr bool isMultiA
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:406
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:1806
remove_cvref_t< decltype(conv_ngchw_to_nhwgc_transformer .template MakeGKCYXTransposeDesc< NDimSpatial >({}, {}))> GKCYXTransposeDescType
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:633
static constexpr auto I2
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:421
static constexpr auto I0
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:419
static constexpr auto I4
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:423
static constexpr auto NXdlPerWave32
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp:404
Grouped Convolution Forward.
Definition device_grouped_conv_fwd_multiple_abd.hpp:73
Definition matrix_padder.hpp:180
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
Definition flush_cache.hpp:299
#define CK_ENV(name)
Definition utility/env.hpp:129