device_batched_contraction_multiple_d_wmma_cshuffle.hpp Source File#
device_batched_contraction_multiple_d_wmma_cshuffle.hpp
Go to the documentation of this file.
40// dimension in a dimension group (eg [G0, G1] in Gs, [M0, M1, M2] in Ms, etc.) are contiguous and
54// is again a packed tensor. MakeGridDescriptor() currently just merges dimensions and ignores some
55// strides from input tensor extents so finer dimension information is lost. Merging dimensions is
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
Definition convolution_backward_data_specialization.hpp:8
TensorSpecialization
Definition tensor_specialization.hpp:11
@ Packed
Definition tensor_specialization.hpp:13
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr auto get_container_subset(const Array< T, N > &arr, Sequence< Is... >)
Definition utility/container_helper.hpp:346
__host__ __device__ constexpr auto container_reduce(const Container &x, Reduce reduce, Init init, Number< IBegin >=Number< 0 >{}, Number< IEnd >=Number< Container::Size()>{}, Number< IStep >=Number< 1 >{})
Definition utility/container_helper.hpp:111
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__global__ void kernel_contraction_multiple_d_wmma_cshuffle(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, DsPointer p_ds_grid, EDataType *__restrict__ p_e_grid, const index_t batch_count, const AGridDesc a_grid_desc, const BGridDesc b_grid_desc, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2CTileMap block_2_etile_map)
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:133
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:326
remove_cvref_t< decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))> EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:891
remove_cvref_t< decltype(MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))> DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:888
remove_cvref_t< decltype(MakeDefaultBlock2CTileMap(EGridDesc_M_N{}, 1, 1))> DefaultBlock2CTileMap
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:894
__host__ static __device__ constexpr bool CalculateHasMainKBlockLoop(index_t K)
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:809
__host__ static __device__ constexpr auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N_ &e_grid_desc_m_n)
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:819
__host__ static __device__ constexpr auto MakeDefaultBlock2CTileMap(const EGridDesc_M_N &e_grid_desc_m_n, index_t, index_t)
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:850
__host__ static __device__ constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N_ &ds_grid_desc_m_n)
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:840
__host__ static __device__ constexpr bool CheckValidity(const AGridDesc &a_grid_desc, const BGridDesc &b_grid_desc, const EGridDesc_M_N &e_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:608
Definition utility/sequence.hpp:43
Definition utility/sequence.hpp:256
typename conditional< kHasContent, type0, type1 >::type type
Definition utility/sequence.hpp:271
Definition utility/math.hpp:34
Definition functional2.hpp:33
Definition device_base.hpp:197
BaseArgument()=default
BaseInvoker()=default
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:498
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:510
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:520
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:515
ComputePtrOffsetOfStridedBatch(index_t batch_stride_A, index_t batch_stride_B, DsGridDesc_G_M_N ds_grid_desc_g_m_n, EGridDesc_G_M_N e_grid_desc_g_m_n)
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:499
__host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:532
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:606
index_t a_kz_stride_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:728
index_t e_nz_stride_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:733
AGridDesc a_grid_desc_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:701
index_t b_kz_stride_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:730
const ADataType * p_a_grid_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:695
EGridDesc_M_N e_grid_desc_m_n_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:704
index_t a_mz_stride_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:727
std::array< index_t, NumDTensor > ds_nz_stride_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:731
EDataType * p_e_grid_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:698
BGridDesc b_grid_desc_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:702
GridwiseOp::DsGridPointer p_ds_grid_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:697
index_t a_batch_stride_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:735
BElementwiseOperation b_element_op_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:722
index_t M01_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:717
Argument(const void *p_a_grid, const void *p_b_grid, std::array< const void *, NumDTensor > p_ds_grid, void *p_e_grid, const std::vector< index_t > &a_gs_ms_ks_lengths, const std::vector< index_t > &b_gs_ns_ks_lengths, const std::array< std::vector< index_t >, NumDTensor > &ds_gs_ms_ns_lengths, const std::vector< index_t > &e_gs_ms_ns_lengths, const std::vector< index_t > &a_gs_ms_ks_strides, const std::vector< index_t > &b_gs_ns_ks_strides, const std::array< std::vector< index_t >, NumDTensor > &ds_gs_ms_ns_strides, const std::vector< index_t > &e_gs_ms_ns_strides, index_t M01, index_t N01, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:607
const BDataType * p_b_grid_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:696
index_t b_batch_stride_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:736
index_t N01_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:718
AElementwiseOperation a_element_op_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:721
GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:711
DsGridDesc_G_M_N ds_grid_desc_g_m_n_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:705
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:739
index_t e_mz_stride_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:732
CDEElementwiseOperation cde_element_op_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:723
EGridDesc_G_M_N e_grid_desc_g_m_n_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:706
DsGridDesc_M_N ds_grid_desc_m_n_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:703
GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:714
GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:709
index_t b_nz_stride_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:729
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:749
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:752
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:823
DeviceOp::Argument Argument
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:750
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:119
decltype(MakeEGridDescriptor_G_M_N({}, {})) EGridDesc_G_M_N
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:495
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, const std::vector< index_t > &a_gs_ms_ks_lengths, const std::vector< index_t > &a_gs_ms_ks_strides, const std::vector< index_t > &b_gs_ns_ks_lengths, const std::vector< index_t > &b_gs_ns_ks_strides, const std::array< std::vector< index_t >, NumDTensor > &ds_gs_ms_ns_lengths, const std::array< std::vector< index_t >, NumDTensor > &ds_gs_ms_ns_strides, const std::vector< index_t > &e_gs_ms_ns_lengths, const std::vector< index_t > &e_gs_ms_ns_strides, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) override
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:990
static constexpr auto I6
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:129
static auto MakeBGridDescriptor(const std::vector< index_t > &b_gs_ns_ks_lengths_vec, const std::vector< index_t > &b_gs_ns_ks_strides_vec)
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:248
static constexpr auto I0
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:123
decltype(DeviceOp::MakeAGridDescriptor({}, {})) AGridDesc
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:545
static constexpr auto I2
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:125
decltype(MakeEGridDescriptor_M_N({}, {})) EGridDesc_M_N
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:492
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N({}, {}))> DsGridDesc_M_N
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:491
static constexpr index_t NumDTensor
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:121
static constexpr auto I1
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:124
decltype(DeviceOp::MakeBGridDescriptor({}, {})) BGridDesc
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:546
static constexpr auto WmmaK
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:135
GridwiseGemmMultipleD_Wmma< ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AGridDesc, BGridDesc, DsGridDesc_M_N, EGridDesc_M_N, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, K1, MRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, AEnableLds, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BEnableLds, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVector_NPerBlock, NumPrefetch, LoopSched, PipelineVer > GridwiseOp
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:549
static constexpr auto BEnableLds_manu
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:147
static constexpr auto AEnableLds_auto
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:140
static auto MakeInvoker()
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:1025
static constexpr auto matrix_padder
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:152
static constexpr auto AEnableLds
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:149
static constexpr auto BEnableLds_auto
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:142
static auto MakeDsGridDescriptor_M_N(const std::array< std::vector< index_t >, NumDTensor > &ds_gs_ms_ns_lengths_vec, const std::array< std::vector< index_t >, NumDTensor > &ds_gs_ms_ns_strides_vec)
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:466
static constexpr auto K1Number
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:131
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:1028
DeviceBatchedContractionMultipleD_Wmma_CShuffle DeviceOp
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:120
std::string GetTypeString() const override
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:1034
static auto MakeEGridDescriptor_G_M_N(const std::vector< index_t > &e_gs_ms_ns_lengths_vec, const std::vector< index_t > &e_gs_ms_ns_strides_vec)
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:396
static constexpr auto MWaves
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:133
static auto MakeEGridDescriptor_M_N(const std::vector< index_t > &e_gs_ms_ns_lengths_vec, const std::vector< index_t > &e_gs_ms_ns_strides_vec)
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:340
remove_cvref_t< decltype(MakeDsGridDescriptor_G_M_N({}, {}))> DsGridDesc_G_M_N
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:494
static constexpr bool IsValidCompilationParameter()
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:830
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:947
static constexpr auto I5
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:128
static auto MakeDsGridDescriptor_G_M_N(const std::array< std::vector< index_t >, NumDTensor > &ds_gs_ms_ns_lengths_vec, const std::array< std::vector< index_t >, NumDTensor > &ds_gs_ms_ns_strides_vec)
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:478
static constexpr auto NWaves
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:134
static constexpr auto I4
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:127
static constexpr auto MaxVectorLoadA
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:137
static constexpr auto AEnableLds_manu
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:146
static constexpr auto BEnableLds
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:150
static auto MakeAGridDescriptor(const std::vector< index_t > &a_gs_ms_ks_lengths_vec, const std::vector< index_t > &a_gs_ms_ks_strides_vec)
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:156
static constexpr auto I3
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:126
static constexpr auto MaxVectorLoadB
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:138
static auto MakeArgument(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, const std::vector< index_t > &a_gs_ms_ks_lengths, const std::vector< index_t > &a_gs_ms_ks_strides, const std::vector< index_t > &b_gs_ns_ks_lengths, const std::vector< index_t > &b_gs_ns_ks_strides, const std::array< std::vector< index_t >, NumDTensor > &ds_gs_ms_ns_lengths, const std::array< std::vector< index_t >, NumDTensor > &ds_gs_ms_ns_strides, const std::vector< index_t > &e_gs_ms_ns_lengths, const std::vector< index_t > &e_gs_ms_ns_strides, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:953
static bool IsSupportedArgument(const Argument &arg)
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:836
Definition device_batched_contraction_multiple_d.hpp:39
Definition matrix_padder.hpp:180