tensor_descriptor_utils.hpp Source File

tensor_descriptor_utils.hpp Source File#

Composable Kernel: tensor_descriptor_utils.hpp Source File
tensor_descriptor_utils.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
7
19
20namespace ck_tile {
21
30template <ck_tile::index_t NumDimG,
31 ck_tile::index_t NumDimM,
32 ck_tile::index_t NumDimN,
33 ck_tile::index_t NumDimK>
35{
42 CK_TILE_HOST static constexpr auto
43 Make_A_GridDescriptor_M_K(const std::vector<ck_tile::index_t>& A_dims = {},
44 const std::vector<ck_tile::index_t>& A_strides = {})
45 {
46 const auto to_tuple = [&](auto& vec, auto start, auto end) {
47 return generate_tuple([&](auto i) { return vec[start + i]; }, number<end - start>{});
48 };
49
50 // Remove G Dimensions
51 const auto A_dims_M_K =
53 const auto A_strides_M_K =
55
56 // dimension Ids for M and K
57 constexpr auto A_dims_M_ids = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
58 constexpr auto A_dims_K_ids =
60
61 // Dimensions for M [M0, M1, ...] and K [K0, K1, ...]
62 const auto dims_M = get_container_subset(A_dims_M_K, A_dims_M_ids);
63 const auto dims_K = get_container_subset(A_dims_M_K, A_dims_K_ids);
64
65 // naive tensor A[M0, M1, M2, ..., K0, K1, K2...] Discriptor
66 const auto A_grid_desc_Ms_Ks =
67 ck_tile::make_naive_tensor_descriptor(A_dims_M_K, A_strides_M_K);
68
69 // transformed tensor to flatten M and K dimensions [M_total = M0 * M1 * M2 * ... , K_total
70 // = K0 * K1 * K2 * ...]
71 const auto A_grid_desc_Mflat_Kflat = ck_tile::transform_tensor_descriptor(
72 A_grid_desc_Ms_Ks,
74 make_tuple(A_dims_M_ids, A_dims_K_ids),
75 make_tuple(sequence<0>{}, sequence<1>{}));
76
77 return A_grid_desc_Mflat_Kflat;
78 }
79
86 CK_TILE_HOST static constexpr auto
87 Make_B_GridDescriptor_N_K(const std::vector<ck_tile::index_t>& B_dims = {},
88 const std::vector<ck_tile::index_t>& B_strides = {})
89 {
90 const auto to_tuple = [&](auto& vec, auto start, auto end) {
91 return generate_tuple([&](auto i) { return vec[start + i]; }, number<end - start>{});
92 };
93
94 // Remove G Dimensions
95 const auto B_dims_N_K =
97 const auto B_strides_N_K =
99
100 // dimension Ids for N and K
101 constexpr auto B_dims_N_ids = typename arithmetic_sequence_gen<0, NumDimN, 1>::type{};
102 constexpr auto B_dims_K_ids =
104
105 // Dimensions for N [N0, N1, ...] and K [K0, K1, ...]
106 const auto dims_N = get_container_subset(B_dims_N_K, B_dims_N_ids);
107 const auto dims_K = get_container_subset(B_dims_N_K, B_dims_K_ids);
108
109 // naive tensor B[N0, N1, N2, ..., K0, K1, K2...] Discriptor
110 const auto B_grid_desc_Ns_Ks =
111 ck_tile::make_naive_tensor_descriptor(B_dims_N_K, B_strides_N_K);
112
113 // transformed tensor to flatten N and K dimensions [N_total = N0 * N1 * N2 * ... , K_total
114 // = K0 * K1 * K2 * ...]
115 const auto B_grid_desc_Nflat_Kflat = ck_tile::transform_tensor_descriptor(
116 B_grid_desc_Ns_Ks,
118 make_tuple(B_dims_N_ids, B_dims_K_ids),
119 make_tuple(sequence<0>{}, sequence<1>{}));
120
121 return B_grid_desc_Nflat_Kflat;
122 }
123
130 CK_TILE_HOST static constexpr auto
131 Make_E_GridDescriptor_M_N(const std::vector<ck_tile::index_t>& E_dims = {},
132 const std::vector<ck_tile::index_t>& E_strides = {})
133 {
134 const auto to_tuple = [&](auto& vec, auto start, auto end) {
135 return generate_tuple([&](auto i) { return vec[start + i]; }, number<end - start>{});
136 };
137
138 // Remove G dimensions
139 const auto E_dims_M_N =
141 const auto E_strides_M_N =
143
144 // dimension Ids for M and N
145 constexpr auto E_dims_M_ids = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
146 constexpr auto E_dims_N_ids =
148
149 // Dimensions for M and N
150 const auto dims_M = get_container_subset(E_dims_M_N, E_dims_M_ids);
151 const auto dims_N = get_container_subset(E_dims_M_N, E_dims_N_ids);
152
153 // naive tensor E[M0, M1, M2, ..., N0, N1, N2...] Discriptor
154 const auto E_grid_desc_Ms_Ns =
155 ck_tile::make_naive_tensor_descriptor(E_dims_M_N, E_strides_M_N);
156
157 // transformed tensor to flatten M and N dimensions [M_total = M0 * M1 * M2 * ... ,
158 // N_total = N0 * N1 * N2 * ...]
159 const auto E_grid_desc_Mflat_Nflat = ck_tile::transform_tensor_descriptor(
160 E_grid_desc_Ms_Ns,
162 make_tuple(E_dims_M_ids, E_dims_N_ids),
163 make_tuple(sequence<0>{}, sequence<1>{}));
164
165 return E_grid_desc_Mflat_Nflat;
166 }
167};
168
169} // namespace ck_tile
#define CK_TILE_HOST
Definition config.hpp:40
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1615
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_descriptor.hpp:203
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F &&f, number< N >)
Definition tile/core/container/tuple.hpp:429
CK_TILE_HOST_DEVICE constexpr auto get_container_subset(const array< T, N > &arr, sequence< Is... >)
Definition tile/core/container/container_helper.hpp:389
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Utility class for creating tensor descriptors in batched contraction operations.
Definition tensor_descriptor_utils.hpp:35
static CK_TILE_HOST constexpr auto Make_E_GridDescriptor_M_N(const std::vector< ck_tile::index_t > &E_dims={}, const std::vector< ck_tile::index_t > &E_strides={})
Creates a tensor descriptor for output tensor E with batch dimensions removed.
Definition tensor_descriptor_utils.hpp:131
static CK_TILE_HOST constexpr auto Make_B_GridDescriptor_N_K(const std::vector< ck_tile::index_t > &B_dims={}, const std::vector< ck_tile::index_t > &B_strides={})
Creates a tensor descriptor for input tensor B with batch dimensions removed.
Definition tensor_descriptor_utils.hpp:87
static CK_TILE_HOST constexpr auto Make_A_GridDescriptor_M_K(const std::vector< ck_tile::index_t > &A_dims={}, const std::vector< ck_tile::index_t > &A_strides={})
Creates a tensor descriptor for input tensor A with batch dimensions removed.
Definition tensor_descriptor_utils.hpp:43
typename std::conditional< kHasContent, type0, type1 >::type type
Definition tile/core/container/sequence.hpp:302