tensor_descriptor_helper.hpp Source File

tensor_descriptor_helper.hpp Source File#

Composable Kernel: tensor_descriptor_helper.hpp Source File
tensor_descriptor_helper.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
9
10namespace ck {
11
12/*
13 * These functions create tensor descriptor at runtime. If they are not constexpr, you will
14 * likely see usage of scratch memory during construction of these tensor descriptors. So
15 * it's better to call these functions on host and then pass the constructed tensor descritpors
16 * to GPU. If the tensor descritpors being constructed are constexpr, then you can call these
17 * functions on GPU without worrying about scratch memory usage.
18 */
19
20#if CK_WORKAROUND_SWDEV_275126
21template <typename Lengths, typename Strides, index_t I, typename AccOld>
22__host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengths& lengths,
23 const Strides& strides,
24 Number<I> i,
25 AccOld acc_old)
26{
27 auto acc_new = acc_old + (lengths[i] - Number<1>{}) * strides[i];
28
29 if constexpr(i.value < Lengths::Size() - 1)
30 {
31 return calculate_element_space_size_impl(lengths, strides, i + Number<1>{}, acc_new);
32 }
33 else
34 {
35 return acc_new;
36 }
37}
38#endif
39
40// Lengths..., Strides... could be:
41// 1) index_t, which is known at run-time, or
42// 2) Number<>, which is known at compile-time
43// element_space_size could be:
44// 1) long_index_t, or
45// 2) LongNumber<>
46template <typename... Lengths,
47 typename... Strides,
48 typename enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
49__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple<Lengths...>& lengths,
50 const Tuple<Strides...>& strides)
51{
52 constexpr index_t N = sizeof...(Lengths);
53
54 const auto transforms = make_tuple(make_embed_transform(lengths, strides));
55
56 constexpr auto low_dim_hidden_idss = make_tuple(Sequence<0>{});
57
58 constexpr auto up_dim_hidden_idss =
60
61 constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{};
62
63#if !CK_WORKAROUND_SWDEV_275126
64 // rocm-4.1 compiler would crash for recursive labmda
65 // recursive function for reduction
66 auto f = [&](auto fs, auto i, auto acc_old) {
67 auto acc_new = acc_old + (lengths[i] - Number<1>{}) * strides[i];
68
69 if constexpr(i.value < N - 1)
70 {
71 return fs(fs, i + Number<1>{}, acc_new);
72 }
73 else
74 {
75 return acc_new;
76 }
77 };
78
79 const auto element_space_size = f(f, Number<0>{}, LongNumber<1>{});
80#else
81 const auto element_space_size =
82 calculate_element_space_size_impl(lengths, strides, Number<0>{}, LongNumber<1>{});
83#endif
84
85 return TensorDescriptor<remove_cv_t<decltype(transforms)>,
86 remove_cv_t<decltype(low_dim_hidden_idss)>,
87 remove_cv_t<decltype(up_dim_hidden_idss)>,
88 remove_cv_t<decltype(visible_dim_hidden_ids)>,
89 remove_cv_t<decltype(element_space_size)>>{transforms,
90 element_space_size};
91}
92
93// Lengths... could be:
94// 1) index_t, which is known at run-time, or
95// 2) Number<>, which is known at compile-time
96// element_space_size could be:
97// 1) long_index_t, or
98// 2) LongNumber<>
99template <typename... Lengths>
100__host__ __device__ constexpr auto
102{
103 constexpr index_t N = sizeof...(Lengths);
104
105 const auto transforms = make_tuple(make_unmerge_transform(lengths));
106
107 constexpr auto low_dim_hidden_idss = make_tuple(Sequence<0>{});
108
109 constexpr auto up_dim_hidden_idss =
111
112 constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{};
113
114 const auto element_space_size = container_reduce(lengths, math::multiplies{}, LongNumber<1>{});
115
116 return TensorDescriptor<remove_cv_t<decltype(transforms)>,
117 remove_cv_t<decltype(low_dim_hidden_idss)>,
118 remove_cv_t<decltype(up_dim_hidden_idss)>,
119 remove_cv_t<decltype(visible_dim_hidden_ids)>,
120 remove_cv_t<decltype(element_space_size)>>{transforms,
121 element_space_size};
122}
123
124// Lengths... could be:
125// 1) index_t, which is known at run-time, or
126// 2) Number<>, which is known at compile-time
127// align could be:
128// 1) index_t, or
129// 2) Number<>
130template <typename... Lengths, typename Align>
131__host__ __device__ constexpr auto
133{
134 constexpr auto I1 = Number<1>{};
135
136 constexpr index_t N = sizeof...(Lengths);
137
138 const auto stride_n_minus_2 = math::integer_least_multiple(lengths[Number<N - 1>{}], align);
139
140 auto strides = generate_tuple(
141 [&](auto i) {
142 if constexpr(i.value == N - 1)
143 {
144 return I1;
145 }
146 else if constexpr(i.value == N - 2)
147 {
149 }
150 else
151 {
152 return container_reduce(lengths,
155 i + I1,
156 Number<N - 1>{},
157 I1);
158 }
159 },
160 Number<N>{});
161
162 return make_naive_tensor_descriptor(lengths, strides);
163}
164
165} // namespace ck
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
CK_TILE_HOST_DEVICE constexpr auto calculate_element_space_size_impl(const Lengths &lengths, const Strides &strides, number< I > i, AccOld acc_old)
Definition tile/core/tensor/tensor_descriptor.hpp:239
Definition ck.hpp:268
integral_constant< long_index_t, N > LongNumber
Definition number.hpp:15
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition multi_index_transform_helper.hpp:48
std::enable_if< B, T > enable_if
Definition enable_if.hpp:24
__host__ __device__ constexpr auto container_reduce(const Container &x, Reduce reduce, Init init, Number< IBegin >=Number< 0 >{}, Number< IEnd >=Number< Container::Size()>{}, Number< IStep >=Number< 1 >{})
Definition utility/container_helper.hpp:111
__host__ __device__ constexpr auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition tensor_descriptor_helper.hpp:132
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
typename remove_cv< T >::type remove_cv_t
Definition type.hpp:295
Definition utility/sequence.hpp:43
Definition tensor_description/tensor_descriptor.hpp:28
Definition utility/tuple.hpp:117
typename conditional< kHasContent, type0, type1 >::type type
Definition utility/sequence.hpp:271
Definition utility/math.hpp:34