device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp Source File

device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp Source File#

Composable Kernel: device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp Source File
device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
10#include "ck/utility/env.hpp"
19
20namespace ck {
21namespace tensor_operation {
22namespace device {
23
24// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
25template <
26 typename InDataType,
27 typename WeiDataType,
28 typename OutDataType,
29 typename AccDataType,
30 typename InElementwiseOperation,
31 typename WeiElementwiseOperation,
32 typename OutElementwiseOperation,
33 ConvolutionForwardSpecialization ConvForwardSpecialization,
34 ck::index_t BlockSize,
35 ck::index_t MPerBlock,
36 ck::index_t NPerBlock,
37 ck::index_t K0PerBlock,
38 ck::index_t K1,
39 ck::index_t MPerXDL,
40 ck::index_t NPerXDL,
41 ck::index_t MXdlPerWave,
42 ck::index_t NXdlPerWave,
43 typename ABlockTransferThreadClusterLengths_K0_M_K1,
44 typename ABlockTransferThreadClusterArrangeOrder,
45 typename ABlockTransferSrcAccessOrder,
46 ck::index_t ABlockTransferSrcVectorDim,
47 ck::index_t ABlockTransferSrcScalarPerVector,
48 ck::index_t ABlockTransferDstScalarPerVector_K1,
49 bool ABlockLdsAddExtraM,
50 typename BBlockTransferThreadClusterLengths_K0_N_K1,
51 typename BBlockTransferThreadClusterArrangeOrder,
52 typename BBlockTransferSrcAccessOrder,
53 ck::index_t BBlockTransferSrcVectorDim,
54 ck::index_t BBlockTransferSrcScalarPerVector,
55 ck::index_t BBlockTransferDstScalarPerVector_K1,
56 bool BBlockLdsAddExtraN,
57 index_t CShuffleMXdlPerWavePerShuffle,
58 index_t CShuffleNXdlPerWavePerShuffle,
59 typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
60 index_t CBlockTransferScalarPerVector_NWaveNPerXdl>
62 : public DeviceConvFwd<2,
63 ck::tensor_layout::convolution::NHWC,
64 ck::tensor_layout::convolution::KYXC,
65 ck::tensor_layout::convolution::NHWK,
66 InDataType,
67 WeiDataType,
68 OutDataType,
69 InElementwiseOperation,
70 WeiElementwiseOperation,
71 OutElementwiseOperation>
72{
74
76 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
77 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
78
79 using ADataType = InDataType;
80 using BDataType = WeiDataType;
81 using CDataType = OutDataType;
82
83 // TODO make A/B datatype different
84 using ABDataType = InDataType;
85
86 static constexpr index_t NDimSpatial = 2;
87
88 static constexpr auto I0 = Number<0>{};
89 static constexpr auto I1 = Number<1>{};
90 static constexpr auto I2 = Number<2>{};
91 static constexpr auto I3 = Number<3>{};
92 static constexpr auto I4 = Number<4>{};
93 static constexpr auto I5 = Number<5>{};
94
95 static constexpr auto K1Number = Number<K1>{};
96 static constexpr auto GemmK1Number = K1Number;
97
98 static auto
100 ck::index_t K,
101 ck::index_t C,
102 std::vector<ck::index_t> input_spatial_lengths,
103 std::vector<ck::index_t> filter_spatial_lengths,
104 std::vector<ck::index_t> output_spatial_lengths,
105 std::vector<ck::index_t> conv_filter_strides,
106 std::vector<ck::index_t> conv_filter_dilations,
107 std::vector<ck::index_t> input_left_pads,
108 std::vector<ck::index_t> input_right_pads)
109 {
110 using namespace ck;
111
112 const index_t Hi = input_spatial_lengths[0];
113 const index_t Wi = input_spatial_lengths[1];
114
115 const index_t Ho = output_spatial_lengths[0];
116 const index_t Wo = output_spatial_lengths[1];
117
118 const index_t Y = filter_spatial_lengths[0];
119 const index_t X = filter_spatial_lengths[1];
120
121 const index_t ConvStrideH = conv_filter_strides[0];
122 const index_t ConvStrideW = conv_filter_strides[1];
123
124 const index_t ConvDilationH = conv_filter_dilations[0];
125 const index_t ConvDilationW = conv_filter_dilations[1];
126
127 const index_t InLeftPadH = input_left_pads[0];
128 const index_t InLeftPadW = input_left_pads[1];
129
130 const index_t InRightPadH = input_right_pads[0];
131 const index_t InRightPadW = input_right_pads[1];
132
133 const index_t GemmMRaw = N * Ho * Wo;
134 const index_t GemmN = K;
135
136 const auto GemmM = math::integer_least_multiple(GemmMRaw, MPerBlock);
137 const auto GemmMPad = GemmM - GemmMRaw;
138
139 if constexpr(ConvForwardSpecialization ==
141 { // 1x1, stride=1, pad=0
142 const index_t GemmK = Y * X * C;
143 assert(GemmK % GemmK1Number == 0);
144
145 const index_t GemmK0 = GemmK / GemmK1Number;
146
147 // A: input tensor
148 const auto in_gemmmraw_gemmk_grid_desc =
150
151 const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
152 in_gemmmraw_gemmk_grid_desc,
154 make_right_pad_transform(GemmMRaw, GemmMPad)),
157
158 // B: weight tensor
159 const auto wei_gemmn_gemmk_grid_desc =
161
162 const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
163 wei_gemmn_gemmk_grid_desc,
168
169 // C: output tensor
170 const auto out_gemmmraw_gemmn_grid_desc =
172
173 const auto out_gemmm_gemmn_grid_desc =
174 transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
175 make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
179
180 return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
181 wei_gemmk0_gemmn_gemmk1_grid_desc,
182 out_gemmm_gemmn_grid_desc);
183 }
184 else if constexpr(ConvForwardSpecialization ==
186 { // 1x1, pad=0
187 const index_t GemmK = Y * X * C;
188 assert(GemmK % GemmK1Number == 0);
189
190 const index_t GemmK0 = GemmK / GemmK1Number;
191
192 // A: input tensor
193 const auto in_n_hi_wi_c_grid_desc =
195
196 const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor(
197 in_n_hi_wi_c_grid_desc,
199 make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
200 make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
204
205 const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
206 in_n_ho_wo_c_grid_desc,
208 make_merge_transform(make_tuple(N, Ho, Wo))),
211
212 const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
213 in_gemmk0_gemmmraw_gemmk1_grid_desc,
215 make_right_pad_transform(GemmMRaw, GemmMPad),
219
220 // B: weight tensor
221 const auto wei_gemmn_gemmk_grid_desc =
223
224 const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
225 wei_gemmn_gemmk_grid_desc,
230
231 // C: output tensor
232 const auto out_gemmmraw_gemmn_grid_desc =
234
235 const auto out_gemmm_gemmn_grid_desc =
236 transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
237 make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
241
242 return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
243 wei_gemmk0_gemmn_gemmk1_grid_desc,
244 out_gemmm_gemmn_grid_desc);
245 }
246 else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::OddC)
247 { // C = odd value
248 const index_t GemmKRaw = Y * X * C;
249 const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number);
250 const index_t GemmKPad = GemmK - GemmKRaw;
251 const index_t GemmK0 = GemmK / GemmK1Number;
252
253 // A: input tensor
254 const auto in_n_hi_wi_c_grid_desc =
256
257 const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
258 in_n_hi_wi_c_grid_desc,
260 make_pad_transform(Hi, InLeftPadH, InRightPadH),
261 make_pad_transform(Wi, InLeftPadW, InRightPadW),
265
266 const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
267 in_n_hip_wip_c_grid_desc,
270 make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
271 make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
275
276 const auto in_gemmkraw_gemmmraw_grid_desc =
277 transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
279 make_merge_transform(make_tuple(N, Ho, Wo))),
282
283 const auto in_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
284 in_gemmkraw_gemmmraw_grid_desc,
285 make_tuple(make_right_pad_transform(GemmKRaw, GemmKPad),
286 make_right_pad_transform(GemmMRaw, GemmMPad)),
289
290 const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
291 in_gemmk_gemmm_grid_desc,
296
297 // B: weight tensor
298 const auto wei_k_yxc_grid_desc =
300
301 const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
302 wei_k_yxc_grid_desc,
304 make_right_pad_transform(GemmKRaw, GemmKPad)),
307
308 const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
309 wei_gemmk_gemmn_grid_desc,
314
315 // C: output tensor
316 const auto out_nhowo_k_grid_desc =
318
319 const auto out_gemmmraw_gemmn_grid_desc =
320 transform_tensor_descriptor(out_nhowo_k_grid_desc,
325
326 const auto out_gemmm_gemmn_grid_desc =
327 transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
328 make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
332
333 return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
334 wei_gemmk0_gemmn_gemmk1_grid_desc,
335 out_gemmm_gemmn_grid_desc);
336 }
337 else
338 {
339 const index_t GemmK = Y * X * C;
340 assert(GemmK % GemmK1Number == 0);
341
342 const index_t GemmK0 = GemmK / GemmK1Number;
343
344 // A: input tensor
345 const auto in_n_hi_wi_c_grid_desc =
347
348 const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
349 in_n_hi_wi_c_grid_desc,
351 make_pad_transform(Hi, InLeftPadH, InRightPadH),
352 make_pad_transform(Wi, InLeftPadW, InRightPadW),
356
357 const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
358 in_n_hip_wip_c_grid_desc,
361 make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
362 make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
366
367 const auto in_gemmk_gemmmraw_grid_desc =
368 transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
370 make_merge_transform(make_tuple(N, Ho, Wo))),
373
374 const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
375 in_gemmk_gemmmraw_grid_desc,
380
381 const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
382 in_gemmk0_gemmmraw_gemmk1_grid_desc,
384 make_right_pad_transform(GemmMRaw, GemmMPad),
388
389 // B: weight tensor
390 const auto wei_k_yxc_grid_desc =
392
393 const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
394 wei_k_yxc_grid_desc,
398
399 const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
400 wei_gemmk_gemmn_grid_desc,
405
406 // C: output tensor
407 const auto out_nhowo_k_grid_desc =
409
410 const auto out_gemmmraw_gemmn_grid_desc =
411 transform_tensor_descriptor(out_nhowo_k_grid_desc,
416
417 const auto out_gemmm_gemmn_grid_desc =
418 transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
419 make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
423
424 return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
425 wei_gemmk0_gemmn_gemmk1_grid_desc,
426 out_gemmm_gemmn_grid_desc);
427 }
428 }
429
431 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}));
432
436
438
439 // GridwiseGemm
440 template <index_t NXdlPerWave_>
442 BlockSize,
443 ABDataType, // TODO: distinguish A/B datatype
444 AccDataType,
445 CDataType, // TODO: Add ShuffleType for DeviceConv2d
446 CDataType,
451 InElementwiseOperation,
452 WeiElementwiseOperation,
453 OutElementwiseOperation,
454 MPerBlock,
455 NPerBlock,
456 K0PerBlock * K1,
457 K1, // AK1
458 K1, // BK1
459 MPerXDL,
460 NPerXDL,
461 MXdlPerWave,
462 NXdlPerWave_,
463 ABlockTransferThreadClusterLengths_K0_M_K1,
464 Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder,
465 Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder,
466 2, // ABlockTransferSrcVectorDim,
467 ABlockTransferSrcScalarPerVector,
468 ABlockTransferDstScalarPerVector_K1,
469 false, // AThreadTransferSrcResetCoordinateAfterRun,
470 ABlockLdsAddExtraM,
471 BBlockTransferThreadClusterLengths_K0_N_K1,
472 Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder,
473 Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder,
474 2, // BBlockTransferSrcVectorDim,
475 BBlockTransferSrcScalarPerVector,
476 BBlockTransferDstScalarPerVector_K1,
477 false, // BThreadTransferSrcResetCoordinateAfterRun,
478 BBlockLdsAddExtraN,
479 CShuffleMXdlPerWavePerShuffle,
480 CShuffleNXdlPerWavePerShuffle,
481 CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
482 CBlockTransferScalarPerVector_NWaveNPerXdl>;
485
486 // Argument
487 struct Argument : public BaseArgument
488 {
489 Argument(const InDataType* p_in_grid,
490 const WeiDataType* p_wei_grid,
491 OutDataType* p_out_grid,
492 ck::index_t N,
493 ck::index_t K,
494 ck::index_t C,
495 std::vector<ck::index_t> input_spatial_lengths,
496 std::vector<ck::index_t> filter_spatial_lengths,
497 std::vector<ck::index_t> output_spatial_lengths,
498 std::vector<ck::index_t> conv_filter_strides,
499 std::vector<ck::index_t> conv_filter_dilations,
500 std::vector<ck::index_t> input_left_pads,
501 std::vector<ck::index_t> input_right_pads,
502 InElementwiseOperation in_element_op,
503 WeiElementwiseOperation wei_element_op,
504 OutElementwiseOperation out_element_op)
505 : p_a_grid_{p_in_grid},
506 p_b_grid_{p_wei_grid},
507 p_c_grid_{p_out_grid},
512 in_element_op_{in_element_op},
513 wei_element_op_{wei_element_op},
514 out_element_op_{out_element_op},
515 Conv_N_{N},
516 Conv_K_{K},
517 Conv_C_{C},
518 input_spatial_lengths_{input_spatial_lengths},
519 filter_spatial_lengths_{filter_spatial_lengths},
520 output_spatial_lengths_{output_spatial_lengths},
521 conv_filter_strides_{conv_filter_strides},
522 conv_filter_dilations_{conv_filter_dilations},
523 input_left_pads_{input_left_pads},
524 input_right_pads_{input_right_pads}
525 {
526 const auto descs =
528 K,
529 C,
530 input_spatial_lengths,
531 filter_spatial_lengths,
532 output_spatial_lengths,
533 conv_filter_strides,
534 conv_filter_dilations,
535 input_left_pads,
536 input_right_pads);
537
538 a_grid_desc_k0_m_k1_ = descs[I0];
539 b_grid_desc_k0_n_k1_ = descs[I1];
540 c_grid_desc_m_n_ = descs[I2];
541
543 }
551 InElementwiseOperation in_element_op_;
552 WeiElementwiseOperation wei_element_op_;
553 OutElementwiseOperation out_element_op_;
554 // for checking IsSupportedArgument()
558 std::vector<index_t> input_spatial_lengths_;
559 std::vector<index_t> filter_spatial_lengths_;
560 std::vector<index_t> output_spatial_lengths_;
561 std::vector<index_t> conv_filter_strides_;
562 std::vector<index_t> conv_filter_dilations_;
563 std::vector<index_t> input_left_pads_;
564 std::vector<index_t> input_right_pads_;
565 };
566
567 // Invoker
568 struct Invoker : public BaseInvoker
569 {
571
572 template <typename GridwiseGemm>
573 float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
574 {
575 if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
579 {
580 throw std::runtime_error(
581 "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting");
582 }
583 auto c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
584 GridwiseGemm::
585 MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
586 arg.c_grid_desc_m_n_);
587 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
588 {
589 std::cout << DeviceOp{}.GetTypeString() << std::endl;
590 std::cout << "N " << arg.Conv_N_ << ", " << "K " << arg.Conv_K_ << ", " << "C "
591 << arg.Conv_C_ << ", " << std::endl;
592 std::cout << "Y X " << arg.filter_spatial_lengths_[0] << ", "
593 << arg.filter_spatial_lengths_[1] << ", " << std::endl;
594 std::cout << "Hi Wi " << arg.input_spatial_lengths_[0] << ", "
595 << arg.input_spatial_lengths_[1] << ", " << std::endl;
596 std::cout << "Ho Wo " << arg.output_spatial_lengths_[0] << ", "
597 << arg.output_spatial_lengths_[1] << ", " << std::endl;
598 std::cout << "Strides " << arg.conv_filter_strides_[0] << ", "
599 << arg.conv_filter_strides_[1] << ", " << std::endl;
600 std::cout << "Dilations " << arg.conv_filter_dilations_[0] << ", "
601 << arg.conv_filter_dilations_[1] << ", " << std::endl;
602 std::cout << "InLeftPads " << arg.input_left_pads_[0] << ", "
603 << arg.input_left_pads_[1] << ", " << std::endl;
604 std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", "
605 << arg.input_right_pads_[1] << ", " << std::endl;
606
607 std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
608 << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
609 << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
610
611 std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
612 << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
613 << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
614
615 std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
616 << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
617
618 std::cout
619 << "arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_"
620 "nwavenperxdl_{ "
621 << c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
622 .GetLength(I0)
623 << ", "
624 << c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
625 .GetLength(I1)
626 << ", "
627 << c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
628 .GetLength(I2)
629 << ", "
630 << c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
631 .GetLength(I3)
632 << ", "
633 << c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
634 .GetLength(I4)
635 << ", "
636 << c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
637 .GetLength(I5)
638 << "}" << std::endl;
639 }
640 const index_t grid_size =
642
643 const auto K =
644 arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
645
646 float ave_time = 0;
647
648 if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
649 {
650 const auto kernel = kernel_gemm_xdlops_v3r1<
651 GridwiseGemm,
652 ADataType, // TODO: distiguish A/B datatype
653 CDataType,
657 typename GridwiseGemm::
658 CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
659 InElementwiseOperation,
660 WeiElementwiseOperation,
661 OutElementwiseOperation,
663 true>;
664
665 ave_time = launch_and_time_kernel(
666 stream_config,
667 kernel,
668 dim3(grid_size),
669 dim3(BlockSize),
670 0,
671 arg.p_a_grid_,
672 arg.p_b_grid_,
673 arg.p_c_grid_,
676 c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
677 arg.in_element_op_,
678 arg.wei_element_op_,
679 arg.out_element_op_,
681 }
682 else
683 {
684 const auto kernel = kernel_gemm_xdlops_v3r1<
685 GridwiseGemm,
686 ADataType, // TODO: distiguish A/B datatype
687 CDataType,
691 typename GridwiseGemm::
692 CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
693 InElementwiseOperation,
694 WeiElementwiseOperation,
695 OutElementwiseOperation,
697 false>;
698
699 ave_time = launch_and_time_kernel(
700 stream_config,
701 kernel,
702 dim3(grid_size),
703 dim3(BlockSize),
704 0,
705 arg.p_a_grid_,
706 arg.p_b_grid_,
707 arg.p_c_grid_,
710 c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
711 arg.in_element_op_,
712 arg.wei_element_op_,
713 arg.out_element_op_,
715 }
716
717 return ave_time;
718 }
719
721
722 float Run(const BaseArgument* p_arg,
723 const StreamConfig& stream_config = StreamConfig{}) override
724 {
725 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
726 }
727 };
728
729 static constexpr bool IsValidCompilationParameter()
730 {
731 // TODO: properly implement this check
732 return true;
733 }
734
735 static bool IsSupportedArgument(const Argument& arg)
736 {
738 {
739 return false;
740 }
741 if constexpr(ConvForwardSpecialization ==
743 {
744 // check if it's 1x1, stride=1 conv
745 if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
746 arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 &&
747 arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
748 arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
749 {
750 return false;
751 }
752 }
753 else if constexpr(ConvForwardSpecialization ==
755 {
756 // check if it's 1x1 conv
757 if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
758 arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
759 arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
760 {
761 return false;
762 }
763 }
764
765 // vector load A/B matrix from global memory
766 if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 &&
767 arg.Conv_C_ % ABlockTransferSrcScalarPerVector == 0 &&
768 arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0))
769 {
770 return false;
771 }
772
773 // vector store C matrix into global memory
774 if(!(arg.Conv_K_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0))
775 {
776 return false;
777 }
778
779 // Gridwise GEMM size
780 if(get_warp_size() == 64)
781 {
782 if constexpr(NXdlPerWave64 > 0)
783 {
788 }
789 }
790 else
791 {
792 if constexpr(NXdlPerWave32 > 0)
793 {
798 }
799 }
800 return false;
801 }
802
803 bool IsSupportedArgument(const BaseArgument* p_arg) override
804 {
805 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
806 }
807
808 static auto MakeArgument(const InDataType* p_in_grid,
809 const WeiDataType* p_wei_grid,
810 OutDataType* p_out_grid,
811 ck::index_t N,
812 ck::index_t K,
813 ck::index_t C,
814 std::vector<ck::index_t> input_spatial_lengths,
815 std::vector<ck::index_t> filter_spatial_lengths,
816 std::vector<ck::index_t> output_spatial_lengths,
817 std::vector<ck::index_t> conv_filter_strides,
818 std::vector<ck::index_t> conv_filter_dilations,
819 std::vector<ck::index_t> input_left_pads,
820 std::vector<ck::index_t> input_right_pads,
821 InElementwiseOperation in_element_op,
822 WeiElementwiseOperation wei_element_op,
823 OutElementwiseOperation out_element_op)
824 {
825 return Argument{p_in_grid,
826 p_wei_grid,
827 p_out_grid,
828 N,
829 K,
830 C,
831 input_spatial_lengths,
832 filter_spatial_lengths,
833 output_spatial_lengths,
834 conv_filter_strides,
835 conv_filter_dilations,
836 input_left_pads,
837 input_right_pads,
838 in_element_op,
839 wei_element_op,
840 out_element_op};
841 }
842
843 static auto MakeInvoker() { return Invoker{}; }
844
845 std::unique_ptr<BaseArgument>
846 MakeArgumentPointer(const void* p_in_grid,
847 const void* p_wei_grid,
848 void* p_out_grid,
849 ck::index_t N,
850 ck::index_t K,
851 ck::index_t C,
852 std::vector<ck::index_t> input_spatial_lengths,
853 std::vector<ck::index_t> filter_spatial_lengths,
854 std::vector<ck::index_t> output_spatial_lengths,
855 std::vector<ck::index_t> conv_filter_strides,
856 std::vector<ck::index_t> conv_filter_dilations,
857 std::vector<ck::index_t> input_left_pads,
858 std::vector<ck::index_t> input_right_pads,
859 InElementwiseOperation in_element_op,
860 WeiElementwiseOperation wei_element_op,
861 OutElementwiseOperation out_element_op) override
862 {
863 return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
864 static_cast<const WeiDataType*>(p_wei_grid),
865 static_cast<OutDataType*>(p_out_grid),
866 N,
867 K,
868 C,
869 input_spatial_lengths,
870 filter_spatial_lengths,
871 output_spatial_lengths,
872 conv_filter_strides,
873 conv_filter_dilations,
874 input_left_pads,
875 input_right_pads,
876 in_element_op,
877 wei_element_op,
878 out_element_op);
879 }
880
881 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
882 {
883 return std::make_unique<Invoker>(Invoker{});
884 }
885
886 std::string GetTypeString() const override
887 {
888 auto str = std::stringstream();
889
890 // clang-format off
891 str << "DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
892 << "<"
893 << BlockSize << ", "
894 << MPerBlock << ", "
895 << NPerBlock << ", "
896 << K0PerBlock << ", "
897 << getConvForwardSpecializationString(ConvForwardSpecialization) << ", "
898 << K1 << ", "
899 << MXdlPerWave << ", "
900 << NXdlPerWave << ", "
901 << ABlockTransferSrcScalarPerVector << ", "
902 << ABlockTransferDstScalarPerVector_K1 << ", "
903 << BBlockTransferSrcScalarPerVector << ", "
904 << BBlockTransferDstScalarPerVector_K1 << ", "
905 << CShuffleMXdlPerWavePerShuffle << ", "
906 << CShuffleNXdlPerWavePerShuffle << ", "
907 << CBlockTransferScalarPerVector_NWaveNPerXdl
908 << ">";
909 // clang-format on
910
911 return str.str();
912 }
913};
914
915} // namespace device
916} // namespace tensor_operation
917} // namespace ck
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
ConvolutionForwardSpecialization
Definition convolution_forward_specialization.hpp:15
@ OddC
Definition convolution_forward_specialization.hpp:19
@ Filter1x1Stride1Pad0
Definition convolution_forward_specialization.hpp:18
@ Filter1x1Pad0
Definition convolution_forward_specialization.hpp:17
std::string getConvForwardSpecializationString(const ConvolutionForwardSpecialization &s)
Definition convolution_forward_specialization.hpp:24
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:19
__host__ __device__ constexpr auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition multi_index_transform_helper.hpp:48
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__global__ void kernel_gemm_xdlops_v3r1(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map)
Definition gridwise_gemm_xdlops_v3r1.hpp:36
Definition ck/stream_config.hpp:10
Definition block_to_ctile_map.hpp:24
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition block_to_ctile_map.hpp:38
Definition gridwise_gemm_xdlops_v3r1.hpp:127
Definition utility/sequence.hpp:43
Definition device_base.hpp:197
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:547
const BDataType * p_b_grid_
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:545
WeiElementwiseOperation wei_element_op_
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:552
std::vector< index_t > output_spatial_lengths_
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:560
Argument(const InDataType *p_in_grid, const WeiDataType *p_wei_grid, OutDataType *p_out_grid, ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op)
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:489
std::vector< index_t > input_left_pads_
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:563
std::vector< index_t > filter_spatial_lengths_
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:559
CGridDesc_M_N c_grid_desc_m_n_
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:549
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:548
std::vector< index_t > conv_filter_dilations_
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:562
std::vector< index_t > input_right_pads_
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:564
std::vector< index_t > conv_filter_strides_
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:561
OutElementwiseOperation out_element_op_
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:553
InElementwiseOperation in_element_op_
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:551
Block2CTileMap block_2_ctile_map_
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:550
std::vector< index_t > input_spatial_lengths_
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:558
CDataType * p_c_grid_
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:546
const ADataType * p_a_grid_
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:544
DeviceOp::Argument Argument
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:570
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:573
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:722
remove_cvref_t< decltype(ABCGridDescs{}[I1])> BGridDesc_K0_N_K1
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:434
static constexpr auto I5
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:93
static auto MakeArgument(const InDataType *p_in_grid, const WeiDataType *p_wei_grid, OutDataType *p_out_grid, ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op)
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:808
static constexpr auto I3
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:91
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:803
OutDataType CDataType
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:81
WeiDataType BDataType
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:80
static constexpr auto GemmK1Number
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:96
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K DeviceOp
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:73
static constexpr auto NXdlPerWave32
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:77
static bool IsSupportedArgument(const Argument &arg)
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:735
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1< BlockSize, ABDataType, AccDataType, CDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, CGridDesc_M_N, InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation, MPerBlock, NPerBlock, K0PerBlock *K1, K1, K1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, Sequence< 1, 0, 2 >, Sequence< 1, 0, 2 >, 2, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, Sequence< 1, 0, 2 >, Sequence< 1, 0, 2 >, 2, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, CBlockTransferScalarPerVector_NWaveNPerXdl > GridwiseGemmBase
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:441
InDataType ABDataType
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:84
std::string GetTypeString() const override
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:886
static constexpr index_t NDimSpatial
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:86
static constexpr auto I0
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:88
static constexpr auto I2
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:90
static auto MakeInvoker()
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:843
remove_cvref_t< decltype(ABCGridDescs{}[I0])> AGridDesc_K0_M_K1
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:433
static constexpr bool IsValidCompilationParameter()
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:729
remove_cvref_t< decltype(ABCGridDescs{}[I2])> CGridDesc_M_N
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:435
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:483
static constexpr auto I4
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:92
decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1})) ABCGridDescs
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:430
static constexpr auto I1
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:89
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:76
BlockToCTileMap_M00_N0_M01< MPerBlock, NPerBlock, CGridDesc_M_N > Block2CTileMap
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:437
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads)
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:99
GridwiseGemmBase< math::max(NXdlPerWave32, 1)> GridwiseGemm32
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:484
static constexpr auto K1Number
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:95
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:881
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_in_grid, const void *p_wei_grid, void *p_out_grid, ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op) override
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:846
InDataType ADataType
Definition device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:79
Definition device_conv_fwd.hpp:25
#define CK_ENV(name)
Definition utility/env.hpp:129