25 template <
typename BRes,
33 const BCoords& cached_coords_b,
35 const OCoords& cached_coords_o,
36 const OFlags& o_flags,
39 const ScaleTensor& scale_,
43 static_assert(BCoords::size() == 8);
44 static_assert(OCoords::size() == 8);
49 static_assert(ScaleTensor::size() == 2);
55 register float v_c0
asm(
"v64");
56 register float v_c1
asm(
"v65");
57 register float v_c2
asm(
"v66");
58 register float v_c3
asm(
"v67");
59 register float v_c4
asm(
"v68");
60 register float v_c5
asm(
"v69");
61 register float v_c6
asm(
"v70");
62 register float v_c7
asm(
"v71");
63 register float v_c8
asm(
"v72");
64 register float v_c9
asm(
"v73");
65 register float v_c10
asm(
"v74");
66 register float v_c11
asm(
"v75");
67 register float v_c12
asm(
"v76");
68 register float v_c13
asm(
"v77");
69 register float v_c14
asm(
"v78");
70 register float v_c15
asm(
"v79");
71 register float v_c16
asm(
"v80");
72 register float v_c17
asm(
"v81");
73 register float v_c18
asm(
"v82");
74 register float v_c19
asm(
"v83");
75 register float v_c20
asm(
"v84");
76 register float v_c21
asm(
"v85");
77 register float v_c22
asm(
"v86");
78 register float v_c23
asm(
"v87");
79 register float v_c24
asm(
"v88");
80 register float v_c25
asm(
"v89");
81 register float v_c26
asm(
"v90");
82 register float v_c27
asm(
"v91");
83 register float v_c28
asm(
"v92");
84 register float v_c29
asm(
"v93");
85 register float v_c30
asm(
"v94");
86 register float v_c31
asm(
"v95");
93 int lane_id = threadIdx.x % 64;
94 int sld_y_os = (lane_id % 16) * 4 + (lane_id / 16) * 128;
104 int sfl_sst = (threadIdx.x % 16 * 4) + (threadIdx.x / 16) * (64 + 4);
111 int sfl_sld = (lane_id % 2) * 2 + (lane_id / 2) * (64 + 4) + (threadIdx.x / 64) * 4;
116#pragma clang diagnostic push
117#pragma clang diagnostic ignored "-Winline-asm"
119#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
121#undef CK_TILE_FLATMM_UK_MFMA
160 [v_sld_y_os]
"v"(sld_y_os),
161 [v_sfl_sld]
"v"(sfl_sld),
162 [v_sfl_sst]
"v"(sfl_sst),
163 [s_res_o0]
"s"(res_o[0]),
164 [s_res_o1]
"s"(res_o[1]),
167 [s_res_b0]
"s"(res_b[0]),
168 [s_res_b1]
"s"(res_b[1]),
169 [s_res_b2]
"s"(res_b[2]),
170 [s_res_b3]
"s"(res_b[3]),
188 [s_tile_os_o]
"s"(tile_stride_o_bytes),
189 [s_tile_os_b]
"s"(tile_stride_b_bytes),
192 [v_nan_lo]
"v"(nan_lo),
193 [v_nan_hi]
"v"(nan_hi),
203 "memory",
"a0",
"a1",
"a2",
"a3",
"a4",
"a5",
"a6",
"a7",
"a8",
"a9",
204 "a10",
"a11",
"a12",
"a13",
"a14",
"a15",
"a16",
"a17",
"a18",
"a19",
205 "a20",
"a21",
"a22",
"a23",
"a24",
"a25",
"a26",
"a27",
"a28",
"a29",
206 "a30",
"a31",
"a32",
"a33",
"a34",
"a35",
"a36",
"a37",
"a38",
"a39",
207 "a40",
"a41",
"a42",
"a43",
"a44",
"a45",
"a46",
"a47",
"a48",
"a49",
208 "a50",
"a51",
"a52",
"a53",
"a54",
"a55",
"a56",
"a57",
"a58",
"a59",
209 "a60",
"a61",
"a62",
"a63",
"a64",
"a65",
"a66",
"a67",
"a68",
"a69",
210 "a70",
"a71",
"a72",
"a73",
"a74",
"a75",
"a76",
"a77",
"a78",
"a79",
211 "a80",
"a81",
"a82",
"a83",
"a84",
"a85",
"a86",
"a87",
"a88",
"a89",
212 "a90",
"a91",
"a92",
"a93",
"a94",
"a95",
"a96",
"a97",
"a98",
"a99",
213 "a100",
"a101",
"a102",
"a103",
"a104",
"a105",
"a106",
"a107",
214 "a108",
"a109",
"a110",
"a111",
"a112",
"a113",
"a114",
"a115",
215 "a116",
"a117",
"a118",
"a119",
"a120",
"a121",
"a122",
"a123",
216 "a124",
"a125",
"a126",
"a127",
"a128",
"a129",
"a130",
"a131",
217 "a132",
"a133",
"a134",
"a135",
"a136",
"a137",
"a138",
"a139",
218 "a140",
"a141",
"a142",
"a143",
"a144",
"a145",
"a146",
"a147",
219 "a148",
"a149",
"a150",
"a151",
"a152",
"a153",
"a154",
"a155",
220 "a156",
"a157",
"a158",
"a159",
"a160",
"a161",
"a162",
"a163",
221 "a164",
"a165",
"a166",
"a167",
"a168",
"a169",
"a170",
"a171",
222 "a172",
"a173",
"a174",
"a175",
"a176",
"a177",
"a178",
"a179",
223 "a180",
"a181",
"a182",
"a183",
"a184",
"a185",
"a186",
"a187",
224 "a188",
"a189",
"a190",
"a191",
"a192",
"a193",
"a194",
"a195",
225 "a196",
"a197",
"a198",
"a199",
"a200",
"a201",
"a202",
"a203",
226 "a204",
"a205",
"a206",
"a207",
"a208",
"a209",
"a210",
"a211",
227 "a212",
"a213",
"a214",
"a215",
"a216",
"a217",
"a218",
"a219",
228 "a220",
"a221",
"a222",
"a223",
"a224",
"a225",
"a226",
"a227",
229 "a228",
"a229",
"a230",
"a231",
"a232",
"a233",
"a234",
"a235",
230 "a236",
"a237",
"a238",
"a239",
"a240",
"a241",
"a242",
"a243",
231 "a244",
"a245",
"a246",
"a247",
"a248",
"a249",
"a250",
"a251",
232 "a252",
"a253",
"a254",
"a255",
233 "s8",
"s9",
"s12",
"s13",
"s14",
"s15",
"s38",
"s39",
"s52",
"s86",
234 "s36",
"s37",
"s59",
"s80",
235 "v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
237 "v64",
"v65",
"v66",
"v67",
"v68",
"v69",
"v70",
"v71",
238 "v72",
"v73",
"v74",
"v75",
"v76",
"v77",
"v78",
"v79",
239 "v80",
"v81",
"v82",
"v83",
"v84",
"v85",
"v86",
"v87",
240 "v88",
"v89",
"v90",
"v91",
"v92",
"v93",
"v94",
"v95",
241 "v128",
"v129",
"v130",
"v131",
242 "v132",
"v133",
"v134",
"v135",
"v136",
"v137",
"v138",
"v139",
243 "v140",
"v141",
"v142",
"v143",
"v144",
"v145",
"v146",
"v147",
244 "v148",
"v149",
"v150",
"v151",
"v152",
"v153",
"v154",
"v155",
245 "v156",
"v157",
"v158",
"v159",
"v160",
"v161",
"v162",
"v163",
246 "v164",
"v165",
"v166",
"v167",
"v168",
"v169",
"v170",
"v171",
247 "v172",
"v173",
"v174",
"v175",
"v176",
"v177",
"v178",
"v179",
248 "v180",
"v181",
"v182",
"v183",
"v184",
"v185",
"v186",
"v187",
249 "v188",
"v189",
"v190",
"v191",
"v192",
"v193",
"v194",
"v195",
250 "v196",
"v197",
"v198",
"v199",
"v200",
"v201",
"v202",
"v203",
251 "v204",
"v205",
"v206",
"v207",
"v208",
"v209",
"v210",
"v211",
252 "v212",
"v213",
"v214",
"v215",
"v216",
"v217",
"v218",
"v219",
253 "v220",
"v221",
"v222",
"v223",
"v224",
"v225",
"v226",
"v227",
254 "v228",
"v229",
"v230",
"v231",
"v232",
"v233",
"v234",
"v235",
255 "v236",
"v237",
"v238",
"v239",
"v240",
"v241",
"v242",
"v243",
256 "v244",
"v245",
"v246",
"v247",
"v248",
"v249",
"v250",
"v251",
257 "v252",
"v253",
"v254",
"v255"
259#pragma clang diagnostic pop
272 template <
typename BRes,
277 typename ScaleTensor>
280 const BCoords& cached_coords_b,
282 const OCoords& cached_coords_o,
283 const OFlags& o_flags,
286 const ScaleTensor& scale_,
290 static_assert(BCoords::size() == 8);
291 static_assert(OCoords::size() == 8);
296 static_assert(ScaleTensor::size() == 2);
302 register float v_c0
asm(
"v64");
303 register float v_c1
asm(
"v65");
304 register float v_c2
asm(
"v66");
305 register float v_c3
asm(
"v67");
306 register float v_c4
asm(
"v68");
307 register float v_c5
asm(
"v69");
308 register float v_c6
asm(
"v70");
309 register float v_c7
asm(
"v71");
310 register float v_c8
asm(
"v72");
311 register float v_c9
asm(
"v73");
312 register float v_c10
asm(
"v74");
313 register float v_c11
asm(
"v75");
314 register float v_c12
asm(
"v76");
315 register float v_c13
asm(
"v77");
316 register float v_c14
asm(
"v78");
317 register float v_c15
asm(
"v79");
318 register float v_c16
asm(
"v80");
319 register float v_c17
asm(
"v81");
320 register float v_c18
asm(
"v82");
321 register float v_c19
asm(
"v83");
322 register float v_c20
asm(
"v84");
323 register float v_c21
asm(
"v85");
324 register float v_c22
asm(
"v86");
325 register float v_c23
asm(
"v87");
326 register float v_c24
asm(
"v88");
327 register float v_c25
asm(
"v89");
328 register float v_c26
asm(
"v90");
329 register float v_c27
asm(
"v91");
330 register float v_c28
asm(
"v92");
331 register float v_c29
asm(
"v93");
332 register float v_c30
asm(
"v94");
333 register float v_c31
asm(
"v95");
340 int lane_id = threadIdx.x % 64;
341 int sld_y_os = (lane_id % 16) * 4 + (lane_id / 16) * 128;
351 int sfl_sst = (threadIdx.x % 16 * 4) + (threadIdx.x / 16) * (64 + 4);
358 int sfl_sld = (lane_id % 2) * 2 + (lane_id / 2) * (64 + 4) + (threadIdx.x / 64) * 4;
363#pragma clang diagnostic push
364#pragma clang diagnostic ignored "-Winline-asm"
366#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16
368#undef CK_TILE_FLATMM_UK_MFMA
406 [v_sld_y_os]
"v"(sld_y_os),
407 [v_sfl_sld]
"v"(sfl_sld),
408 [v_sfl_sst]
"v"(sfl_sst),
409 [s_res_o0]
"s"(res_o[0]),
410 [s_res_o1]
"s"(res_o[1]),
413 [s_res_b0]
"s"(res_b[0]),
414 [s_res_b1]
"s"(res_b[1]),
415 [s_res_b2]
"s"(res_b[2]),
416 [s_res_b3]
"s"(res_b[3]),
434 [s_tile_os_o]
"s"(tile_stride_o_bytes),
435 [s_tile_os_b]
"s"(tile_stride_b_bytes),
438 [v_nan_lo]
"v"(nan_lo),
439 [v_nan_hi]
"v"(nan_hi),
449 "memory",
"a0",
"a1",
"a2",
"a3",
"a4",
"a5",
"a6",
"a7",
"a8",
"a9",
450 "a10",
"a11",
"a12",
"a13",
"a14",
"a15",
"a16",
"a17",
"a18",
"a19",
451 "a20",
"a21",
"a22",
"a23",
"a24",
"a25",
"a26",
"a27",
"a28",
"a29",
452 "a30",
"a31",
"a32",
"a33",
"a34",
"a35",
"a36",
"a37",
"a38",
"a39",
453 "a40",
"a41",
"a42",
"a43",
"a44",
"a45",
"a46",
"a47",
"a48",
"a49",
454 "a50",
"a51",
"a52",
"a53",
"a54",
"a55",
"a56",
"a57",
"a58",
"a59",
455 "a60",
"a61",
"a62",
"a63",
"a64",
"a65",
"a66",
"a67",
"a68",
"a69",
456 "a70",
"a71",
"a72",
"a73",
"a74",
"a75",
"a76",
"a77",
"a78",
"a79",
457 "a80",
"a81",
"a82",
"a83",
"a84",
"a85",
"a86",
"a87",
"a88",
"a89",
458 "a90",
"a91",
"a92",
"a93",
"a94",
"a95",
"a96",
"a97",
"a98",
"a99",
459 "a100",
"a101",
"a102",
"a103",
"a104",
"a105",
"a106",
"a107",
460 "a108",
"a109",
"a110",
"a111",
"a112",
"a113",
"a114",
"a115",
461 "a116",
"a117",
"a118",
"a119",
"a120",
"a121",
"a122",
"a123",
462 "a124",
"a125",
"a126",
"a127",
"a128",
"a129",
"a130",
"a131",
463 "a132",
"a133",
"a134",
"a135",
"a136",
"a137",
"a138",
"a139",
464 "a140",
"a141",
"a142",
"a143",
"a144",
"a145",
"a146",
"a147",
465 "a148",
"a149",
"a150",
"a151",
"a152",
"a153",
"a154",
"a155",
466 "a156",
"a157",
"a158",
"a159",
"a160",
"a161",
"a162",
"a163",
467 "a164",
"a165",
"a166",
"a167",
"a168",
"a169",
"a170",
"a171",
468 "a172",
"a173",
"a174",
"a175",
"a176",
"a177",
"a178",
"a179",
469 "a180",
"a181",
"a182",
"a183",
"a184",
"a185",
"a186",
"a187",
470 "a188",
"a189",
"a190",
"a191",
"a192",
"a193",
"a194",
"a195",
471 "a196",
"a197",
"a198",
"a199",
"a200",
"a201",
"a202",
"a203",
472 "a204",
"a205",
"a206",
"a207",
"a208",
"a209",
"a210",
"a211",
473 "a212",
"a213",
"a214",
"a215",
"a216",
"a217",
"a218",
"a219",
474 "a220",
"a221",
"a222",
"a223",
"a224",
"a225",
"a226",
"a227",
475 "a228",
"a229",
"a230",
"a231",
"a232",
"a233",
"a234",
"a235",
476 "a236",
"a237",
"a238",
"a239",
"a240",
"a241",
"a242",
"a243",
477 "a244",
"a245",
"a246",
"a247",
"a248",
"a249",
"a250",
"a251",
478 "a252",
"a253",
"a254",
"a255",
479 "s8",
"s9",
"s12",
"s13",
"s14",
"s15",
"s38",
"s39",
"s52",
"s86",
480 "s36",
"s37",
"s56",
"s59",
"s60",
"s80",
481 "v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
483 "v64",
"v65",
"v66",
"v67",
"v68",
"v69",
"v70",
"v71",
484 "v72",
"v73",
"v74",
"v75",
"v76",
"v77",
"v78",
"v79",
485 "v80",
"v81",
"v82",
"v83",
"v84",
"v85",
"v86",
"v87",
486 "v88",
"v89",
"v90",
"v91",
"v92",
"v93",
"v94",
"v95",
487 "v128",
"v129",
"v130",
"v131",
488 "v132",
"v133",
"v134",
"v135",
"v136",
"v137",
"v138",
"v139",
489 "v140",
"v141",
"v142",
"v143",
"v144",
"v145",
"v146",
"v147",
490 "v148",
"v149",
"v150",
"v151",
"v152",
"v153",
"v154",
"v155",
491 "v156",
"v157",
"v158",
"v159",
"v160",
"v161",
"v162",
"v163",
492 "v164",
"v165",
"v166",
"v167",
"v168",
"v169",
"v170",
"v171",
493 "v172",
"v173",
"v174",
"v175",
"v176",
"v177",
"v178",
"v179",
494 "v180",
"v181",
"v182",
"v183",
"v184",
"v185",
"v186",
"v187",
495 "v188",
"v189",
"v190",
"v191",
"v192",
"v193",
"v194",
"v195",
496 "v196",
"v197",
"v198",
"v199",
"v200",
"v201",
"v202",
"v203",
497 "v204",
"v205",
"v206",
"v207",
"v208",
"v209",
"v210",
"v211",
498 "v212",
"v213",
"v214",
"v215",
"v216",
"v217",
"v218",
"v219",
499 "v220",
"v221",
"v222",
"v223",
"v224",
"v225",
"v226",
"v227",
500 "v228",
"v229",
"v230",
"v231",
"v232",
"v233",
"v234",
"v235",
501 "v236",
"v237",
"v238",
"v239",
"v240",
"v241",
"v242",
"v243",
502 "v244",
"v245",
"v246",
"v247",
"v248",
"v249",
"v250",
"v251",
503 "v252",
"v253",
"v254",
"v255"
505#pragma clang diagnostic pop
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_LDS_ADDR
Definition config.hpp:58
Definition tile/core/algorithm/cluster_descriptor.hpp:13
bfloat16_t bf16_t
Definition bfloat16.hpp:113
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
int32_t int32_t
Definition integer.hpp:10
int32_t index_t
Definition integer.hpp:9
Definition flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp:18
bf16_t ODataType
Definition flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp:20
CK_TILE_DEVICE auto operator()(const BRes &res_b, const BCoords &cached_coords_b, const ORes &res_o, const OCoords &cached_coords_o, const OFlags &o_flags, CK_TILE_LDS_ADDR void *smem, index_t n, const ScaleTensor &scale_, index_t tile_offset_b, index_t tile_offset_o)
Definition flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp:32
bf16_t BDataType
Definition flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp:19
Definition flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:16
Definition flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp:265
bf16_t ODataType
Definition flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp:267
bf16_t BDataType
Definition flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp:266
CK_TILE_DEVICE auto operator()(const BRes &res_b, const BCoords &cached_coords_b, const ORes &res_o, const OCoords &cached_coords_o, const OFlags &o_flags, CK_TILE_LDS_ADDR void *smem, index_t n, const ScaleTensor &scale_, index_t tile_offset_b, index_t tile_offset_o)
Definition flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp:279