core/stdarch/crates/core_arch/src/x86_64/
amx.rs

1use crate::core_arch::{simd::*, x86::*};
2
3#[cfg(test)]
4use stdarch_test::assert_instr;
5
6/// Load tile configuration from a 64-byte memory location specified by mem_addr.
7/// The tile configuration format is specified below, and includes the tile type pallette,
8/// the number of bytes per row, and the number of rows. If the specified pallette_id is zero,
9/// that signifies the init state for both the tile config and the tile data, and the tiles are zeroed.
10/// Any invalid configurations will result in #GP fault.
11///
12/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_loadconfig&ig_expand=6875)
13#[inline]
14#[target_feature(enable = "amx-tile")]
15#[cfg_attr(test, assert_instr(ldtilecfg))]
16#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
17pub unsafe fn _tile_loadconfig(mem_addr: *const u8) {
18    ldtilecfg(mem_addr);
19}
20
21/// Stores the current tile configuration to a 64-byte memory location specified by mem_addr.
22/// The tile configuration format is specified below, and includes the tile type pallette,
23/// the number of bytes per row, and the number of rows. If tiles are not configured, all zeroes will be stored to memory.
24///
25/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_storeconfig&ig_expand=6879)
26#[inline]
27#[target_feature(enable = "amx-tile")]
28#[cfg_attr(test, assert_instr(sttilecfg))]
29#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
30pub unsafe fn _tile_storeconfig(mem_addr: *mut u8) {
31    sttilecfg(mem_addr);
32}
33
34/// Load tile rows from memory specifieid by base address and stride into destination tile dst using the tile configuration previously configured via _tile_loadconfig.
35///
36/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_loadd&ig_expand=6877)
37#[inline]
38#[rustc_legacy_const_generics(0)]
39#[target_feature(enable = "amx-tile")]
40#[cfg_attr(test, assert_instr(tileloadd, DST = 0))]
41#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
42pub unsafe fn _tile_loadd<const DST: i32>(base: *const u8, stride: usize) {
43    static_assert_uimm_bits!(DST, 3);
44    tileloadd64(DST as i8, base, stride);
45}
46
47/// Release the tile configuration to return to the init state, which releases all storage it currently holds.
48///
49/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_release&ig_expand=6878)
50#[inline]
51#[target_feature(enable = "amx-tile")]
52#[cfg_attr(test, assert_instr(tilerelease))]
53#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
54pub unsafe fn _tile_release() {
55    tilerelease();
56}
57
58/// Store the tile specified by src to memory specifieid by base address and stride using the tile configuration previously configured via _tile_loadconfig.
59///
60/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_stored&ig_expand=6881)
61#[inline]
62#[rustc_legacy_const_generics(0)]
63#[target_feature(enable = "amx-tile")]
64#[cfg_attr(test, assert_instr(tilestored, DST = 0))]
65#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
66pub unsafe fn _tile_stored<const DST: i32>(base: *mut u8, stride: usize) {
67    static_assert_uimm_bits!(DST, 3);
68    tilestored64(DST as i8, base, stride);
69}
70
71/// Load tile rows from memory specifieid by base address and stride into destination tile dst using the tile configuration
72/// previously configured via _tile_loadconfig. This intrinsic provides a hint to the implementation that the data will
73/// likely not be reused in the near future and the data caching can be optimized accordingly.
74///
75/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_stream_loadd&ig_expand=6883)
76#[inline]
77#[rustc_legacy_const_generics(0)]
78#[target_feature(enable = "amx-tile")]
79#[cfg_attr(test, assert_instr(tileloaddt1, DST = 0))]
80#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
81pub unsafe fn _tile_stream_loadd<const DST: i32>(base: *const u8, stride: usize) {
82    static_assert_uimm_bits!(DST, 3);
83    tileloaddt164(DST as i8, base, stride);
84}
85
86/// Zero the tile specified by tdest.
87///
88/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_zero&ig_expand=6885)
89#[inline]
90#[rustc_legacy_const_generics(0)]
91#[target_feature(enable = "amx-tile")]
92#[cfg_attr(test, assert_instr(tilezero, DST = 0))]
93#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
94pub unsafe fn _tile_zero<const DST: i32>() {
95    static_assert_uimm_bits!(DST, 3);
96    tilezero(DST as i8);
97}
98
99/// Compute dot-product of BF16 (16-bit) floating-point pairs in tiles a and b,
100/// accumulating the intermediate single-precision (32-bit) floating-point elements
101/// with elements in dst, and store the 32-bit result back to tile dst.
102///
103/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbf16ps&ig_expand=6864)
104#[inline]
105#[rustc_legacy_const_generics(0, 1, 2)]
106#[target_feature(enable = "amx-bf16")]
107#[cfg_attr(test, assert_instr(tdpbf16ps, DST = 0, A = 1, B = 2))]
108#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
109pub unsafe fn _tile_dpbf16ps<const DST: i32, const A: i32, const B: i32>() {
110    static_assert_uimm_bits!(DST, 3);
111    static_assert_uimm_bits!(A, 3);
112    static_assert_uimm_bits!(B, 3);
113    tdpbf16ps(DST as i8, A as i8, B as i8);
114}
115
116/// Compute dot-product of bytes in tiles with a source/destination accumulator.
117/// Multiply groups of 4 adjacent pairs of signed 8-bit integers in a with corresponding
118/// signed 8-bit integers in b, producing 4 intermediate 32-bit results.
119/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst.
120///
121/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbssd&ig_expand=6866)
122#[inline]
123#[rustc_legacy_const_generics(0, 1, 2)]
124#[target_feature(enable = "amx-int8")]
125#[cfg_attr(test, assert_instr(tdpbssd, DST = 0, A = 1, B = 2))]
126#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
127pub unsafe fn _tile_dpbssd<const DST: i32, const A: i32, const B: i32>() {
128    static_assert_uimm_bits!(DST, 3);
129    static_assert_uimm_bits!(A, 3);
130    static_assert_uimm_bits!(B, 3);
131    tdpbssd(DST as i8, A as i8, B as i8);
132}
133
134/// Compute dot-product of bytes in tiles with a source/destination accumulator.
135/// Multiply groups of 4 adjacent pairs of signed 8-bit integers in a with corresponding
136/// unsigned 8-bit integers in b, producing 4 intermediate 32-bit results.
137/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst.
138///
139/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbsud&ig_expand=6868)
140#[inline]
141#[rustc_legacy_const_generics(0, 1, 2)]
142#[target_feature(enable = "amx-int8")]
143#[cfg_attr(test, assert_instr(tdpbsud, DST = 0, A = 1, B = 2))]
144#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
145pub unsafe fn _tile_dpbsud<const DST: i32, const A: i32, const B: i32>() {
146    static_assert_uimm_bits!(DST, 3);
147    static_assert_uimm_bits!(A, 3);
148    static_assert_uimm_bits!(B, 3);
149    tdpbsud(DST as i8, A as i8, B as i8);
150}
151
152/// Compute dot-product of bytes in tiles with a source/destination accumulator.
153/// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding
154/// signed 8-bit integers in b, producing 4 intermediate 32-bit results.
155/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst.
156///
157/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbusd&ig_expand=6870)
158#[inline]
159#[rustc_legacy_const_generics(0, 1, 2)]
160#[target_feature(enable = "amx-int8")]
161#[cfg_attr(test, assert_instr(tdpbusd, DST = 0, A = 1, B = 2))]
162#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
163pub unsafe fn _tile_dpbusd<const DST: i32, const A: i32, const B: i32>() {
164    static_assert_uimm_bits!(DST, 3);
165    static_assert_uimm_bits!(A, 3);
166    static_assert_uimm_bits!(B, 3);
167    tdpbusd(DST as i8, A as i8, B as i8);
168}
169
170/// Compute dot-product of bytes in tiles with a source/destination accumulator.
171/// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding
172/// unsigned 8-bit integers in b, producing 4 intermediate 32-bit results.
173/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst.
174///
175/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbuud&ig_expand=6872)
176#[inline]
177#[rustc_legacy_const_generics(0, 1, 2)]
178#[target_feature(enable = "amx-int8")]
179#[cfg_attr(test, assert_instr(tdpbuud, DST = 0, A = 1, B = 2))]
180#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
181pub unsafe fn _tile_dpbuud<const DST: i32, const A: i32, const B: i32>() {
182    static_assert_uimm_bits!(DST, 3);
183    static_assert_uimm_bits!(A, 3);
184    static_assert_uimm_bits!(B, 3);
185    tdpbuud(DST as i8, A as i8, B as i8);
186}
187
188/// Compute dot-product of FP16 (16-bit) floating-point pairs in tiles a and b,
189/// accumulating the intermediate single-precision (32-bit) floating-point elements
190///  with elements in dst, and store the 32-bit result back to tile dst.
191///
192/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpfp16ps&ig_expand=6874)
193#[inline]
194#[rustc_legacy_const_generics(0, 1, 2)]
195#[target_feature(enable = "amx-fp16")]
196#[cfg_attr(test, assert_instr(tdpfp16ps, DST = 0, A = 1, B = 2))]
197#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
198pub unsafe fn _tile_dpfp16ps<const DST: i32, const A: i32, const B: i32>() {
199    static_assert_uimm_bits!(DST, 3);
200    static_assert_uimm_bits!(A, 3);
201    static_assert_uimm_bits!(B, 3);
202    tdpfp16ps(DST as i8, A as i8, B as i8);
203}
204
205/// Perform matrix multiplication of two tiles containing complex elements and accumulate the results into a packed single precision tile.
206/// Each dword element in input tiles a and b is interpreted as a complex number with FP16 real part and FP16 imaginary part.
207/// Calculates the imaginary part of the result. For each possible combination of (row of a, column of b),
208/// it performs a set of multiplication and accumulations on all corresponding complex numbers (one from a and one from b).
209/// The imaginary part of the a element is multiplied with the real part of the corresponding b element, and the real part of
210/// the a element is multiplied with the imaginary part of the corresponding b elements. The two accumulated results are added,
211/// and then accumulated into the corresponding row and column of dst.
212///
213/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_cmmimfp16ps&ig_expand=6860)
214#[inline]
215#[rustc_legacy_const_generics(0, 1, 2)]
216#[target_feature(enable = "amx-complex")]
217#[cfg_attr(test, assert_instr(tcmmimfp16ps, DST = 0, A = 1, B = 2))]
218#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
219pub unsafe fn _tile_cmmimfp16ps<const DST: i32, const A: i32, const B: i32>() {
220    static_assert_uimm_bits!(DST, 3);
221    static_assert_uimm_bits!(A, 3);
222    static_assert_uimm_bits!(B, 3);
223    tcmmimfp16ps(DST as i8, A as i8, B as i8);
224}
225
226/// Perform matrix multiplication of two tiles containing complex elements and accumulate the results into a packed single precision tile.
227/// Each dword element in input tiles a and b is interpreted as a complex number with FP16 real part and FP16 imaginary part.
228/// Calculates the real part of the result. For each possible combination of (row of a, column of b),
229/// it performs a set of multiplication and accumulations on all corresponding complex numbers (one from a and one from b).
230/// The real part of the a element is multiplied with the real part of the corresponding b element, and the negated imaginary part of
231/// the a element is multiplied with the imaginary part of the corresponding b elements.
232/// The two accumulated results are added, and then accumulated into the corresponding row and column of dst.
233///
234/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_cmmrlfp16ps&ig_expand=6862)
235#[inline]
236#[rustc_legacy_const_generics(0, 1, 2)]
237#[target_feature(enable = "amx-complex")]
238#[cfg_attr(test, assert_instr(tcmmrlfp16ps, DST = 0, A = 1, B = 2))]
239#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
240pub unsafe fn _tile_cmmrlfp16ps<const DST: i32, const A: i32, const B: i32>() {
241    static_assert_uimm_bits!(DST, 3);
242    static_assert_uimm_bits!(A, 3);
243    static_assert_uimm_bits!(B, 3);
244    tcmmrlfp16ps(DST as i8, A as i8, B as i8);
245}
246
247/// Compute dot-product of BF8 (8-bit E5M2) floating-point elements in tile a and BF8 (8-bit E5M2)
248/// floating-point elements in tile b, accumulating the intermediate single-precision
249/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result
250/// back to tile dst.
251#[inline]
252#[rustc_legacy_const_generics(0, 1, 2)]
253#[target_feature(enable = "amx-fp8")]
254#[cfg_attr(
255    all(test, any(target_os = "linux", target_env = "msvc")),
256    assert_instr(tdpbf8ps, DST = 0, A = 1, B = 2)
257)]
258#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
259pub unsafe fn _tile_dpbf8ps<const DST: i32, const A: i32, const B: i32>() {
260    static_assert_uimm_bits!(DST, 3);
261    static_assert_uimm_bits!(A, 3);
262    static_assert_uimm_bits!(B, 3);
263    tdpbf8ps(DST as i8, A as i8, B as i8);
264}
265
266/// Compute dot-product of BF8 (8-bit E5M2) floating-point elements in tile a and HF8
267/// (8-bit E4M3) floating-point elements in tile b, accumulating the intermediate single-precision
268/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result
269/// back to tile dst.
270#[inline]
271#[rustc_legacy_const_generics(0, 1, 2)]
272#[target_feature(enable = "amx-fp8")]
273#[cfg_attr(
274    all(test, any(target_os = "linux", target_env = "msvc")),
275    assert_instr(tdpbhf8ps, DST = 0, A = 1, B = 2)
276)]
277#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
278pub unsafe fn _tile_dpbhf8ps<const DST: i32, const A: i32, const B: i32>() {
279    static_assert_uimm_bits!(DST, 3);
280    static_assert_uimm_bits!(A, 3);
281    static_assert_uimm_bits!(B, 3);
282    tdpbhf8ps(DST as i8, A as i8, B as i8);
283}
284
285/// Compute dot-product of HF8 (8-bit E4M3) floating-point elements in tile a and BF8
286/// (8-bit E5M2) floating-point elements in tile b, accumulating the intermediate single-precision
287/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result
288/// back to tile dst.
289#[inline]
290#[rustc_legacy_const_generics(0, 1, 2)]
291#[target_feature(enable = "amx-fp8")]
292#[cfg_attr(
293    all(test, any(target_os = "linux", target_env = "msvc")),
294    assert_instr(tdphbf8ps, DST = 0, A = 1, B = 2)
295)]
296#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
297pub unsafe fn _tile_dphbf8ps<const DST: i32, const A: i32, const B: i32>() {
298    static_assert_uimm_bits!(DST, 3);
299    static_assert_uimm_bits!(A, 3);
300    static_assert_uimm_bits!(B, 3);
301    tdphbf8ps(DST as i8, A as i8, B as i8);
302}
303
304/// Compute dot-product of HF8 (8-bit E4M3) floating-point elements in tile a and HF8 (8-bit E4M3)
305/// floating-point elements in tile b, accumulating the intermediate single-precision
306/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result
307/// back to tile dst.
308#[inline]
309#[rustc_legacy_const_generics(0, 1, 2)]
310#[target_feature(enable = "amx-fp8")]
311#[cfg_attr(
312    all(test, any(target_os = "linux", target_env = "msvc")),
313    assert_instr(tdphf8ps, DST = 0, A = 1, B = 2)
314)]
315#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
316pub unsafe fn _tile_dphf8ps<const DST: i32, const A: i32, const B: i32>() {
317    static_assert_uimm_bits!(DST, 3);
318    static_assert_uimm_bits!(A, 3);
319    static_assert_uimm_bits!(B, 3);
320    tdphf8ps(DST as i8, A as i8, B as i8);
321}
322
323/// Load tile rows from memory specified by base address and stride into destination tile dst
324/// using the tile configuration previously configured via _tile_loadconfig.
325/// Additionally, this intrinsic indicates the source memory location is likely to become
326/// read-shared by multiple processors, i.e., read in the future by at least one other processor
327/// before it is written, assuming it is ever written in the future.
328#[inline]
329#[rustc_legacy_const_generics(0)]
330#[target_feature(enable = "amx-movrs")]
331#[cfg_attr(
332    all(test, any(target_os = "linux", target_env = "msvc")),
333    assert_instr(tileloaddrs, DST = 0)
334)]
335#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
336pub unsafe fn _tile_loaddrs<const DST: i32>(base: *const u8, stride: usize) {
337    static_assert_uimm_bits!(DST, 3);
338    tileloaddrs64(DST as i8, base, stride);
339}
340
341/// Load tile rows from memory specified by base address and stride into destination tile dst
342/// using the tile configuration previously configured via _tile_loadconfig.
343/// Provides a hint to the implementation that the data would be reused but does not need
344/// to be resident in the nearest cache levels.
345/// Additionally, this intrinsic indicates the source memory location is likely to become
346/// read-shared by multiple processors, i.e., read in the future by at least one other processor
347/// before it is written, assuming it is ever written in the future.
348#[inline]
349#[rustc_legacy_const_generics(0)]
350#[target_feature(enable = "amx-movrs")]
351#[cfg_attr(
352    all(test, any(target_os = "linux", target_env = "msvc")),
353    assert_instr(tileloaddrst1, DST = 0)
354)]
355#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
356pub unsafe fn _tile_stream_loaddrs<const DST: i32>(base: *const u8, stride: usize) {
357    static_assert_uimm_bits!(DST, 3);
358    tileloaddrst164(DST as i8, base, stride);
359}
360
361/// Perform matrix multiplication of two tiles a and b, containing packed single precision (32-bit)
362/// floating-point elements, which are converted to TF32 (tensor-float32) format, and accumulate the
363///  results into a packed single precision tile.
364/// For each possible combination of (row of a, column of b), it performs
365///  - convert to TF32
366///  - multiply the corresponding elements of a and b
367///  - accumulate the results into the corresponding row and column of dst using round-to-nearest-even
368/// rounding mode.
369/// Output FP32 denormals are always flushed to zero, input single precision denormals are always
370/// handled and *not* treated as zero.
371#[inline]
372#[rustc_legacy_const_generics(0, 1, 2)]
373#[target_feature(enable = "amx-tf32")]
374#[cfg_attr(
375    all(test, any(target_os = "linux", target_env = "msvc")),
376    assert_instr(tmmultf32ps, DST = 0, A = 1, B = 2)
377)]
378#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
379pub unsafe fn _tile_mmultf32ps<const DST: i32, const A: i32, const B: i32>() {
380    static_assert_uimm_bits!(DST, 3);
381    static_assert_uimm_bits!(A, 3);
382    static_assert_uimm_bits!(B, 3);
383    tmmultf32ps(DST as i8, A as i8, B as i8);
384}
385
386/// Moves a row from a tile register to a zmm register, converting the packed 32-bit signed integer
387/// elements to packed single-precision (32-bit) floating-point elements.
388#[inline]
389#[rustc_legacy_const_generics(0)]
390#[target_feature(enable = "amx-avx512,avx10.2")]
391#[cfg_attr(
392    all(test, any(target_os = "linux", target_env = "msvc")),
393    assert_instr(tcvtrowd2ps, TILE = 0)
394)]
395#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
396pub unsafe fn _tile_cvtrowd2ps<const TILE: i32>(row: u32) -> __m512 {
397    static_assert_uimm_bits!(TILE, 3);
398    tcvtrowd2ps(TILE as i8, row).as_m512()
399}
400
401/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
402/// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting
403/// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector.
404#[inline]
405#[rustc_legacy_const_generics(0)]
406#[target_feature(enable = "amx-avx512,avx10.2")]
407#[cfg_attr(
408    all(test, any(target_os = "linux", target_env = "msvc")),
409    assert_instr(tcvtrowps2phh, TILE = 0)
410)]
411#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
412pub unsafe fn _tile_cvtrowps2phh<const TILE: i32>(row: u32) -> __m512h {
413    static_assert_uimm_bits!(TILE, 3);
414    tcvtrowps2phh(TILE as i8, row).as_m512h()
415}
416
417/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
418/// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting
419/// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector.
420#[inline]
421#[rustc_legacy_const_generics(0)]
422#[target_feature(enable = "amx-avx512,avx10.2")]
423#[cfg_attr(
424    all(test, any(target_os = "linux", target_env = "msvc")),
425    assert_instr(tcvtrowps2phl, TILE = 0)
426)]
427#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
428pub unsafe fn _tile_cvtrowps2phl<const TILE: i32>(row: u32) -> __m512h {
429    static_assert_uimm_bits!(TILE, 3);
430    tcvtrowps2phl(TILE as i8, row).as_m512h()
431}
432
433/// Moves one row of tile data into a zmm vector register
434#[inline]
435#[rustc_legacy_const_generics(0)]
436#[target_feature(enable = "amx-avx512,avx10.2")]
437#[cfg_attr(
438    all(test, any(target_os = "linux", target_env = "msvc")),
439    assert_instr(tilemovrow, TILE = 0)
440)]
441#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
442pub unsafe fn _tile_movrow<const TILE: i32>(row: u32) -> __m512i {
443    static_assert_uimm_bits!(TILE, 3);
444    tilemovrow(TILE as i8, row).as_m512i()
445}
446
447#[allow(improper_ctypes)]
448unsafe extern "C" {
449    #[link_name = "llvm.x86.ldtilecfg"]
450    fn ldtilecfg(mem_addr: *const u8);
451    #[link_name = "llvm.x86.sttilecfg"]
452    fn sttilecfg(mem_addr: *mut u8);
453    #[link_name = "llvm.x86.tileloadd64"]
454    fn tileloadd64(dst: i8, base: *const u8, stride: usize);
455    #[link_name = "llvm.x86.tileloaddt164"]
456    fn tileloaddt164(dst: i8, base: *const u8, stride: usize);
457    #[link_name = "llvm.x86.tilerelease"]
458    fn tilerelease();
459    #[link_name = "llvm.x86.tilestored64"]
460    fn tilestored64(dst: i8, base: *mut u8, stride: usize);
461    #[link_name = "llvm.x86.tilezero"]
462    fn tilezero(dst: i8);
463    #[link_name = "llvm.x86.tdpbf16ps"]
464    fn tdpbf16ps(dst: i8, a: i8, b: i8);
465    #[link_name = "llvm.x86.tdpbuud"]
466    fn tdpbuud(dst: i8, a: i8, b: i8);
467    #[link_name = "llvm.x86.tdpbusd"]
468    fn tdpbusd(dst: i8, a: i8, b: i8);
469    #[link_name = "llvm.x86.tdpbsud"]
470    fn tdpbsud(dst: i8, a: i8, b: i8);
471    #[link_name = "llvm.x86.tdpbssd"]
472    fn tdpbssd(dst: i8, a: i8, b: i8);
473    #[link_name = "llvm.x86.tdpfp16ps"]
474    fn tdpfp16ps(dst: i8, a: i8, b: i8);
475    #[link_name = "llvm.x86.tcmmimfp16ps"]
476    fn tcmmimfp16ps(dst: i8, a: i8, b: i8);
477    #[link_name = "llvm.x86.tcmmrlfp16ps"]
478    fn tcmmrlfp16ps(dst: i8, a: i8, b: i8);
479    #[link_name = "llvm.x86.tdpbf8ps"]
480    fn tdpbf8ps(dst: i8, a: i8, b: i8);
481    #[link_name = "llvm.x86.tdpbhf8ps"]
482    fn tdpbhf8ps(dst: i8, a: i8, b: i8);
483    #[link_name = "llvm.x86.tdphbf8ps"]
484    fn tdphbf8ps(dst: i8, a: i8, b: i8);
485    #[link_name = "llvm.x86.tdphf8ps"]
486    fn tdphf8ps(dst: i8, a: i8, b: i8);
487    #[link_name = "llvm.x86.tileloaddrs64"]
488    fn tileloaddrs64(dst: i8, base: *const u8, stride: usize);
489    #[link_name = "llvm.x86.tileloaddrst164"]
490    fn tileloaddrst164(dst: i8, base: *const u8, stride: usize);
491    #[link_name = "llvm.x86.tmmultf32ps"]
492    fn tmmultf32ps(dst: i8, a: i8, b: i8);
493    #[link_name = "llvm.x86.tcvtrowd2ps"]
494    fn tcvtrowd2ps(tile: i8, row: u32) -> f32x16;
495    #[link_name = "llvm.x86.tcvtrowps2phh"]
496    fn tcvtrowps2phh(tile: i8, row: u32) -> f16x32;
497    #[link_name = "llvm.x86.tcvtrowps2phl"]
498    fn tcvtrowps2phl(tile: i8, row: u32) -> f16x32;
499    #[link_name = "llvm.x86.tilemovrow"]
500    fn tilemovrow(tile: i8, row: u32) -> i32x16;
501}
502
503#[cfg(test)]
504mod tests {
505    use crate::core_arch::x86::_mm_cvtness_sbh;
506    use crate::core_arch::x86_64::*;
507    use core::{array, mem::transmute};
508    use stdarch_test::simd_test;
509    #[cfg(target_os = "linux")]
510    use syscalls::{Sysno, syscall};
511
512    #[allow(non_camel_case_types)]
513    #[repr(packed)]
514    #[derive(Copy, Clone, Default, Debug, PartialEq)]
515    struct __tilecfg {
516        /// 0 `or` 1
517        palette: u8,
518        start_row: u8,
519        /// reserved, must be zero
520        reserved_a0: [u8; 14],
521        /// number of bytes of one row in each tile
522        colsb: [u16; 8],
523        /// reserved, must be zero
524        reserved_b0: [u16; 8],
525        /// number of rows in each tile
526        rows: [u8; 8],
527        /// reserved, must be zero
528        reserved_c0: [u8; 8],
529    }
530
531    impl __tilecfg {
532        fn new(palette: u8, start_row: u8, colsb: [u16; 8], rows: [u8; 8]) -> Self {
533            Self {
534                palette,
535                start_row,
536                reserved_a0: [0u8; 14],
537                colsb,
538                reserved_b0: [0u16; 8],
539                rows,
540                reserved_c0: [0u8; 8],
541            }
542        }
543
544        const fn as_ptr(&self) -> *const u8 {
545            self as *const Self as *const u8
546        }
547
548        fn as_mut_ptr(&mut self) -> *mut u8 {
549            self as *mut Self as *mut u8
550        }
551    }
552
553    #[cfg(not(target_os = "linux"))]
554    #[target_feature(enable = "amx-tile")]
555    fn _init_amx() {}
556
557    #[cfg(target_os = "linux")]
558    #[target_feature(enable = "amx-tile")]
559    #[inline]
560    unsafe fn _init_amx() {
561        let mut ret: usize;
562        let mut xfeatures: usize = 0;
563        ret = syscall!(Sysno::arch_prctl, 0x1022, &mut xfeatures as *mut usize)
564            .expect("arch_prctl ARCH_GET_XCOMP_PERM syscall failed");
565        if ret != 0 {
566            panic!("Failed to get XFEATURES");
567        } else {
568            match 0b11 & (xfeatures >> 17) {
569                0 => panic!("AMX is not available"),
570                1 => {
571                    ret = syscall!(Sysno::arch_prctl, 0x1023, 18)
572                        .expect("arch_prctl ARCH_REQ_XCOMP_PERM syscall failed");
573                    if ret != 0 {
574                        panic!("Failed to enable AMX");
575                    }
576                }
577                3 => {}
578                _ => unreachable!(),
579            }
580        }
581    }
582
583    #[simd_test(enable = "amx-tile")]
584    unsafe fn test_tile_loadconfig() {
585        let config = __tilecfg::default();
586        _tile_loadconfig(config.as_ptr());
587        _tile_release();
588    }
589
590    #[simd_test(enable = "amx-tile")]
591    unsafe fn test_tile_storeconfig() {
592        let config = __tilecfg::new(1, 0, [32; 8], [8; 8]);
593        _tile_loadconfig(config.as_ptr());
594        let mut _config = __tilecfg::default();
595        _tile_storeconfig(_config.as_mut_ptr());
596        _tile_release();
597        assert_eq!(config, _config);
598    }
599
600    #[simd_test(enable = "amx-tile")]
601    unsafe fn test_tile_zero() {
602        _init_amx();
603        let mut config = __tilecfg::default();
604        config.palette = 1;
605        config.colsb[0] = 64;
606        config.rows[0] = 16;
607        _tile_loadconfig(config.as_ptr());
608        _tile_zero::<0>();
609        let mut out = [[1_i8; 64]; 16];
610        _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
611        _tile_release();
612        assert_eq!(out, [[0; 64]; 16]);
613    }
614
615    #[simd_test(enable = "amx-tile")]
616    unsafe fn test_tile_stored() {
617        _init_amx();
618        let mut config = __tilecfg::default();
619        config.palette = 1;
620        config.colsb[0] = 64;
621        config.rows[0] = 16;
622        _tile_loadconfig(config.as_ptr());
623        _tile_zero::<0>();
624        let mut out = [[1_i8; 64]; 16];
625        _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
626        _tile_release();
627        assert_eq!(out, [[0; 64]; 16]);
628    }
629
630    #[simd_test(enable = "amx-tile")]
631    unsafe fn test_tile_loadd() {
632        _init_amx();
633        let mut config = __tilecfg::default();
634        config.palette = 1;
635        config.colsb[0] = 64;
636        config.rows[0] = 16;
637        _tile_loadconfig(config.as_ptr());
638        _tile_zero::<0>();
639        let mat = [1_i8; 1024];
640        _tile_loadd::<0>(&mat as *const i8 as *const u8, 64);
641        let mut out = [[0_i8; 64]; 16];
642        _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
643        _tile_release();
644        assert_eq!(out, [[1; 64]; 16]);
645    }
646
647    #[simd_test(enable = "amx-tile")]
648    unsafe fn test_tile_stream_loadd() {
649        _init_amx();
650        let mut config = __tilecfg::default();
651        config.palette = 1;
652        config.colsb[0] = 64;
653        config.rows[0] = 16;
654        _tile_loadconfig(config.as_ptr());
655        _tile_zero::<0>();
656        let mat = [1_i8; 1024];
657        _tile_stream_loadd::<0>(&mat as *const i8 as *const u8, 64);
658        let mut out = [[0_i8; 64]; 16];
659        _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
660        _tile_release();
661        assert_eq!(out, [[1; 64]; 16]);
662    }
663
664    #[simd_test(enable = "amx-tile")]
665    unsafe fn test_tile_release() {
666        _tile_release();
667    }
668
669    #[simd_test(enable = "amx-bf16,avx512f")]
670    unsafe fn test_tile_dpbf16ps() {
671        _init_amx();
672        let bf16_1: u16 = _mm_cvtness_sbh(1.0).to_bits();
673        let bf16_2: u16 = _mm_cvtness_sbh(2.0).to_bits();
674        let ones: [u8; 1024] = transmute([bf16_1; 512]);
675        let twos: [u8; 1024] = transmute([bf16_2; 512]);
676        let mut res = [[0f32; 16]; 16];
677        let mut config = __tilecfg::default();
678        config.palette = 1;
679        (0..=2).for_each(|i| {
680            config.colsb[i] = 64;
681            config.rows[i] = 16;
682        });
683        _tile_loadconfig(config.as_ptr());
684        _tile_zero::<0>();
685        _tile_loadd::<1>(&ones as *const u8, 64);
686        _tile_loadd::<2>(&twos as *const u8, 64);
687        _tile_dpbf16ps::<0, 1, 2>();
688        _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64);
689        _tile_release();
690        assert_eq!(res, [[64f32; 16]; 16]);
691    }
692
693    #[simd_test(enable = "amx-int8")]
694    unsafe fn test_tile_dpbssd() {
695        _init_amx();
696        let ones = [-1_i8; 1024];
697        let twos = [-2_i8; 1024];
698        let mut res = [[0_i32; 16]; 16];
699        let mut config = __tilecfg::default();
700        config.palette = 1;
701        (0..=2).for_each(|i| {
702            config.colsb[i] = 64;
703            config.rows[i] = 16;
704        });
705        _tile_loadconfig(config.as_ptr());
706        _tile_zero::<0>();
707        _tile_loadd::<1>(&ones as *const i8 as *const u8, 64);
708        _tile_loadd::<2>(&twos as *const i8 as *const u8, 64);
709        _tile_dpbssd::<0, 1, 2>();
710        _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64);
711        _tile_release();
712        assert_eq!(res, [[128_i32; 16]; 16]);
713    }
714
715    #[simd_test(enable = "amx-int8")]
716    unsafe fn test_tile_dpbsud() {
717        _init_amx();
718        let ones = [-1_i8; 1024];
719        let twos = [2_u8; 1024];
720        let mut res = [[0_i32; 16]; 16];
721        let mut config = __tilecfg::default();
722        config.palette = 1;
723        (0..=2).for_each(|i| {
724            config.colsb[i] = 64;
725            config.rows[i] = 16;
726        });
727        _tile_loadconfig(config.as_ptr());
728        _tile_zero::<0>();
729        _tile_loadd::<1>(&ones as *const i8 as *const u8, 64);
730        _tile_loadd::<2>(&twos as *const u8, 64);
731        _tile_dpbsud::<0, 1, 2>();
732        _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64);
733        _tile_release();
734        assert_eq!(res, [[-128_i32; 16]; 16]);
735    }
736
737    #[simd_test(enable = "amx-int8")]
738    unsafe fn test_tile_dpbusd() {
739        _init_amx();
740        let ones = [1_u8; 1024];
741        let twos = [-2_i8; 1024];
742        let mut res = [[0_i32; 16]; 16];
743        let mut config = __tilecfg::default();
744        config.palette = 1;
745        (0..=2).for_each(|i| {
746            config.colsb[i] = 64;
747            config.rows[i] = 16;
748        });
749        _tile_loadconfig(config.as_ptr());
750        _tile_zero::<0>();
751        _tile_loadd::<1>(&ones as *const u8, 64);
752        _tile_loadd::<2>(&twos as *const i8 as *const u8, 64);
753        _tile_dpbusd::<0, 1, 2>();
754        _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64);
755        _tile_release();
756        assert_eq!(res, [[-128_i32; 16]; 16]);
757    }
758
759    #[simd_test(enable = "amx-int8")]
760    unsafe fn test_tile_dpbuud() {
761        _init_amx();
762        let ones = [1_u8; 1024];
763        let twos = [2_u8; 1024];
764        let mut res = [[0_i32; 16]; 16];
765        let mut config = __tilecfg::default();
766        config.palette = 1;
767        (0..=2).for_each(|i| {
768            config.colsb[i] = 64;
769            config.rows[i] = 16;
770        });
771        _tile_loadconfig(config.as_ptr());
772        _tile_zero::<0>();
773        _tile_loadd::<1>(&ones as *const u8, 64);
774        _tile_loadd::<2>(&twos as *const u8, 64);
775        _tile_dpbuud::<0, 1, 2>();
776        _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64);
777        _tile_release();
778        assert_eq!(res, [[128_i32; 16]; 16]);
779    }
780
781    #[simd_test(enable = "amx-fp16")]
782    unsafe fn test_tile_dpfp16ps() {
783        _init_amx();
784        let ones = [1f16; 512];
785        let twos = [2f16; 512];
786        let mut res = [[0f32; 16]; 16];
787        let mut config = __tilecfg::default();
788        config.palette = 1;
789        (0..=2).for_each(|i| {
790            config.colsb[i] = 64;
791            config.rows[i] = 16;
792        });
793        _tile_loadconfig(config.as_ptr());
794        _tile_zero::<0>();
795        _tile_loadd::<1>(&ones as *const f16 as *const u8, 64);
796        _tile_loadd::<2>(&twos as *const f16 as *const u8, 64);
797        _tile_dpfp16ps::<0, 1, 2>();
798        _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64);
799        _tile_release();
800        assert_eq!(res, [[64f32; 16]; 16]);
801    }
802
803    #[simd_test(enable = "amx-complex")]
804    unsafe fn test_tile_cmmimfp16ps() {
805        _init_amx();
806        let ones = [1f16; 512];
807        let twos = [2f16; 512];
808        let mut res = [[0f32; 16]; 16];
809        let mut config = __tilecfg::default();
810        config.palette = 1;
811        (0..=2).for_each(|i| {
812            config.colsb[i] = 64;
813            config.rows[i] = 16;
814        });
815        _tile_loadconfig(config.as_ptr());
816        _tile_zero::<0>();
817        _tile_loadd::<1>(&ones as *const f16 as *const u8, 64);
818        _tile_loadd::<2>(&twos as *const f16 as *const u8, 64);
819        _tile_cmmimfp16ps::<0, 1, 2>();
820        _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64);
821        _tile_release();
822        assert_eq!(res, [[64f32; 16]; 16]);
823    }
824
825    #[simd_test(enable = "amx-complex")]
826    unsafe fn test_tile_cmmrlfp16ps() {
827        _init_amx();
828        let ones = [1f16; 512];
829        let twos = [2f16; 512];
830        let mut res = [[0f32; 16]; 16];
831        let mut config = __tilecfg::default();
832        config.palette = 1;
833        (0..=2).for_each(|i| {
834            config.colsb[i] = 64;
835            config.rows[i] = 16;
836        });
837        _tile_loadconfig(config.as_ptr());
838        _tile_zero::<0>();
839        _tile_loadd::<1>(&ones as *const f16 as *const u8, 64);
840        _tile_loadd::<2>(&twos as *const f16 as *const u8, 64);
841        _tile_cmmrlfp16ps::<0, 1, 2>();
842        _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64);
843        _tile_release();
844        assert_eq!(res, [[0f32; 16]; 16]);
845    }
846
847    const BF8_ONE: u8 = 0x3c;
848    const BF8_TWO: u8 = 0x40;
849    const HF8_ONE: u8 = 0x38;
850    const HF8_TWO: u8 = 0x40;
851
852    #[simd_test(enable = "amx-fp8")]
853    unsafe fn test_tile_dpbf8ps() {
854        _init_amx();
855        let ones = [BF8_ONE; 1024];
856        let twos = [BF8_TWO; 1024];
857        let mut res = [[0.0_f32; 16]; 16];
858        let mut config = __tilecfg::default();
859        config.palette = 1;
860        (0..=2).for_each(|i| {
861            config.colsb[i] = 64;
862            config.rows[i] = 16;
863        });
864        _tile_loadconfig(config.as_ptr());
865        _tile_zero::<0>();
866        _tile_loadd::<1>(&ones as *const u8, 64);
867        _tile_loadd::<2>(&twos as *const u8, 64);
868        _tile_dpbf8ps::<0, 1, 2>();
869        _tile_stored::<0>(res.as_mut_ptr().cast(), 64);
870        _tile_release();
871        assert_eq!(res, [[128.0_f32; 16]; 16]);
872    }
873
874    #[simd_test(enable = "amx-fp8")]
875    unsafe fn test_tile_dpbhf8ps() {
876        _init_amx();
877        let ones = [BF8_ONE; 1024];
878        let twos = [HF8_TWO; 1024];
879        let mut res = [[0.0_f32; 16]; 16];
880        let mut config = __tilecfg::default();
881        config.palette = 1;
882        (0..=2).for_each(|i| {
883            config.colsb[i] = 64;
884            config.rows[i] = 16;
885        });
886        _tile_loadconfig(config.as_ptr());
887        _tile_zero::<0>();
888        _tile_loadd::<1>(&ones as *const u8, 64);
889        _tile_loadd::<2>(&twos as *const u8, 64);
890        _tile_dpbhf8ps::<0, 1, 2>();
891        _tile_stored::<0>(res.as_mut_ptr().cast(), 64);
892        _tile_release();
893        assert_eq!(res, [[128.0_f32; 16]; 16]);
894    }
895
896    #[simd_test(enable = "amx-fp8")]
897    unsafe fn test_tile_dphbf8ps() {
898        _init_amx();
899        let ones = [HF8_ONE; 1024];
900        let twos = [BF8_TWO; 1024];
901        let mut res = [[0.0_f32; 16]; 16];
902        let mut config = __tilecfg::default();
903        config.palette = 1;
904        (0..=2).for_each(|i| {
905            config.colsb[i] = 64;
906            config.rows[i] = 16;
907        });
908        _tile_loadconfig(config.as_ptr());
909        _tile_zero::<0>();
910        _tile_loadd::<1>(&ones as *const u8, 64);
911        _tile_loadd::<2>(&twos as *const u8, 64);
912        _tile_dphbf8ps::<0, 1, 2>();
913        _tile_stored::<0>(res.as_mut_ptr().cast(), 64);
914        _tile_release();
915        assert_eq!(res, [[128.0_f32; 16]; 16]);
916    }
917
918    #[simd_test(enable = "amx-fp8")]
919    unsafe fn test_tile_dphf8ps() {
920        _init_amx();
921        let ones = [HF8_ONE; 1024];
922        let twos = [HF8_TWO; 1024];
923        let mut res = [[0.0_f32; 16]; 16];
924        let mut config = __tilecfg::default();
925        config.palette = 1;
926        (0..=2).for_each(|i| {
927            config.colsb[i] = 64;
928            config.rows[i] = 16;
929        });
930        _tile_loadconfig(config.as_ptr());
931        _tile_zero::<0>();
932        _tile_loadd::<1>(&ones as *const u8, 64);
933        _tile_loadd::<2>(&twos as *const u8, 64);
934        _tile_dphf8ps::<0, 1, 2>();
935        _tile_stored::<0>(res.as_mut_ptr().cast(), 64);
936        _tile_release();
937        assert_eq!(res, [[128.0_f32; 16]; 16]);
938    }
939
940    #[simd_test(enable = "amx-movrs")]
941    unsafe fn test_tile_loaddrs() {
942        _init_amx();
943        let mut config = __tilecfg::default();
944        config.palette = 1;
945        config.colsb[0] = 64;
946        config.rows[0] = 16;
947        _tile_loadconfig(config.as_ptr());
948        _tile_zero::<0>();
949        let mat = [1_i8; 1024];
950        _tile_loaddrs::<0>(&mat as *const i8 as *const u8, 64);
951        let mut out = [[0_i8; 64]; 16];
952        _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
953        _tile_release();
954        assert_eq!(out, [[1; 64]; 16]);
955    }
956
957    #[simd_test(enable = "amx-movrs")]
958    unsafe fn test_tile_stream_loaddrs() {
959        _init_amx();
960        let mut config = __tilecfg::default();
961        config.palette = 1;
962        config.colsb[0] = 64;
963        config.rows[0] = 16;
964        _tile_loadconfig(config.as_ptr());
965        _tile_zero::<0>();
966        let mat = [1_i8; 1024];
967        _tile_stream_loaddrs::<0>(&mat as *const i8 as *const u8, 64);
968        let mut out = [[0_i8; 64]; 16];
969        _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
970        _tile_release();
971        assert_eq!(out, [[1; 64]; 16]);
972    }
973
974    #[simd_test(enable = "amx-avx512,avx10.2")]
975    unsafe fn test_tile_movrow() {
976        _init_amx();
977        let array: [[u8; 64]; 16] = array::from_fn(|i| [i as _; _]);
978
979        let mut config = __tilecfg::default();
980        config.palette = 1;
981        config.colsb[0] = 64;
982        config.rows[0] = 16;
983        _tile_loadconfig(config.as_ptr());
984        _tile_loadd::<0>(array.as_ptr().cast(), 64);
985        for i in 0..16 {
986            let row = _tile_movrow::<0>(i);
987            assert_eq!(*row.as_u8x64().as_array(), [i as _; _]);
988        }
989    }
990
991    #[simd_test(enable = "amx-avx512,avx10.2")]
992    unsafe fn test_tile_cvtrowd2ps() {
993        _init_amx();
994        let array: [[u32; 16]; 16] = array::from_fn(|i| [i as _; _]);
995
996        let mut config = __tilecfg::default();
997        config.palette = 1;
998        config.colsb[0] = 64;
999        config.rows[0] = 16;
1000        _tile_loadconfig(config.as_ptr());
1001        _tile_loadd::<0>(array.as_ptr().cast(), 64);
1002        for i in 0..16 {
1003            let row = _tile_cvtrowd2ps::<0>(i);
1004            assert_eq!(*row.as_f32x16().as_array(), [i as _; _]);
1005        }
1006    }
1007
1008    #[simd_test(enable = "amx-avx512,avx10.2")]
1009    unsafe fn test_tile_cvtrowps2phh() {
1010        _init_amx();
1011        let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1012
1013        let mut config = __tilecfg::default();
1014        config.palette = 1;
1015        config.colsb[0] = 64;
1016        config.rows[0] = 16;
1017        _tile_loadconfig(config.as_ptr());
1018        _tile_loadd::<0>(array.as_ptr().cast(), 64);
1019        for i in 0..16 {
1020            let row = _tile_cvtrowps2phh::<0>(i);
1021            assert_eq!(
1022                *row.as_f16x32().as_array(),
1023                array::from_fn(|j| if j & 1 == 0 { 0.0 } else { i as _ })
1024            );
1025        }
1026    }
1027
1028    #[simd_test(enable = "amx-avx512,avx10.2")]
1029    unsafe fn test_tile_cvtrowps2phl() {
1030        _init_amx();
1031        let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1032
1033        let mut config = __tilecfg::default();
1034        config.palette = 1;
1035        config.colsb[0] = 64;
1036        config.rows[0] = 16;
1037        _tile_loadconfig(config.as_ptr());
1038        _tile_loadd::<0>(array.as_ptr().cast(), 64);
1039        for i in 0..16 {
1040            let row = _tile_cvtrowps2phl::<0>(i);
1041            assert_eq!(
1042                *row.as_f16x32().as_array(),
1043                array::from_fn(|j| if j & 1 == 0 { i as _ } else { 0.0 })
1044            );
1045        }
1046    }
1047
1048    #[simd_test(enable = "amx-tf32")]
1049    unsafe fn test_tile_mmultf32ps() {
1050        _init_amx();
1051        let a: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1052        let b: [[f32; 16]; 16] = [array::from_fn(|j| j as _); _];
1053        let mut res = [[0.0; 16]; 16];
1054
1055        let mut config = __tilecfg::default();
1056        config.palette = 1;
1057        (0..=2).for_each(|i| {
1058            config.colsb[i] = 64;
1059            config.rows[i] = 16;
1060        });
1061        _tile_loadconfig(config.as_ptr());
1062        _tile_zero::<0>();
1063        _tile_loadd::<1>(a.as_ptr().cast(), 64);
1064        _tile_loadd::<2>(b.as_ptr().cast(), 64);
1065        _tile_mmultf32ps::<0, 1, 2>();
1066        _tile_stored::<0>(res.as_mut_ptr().cast(), 64);
1067        _tile_release();
1068
1069        let expected = array::from_fn(|i| array::from_fn(|j| 16.0 * i as f32 * j as f32));
1070        assert_eq!(res, expected);
1071    }
1072}