batched_transpose_lds_problem.hpp Source File

batched_transpose_lds_problem.hpp Source File#

Composable Kernel: batched_transpose_lds_problem.hpp Source File
batched_transpose_lds_problem.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
7
8namespace ck_tile {
9
10// supports 2D transpose which will store to lds,
11// then use ds_read_b*_tr_b* instruction to get the transposed data
12template <typename DataType_,
13 typename BlockTile, // sequence<block_x, block_y>
14 typename NumWarps,
15 bool kPadM_,
16 bool kPadN_>
18{
20
21 static constexpr index_t kRowWarps_ = NumWarps::at(number<0>{});
22 static constexpr index_t kColWarps_ = NumWarps::at(number<1>{});
23 static constexpr index_t kRowPerBlock_ = BlockTile::at(number<0>{});
24 static constexpr index_t kColPerBlock_ = BlockTile::at(number<1>{});
25
27 // warps per block
28 static constexpr index_t kLeadNumWarps = kColWarps_;
30
33
36
37 static_assert(kLeadSizePerBlock % kLeadNumWarps == 0,
38 "block dim should be divided by warp count!");
39 static_assert(kSecondSizePerBlock % kSecondNumWarps == 0,
40 "block dim should be divided by warp count!");
41 // rows/cols per warp
44
45 static_assert(kLeadSizePerWarp % kQuadrantLeadDim == 0,
46 "xdl dim should be divided by quad dim!");
47 static_assert(kSecondSizePerWarp % kQuadrantSecondDim == 0,
48 "xdl dim should be divided by quad dim!");
49 // xdl rows/cols is divided into quadrants.
52
55
56 // definitions to adapt to BatchedTransposeKernel
57
58 // FIXME: support padding
59 static constexpr bool kPadM = kPadM_;
60 static constexpr bool kPadN = kPadN_;
61
62 static constexpr auto kMPerBlock = kSecondSizePerBlock;
63 static constexpr auto kNPerBlock = kLeadSizePerBlock;
64
65 // 128-bit is the max single-instruction bandwidth for load/store
66 static constexpr index_t MaxLoadStoreSize = 16;
67 static constexpr auto VectorSizeInput = kPadN ? 1 : MaxLoadStoreSize / sizeof(DataType);
68 static constexpr auto VectorSizeOutput = kPadM ? 1 : MaxLoadStoreSize / sizeof(DataType);
69 static constexpr auto LDSVectorSize = MaxLoadStoreSize / sizeof(DataType);
70};
71
72} // namespace ck_tile
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
int32_t index_t
Definition integer.hpp:9
Definition batched_transpose_lds_problem.hpp:18
static constexpr index_t kRowPerBlock_
Definition batched_transpose_lds_problem.hpp:23
static constexpr index_t kBlockSize
Definition batched_transpose_lds_problem.hpp:26
static constexpr index_t kSecondSizePerWarp
Definition batched_transpose_lds_problem.hpp:43
static constexpr auto VectorSizeOutput
Definition batched_transpose_lds_problem.hpp:68
static constexpr index_t kQuadrantSecondDim
Definition batched_transpose_lds_problem.hpp:35
static constexpr index_t kRowWarps_
Definition batched_transpose_lds_problem.hpp:21
static constexpr index_t kQuadNumPerLeadDim
Definition batched_transpose_lds_problem.hpp:50
static constexpr index_t MaxLoadStoreSize
Definition batched_transpose_lds_problem.hpp:66
static constexpr auto kMPerBlock
Definition batched_transpose_lds_problem.hpp:62
static constexpr index_t kColPerBlock_
Definition batched_transpose_lds_problem.hpp:24
static constexpr bool kPadM
Definition batched_transpose_lds_problem.hpp:59
static constexpr index_t kIterationsInSecondDim
Definition batched_transpose_lds_problem.hpp:53
static constexpr index_t kLeadSizePerWarp
Definition batched_transpose_lds_problem.hpp:42
static constexpr index_t kQuadNumPerSecondDim
Definition batched_transpose_lds_problem.hpp:51
static constexpr index_t kLeadNumWarps
Definition batched_transpose_lds_problem.hpp:28
static constexpr auto LDSVectorSize
Definition batched_transpose_lds_problem.hpp:69
static constexpr auto kNPerBlock
Definition batched_transpose_lds_problem.hpp:63
static constexpr index_t kQuadrantLeadDim
Definition batched_transpose_lds_problem.hpp:34
static constexpr index_t kColWarps_
Definition batched_transpose_lds_problem.hpp:22
static constexpr index_t kSecondSizePerBlock
Definition batched_transpose_lds_problem.hpp:32
static constexpr index_t kLeadSizePerBlock
Definition batched_transpose_lds_problem.hpp:31
static constexpr bool kPadN
Definition batched_transpose_lds_problem.hpp:60
static constexpr index_t kSecondNumWarps
Definition batched_transpose_lds_problem.hpp:29
static constexpr auto VectorSizeInput
Definition batched_transpose_lds_problem.hpp:67
remove_cvref_t< DataType_ > DataType
Definition batched_transpose_lds_problem.hpp:19
Definition amd_transpose_load_encoding.hpp:14