amd_xdlops.hpp Source File

amd_xdlops.hpp Source File#

Composable Kernel: amd_xdlops.hpp Source File
amd_xdlops.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
6
7namespace ck {
8// Define the common macro for MI300 models
9#if defined(__gfx942__) || defined(__gfx950__)
10#define __gfx94__
11#endif
12
13// fp32
14template <index_t MPerWave, index_t NPerWave>
16
17template <>
19{
20 template <class FloatC>
21 __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
22 {
23 reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
24 reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
25 reg_c.template AsType<float32_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
26 reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<1>{}], 1, 1, 0);
27 }
28};
29
30template <>
32{
33 template <class FloatC>
34 __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
35 {
36 reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
37 reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
38 }
39};
40
41template <index_t MPerWave, index_t NPerWave>
43
44template <>
46{
47 template <class FloatC>
48 __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
49 {
50 reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x2f32(
51 reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
52 }
53};
54
55template <index_t MPerWave, index_t NPerWave>
57
58template <>
60{
61 template <class FloatC>
62 __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
63 {
64 reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f32(
65 reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
66 }
67};
68
69template <index_t MPerWave, index_t NPerWave>
71
72template <>
74{
75 template <class FloatC>
76 __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
77 {
78 reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x1f32(
79 reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 2, 0, 0);
80 }
81};
82
83template <index_t MPerWave, index_t NPerWave>
85
86template <>
88{
89 template <class FloatC>
90 __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
91 {
92 reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
93 reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
94 }
95};
96
97template <>
99{
100 template <class FloatC>
101 __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
102 {
103 reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
104 reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
105 reg_c.template AsType<float4_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
106 reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<1>{}], 4, 1, 0);
107 }
108};
109
110// fp16
111template <index_t MPerWave, index_t NPerWave>
113
114template <>
116{
117 template <class FloatC>
118 __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
119 {
120 reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
121 reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
122 reg_c.template AsType<float32_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
123 reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<1>{}], 1, 1, 0);
124 }
125};
126
127template <>
129{
130 template <class FloatC>
131 __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
132 {
133 reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
134 reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
135 }
136};
137
138template <index_t MPerWave, index_t NPerWave>
140
141template <>
143{
144 template <class FloatC>
145 __device__ static void Run(const half8_t& reg_a, const half8_t& reg_b, FloatC& reg_c)
146 {
147#if defined(__gfx950__)
148 reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_f16(
149 reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
150#else
151 ignore = reg_a;
152 ignore = reg_b;
153 ignore = reg_c;
154#endif // defined(__gfx950__)
155 }
156};
157
158template <index_t MPerWave, index_t NPerWave>
160
161template <>
163{
164 template <class FloatC>
165 __device__ static void Run(const half8_t& reg_a, const half8_t& reg_b, FloatC& reg_c)
166 {
167#if defined(__gfx950__)
168 reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_f16(
169 reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
170#else
171 ignore = reg_a;
172 ignore = reg_b;
173 ignore = reg_c;
174#endif // defined(__gfx950__)
175 }
176};
177
178template <index_t MPerWave, index_t NPerWave>
180
181template <>
183{
184 template <class FloatC>
185 __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
186 {
187 reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8f16(
188 reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
189 }
190};
191
192template <index_t MPerWave, index_t NPerWave>
194
195template <>
197{
198 template <class FloatC>
199 __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
200 {
201 reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16f16(
202 reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
203 }
204};
205
206template <index_t MPerWave, index_t NPerWave>
208
209template <>
211{
212 template <class FloatC>
213 __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
214 {
215 reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f16(
216 reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 2, 0, 0);
217 }
218};
219
220template <index_t MPerWave, index_t NPerWave>
222
223template <>
225{
226 template <class FloatC>
227 __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
228 {
229 reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
230 reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
231 }
232};
233
234template <>
236{
237 template <class FloatC>
238 __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
239 {
240 reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
241 reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
242 reg_c.template AsType<float4_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
243 reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<1>{}], 4, 1, 0);
244 }
245};
246
247// bfp16
248template <index_t MPerWave, index_t NPerWave>
250
251template <>
253{
254 template <class FloatC>
255 __device__ static void Run(const bhalf8_t& reg_a, const bhalf8_t& reg_b, FloatC& reg_c)
256 {
257#if defined(__gfx950__)
258 reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_bf16(
259 reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
260#else
261 ignore = reg_a;
262 ignore = reg_b;
263 ignore = reg_c;
264#endif // defined(__gfx950__)
265 }
266};
267
268template <index_t MPerWave, index_t NPerWave>
270
271template <>
273{
274 template <class FloatC>
275 __device__ static void Run(const bhalf8_t& reg_a, const bhalf8_t& reg_b, FloatC& reg_c)
276 {
277#if defined(__gfx950__)
278 reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf16(
279 reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
280#else
281 ignore = reg_a;
282 ignore = reg_b;
283 ignore = reg_c;
284#endif // defined(__gfx950__)
285 }
286};
287
288template <index_t MPerWave, index_t NPerWave>
290
291template <>
293{
294 template <class FloatC>
295 __device__ static void Run(const bhalf4_t& reg_a, const bhalf4_t& reg_b, FloatC& reg_c)
296 {
297 reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
298 reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
299 }
300};
301
302template <index_t MPerWave, index_t NPerWave>
304
305template <>
307{
308 template <class FloatC>
309 __device__ static void Run(const bhalf4_t& reg_a, const bhalf4_t& reg_b, FloatC& reg_c)
310 {
311 reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
312 reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
313 }
314};
315
316template <index_t MPerWave, index_t NPerWave>
318
319template <>
321{
322 template <class FloatC>
323 __device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c)
324 {
325 reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16(
326 reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
327 }
328};
329
330template <index_t MPerWave, index_t NPerWave>
332
333template <>
335{
336 template <class FloatC>
337 __device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c)
338 {
339 reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x8bf16(
340 reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
341 }
342};
343
344template <index_t MPerWave, index_t NPerWave>
346
347template <>
349{
350 template <class FloatC>
351 __device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
352 {
353 reg_c.template AsType<int32x16_t>()(Number<0>{}) =
354 __builtin_amdgcn_mfma_i32_32x32x8i8(bit_cast<int32_t>(reg_a),
355 bit_cast<int32_t>(reg_b),
356 reg_c.template AsType<int32x16_t>()[Number<0>{}],
357 0,
358 0,
359 0);
360 }
361};
362
363template <index_t MPerWave, index_t NPerWave>
365
366template <>
368{
369 template <class FloatC>
370 __device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
371 {
372 reg_c.template AsType<int32x4_t>()(Number<0>{}) =
373 __builtin_amdgcn_mfma_i32_16x16x16i8(bit_cast<int32_t>(reg_a),
374 bit_cast<int32_t>(reg_b),
375 reg_c.template AsType<int32x4_t>()[Number<0>{}],
376 0,
377 0,
378 0);
379 }
380};
381
382template <index_t MPerWave, index_t NPerWave>
384
385template <>
387{
388 template <class FloatC>
389 __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
390 {
391#if defined(__gfx950__)
392 reg_c.template AsType<int32x16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_i32_32x32x32_i8(
393 reg_a, reg_b, reg_c.template AsType<int32x16_t>()[Number<0>{}], 0, 0, 0);
394#else
395 ignore = reg_a;
396 ignore = reg_b;
397 ignore = reg_c;
398#endif // defined(__gfx950__)
399 }
400};
401
402template <index_t MPerWave, index_t NPerWave>
404
405template <>
407{
408 template <class FloatC>
409 __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
410 {
411#if defined(__gfx950__)
412 reg_c.template AsType<int32x4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_i32_16x16x64_i8(
413 reg_a, reg_b, reg_c.template AsType<int32x4_t>()[Number<0>{}], 0, 0, 0);
414#else
415 ignore = reg_a;
416 ignore = reg_b;
417 ignore = reg_c;
418#endif // defined(__gfx950__)
419 }
420};
421
422template <index_t MPerWave, index_t NPerWave>
424
425template <>
427{
428 template <class FloatC>
429 __device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c)
430 {
431 reg_c.template AsType<int32x16_t>()(Number<0>{}) =
432 __builtin_amdgcn_mfma_i32_32x32x16_i8(bit_cast<int64_t>(reg_a),
433 bit_cast<int64_t>(reg_b),
434 reg_c.template AsType<int32x16_t>()[Number<0>{}],
435 0,
436 0,
437 0);
438 }
439};
440
441template <index_t MPerWave, index_t NPerWave>
443
444template <>
446{
447 template <class FloatC>
448 __device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c)
449 {
450 reg_c.template AsType<int32x4_t>()(Number<0>{}) =
451 __builtin_amdgcn_mfma_i32_16x16x32_i8(bit_cast<int64_t>(reg_a),
452 bit_cast<int64_t>(reg_b),
453 reg_c.template AsType<int32x4_t>()[Number<0>{}],
454 0,
455 0,
456 0);
457 }
458};
459
460template <index_t MPerWave, index_t NPerWave>
462
463template <>
465{
466 template <class FloatC>
467 __device__ static void Run(const double& reg_a, const double& reg_b, FloatC& reg_c)
468 {
469#if defined(__gfx90a__) || defined(__gfx94__)
470 reg_c.template AsType<double4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f64_16x16x4f64(
471 reg_a, reg_b, reg_c.template AsType<double4_t>()[Number<0>{}], 0, 0, 0);
472#else
473 ignore = reg_a;
474 ignore = reg_b;
475 ignore = reg_c;
476#endif
477 }
478};
479
480template <index_t MPerWave, index_t NPerWave>
482
489template <>
491{
492 template <class FloatC>
493 __device__ static void Run(const f8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c)
494 {
495#if defined(__gfx950__)
496 reg_c.template AsType<float16_t>()(Number<0>{}) =
497 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
498 reg_a,
499 reg_b,
500 reg_c.template AsType<float16_t>()[Number<0>{}],
501 0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
502 0, // blgp
503 0,
504 0,
505 0,
506 0);
507#else
508 ignore = reg_a;
509 ignore = reg_b;
510 ignore = reg_c;
511#endif
512 }
513
514 template <class FloatC>
515 __device__ static void Run(const bf8x32_t& reg_a, const bf8x32_t& reg_b, FloatC& reg_c)
516 {
517#if defined(__gfx950__)
518 reg_c.template AsType<float16_t>()(Number<0>{}) =
519 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
520 reg_a,
521 reg_b,
522 reg_c.template AsType<float16_t>()[Number<0>{}],
523 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
524 1, // blgp
525 0,
526 0,
527 0,
528 0);
529#else
530 ignore = reg_a;
531 ignore = reg_b;
532 ignore = reg_c;
533#endif
534 }
535
536 template <class FloatC>
537 __device__ static void Run(const bf8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c)
538 {
539#if defined(__gfx950__)
540 reg_c.template AsType<float16_t>()(Number<0>{}) =
541 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
542 reg_a,
543 reg_b,
544 reg_c.template AsType<float16_t>()[Number<0>{}],
545 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
546 0, // blgp
547 0,
548 0,
549 0,
550 0);
551#else
552 ignore = reg_a;
553 ignore = reg_b;
554 ignore = reg_c;
555#endif
556 }
557
558 template <class FloatC>
559 __device__ static void Run(const f8x32_t& reg_a, const bf8x32_t& reg_b, FloatC& reg_c)
560 {
561#if defined(__gfx950__)
562 reg_c.template AsType<float16_t>()(Number<0>{}) =
563 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
564 reg_a,
565 reg_b,
566 reg_c.template AsType<float16_t>()[Number<0>{}],
567 0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
568 1, // blgp
569 0,
570 0,
571 0,
572 0);
573#else
574 ignore = reg_a;
575 ignore = reg_b;
576 ignore = reg_c;
577#endif
578 }
579
580 template <class FloatC>
581 __device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c)
582 {
583#if defined(__gfx950__)
584
585 int32x4_t arg_a = bit_cast<int32x4_t>(reg_a);
586 int32x4_t arg_b = bit_cast<int32x4_t>(reg_b);
587
588 using arg_type = int32x8_t;
589
590 reg_c.template AsType<float16_t>()(Number<0>{}) =
591 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
592 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
593 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
594 reg_c.template AsType<float16_t>()[Number<0>{}],
595 4, // cbsz
596 4, // blgp
597 0, // OPSEL
598 0,
599 0, // OPSEL
600 0);
601#else
602 ignore = reg_a;
603 ignore = reg_b;
604 ignore = reg_c;
605#endif
606 }
607
608 template <class FloatC>
609 __device__ static void Run(const f6x32_t& reg_a, const f6x32_t& reg_b, FloatC& reg_c)
610 {
611#if defined(__gfx950__)
612
613 int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
614 int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
615
616 using arg_type = int32x8_t;
617
618 reg_c.template AsType<float16_t>()(Number<0>{}) =
619 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
620 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
621 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
622 reg_c.template AsType<float16_t>()[Number<0>{}],
623 2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
624 2, // blgp
625 0, // OPSEL
626 0,
627 0, // OPSEL
628 0);
629#else
630 ignore = reg_a;
631 ignore = reg_b;
632 ignore = reg_c;
633#endif
634 }
635
636 template <class FloatC>
637 __device__ static void Run(const bf6x32_t& reg_a, const bf6x32_t& reg_b, FloatC& reg_c)
638 {
639#if defined(__gfx950__)
640
641 int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
642 int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
643
644 using arg_type = int32x8_t;
645
646 reg_c.template AsType<float16_t>()(Number<0>{}) =
647 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
648 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
649 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
650 reg_c.template AsType<float16_t>()[Number<0>{}],
651 3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
652 3, // blgp
653 0, // OPSEL
654 0,
655 0, // OPSEL
656 0);
657#else
658 ignore = reg_a;
659 ignore = reg_b;
660 ignore = reg_c;
661#endif
662 }
663};
664
665template <index_t MPerWave, index_t NPerWave, index_t OpselA, index_t OpselB>
667
668template <index_t OpselA, index_t OpselB>
669struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32, OpselA, OpselB>
670{
671 template <class FloatC>
672 __device__ static void Run(const f8x32_t& reg_a,
673 const int32_t& scale_a,
674 const f8x32_t& reg_b,
675 const int32_t& scale_b,
676 FloatC& reg_c)
677 {
678#if defined(__gfx950__)
679 // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
680 reg_c.template AsType<float16_t>()(Number<0>{}) =
681 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
682 reg_a,
683 reg_b,
684 reg_c.template AsType<float16_t>()[Number<0>{}],
685 0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
686 0, // blgp
687 OpselA, // OPSEL
688 scale_a,
689 OpselB, // OPSEL
690 scale_b);
691 // XXX: Note on the scale_a and scale_b parameters:
692 // If compiler detects that one or both scales are constant values, it will treat that
693 // constant as F32 constant. I.e., if scale_a at some point was declared as
694 // `e8m0_bexp_t a_scale{1.0f}`, the instruction would only work if scale_a parameter is
695 // assigned value `bit_cast<int32_t>(static_cast<float>(a_scale))`.
696
697 // XXX: Note on the OPSEL parameters: Instruction always takes byte0 as a scale value even
698 // when OPSEL is set otherwise.
699#else
700 ignore = reg_a;
701 ignore = scale_a;
702 ignore = reg_b;
703 ignore = scale_b;
704 ignore = reg_c;
705#endif
706 }
707
708 template <class FloatC>
709 __device__ static void Run(const bf8x32_t& reg_a,
710 const int32_t& scale_a,
711 const bf8x32_t& reg_b,
712 const int32_t& scale_b,
713 FloatC& reg_c)
714 {
715#if defined(__gfx950__)
716 // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
717 reg_c.template AsType<float16_t>()(Number<0>{}) =
718 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
719 reg_a,
720 reg_b,
721 reg_c.template AsType<float16_t>()[Number<0>{}],
722 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
723 1, // blgp
724 OpselA, // OPSEL
725 scale_a,
726 OpselB, // OPSEL
727 scale_b);
728 // XXX: Note on the scale_a and scale_b parameters:
729 // If compiler detects that one or both scales are constant values, it will treat that
730 // constant as F32 constant. I.e., if scale_a at some point was declared as
731 // `e8m0_bexp_t a_scale{1.0f}`, the instruction would only work if scale_a parameter is
732 // assigned value `bit_cast<int32_t>(static_cast<float>(a_scale))`.
733
734 // XXX: Note on the OPSEL parameters: Instruction always takes byte0 as a scale value even
735 // when OPSEL is set otherwise.
736#else
737 ignore = reg_a;
738 ignore = scale_a;
739 ignore = reg_b;
740 ignore = scale_b;
741 ignore = reg_c;
742#endif
743 }
744
745 template <class FloatC>
746 __device__ static void Run(const bf8x32_t& reg_a,
747 const int32_t& scale_a,
748 const f8x32_t& reg_b,
749 const int32_t& scale_b,
750 FloatC& reg_c)
751 {
752#if defined(__gfx950__)
753 // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
754 reg_c.template AsType<float16_t>()(Number<0>{}) =
755 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
756 reg_a,
757 reg_b,
758 reg_c.template AsType<float16_t>()[Number<0>{}],
759 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
760 0, // blgp
761 OpselA, // OPSEL
762 scale_a,
763 OpselB, // OPSEL
764 scale_b);
765 // XXX: Note on the scale_a and scale_b parameters:
766 // If compiler detects that one or both scales are constant values, it will treat that
767 // constant as F32 constant. I.e., if scale_a at some point was declared as
768 // `e8m0_bexp_t a_scale{1.0f}`, the instruction would only work if scale_a parameter is
769 // assigned value `bit_cast<int32_t>(static_cast<float>(a_scale))`.
770
771 // XXX: Note on the OPSEL parameters: Instruction always takes byte0 as a scale value even
772 // when OPSEL is set otherwise.
773#else
774 ignore = reg_a;
775 ignore = scale_a;
776 ignore = reg_b;
777 ignore = scale_b;
778 ignore = reg_c;
779#endif
780 }
781
782 template <class FloatC>
783 __device__ static void Run(const f6x32_t& reg_a,
784 const int32_t scale_a,
785 const f6x32_t& reg_b,
786 const int32_t scale_b,
787 FloatC& reg_c)
788 {
789#if defined(__gfx950__)
790
791 int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
792 int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
793
794 using arg_type = int32x8_t;
795
796 reg_c.template AsType<float16_t>()(Number<0>{}) =
797 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
798 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
799 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
800 reg_c.template AsType<float16_t>()[Number<0>{}],
801 2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
802 2, // blgp
803 OpselA, // OPSEL
804 scale_a,
805 OpselB, // OPSEL
806 scale_b);
807#else
808 ignore = reg_a;
809 ignore = scale_a;
810 ignore = reg_b;
811 ignore = scale_b;
812 ignore = reg_c;
813#endif
814 }
815
816 template <class FloatC>
817 __device__ static void Run(const bf6x32_t& reg_a,
818 const int32_t scale_a,
819 const bf6x32_t& reg_b,
820 const int32_t scale_b,
821 FloatC& reg_c)
822 {
823#if defined(__gfx950__)
824
825 int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
826 int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
827
828 using arg_type = int32x8_t;
829
830 reg_c.template AsType<float16_t>()(Number<0>{}) =
831 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
832 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
833 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
834 reg_c.template AsType<float16_t>()[Number<0>{}],
835 3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
836 3, // blgp
837 OpselA, // OPSEL
838 scale_a,
839 OpselB, // OPSEL
840 scale_b);
841#else
842 ignore = reg_a;
843 ignore = scale_a;
844 ignore = reg_b;
845 ignore = scale_b;
846 ignore = reg_c;
847#endif
848 }
849
850 template <class FloatC>
851 __device__ static void Run(const f4x32_t& reg_a,
852 const int32_t scale_a,
853 const f4x32_t& reg_b,
854 const int32_t scale_b,
855 FloatC& reg_c)
856 {
857#if defined(__gfx950__)
858
859 int32x4_t arg_a = bit_cast<int32x4_t>(reg_a);
860 int32x4_t arg_b = bit_cast<int32x4_t>(reg_b);
861
862 using arg_type = int32x8_t;
863
864 reg_c.template AsType<float16_t>()(Number<0>{}) =
865 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
866 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
867 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
868 reg_c.template AsType<float16_t>()[Number<0>{}],
869 4, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
870 4, // blgp
871 OpselA, // OPSEL
872 scale_a,
873 OpselB, // OPSEL
874 scale_b);
875#else
876 ignore = reg_a;
877 ignore = scale_a;
878 ignore = reg_b;
879 ignore = scale_b;
880 ignore = reg_c;
881#endif
882 }
883};
884
885template <index_t MPerWave, index_t NPerWave, index_t OpselA, index_t OpselB>
887
888template <index_t OpselA, index_t OpselB>
889struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16, OpselA, OpselB>
890{
891 template <class FloatC>
892 __device__ static void Run(const f8x32_t& reg_a,
893 const int32_t& scale_a,
894 const f8x32_t& reg_b,
895 const int32_t& scale_b,
896 FloatC& reg_c)
897 {
898#if defined(__gfx950__)
899 // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
900 reg_c.template AsType<float4_t>()(Number<0>{}) =
901 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
902 reg_a,
903 reg_b,
904 reg_c.template AsType<float4_t>()[Number<0>{}],
905 0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
906 0, // blgp
907 OpselA, // OPSEL
908 scale_a,
909 OpselB, // OPSEL
910 scale_b);
911#else
912 ignore = reg_a;
913 ignore = scale_a;
914 ignore = reg_b;
915 ignore = scale_b;
916 ignore = reg_c;
917#endif
918 }
919
920 template <class FloatC>
921 __device__ static void Run(const bf8x32_t& reg_a,
922 const int32_t& scale_a,
923 const bf8x32_t& reg_b,
924 const int32_t& scale_b,
925 FloatC& reg_c)
926 {
927#if defined(__gfx950__)
928 // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
929 reg_c.template AsType<float4_t>()(Number<0>{}) =
930 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
931 reg_a,
932 reg_b,
933 reg_c.template AsType<float4_t>()[Number<0>{}],
934 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
935 1, // blgp
936 OpselA, // OPSEL
937 scale_a,
938 OpselB, // OPSEL
939 scale_b);
940#else
941 ignore = reg_a;
942 ignore = scale_a;
943 ignore = reg_b;
944 ignore = scale_b;
945 ignore = reg_c;
946#endif
947 }
948
949 template <class FloatC>
950 __device__ static void Run(const f8x32_t& reg_a,
951 const int32_t& scale_a,
952 const bf8x32_t& reg_b,
953 const int32_t& scale_b,
954 FloatC& reg_c)
955 {
956#if defined(__gfx950__)
957 // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
958 reg_c.template AsType<float4_t>()(Number<0>{}) =
959 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
960 reg_a,
961 reg_b,
962 reg_c.template AsType<float4_t>()[Number<0>{}],
963 0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
964 1, // blgp
965 OpselA, // OPSEL
966 scale_a,
967 OpselB, // OPSEL
968 scale_b);
969#else
970 ignore = reg_a;
971 ignore = scale_a;
972 ignore = reg_b;
973 ignore = scale_b;
974 ignore = reg_c;
975#endif
976 }
977
978 template <class FloatC>
979 __device__ static void Run(const bf8x32_t& reg_a,
980 const int32_t& scale_a,
981 const f8x32_t& reg_b,
982 const int32_t& scale_b,
983 FloatC& reg_c)
984 {
985#if defined(__gfx950__)
986 // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
987 reg_c.template AsType<float4_t>()(Number<0>{}) =
988 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
989 reg_a,
990 reg_b,
991 reg_c.template AsType<float4_t>()[Number<0>{}],
992 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
993 0, // blgp
994 OpselA, // OPSEL
995 scale_a,
996 OpselB, // OPSEL
997 scale_b);
998#else
999 ignore = reg_a;
1000 ignore = scale_a;
1001 ignore = reg_b;
1002 ignore = scale_b;
1003 ignore = reg_c;
1004#endif
1005 }
1006
1007 template <class FloatC>
1008 __device__ static void Run(const f6x32_t& reg_a,
1009 const int32_t scale_a,
1010 const f6x32_t& reg_b,
1011 const int32_t scale_b,
1012 FloatC& reg_c)
1013 {
1014#if defined(__gfx950__)
1015 int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
1016 int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
1017
1018 using arg_type = int32x8_t;
1019
1020 reg_c.template AsType<float4_t>()(Number<0>{}) =
1021 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1022 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
1023 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
1024 reg_c.template AsType<float4_t>()[Number<0>{}],
1025 2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1026 2, // blgp
1027 OpselA, // OPSEL
1028 scale_a,
1029 OpselB, // OPSEL
1030 scale_b);
1031#else
1032 ignore = reg_a;
1033 ignore = scale_a;
1034 ignore = reg_b;
1035 ignore = scale_b;
1036 ignore = reg_c;
1037#endif
1038 }
1039
1040 template <class FloatC>
1041 __device__ static void Run(const f6x16x2_t& reg_a,
1042 const int32_t scale_a,
1043 const f6x16x2_t& reg_b,
1044 const int32_t scale_b,
1045 FloatC& reg_c)
1046 {
1047#if defined(__gfx950__)
1048 using arg_type = int32x8_t;
1049 arg_type arg_a{
1050 static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[Number<0>{}][0]),
1051 static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[Number<0>{}][1]),
1052 static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[Number<0>{}][2]),
1053 static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[Number<1>{}][0]),
1054 static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[Number<1>{}][1]),
1055 static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[Number<1>{}][2]),
1056 0,
1057 0};
1058 arg_type arg_b{
1059 static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[Number<0>{}][0]),
1060 static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[Number<0>{}][1]),
1061 static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[Number<0>{}][2]),
1062 static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[Number<1>{}][0]),
1063 static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[Number<1>{}][1]),
1064 static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[Number<1>{}][2]),
1065 0,
1066 0};
1067
1068 reg_c.template AsType<float4_t>()(Number<0>{}) =
1069 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1070 arg_a,
1071 arg_b,
1072 reg_c.template AsType<float4_t>()[Number<0>{}],
1073 2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1074 2, // blgp
1075 OpselA, // OPSEL
1076 scale_a,
1077 OpselB, // OPSEL
1078 scale_b);
1079#else
1080 ignore = reg_a;
1081 ignore = scale_a;
1082 ignore = reg_b;
1083 ignore = scale_b;
1084 ignore = reg_c;
1085#endif
1086 }
1087
1088 template <class FloatC>
1089 __device__ static void Run(const bf6x32_t& reg_a,
1090 const int32_t scale_a,
1091 const bf6x32_t& reg_b,
1092 const int32_t scale_b,
1093 FloatC& reg_c)
1094 {
1095#if defined(__gfx950__)
1096 int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
1097 int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
1098
1099 using arg_type = int32x8_t;
1100
1101 reg_c.template AsType<float4_t>()(Number<0>{}) =
1102 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1103 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
1104 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
1105 reg_c.template AsType<float4_t>()[Number<0>{}],
1106 3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1107 3, // blgp
1108 OpselA, // OPSEL
1109 scale_a,
1110 OpselB, // OPSEL
1111 scale_b);
1112#else
1113 ignore = reg_a;
1114 ignore = scale_a;
1115 ignore = reg_b;
1116 ignore = scale_b;
1117 ignore = reg_c;
1118#endif
1119 }
1120
1121 template <class FloatC>
1122 __device__ static void Run(const bf6x16x2_t& reg_a,
1123 const int32_t scale_a,
1124 const bf6x16x2_t& reg_b,
1125 const int32_t scale_b,
1126 FloatC& reg_c)
1127 {
1128#if defined(__gfx950__)
1129 using arg_type = int32x8_t;
1130 arg_type arg_a{
1131 static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[Number<0>{}][0]),
1132 static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[Number<0>{}][1]),
1133 static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[Number<0>{}][2]),
1134 static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[Number<1>{}][0]),
1135 static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[Number<1>{}][1]),
1136 static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[Number<1>{}][2]),
1137 0,
1138 0};
1139 arg_type arg_b{
1140 static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[Number<0>{}][0]),
1141 static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[Number<0>{}][1]),
1142 static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[Number<0>{}][2]),
1143 static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[Number<1>{}][0]),
1144 static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[Number<1>{}][1]),
1145 static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[Number<1>{}][2]),
1146 0,
1147 0};
1148
1149 reg_c.template AsType<float4_t>()(Number<0>{}) =
1150 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1151 arg_a,
1152 arg_b,
1153 reg_c.template AsType<float4_t>()[Number<0>{}],
1154 3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1155 3, // blgp
1156 OpselA, // OPSEL
1157 scale_a,
1158 OpselB, // OPSEL
1159 scale_b);
1160#else
1161 ignore = reg_a;
1162 ignore = scale_a;
1163 ignore = reg_b;
1164 ignore = scale_b;
1165 ignore = reg_c;
1166#endif
1167 }
1168
1169 template <class FloatC>
1170 __device__ static void Run(const f4x32_t& reg_a,
1171 const int32_t scale_a,
1172 const f4x32_t& reg_b,
1173 const int32_t scale_b,
1174 FloatC& reg_c)
1175 {
1176#if defined(__gfx950__)
1177 int32x4_t arg_a = bit_cast<int32x4_t>(reg_a);
1178 int32x4_t arg_b = bit_cast<int32x4_t>(reg_b);
1179 using arg_type = int32x8_t;
1180 reg_c.template AsType<float4_t>()(Number<0>{}) =
1181 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1182 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
1183 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
1184 reg_c.template AsType<float4_t>()[Number<0>{}],
1185 4, // cbsz
1186 4, // blgp
1187 OpselA, // OPSEL
1188 scale_a,
1189 OpselB, // OPSEL
1190 scale_b);
1191#else
1192 ignore = reg_a;
1193 ignore = scale_a;
1194 ignore = reg_b;
1195 ignore = scale_b;
1196 ignore = reg_c;
1197#endif
1198 }
1199};
1200
1201template <index_t MPerWave, index_t NPerWave>
1203
1210template <>
1212{
1213 template <class FloatC>
1214 __device__ static void Run(const f8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c)
1215 {
1216#if defined(__gfx950__)
1217 // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
1218 reg_c.template AsType<float4_t>()(Number<0>{}) =
1219 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1220 reg_a,
1221 reg_b,
1222 reg_c.template AsType<float4_t>()[Number<0>{}],
1223 0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1224 0, // blgp
1225 0,
1226 0,
1227 0,
1228 0);
1229#else
1230 ignore = reg_a;
1231 ignore = reg_b;
1232 ignore = reg_c;
1233#endif
1234 }
1235
1236 template <class FloatC>
1237 __device__ static void Run(const bf8x32_t& reg_a, const bf8x32_t& reg_b, FloatC& reg_c)
1238 {
1239#if defined(__gfx950__)
1240 // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
1241 reg_c.template AsType<float4_t>()(Number<0>{}) =
1242 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1243 reg_a,
1244 reg_b,
1245 reg_c.template AsType<float4_t>()[Number<0>{}],
1246 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1247 1, // blgp
1248 0,
1249 0,
1250 0,
1251 0);
1252#else
1253 ignore = reg_a;
1254 ignore = reg_b;
1255 ignore = reg_c;
1256#endif
1257 }
1258
1259 template <class FloatC>
1260 __device__ static void Run(const bf8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c)
1261 {
1262#if defined(__gfx950__)
1263 // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
1264 reg_c.template AsType<float4_t>()(Number<0>{}) =
1265 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1266 reg_a,
1267 reg_b,
1268 reg_c.template AsType<float4_t>()[Number<0>{}],
1269 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1270 0, // blgp
1271 0,
1272 0,
1273 0,
1274 0);
1275#else
1276 ignore = reg_a;
1277 ignore = reg_b;
1278 ignore = reg_c;
1279#endif
1280 }
1281
1282 template <class FloatC>
1283 __device__ static void Run(const f8x32_t& reg_a, const bf8x32_t& reg_b, FloatC& reg_c)
1284 {
1285#if defined(__gfx950__)
1286 // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
1287 reg_c.template AsType<float4_t>()(Number<0>{}) =
1288 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1289 reg_a,
1290 reg_b,
1291 reg_c.template AsType<float4_t>()[Number<0>{}],
1292 0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1293 1, // blgp
1294 0,
1295 0,
1296 0,
1297 0);
1298#else
1299 ignore = reg_a;
1300 ignore = reg_b;
1301 ignore = reg_c;
1302#endif
1303 }
1304
1305 template <class FloatC>
1306 __device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c)
1307 {
1308#if defined(__gfx950__)
1309 int32x4_t arg_a = bit_cast<int32x4_t>(reg_a);
1310 int32x4_t arg_b = bit_cast<int32x4_t>(reg_b);
1311
1312 using arg_type = int32x8_t;
1313
1314 reg_c.template AsType<float4_t>()(Number<0>{}) =
1315 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1316 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
1317 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
1318 reg_c.template AsType<float4_t>()[Number<0>{}],
1319 4, // cbsz
1320 4, // blgp
1321 0, // OPSEL
1322 0,
1323 0, // OPSEL
1324 0);
1325#else
1326 ignore = reg_a;
1327 ignore = reg_b;
1328 ignore = reg_c;
1329#endif
1330 }
1331
1332 template <class FloatC>
1333 __device__ static void Run(const f6x32_t& reg_a, const f6x32_t& reg_b, FloatC& reg_c)
1334 {
1335#if defined(__gfx950__)
1336 int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
1337 int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
1338
1339 using arg_type = int32x8_t;
1340
1341 reg_c.template AsType<float4_t>()(Number<0>{}) =
1342 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1343 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
1344 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
1345 reg_c.template AsType<float4_t>()[Number<0>{}],
1346 2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1347 2, // blgp
1348 0, // OPSEL
1349 0,
1350 0, // OPSEL
1351 0);
1352#else
1353 ignore = reg_a;
1354 ignore = reg_b;
1355 ignore = reg_c;
1356#endif
1357 }
1358
1359 template <class FloatC>
1360 __device__ static void Run(const bf6x32_t& reg_a, const bf6x32_t& reg_b, FloatC& reg_c)
1361 {
1362#if defined(__gfx950__)
1363 int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
1364 int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
1365
1366 using arg_type = int32x8_t;
1367
1368 reg_c.template AsType<float4_t>()(Number<0>{}) =
1369 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1370 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
1371 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
1372 reg_c.template AsType<float4_t>()[Number<0>{}],
1373 3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1374 3, // blgp
1375 0, // OPSEL
1376 0,
1377 0, // OPSEL
1378 0);
1379#else
1380 ignore = reg_a;
1381 ignore = reg_b;
1382 ignore = reg_c;
1383#endif
1384 }
1385};
1386
1387template <index_t MPerWave, index_t NPerWave>
1389
1390template <>
1392{
1393 template <class FloatC>
1394 __device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
1395 {
1396#if defined(__gfx94__)
1397 reg_c.template AsType<float16_t>()(Number<0>{}) =
1398 __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
1399 bit_cast<int64_t>(reg_a),
1400 bit_cast<int64_t>(reg_b),
1401 reg_c.template AsType<float16_t>()[Number<0>{}],
1402 0,
1403 0,
1404 0);
1405#else
1406 vector_type<f8_t, 8> reg_a_v(reg_a);
1407 vector_type<f8_t, 8> reg_b_v(reg_b);
1408
1409 static_for<0, 8, 1>{}([&](auto k) {
1410 float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
1411 float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[Number<k>{}]);
1412
1413 intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c);
1414 });
1415#endif
1416 }
1417};
1418
1419template <index_t MPerWave, index_t NPerWave>
1421
1422template <>
1424{
1425 template <class FloatC>
1426 __device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
1427 {
1428#if defined(__gfx94__)
1429 reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
1430 bit_cast<int64_t>(reg_a),
1431 bit_cast<int64_t>(reg_b),
1432 reg_c.template AsType<float4_t>()[Number<0>{}],
1433 0,
1434 0,
1435 0);
1436#else
1437 vector_type<f8_t, 8> reg_a_v(reg_a);
1438 vector_type<f8_t, 8> reg_b_v(reg_b);
1439
1440 static_for<0, 8, 1>{}([&](auto k) {
1441 float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
1442 float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[Number<k>{}]);
1443
1444 intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c);
1445 });
1446#endif
1447 }
1448};
1449
1450template <index_t MPerWave, index_t NPerWave>
1452
1453template <>
1455{
1456 template <class FloatC>
1457 __device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
1458 {
1459#if defined(__gfx94__)
1460 reg_c.template AsType<float16_t>()(Number<0>{}) =
1461 __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
1462 bit_cast<int64_t>(reg_a),
1463 bit_cast<int64_t>(reg_b),
1464 reg_c.template AsType<float16_t>()[Number<0>{}],
1465 0,
1466 0,
1467 0);
1468#else
1469 vector_type<bf8_t, 8> reg_a_v(reg_a);
1470 vector_type<bf8_t, 8> reg_b_v(reg_b);
1471
1472 static_for<0, 8, 1>{}([&](auto k) {
1473 float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
1474 float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[Number<k>{}]);
1475
1476 intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c);
1477 });
1478#endif
1479 }
1480};
1481
1482template <index_t MPerWave, index_t NPerWave>
1484
1485template <>
1487{
1488 template <class FloatC>
1489 __device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
1490 {
1491#if defined(__gfx94__)
1492 reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
1493 bit_cast<int64_t>(reg_a),
1494 bit_cast<int64_t>(reg_b),
1495 reg_c.template AsType<float4_t>()[Number<0>{}],
1496 0,
1497 0,
1498 0);
1499#else
1500 vector_type<bf8_t, 8> reg_a_v(reg_a);
1501 vector_type<bf8_t, 8> reg_b_v(reg_b);
1502
1503 static_for<0, 8, 1>{}([&](auto k) {
1504 float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
1505 float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[Number<k>{}]);
1506
1507 intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c);
1508 });
1509#endif
1510 }
1511};
1512
1513template <index_t MPerWave, index_t NPerWave>
1515
1516template <>
1518{
1519 template <class FloatC>
1520 __device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
1521 {
1522#if defined(__gfx94__)
1523 reg_c.template AsType<float16_t>()(Number<0>{}) =
1524 __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
1525 bit_cast<int64_t>(reg_a),
1526 bit_cast<int64_t>(reg_b),
1527 reg_c.template AsType<float16_t>()[Number<0>{}],
1528 0,
1529 0,
1530 0);
1531#else
1532 vector_type<f8_t, 8> reg_a_v(reg_a);
1533 vector_type<bf8_t, 8> reg_b_v(reg_b);
1534
1535 static_for<0, 8, 1>{}([&](auto k) {
1536 float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
1537 float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[Number<k>{}]);
1538
1539 intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c);
1540 });
1541#endif
1542 }
1543};
1544
1545template <index_t MPerWave, index_t NPerWave>
1547
1548template <>
1550{
1551 template <class FloatC>
1552 __device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
1553 {
1554#if defined(__gfx94__)
1555 reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
1556 bit_cast<int64_t>(reg_a),
1557 bit_cast<int64_t>(reg_b),
1558 reg_c.template AsType<float4_t>()[Number<0>{}],
1559 0,
1560 0,
1561 0);
1562#else
1563 vector_type<f8_t, 8> reg_a_v(reg_a);
1564 vector_type<bf8_t, 8> reg_b_v(reg_b);
1565
1566 static_for<0, 8, 1>{}([&](auto k) {
1567 float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
1568 float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[Number<k>{}]);
1569
1570 intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c);
1571 });
1572#endif
1573 }
1574};
1575
1576template <index_t MPerWave, index_t NPerWave>
1578
1579template <>
1581{
1582 template <class FloatC>
1583 __device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
1584 {
1585#if defined(__gfx94__)
1586 reg_c.template AsType<float16_t>()(Number<0>{}) =
1587 __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
1588 bit_cast<int64_t>(reg_a),
1589 bit_cast<int64_t>(reg_b),
1590 reg_c.template AsType<float16_t>()[Number<0>{}],
1591 0,
1592 0,
1593 0);
1594#else
1595 vector_type<bf8_t, 8> reg_a_v(reg_a);
1596 vector_type<f8_t, 8> reg_b_v(reg_b);
1597
1598 static_for<0, 8, 1>{}([&](auto k) {
1599 float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
1600 float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[Number<k>{}]);
1601
1602 intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c);
1603 });
1604#endif
1605 }
1606};
1607
1608template <index_t MPerWave, index_t NPerWave>
1610
1611template <>
1613{
1614 template <class FloatC>
1615 __device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
1616 {
1617#if defined(__gfx94__)
1618 reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
1619 bit_cast<int64_t>(reg_a),
1620 bit_cast<int64_t>(reg_b),
1621 reg_c.template AsType<float4_t>()[Number<0>{}],
1622 0,
1623 0,
1624 0);
1625#else
1626 vector_type<bf8_t, 8> reg_a_v(reg_a);
1627 vector_type<f8_t, 8> reg_b_v(reg_b);
1628
1629 static_for<0, 8, 1>{}([&](auto k) {
1630 float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
1631 float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[Number<k>{}]);
1632
1633 intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c);
1634 });
1635#endif
1636 }
1637};
1638
1639/******************* tf32 *************************************/
1640template <index_t MPerWave, index_t NPerWave>
1642
1643template <>
1645{
1646 template <class FloatC>
1647 __device__ static void Run(const float2_t& reg_a, const float2_t& reg_b, FloatC& reg_c)
1648 {
1649#if defined(__gfx94__)
1650 reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x8_xf32(
1651 reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
1652#else
1653 ignore = reg_a;
1654 ignore = reg_b;
1655 ignore = reg_c;
1656#endif
1657 }
1658};
1659
1660template <index_t MPerWave, index_t NPerWave>
1662
1663template <>
1665{
1666 template <class FloatC>
1667 __device__ static void Run(const float2_t& reg_a, const float2_t& reg_b, FloatC& reg_c)
1668 {
1669#if defined(__gfx94__)
1670 reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4_xf32(
1671 reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
1672#else
1673 ignore = reg_a;
1674 ignore = reg_b;
1675 ignore = reg_c;
1676#endif
1677 }
1678};
1679
1680} // namespace ck
Definition ck.hpp:268
typename vector_type< int8_t, 8 >::type int8x8_t
Definition dtype_vector.hpp:2178
typename vector_type< bhalf_t, 8 >::type bhalf8_t
Definition dtype_vector.hpp:2162
typename vector_type< int8_t, 4 >::type int8x4_t
Definition dtype_vector.hpp:2177
integral_constant< index_t, N > Number
Definition number.hpp:12
typename vector_type< int32_t, 4 >::type int32x4_t
Definition dtype_vector.hpp:2168
typename vector_type< half_t, 8 >::type half8_t
Definition dtype_vector.hpp:2155
typename vector_type< int32_t, 8 >::type int32x8_t
Definition dtype_vector.hpp:2170
typename vector_type< float, 2 >::type float2_t
Definition dtype_vector.hpp:2145
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
typename vector_type< int8_t, 16 >::type int8x16_t
Definition dtype_vector.hpp:2179
typename vector_type< bhalf_t, 4 >::type bhalf4_t
Definition dtype_vector.hpp:2161
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
typename vector_type< bf6x32_pk_t, 1 >::type bf6x32_t
Definition dtype_vector.hpp:2273
typename vector_type< int32_t, 6 >::type int32x6_t
Definition dtype_vector.hpp:2169
typename vector_type< f6x32_pk_t, 1 >::type f6x32_t
Definition dtype_vector.hpp:2268
typename vector_type< bhalf_t, 2 >::type bhalf2_t
Definition dtype_vector.hpp:2160
typename vector_type< bf6x16_pk_t, 2 >::type bf6x16x2_t
Definition dtype_vector.hpp:2272
typename vector_type< f4x2_pk_t, 16 >::type f4x32_t
Definition dtype_vector.hpp:2262
__host__ __device__ constexpr Y bit_cast(const X &x)
Definition type.hpp:306
typename vector_type< half_t, 4 >::type half4_t
Definition dtype_vector.hpp:2154
typename vector_type< f6x16_pk_t, 2 >::type f6x16x2_t
Definition dtype_vector.hpp:2267
signed int int32_t
Definition stdint.h:123
static __device__ void Run(const bf6x32_t &reg_a, const bf6x32_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:1360
static __device__ void Run(const f6x32_t &reg_a, const f6x32_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:1333
static __device__ void Run(const f8x32_t &reg_a, const bf8x32_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:1283
static __device__ void Run(const bf8x32_t &reg_a, const bf8x32_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:1237
static __device__ void Run(const bf8x32_t &reg_a, const f8x32_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:1260
static __device__ void Run(const f4x32_t &reg_a, const f4x32_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:1306
static __device__ void Run(const f8x32_t &reg_a, const f8x32_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:1214
Definition amd_xdlops.hpp:1202
static __device__ void Run(const bhalf4_t &reg_a, const bhalf4_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:309
Definition amd_xdlops.hpp:303
static __device__ void Run(const half4_t &reg_a, const half4_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:199
Definition amd_xdlops.hpp:193
static __device__ void Run(const float &reg_a, const float &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:76
Definition amd_xdlops.hpp:70
static __device__ void Run(const bhalf8_t &reg_a, const bhalf8_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:275
Definition amd_xdlops.hpp:269
static __device__ void Run(const bf8x8_t &reg_a, const bf8x8_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:1489
Definition amd_xdlops.hpp:1483
static __device__ void Run(const bf8x8_t &reg_a, const f8x8_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:1615
Definition amd_xdlops.hpp:1609
static __device__ void Run(const half8_t &reg_a, const half8_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:165
Definition amd_xdlops.hpp:159
static __device__ void Run(const f8x8_t &reg_a, const bf8x8_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:1552
Definition amd_xdlops.hpp:1546
static __device__ void Run(const f8x8_t &reg_a, const f8x8_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:1426
Definition amd_xdlops.hpp:1420
static __device__ void Run(const half4_t &reg_a, const half4_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:213
Definition amd_xdlops.hpp:207
static __device__ void Run(const float &reg_a, const float &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:62
Definition amd_xdlops.hpp:56
static __device__ void Run(const bhalf2_t &reg_a, const bhalf2_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:337
Definition amd_xdlops.hpp:331
static __device__ void Run(const float2_t &reg_a, const float2_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:1647
Definition amd_xdlops.hpp:1641
static __device__ void Run(const bhalf8_t &reg_a, const bhalf8_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:255
Definition amd_xdlops.hpp:249
static __device__ void Run(const bf8x8_t &reg_a, const bf8x8_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:1457
Definition amd_xdlops.hpp:1451
static __device__ void Run(const bf8x8_t &reg_a, const f8x8_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:1583
Definition amd_xdlops.hpp:1577
static __device__ void Run(const half8_t &reg_a, const half8_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:145
Definition amd_xdlops.hpp:139
static __device__ void Run(const f8x8_t &reg_a, const bf8x8_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:1520
Definition amd_xdlops.hpp:1514
static __device__ void Run(const f8x8_t &reg_a, const f8x8_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:1394
Definition amd_xdlops.hpp:1388
static __device__ void Run(const float &reg_a, const float &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:34
static __device__ void Run(const float &reg_a, const float &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:21
Definition amd_xdlops.hpp:15
static __device__ void Run(const float &reg_a, const float &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:48
Definition amd_xdlops.hpp:42
static __device__ void Run(const bhalf2_t &reg_a, const bhalf2_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:323
Definition amd_xdlops.hpp:317
static __device__ void Run(const half4_t &reg_a, const half4_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:131
static __device__ void Run(const half4_t &reg_a, const half4_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:118
Definition amd_xdlops.hpp:112
static __device__ void Run(const float2_t &reg_a, const float2_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:1667
Definition amd_xdlops.hpp:1661
static __device__ void Run(const bf8x32_t &reg_a, const f8x32_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:537
static __device__ void Run(const f8x32_t &reg_a, const bf8x32_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:559
static __device__ void Run(const bf6x32_t &reg_a, const bf6x32_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:637
static __device__ void Run(const f6x32_t &reg_a, const f6x32_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:609
static __device__ void Run(const f8x32_t &reg_a, const f8x32_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:493
static __device__ void Run(const f4x32_t &reg_a, const f4x32_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:581
static __device__ void Run(const bf8x32_t &reg_a, const bf8x32_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:515
Definition amd_xdlops.hpp:481
static __device__ void Run(const bhalf4_t &reg_a, const bhalf4_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:295
Definition amd_xdlops.hpp:289
static __device__ void Run(const half4_t &reg_a, const half4_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:185
Definition amd_xdlops.hpp:179
static __device__ void Run(const float &reg_a, const float &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:90
static __device__ void Run(const float &reg_a, const float &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:101
Definition amd_xdlops.hpp:84
static __device__ void Run(const half4_t &reg_a, const half4_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:227
static __device__ void Run(const half4_t &reg_a, const half4_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:238
Definition amd_xdlops.hpp:221
static __device__ void Run(const double &reg_a, const double &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:467
Definition amd_xdlops.hpp:461
static __device__ void Run(const int8x4_t &reg_a, const int8x4_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:370
Definition amd_xdlops.hpp:364
static __device__ void Run(const int8x8_t &reg_a, const int8x8_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:448
Definition amd_xdlops.hpp:442
static __device__ void Run(const int8x16_t &reg_a, const int8x16_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:409
Definition amd_xdlops.hpp:403
static __device__ void Run(const int8x8_t &reg_a, const int8x8_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:429
Definition amd_xdlops.hpp:423
static __device__ void Run(const int8x16_t &reg_a, const int8x16_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:389
Definition amd_xdlops.hpp:383
static __device__ void Run(const int8x4_t &reg_a, const int8x4_t &reg_b, FloatC &reg_c)
Definition amd_xdlops.hpp:351
Definition amd_xdlops.hpp:345
static __device__ void Run(const f6x16x2_t &reg_a, const int32_t scale_a, const f6x16x2_t &reg_b, const int32_t scale_b, FloatC &reg_c)
Definition amd_xdlops.hpp:1041
static __device__ void Run(const f4x32_t &reg_a, const int32_t scale_a, const f4x32_t &reg_b, const int32_t scale_b, FloatC &reg_c)
Definition amd_xdlops.hpp:1170
static __device__ void Run(const f6x32_t &reg_a, const int32_t scale_a, const f6x32_t &reg_b, const int32_t scale_b, FloatC &reg_c)
Definition amd_xdlops.hpp:1008
static __device__ void Run(const f8x32_t &reg_a, const int32_t &scale_a, const f8x32_t &reg_b, const int32_t &scale_b, FloatC &reg_c)
Definition amd_xdlops.hpp:892
static __device__ void Run(const bf8x32_t &reg_a, const int32_t &scale_a, const f8x32_t &reg_b, const int32_t &scale_b, FloatC &reg_c)
Definition amd_xdlops.hpp:979
static __device__ void Run(const f8x32_t &reg_a, const int32_t &scale_a, const bf8x32_t &reg_b, const int32_t &scale_b, FloatC &reg_c)
Definition amd_xdlops.hpp:950
static __device__ void Run(const bf6x16x2_t &reg_a, const int32_t scale_a, const bf6x16x2_t &reg_b, const int32_t scale_b, FloatC &reg_c)
Definition amd_xdlops.hpp:1122
static __device__ void Run(const bf6x32_t &reg_a, const int32_t scale_a, const bf6x32_t &reg_b, const int32_t scale_b, FloatC &reg_c)
Definition amd_xdlops.hpp:1089
static __device__ void Run(const bf8x32_t &reg_a, const int32_t &scale_a, const bf8x32_t &reg_b, const int32_t &scale_b, FloatC &reg_c)
Definition amd_xdlops.hpp:921
Definition amd_xdlops.hpp:886
static __device__ void Run(const f8x32_t &reg_a, const int32_t &scale_a, const f8x32_t &reg_b, const int32_t &scale_b, FloatC &reg_c)
Definition amd_xdlops.hpp:672
static __device__ void Run(const f6x32_t &reg_a, const int32_t scale_a, const f6x32_t &reg_b, const int32_t scale_b, FloatC &reg_c)
Definition amd_xdlops.hpp:783
static __device__ void Run(const bf8x32_t &reg_a, const int32_t &scale_a, const bf8x32_t &reg_b, const int32_t &scale_b, FloatC &reg_c)
Definition amd_xdlops.hpp:709
static __device__ void Run(const bf8x32_t &reg_a, const int32_t &scale_a, const f8x32_t &reg_b, const int32_t &scale_b, FloatC &reg_c)
Definition amd_xdlops.hpp:746
static __device__ void Run(const f4x32_t &reg_a, const int32_t scale_a, const f4x32_t &reg_b, const int32_t scale_b, FloatC &reg_c)
Definition amd_xdlops.hpp:851
static __device__ void Run(const bf6x32_t &reg_a, const int32_t scale_a, const bf6x32_t &reg_b, const int32_t scale_b, FloatC &reg_c)
Definition amd_xdlops.hpp:817
Definition amd_xdlops.hpp:666
Definition functional2.hpp:33
Definition dtype_vector.hpp:10