aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVishal Verma <vishal.l.verma@intel.com>2021-11-15 12:54:42 -0700
committerVishal Verma <vishal.l.verma@intel.com>2021-11-15 16:37:59 -0700
commit279f06b6a9f3b5d8425ee8edf52884284c87f594 (patch)
tree3702520de2a3b477c0c0e900b45ddb7b360b1dbd
parentfa55b7dcdc43c1aa1ba12bca9d2dd4318c2a0dbf (diff)
downloadlinux-amx-example.tar.gz
Documentation/x86: Add an example program for AMX usageamx-example
Add an example program that demonstrates use of the arch_prctl based API for dynamically requesting the use of AMX instructions, and then performing a simple matrix dot product. Cc: Dave Hansen <dave.hansen@linux.intel.com> Cc: Chang Seok Bae <chang.seok.bae@intel.com> Signed-off-by: Vishal Verma <vishal.l.verma@intel.com>
-rw-r--r--Documentation/x86/amx-example.rst376
1 files changed, 376 insertions, 0 deletions
diff --git a/Documentation/x86/amx-example.rst b/Documentation/x86/amx-example.rst
new file mode 100644
index 00000000000000..4b7477e0a50724
--- /dev/null
+++ b/Documentation/x86/amx-example.rst
@@ -0,0 +1,376 @@
+Using arch_prctl to request AMX capabilities
+============================================
+
+Intel AMX (Advanced Matrix Extensions) is a dynamically enabled feature
+that requires a process to request and obtain prior permission from the
+kernel before it can be used.
+
+The following is an example program that obtains the necessary permissions,
+sets up the sigaltstack, and then performs a simple matrix dot product.
+
+Toolchain notes
+---------------
+
+This requires at least gcc-11.2, glibc-2.34 and binutils-2.37.
+
+Example Program
+---------------
+
+.. code-block:: C
+
+
+ /*
+ * AMX usage example
+ *
+ * This performs the following high level steps:
+ *
+ * 1. Detect AMX tile architecture support - CPUID.0x7.0.EDX.AMX_TILE[bit 24] = 1
+ * 2. Setup sigaltstack with proper size (see setup_sigaltstack())
+ * 3. Request permission to use AMX tile data (see request_perm_xtile_data())
+ * 4. Load data and compute dot product (see load_rand_tiledata() and mult_abc())
+ */
+
+ #define _GNU_SOURCE
+ #include <err.h>
+ #include <errno.h>
+ #include <stdio.h>
+ #include <string.h>
+ #include <stdbool.h>
+ #include <unistd.h>
+ #include <x86intrin.h>
+ #include <immintrin.h>
+
+ #include <sys/auxv.h>
+ #include <sys/mman.h>
+ #include <sys/syscall.h>
+ #include <sys/signal.h>
+
+ #define fatal_error(msg, ...) err(1, "[FAIL]\t" msg, ##__VA_ARGS__)
+
+ #ifndef AT_MINSIGSTKSZ
+ # define AT_MINSIGSTKSZ 51
+ #endif
+
+ #define XFEATURE_XTILECFG 17
+ #define XFEATURE_XTILEDATA 18
+ #define XFEATURE_MASK_XTILECFG (1 << XFEATURE_XTILECFG)
+ #define XFEATURE_MASK_XTILEDATA (1 << XFEATURE_XTILEDATA)
+ #define XFEATURE_MASK_XTILE (XFEATURE_MASK_XTILECFG | XFEATURE_MASK_XTILEDATA)
+
+ #define ARCH_GET_XCOMP_PERM 0x1022
+ #define ARCH_REQ_XCOMP_PERM 0x1023
+
+ #define TILE_M 8
+ #define TILE_K 8
+ #define TILE_N 8
+ #define MAX_ELEMENTS ((TILE_M * TILE_K) + (TILE_K * TILE_N) + (TILE_M * TILE_N))
+ #define BYTES_PER_ELEMENT 4
+
+ struct tile_buffer {
+ union {
+ struct {
+ uint32_t a[TILE_M * TILE_K];
+ uint32_t b[TILE_K * TILE_N];
+ uint32_t c[TILE_M * TILE_N];
+ };
+ uint32_t bytes[0];
+ };
+ };
+
+ static inline void cpuid(uint32_t *eax, uint32_t *ebx, uint32_t *ecx, uint32_t *edx)
+ {
+ asm volatile("cpuid;"
+ : "=a" (*eax), "=b" (*ebx), "=c" (*ecx), "=d" (*edx)
+ : "0" (*eax), "2" (*ecx));
+ }
+
+ #define CPUID_LEAF_XFEATURE_ENUM 0x07
+ #define CPUID_LEAF_TILE_INFO 0x1d
+ #define CPUID_LEAF_TMUL_INFO 0x1e
+
+ #define CPUID_LEAF_XFEATURE_AMX_TILE_SHIFT 24
+ #define CPUID_LEAF_XFEATURE_AMX_INT8_SHIFT 25
+
+ #define CPUID_SUBLEAF_TILE_ECX_PALETTE_1 1
+
+ #define CPUID_TILE_BYTES_MASK 0xffff
+ #define CPUID_TILE_BYTES_PER_TILE_SHIFT 16
+ #define CPUID_TILES_MAX_SHIFT 16
+ #define CPUID_TILE_BYTES_PER_ROW_MASK 0xffff
+ #define CPUID_TILE_ROWS_MAX_MASK 0xffff
+
+ #define CPUID_TMUL_MAXK_MASK 0xff
+ #define CPUID_TMUL_MAXN_MASK 0xffff
+ #define CPUID_TMUL_MAXN_SHIFT 0x8
+
+ #define TMM0 0
+ #define TMM1 1
+ #define TMM2 2
+ #define TMM3 3
+ #define TMM4 4
+ #define TMM5 5
+ #define TMM6 6
+ #define TMM7 7
+
+ static uint32_t max_palette;
+ static uint32_t total_tile_bytes, bytes_per_tile, max_tiles;
+ static uint32_t bytes_per_row, max_rows;
+ static uint32_t tmul_maxk, tmul_maxn;
+
+ static void amx_check_cpuid(void)
+ {
+ uint32_t eax, ebx, ecx, edx;
+
+ eax = CPUID_LEAF_XFEATURE_ENUM;
+ ecx = 0;
+ cpuid(&eax, &ebx, &ecx, &edx);
+ if (!((edx >> CPUID_LEAF_XFEATURE_AMX_TILE_SHIFT) & 0x1))
+ fatal_error("CPUID: AMX Tile architecture not supported");
+ if (!((edx >> CPUID_LEAF_XFEATURE_AMX_INT8_SHIFT) & 0x1))
+ fatal_error("CPUID: AMX-INT8 operations not supported");
+
+ eax = CPUID_LEAF_TILE_INFO;
+ ecx = 0;
+ cpuid(&eax, &ebx, &ecx, &edx);
+
+ max_palette = eax;
+ printf("CPUID Tile Info leaf:\n");
+ printf(" max_palette: %u\n", max_palette);
+
+ if (!max_palette)
+ fatal_error("AMX support missing (max_palette = 0)");
+
+ eax = CPUID_LEAF_TILE_INFO;
+ ecx = CPUID_SUBLEAF_TILE_ECX_PALETTE_1;
+ cpuid(&eax, &ebx, &ecx, &edx);
+
+ total_tile_bytes = eax & CPUID_TILE_BYTES_MASK;
+ bytes_per_tile = eax >> CPUID_TILE_BYTES_PER_TILE_SHIFT;
+ bytes_per_row = ebx & CPUID_TILE_BYTES_PER_ROW_MASK;
+ max_tiles = ebx >> CPUID_TILES_MAX_SHIFT;
+ max_rows = ecx & CPUID_TILE_ROWS_MAX_MASK;
+ printf(" total_tile_bytes: %u\n", total_tile_bytes);
+ printf(" bytes_per_tile: %u\n", bytes_per_tile);
+ printf(" bytes_per_row: %u\n", bytes_per_row);
+ printf(" max_tiles: %u\n", max_tiles);
+ printf(" max_rows: %u\n", max_rows);
+
+ eax = CPUID_LEAF_TMUL_INFO;
+ ecx = 0;
+ cpuid(&eax, &ebx, &ecx, &edx);
+
+ tmul_maxk = ebx & CPUID_TMUL_MAXK_MASK;
+ tmul_maxn = (ebx >> CPUID_TMUL_MAXN_SHIFT) & CPUID_TMUL_MAXN_MASK;
+ printf("CPUID TMUL Info leaf:\n");
+ printf(" tmul_maxk: %u\n", tmul_maxk);
+ printf(" tmul_maxn: %u\n", tmul_maxn);
+ }
+
+ static struct tilecfg {
+ uint8_t palette; /* byte 0 */
+ uint8_t start_row; /* byte 1 */
+ char rsvd1[14]; /* bytes 2-15 */
+ uint16_t tile_colsb[8]; /* bytes 16-31 */
+ char rsvd2[16]; /* bytes 32-47 */
+ uint8_t tile_rows[8]; /* bytes 48-55 */
+ char rsvd3[8]; /* bytes 56-63 */
+ } __attribute__((packed)) tilecfg;
+
+ static void print_tilecfg(struct tilecfg *t)
+ {
+ int i;
+
+ printf("TILECFG:\n");
+ printf(" palette: %d\n", t->palette);
+ printf(" start_row: %d\n", t->start_row);
+ for(i = 0; i < 8; i++)
+ printf(" tmm%d: [ %d x %d ]\n", i, t->tile_rows[i], t->tile_colsb[i]);
+ }
+
+ static void load_tile_config(struct tilecfg *t)
+ {
+ t->palette = 1;
+ t->start_row = 0;
+
+ t->tile_rows[TMM0] = TILE_M; /* tmm0 -> A: src1 matrix, MxK */
+ t->tile_colsb[TMM0] = TILE_K * BYTES_PER_ELEMENT;
+
+ t->tile_rows[TMM1] = TILE_K; /* tmm1 -> B: src2 matrix, KxN */
+ t->tile_colsb[TMM1] = TILE_N * BYTES_PER_ELEMENT;
+
+ t->tile_rows[TMM2] = TILE_M; /* tmm2 -> C: dst matrix, MxN */
+ t->tile_colsb[TMM2] = TILE_N * BYTES_PER_ELEMENT;
+
+ _tile_loadconfig(t);
+ }
+
+ static void get_stored_tilecfg(struct tilecfg *t)
+ {
+ _tile_storeconfig(t);
+ }
+
+ static void set_rand_tiledata(struct tile_buffer *tbuf)
+ {
+ int data;
+ int i;
+
+ /*
+ * Ensure that 'data' is never 0. This ensures that
+ * the registers are never in their initial configuration
+ * and thus never tracked as being in the init state.
+ */
+
+ for (i = 0; i < (MAX_ELEMENTS - (TILE_M * TILE_N)); i++) {
+ data = (rand() % 0xff) | 1;
+ tbuf->bytes[i] = data;
+ }
+ }
+
+ static void load_rand_tiledata(struct tile_buffer *tbuf)
+ {
+ set_rand_tiledata(tbuf);
+
+ _tile_release();
+ printf("TILERELEASE Done\n");
+
+ load_tile_config(&tilecfg);
+ printf("LDTILECFG Done\n");
+
+ get_stored_tilecfg(&tilecfg);
+ print_tilecfg(&tilecfg);
+
+ _tile_loadd(TMM0, &tbuf->a[0], TILE_K * BYTES_PER_ELEMENT);
+ printf("TILELOADD tmm0 Done\n");
+ _tile_loadd(TMM1, &tbuf->b[0], TILE_N * BYTES_PER_ELEMENT);
+ printf("TILELOADD tmm1 Done\n");
+ _tile_loadd(TMM2, &tbuf->c[0], TILE_N * BYTES_PER_ELEMENT);
+ printf("TILELOADD tmm2 Done\n");
+ }
+
+ static void mult_abc(struct tile_buffer *tbuf)
+ {
+ _tile_dpbuud(TMM2, TMM0, TMM1);
+ printf("TDPBUUD (tmm2 += tmm0 . tmm1) Done\n");
+
+ _tile_stored(TMM2, &tbuf->c[0], TILE_N * BYTES_PER_ELEMENT);
+ printf("TILESTORED (tmm2-> 'C') Done\n");
+ }
+
+ static void request_perm_xtile_data()
+ {
+ unsigned long bitmask;
+ long rc;
+
+ rc = syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA);
+ if (rc)
+ fatal_error("XTILE_DATA request failed: %ld", rc);
+
+ rc = syscall(SYS_arch_prctl, ARCH_GET_XCOMP_PERM, &bitmask);
+ if (rc)
+ fatal_error("prctl(ARCH_GET_XCOMP_PERM) error: %ld", rc);
+
+ if (bitmask & XFEATURE_MASK_XTILE)
+ printf("ARCH_REQ_XCOMP_PERM XTILE_DATA successful.\n");
+ }
+
+ static void setup_sigaltstack()
+ {
+ unsigned long minsigstksz, new_size;
+ void *altstack;
+ stack_t ss;
+ int rc;
+
+ minsigstksz = getauxval(AT_MINSIGSTKSZ);
+ printf("AT_MINSIGSTKSZ = %lu\n", minsigstksz);
+ /*
+ * getauxval() itself can return 0 for failure or
+ * success. But, in this case, AT_MINSIGSTKSZ
+ * will always return a >=0 value if implemented.
+ * Just check for 0.
+ */
+ if (minsigstksz == 0)
+ fatal_error("no support for AT_MINSIGSTKSZ");
+
+ new_size = minsigstksz * 2;
+ altstack = mmap(NULL, new_size, PROT_READ | PROT_WRITE,
+ MAP_PRIVATE | MAP_ANONYMOUS | MAP_STACK, -1, 0);
+ if (altstack == MAP_FAILED)
+ fatal_error("mmap() for altstack");
+
+ memset(&ss, 0, sizeof(ss));
+ ss.ss_size = new_size;
+ ss.ss_sp = altstack;
+
+ rc = sigaltstack(&ss, NULL);
+ if (rc)
+ fatal_error("sigaltstack failed: %d", rc);
+
+ }
+
+ static void print_abc(struct tile_buffer *tbuf)
+ {
+ int i, j;
+
+ /* printf("Raw Buffer\n [");
+ for (i = 0; i < MAX_ELEMENTS; i++)
+ printf(" %u", tbuf->bytes[i]);
+ printf(" ]\n\n");
+ */
+
+ printf("Matrix A:\n");
+ for (i = 0; i < TILE_M; i++) {
+ printf(" [");
+ for (j = 0; j < TILE_K; j++)
+ printf(" %03u", tbuf->a[(i * TILE_K) + j]);
+ printf(" ]\n");
+ }
+ printf("\n");
+
+ printf("Matrix B:\n");
+ for (i = 0; i < TILE_K; i++) {
+ printf(" [");
+ for (j = 0; j < TILE_N; j++)
+ printf(" %03u", tbuf->b[(i * TILE_N) + j]);
+ printf(" ]\n");
+ }
+ printf("\n");
+
+ printf("Matrix C:\n");
+ for (i = 0; i < TILE_M; i++) {
+ printf(" [");
+ for (j = 0; j < TILE_N; j++)
+ printf(" %06u", tbuf->c[(i * TILE_N) + j]);
+ printf(" ]\n");
+ }
+ printf("\n");
+ }
+
+ int main(void)
+ {
+ struct tile_buffer *tile;
+
+ amx_check_cpuid();
+ tile = aligned_alloc(64, MAX_ELEMENTS * BYTES_PER_ELEMENT);
+ if (!tile)
+ fatal_error("failed to allocate tile");
+
+ setup_sigaltstack();
+
+ /* Load tile configuration and tile data for matrices */
+ request_perm_xtile_data();
+ load_rand_tiledata(tile);
+
+ printf("\nA, B, C matrices before dot product:\n");
+ print_abc(tile);
+
+ /* compute the dot product, store result in the memory for 'C' */
+ mult_abc(tile);
+
+ /* print multiplication result */
+ printf("\nA, B, C matrices after dot product (C = A . B):\n");
+ print_abc(tile);
+
+ free(tile);
+ printf("All done\n");
+ return 0;
+ }