block_fmha_bwd_convert_dq.hpp Source File

block_fmha_bwd_convert_dq.hpp Source File#

Composable Kernel: block_fmha_bwd_convert_dq.hpp Source File
block_fmha_bwd_convert_dq.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8
9namespace ck_tile {
10
11template <typename Problem, typename Policy = BlockFmhaBwdPipelineDefaultPolicy>
13{
16
17 static constexpr index_t kM0 = Problem::kM0;
18 static constexpr index_t kN0 = Problem::kN0;
19
20 static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
21 static constexpr index_t kBlockSize = Problem::kBlockSize;
22 static constexpr index_t kQKHeaddim = Problem::kQKHeaddim;
23
24 static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
25 static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
26 static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
27 static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
28
29 static constexpr index_t kAlignmentQGradAcc =
30 kPadHeadDimQ ? 1 : Policy::template GetAlignmentPostQGradAcc<Problem>();
31 static constexpr index_t kAlignmentQGrad =
32 kPadHeadDimQ ? 1 : Policy::template GetAlignmentPostQGrad<Problem>();
33
34 CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; }
35
36 // Convert only
37 template <typename QGradAccDramBlockWindowTmp, typename QGradDramBlockWindowTmp>
39 operator()(const QGradAccDramBlockWindowTmp& dq_acc_dram_block_window_tmp,
40 QGradDramBlockWindowTmp& dq_dram_block_window_tmp) const
41 {
42 static_assert(
43 std::is_same_v<AccDataType,
45 std::is_same_v<QGradDataType,
47 "wrong!");
48
49 static_assert(kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}], "wrong!");
50
51 auto dq_acc_dram_window =
52 make_tile_window(dq_acc_dram_block_window_tmp.get_bottom_tensor_view(),
53 dq_acc_dram_block_window_tmp.get_window_lengths(),
54 dq_acc_dram_block_window_tmp.get_window_origin(),
55 Policy::template MakePostQGradDramTileDistribution<Problem>());
56
57 auto dq_acc = load_tile(dq_acc_dram_window);
58 const auto dq = cast_tile<QGradDataType>(dq_acc);
59
60 store_tile(dq_dram_block_window_tmp, dq);
61 }
62
63 // Reduce + Convert
64 template <typename QGradAccDramBlockWindowTmp, typename QGradDramBlockWindowTmp>
66 operator()(const QGradAccDramBlockWindowTmp& dq_acc_dram_block_window_tmp,
67 QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
68 index_t nsplits) const
69 {
70 static_assert(
71 std::is_same_v<AccDataType,
73 std::is_same_v<QGradDataType,
75 "wrong!");
76
77 static_assert(kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}], "wrong!");
78
79 auto dq_acc_dram_window =
80 make_tile_window(dq_acc_dram_block_window_tmp.get_bottom_tensor_view(),
81 dq_acc_dram_block_window_tmp.get_window_lengths(),
82 dq_acc_dram_block_window_tmp.get_window_origin(),
83 Policy::template MakePostQGradAccDramTileDistribution<Problem>());
84
85 auto dq_acc = decltype(load_tile(dq_acc_dram_window)){};
86 clear_tile(dq_acc);
87
88 constexpr auto dq_acc_spans = decltype(dq_acc)::get_distributed_spans();
89 index_t i_total_loops = 0;
90 auto dq_acc_buf = load_tile(dq_acc_dram_window);
91 move_tile_window(dq_acc_dram_window, {1, 0, 0});
92
93 do
94 {
95 sweep_tile_span(dq_acc_spans[number<0>{}], [&](auto idx0) {
96 sweep_tile_span(dq_acc_spans[number<1>{}], [&](auto idx1) {
97 sweep_tile_span(dq_acc_spans[number<2>{}], [&](auto idx2) {
98 constexpr auto n_i_j_idx = make_tuple(idx0, idx1, idx2);
99 dq_acc(n_i_j_idx) += dq_acc_buf(n_i_j_idx);
100 });
101 });
102 });
103
104 dq_acc_buf = load_tile(dq_acc_dram_window);
105 move_tile_window(dq_acc_dram_window, {1, 0, 0});
106
107 i_total_loops += 1;
108 } while(i_total_loops < (nsplits - 1));
109
110 sweep_tile_span(dq_acc_spans[number<0>{}], [&](auto idx0) {
111 sweep_tile_span(dq_acc_spans[number<1>{}], [&](auto idx1) {
112 sweep_tile_span(dq_acc_spans[number<2>{}], [&](auto idx2) {
113 constexpr auto n_i_j_idx = make_tuple(idx0, idx1, idx2);
114 dq_acc(n_i_j_idx) += dq_acc_buf(n_i_j_idx);
115 });
116 });
117 });
118
119 // declare dq
120 constexpr auto dq_converted_dstr =
121 Policy::template MakePostQGradAccDramTileDistribution<Problem>();
122 auto dq_converted = make_static_distributed_tensor<QGradDataType>(dq_converted_dstr);
123
124 sweep_tile_span(dq_acc_spans[number<0>{}], [&](auto idx0) {
125 sweep_tile_span(dq_acc_spans[number<1>{}], [&](auto idx1) {
126 sweep_tile_span(dq_acc_spans[number<2>{}], [&](auto idx2) {
127 constexpr auto n_i_j_idx = make_tuple(idx0, idx1, idx2);
128 dq_converted(n_i_j_idx) = type_convert<QGradDataType>(dq_acc[n_i_j_idx]);
129 });
130 });
131 });
132
133 constexpr auto dq_dstr = Policy::template MakePostQGradDramTileDistribution<Problem>();
135 dq.get_thread_buffer() = dq_converted.get_thread_buffer();
136
137 store_tile(dq_dram_block_window_tmp, dq);
138 }
139};
140
141} // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition tile_elementwise.hpp:327
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition sweep_tile.hpp:20
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition tile_elementwise.hpp:177
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition block_fmha_bwd_convert_dq.hpp:13
static constexpr bool kPadSeqLenQ
Definition block_fmha_bwd_convert_dq.hpp:25
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition block_fmha_bwd_convert_dq.hpp:34
static constexpr bool kIsDeterministic
Definition block_fmha_bwd_convert_dq.hpp:27
remove_cvref_t< typename Problem::AccDataType > AccDataType
Definition block_fmha_bwd_convert_dq.hpp:14
static constexpr index_t kAlignmentQGradAcc
Definition block_fmha_bwd_convert_dq.hpp:29
remove_cvref_t< typename Problem::QGradDataType > QGradDataType
Definition block_fmha_bwd_convert_dq.hpp:15
static constexpr index_t kBlockSize
Definition block_fmha_bwd_convert_dq.hpp:21
static constexpr index_t kM0
Definition block_fmha_bwd_convert_dq.hpp:17
static constexpr bool kPadHeadDimQ
Definition block_fmha_bwd_convert_dq.hpp:26
static constexpr index_t kBlockPerCu
Definition block_fmha_bwd_convert_dq.hpp:20
static constexpr index_t kAlignmentQGrad
Definition block_fmha_bwd_convert_dq.hpp:31
static constexpr index_t kN0
Definition block_fmha_bwd_convert_dq.hpp:18
CK_TILE_HOST_DEVICE void operator()(const QGradAccDramBlockWindowTmp &dq_acc_dram_block_window_tmp, QGradDramBlockWindowTmp &dq_dram_block_window_tmp, index_t nsplits) const
Definition block_fmha_bwd_convert_dq.hpp:66
static constexpr index_t kQKHeaddim
Definition block_fmha_bwd_convert_dq.hpp:22
CK_TILE_HOST_DEVICE void operator()(const QGradAccDramBlockWindowTmp &dq_acc_dram_block_window_tmp, QGradDramBlockWindowTmp &dq_dram_block_window_tmp) const
Definition block_fmha_bwd_convert_dq.hpp:39
static constexpr bool kIsGroupMode
Definition block_fmha_bwd_convert_dq.hpp:24