reference_pool.hpp Source File

reference_pool.hpp Source File#

Composable Kernel: reference_pool.hpp Source File
reference_pool.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
9#include <thread>
10#include <cmath>
11
12namespace ck_tile {
13
14template <typename InDataType,
15 typename ComputeDataType,
16 typename OutDataType,
17 typename IndexDataType,
18 typename ReduceOp,
19 typename TensorShape,
20 typename WindowShape,
21 bool OutputIndex = false>
24 HostTensor<IndexDataType>& output_index,
26 ReduceOp reduce_op)
27{
32
35
38
41
44
45 const ck_tile::index_t LeftPy = kargs.input_left_pads.at(ck_tile::number<0>{});
46 const ck_tile::index_t LeftPx = kargs.input_left_pads.at(ck_tile::number<1>{});
47 // Right padding is handled implicitly by bounds checking
48
49 auto f = [&](auto n, auto ho, auto wo, auto c) {
50 ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
51
52 IndexDataType current_index = 0; // Declare outside if constexpr for efficiency
53
54 for(ck_tile::index_t y = 0; y < Y; ++y)
55 {
56 // Calculate input height index with stride, dilation, and padding
57 ck_tile::index_t hi = ho * Sy + y * Dy - LeftPy;
58
59 for(ck_tile::index_t x = 0; x < X; ++x)
60 {
61 // Calculate input width index with stride, dilation, and padding
62 ck_tile::index_t wi = wo * Sx + x * Dx - LeftPx;
63
64 if(hi >= 0 && hi < H && wi >= 0 && wi < W)
65 {
66 const ComputeDataType v_in = type_convert<ComputeDataType>(input(n, hi, wi, c));
67
68 if constexpr(OutputIndex)
69 {
70 IndexDataType flat_index = input.GetOffsetFromMultiIndex(n, hi, wi, c);
71 bool changed = false;
72 v_acc = reduce_op(v_acc, v_in, changed);
73 if(changed)
74 {
75 current_index = flat_index;
76 }
77 }
78 else
79 {
80 v_acc = reduce_op(v_acc, v_in);
81 }
82 }
83 // For positions outside bounds, we implicitly use identity value
84 }
85 }
86
87 output(n, ho, wo, c) = ck_tile::type_convert<OutDataType>(v_acc);
88
89 if constexpr(OutputIndex)
90 {
91 output_index(n, ho, wo, c) = current_index;
92 }
93 };
94
95 // Parallelize over all output dimensions
96 make_ParallelTensorFunctor(f, N, Ho, Wo, C)(std::thread::hardware_concurrency());
97}
98
99template <typename InDataType,
100 typename ComputeDataType,
101 typename OutDataType,
102 typename IndexDataType,
103 typename ReduceOp,
104 typename TensorShape,
105 typename WindowShape,
106 bool OutputIndex = false>
109 HostTensor<IndexDataType>& output_index,
111 ReduceOp reduce_op)
112{
113 const ck_tile::index_t N = kargs.input_shape.at(ck_tile::number<0>{});
114 const ck_tile::index_t D = kargs.input_shape.at(ck_tile::number<1>{});
115 const ck_tile::index_t H = kargs.input_shape.at(ck_tile::number<2>{});
116 const ck_tile::index_t W = kargs.input_shape.at(ck_tile::number<3>{});
117 const ck_tile::index_t C = kargs.input_shape.at(ck_tile::number<4>{});
118
119 const ck_tile::index_t Do = kargs.output_shape.at(ck_tile::number<1>{});
120 const ck_tile::index_t Ho = kargs.output_shape.at(ck_tile::number<2>{});
121 const ck_tile::index_t Wo = kargs.output_shape.at(ck_tile::number<3>{});
122
126
130
134
135 const ck_tile::index_t LeftPz = kargs.input_left_pads.at(ck_tile::number<0>{});
136 const ck_tile::index_t LeftPy = kargs.input_left_pads.at(ck_tile::number<1>{});
137 const ck_tile::index_t LeftPx = kargs.input_left_pads.at(ck_tile::number<2>{});
138 // Right padding is handled implicitly by bounds checking
139
140 auto f = [&](auto n, auto do_, auto ho, auto wo, auto c) {
141 ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
142
143 IndexDataType current_index = 0; // Declare outside if constexpr for efficiency
144
145 for(ck_tile::index_t z = 0; z < Z; ++z)
146 {
147 // Calculate input depth index with stride, dilation, and padding
148 ck_tile::index_t di = do_ * Sz + z * Dz - LeftPz;
149
150 for(ck_tile::index_t y = 0; y < Y; ++y)
151 {
152 // Calculate input height index with stride, dilation, and padding
153 ck_tile::index_t hi = ho * Sy + y * Dy - LeftPy;
154
155 for(ck_tile::index_t x = 0; x < X; ++x)
156 {
157 // Calculate input width index with stride, dilation, and padding
158 ck_tile::index_t wi = wo * Sx + x * Dx - LeftPx;
159
160 if(di >= 0 && di < D && hi >= 0 && hi < H && wi >= 0 && wi < W)
161 {
162 const ComputeDataType v_in =
163 type_convert<ComputeDataType>(input(n, di, hi, wi, c));
164
165 if constexpr(OutputIndex)
166 {
167 IndexDataType flat_index =
168 input.GetOffsetFromMultiIndex(n, di, hi, wi, c);
169 bool changed = false;
170 v_acc = reduce_op(v_acc, v_in, changed);
171 if(changed)
172 {
173 current_index = flat_index;
174 }
175 }
176 else
177 {
178 v_acc = reduce_op(v_acc, v_in);
179 }
180 }
181 // For positions outside bounds, we implicitly use identity value
182 }
183 }
184 }
185
186 output(n, do_, ho, wo, c) = ck_tile::type_convert<OutDataType>(v_acc);
187
188 if constexpr(OutputIndex)
189 {
190
191 output_index(n, do_, ho, wo, c) = current_index;
192 }
193 };
194
195 // Parallelize over all output dimensions
196 make_ParallelTensorFunctor(f, N, Do, Ho, Wo, C)(std::thread::hardware_concurrency());
197}
198} // namespace ck_tile
#define CK_TILE_HOST
Definition config.hpp:40
Definition reduce_operator.hpp:11
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition tile/host/host_tensor.hpp:329
CK_TILE_HOST void reference_pool2d(const HostTensor< InDataType > &input, HostTensor< OutDataType > &output, HostTensor< IndexDataType > &output_index, PoolKernelArgs< TensorShape, WindowShape > kargs, ReduceOp reduce_op)
Definition reference_pool.hpp:22
CK_TILE_HOST void reference_pool3d(const HostTensor< InDataType > &input, HostTensor< OutDataType > &output, HostTensor< IndexDataType > &output_index, PoolKernelArgs< TensorShape, WindowShape > kargs, ReduceOp reduce_op)
Definition reference_pool.hpp:107
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
Definition tile/host/host_tensor.hpp:336
std::size_t GetOffsetFromMultiIndex(Is... is) const
Definition tile/host/host_tensor.hpp:531
Kernel arguments for pooling operations.
Definition pool_kernel.hpp:63
TensorShape output_shape
Definition pool_kernel.hpp:68
WindowShape window_lengths
Definition pool_kernel.hpp:71
WindowShape window_dilations
Definition pool_kernel.hpp:73
WindowShape input_left_pads
Definition pool_kernel.hpp:74
TensorShape input_shape
Definition pool_kernel.hpp:67
WindowShape window_strides
Definition pool_kernel.hpp:72