//go:build amd64 // +build amd64 #include "textflag.h" // func gemvF32Tile8AVX512(a *float32, b *float32, out *float32, K int) // Computes 8 independent dot products: // out[t] = sum_{i=0..K-1} a[i] * b[t*K+i], for t=0..7 // Vectorizes over K with AVX-512/FMA and reuses each A vector across 8 outputs. TEXT ·gemvF32Tile8AVX512(SB), NOSPLIT, $96-32 // Preserve general-purpose registers (Go ABI + ABI wrappers). MOVQ AX, 0(SP) MOVQ BX, 8(SP) MOVQ CX, 16(SP) MOVQ DX, 24(SP) MOVQ DI, 32(SP) MOVQ SI, 40(SP) MOVQ R8, 48(SP) MOVQ R9, 56(SP) MOVQ R10, 64(SP) MOVQ R11, 72(SP) MOVQ R12, 80(SP) MOVQ R13, 88(SP) MOVQ a+0(FP), DI MOVQ b+8(FP), SI MOVQ out+16(FP), DX MOVQ K+24(FP), CX // strideBytes = K * 4 MOVQ CX, BX SHLQ $2, BX // kMain = K &^ 15 (multiple of 16 floats) ANDQ $-16, CX JLE zero // b1..b7 pointers MOVQ SI, R8 ADDQ BX, R8 MOVQ R8, R9 ADDQ BX, R9 MOVQ R9, R10 ADDQ BX, R10 MOVQ R10, R11 ADDQ BX, R11 MOVQ R11, R12 ADDQ BX, R12 MOVQ R12, R13 ADDQ BX, R13 MOVQ R13, AX ADDQ BX, AX // zero accumulators Z0..Z7 VXORPS Z0, Z0, Z0 VXORPS Z1, Z1, Z1 VXORPS Z2, Z2, Z2 VXORPS Z3, Z3, Z3 VXORPS Z4, Z4, Z4 VXORPS Z5, Z5, Z5 VXORPS Z6, Z6, Z6 VXORPS Z7, Z7, Z7 loop: VMOVUPS (DI), Z8 VMOVUPS (SI), Z9 VFMADD231PS Z8, Z9, Z0 VMOVUPS (R8), Z9 VFMADD231PS Z8, Z9, Z1 VMOVUPS (R9), Z9 VFMADD231PS Z8, Z9, Z2 VMOVUPS (R10), Z9 VFMADD231PS Z8, Z9, Z3 VMOVUPS (R11), Z9 VFMADD231PS Z8, Z9, Z4 VMOVUPS (R12), Z9 VFMADD231PS Z8, Z9, Z5 VMOVUPS (R13), Z9 VFMADD231PS Z8, Z9, Z6 VMOVUPS (AX), Z9 VFMADD231PS Z8, Z9, Z7 ADDQ $64, DI ADDQ $64, SI ADDQ $64, R8 ADDQ $64, R9 ADDQ $64, R10 ADDQ $64, R11 ADDQ $64, R12 ADDQ $64, R13 ADDQ $64, AX SUBQ $16, CX JNZ loop // Reduce each accumulator to scalar and store. // Z0 -> out[0] VEXTRACTF32X8 $1, Z0, Y8 VADDPS Y8, Y0, Y0 VEXTRACTF128 $1, Y0, X8 VADDPS X8, X0, X0 VPSHUFD $0x4E, X0, X8 VADDPS X8, X0, X0 VPSHUFD $0xB1, X0, X8 VADDPS X8, X0, X0 MOVSS X0, 0(DX) // Z1 -> out[1] VEXTRACTF32X8 $1, Z1, Y8 VADDPS Y8, Y1, Y1 VEXTRACTF128 $1, Y1, X8 VADDPS X8, X1, X1 VPSHUFD $0x4E, X1, X8 VADDPS X8, X1, X1 VPSHUFD $0xB1, X1, X8 VADDPS X8, X1, X1 MOVSS X1, 4(DX) // Z2 -> out[2] VEXTRACTF32X8 $1, Z2, Y8 VADDPS Y8, Y2, Y2 VEXTRACTF128 $1, Y2, X8 VADDPS X8, X2, X2 VPSHUFD $0x4E, X2, X8 VADDPS X8, X2, X2 VPSHUFD $0xB1, X2, X8 VADDPS X8, X2, X2 MOVSS X2, 8(DX) // Z3 -> out[3] VEXTRACTF32X8 $1, Z3, Y8 VADDPS Y8, Y3, Y3 VEXTRACTF128 $1, Y3, X8 VADDPS X8, X3, X3 VPSHUFD $0x4E, X3, X8 VADDPS X8, X3, X3 VPSHUFD $0xB1, X3, X8 VADDPS X8, X3, X3 MOVSS X3, 12(DX) // Z4 -> out[4] VEXTRACTF32X8 $1, Z4, Y8 VADDPS Y8, Y4, Y4 VEXTRACTF128 $1, Y4, X8 VADDPS X8, X4, X4 VPSHUFD $0x4E, X4, X8 VADDPS X8, X4, X4 VPSHUFD $0xB1, X4, X8 VADDPS X8, X4, X4 MOVSS X4, 16(DX) // Z5 -> out[5] VEXTRACTF32X8 $1, Z5, Y8 VADDPS Y8, Y5, Y5 VEXTRACTF128 $1, Y5, X8 VADDPS X8, X5, X5 VPSHUFD $0x4E, X5, X8 VADDPS X8, X5, X5 VPSHUFD $0xB1, X5, X8 VADDPS X8, X5, X5 MOVSS X5, 20(DX) // Z6 -> out[6] VEXTRACTF32X8 $1, Z6, Y8 VADDPS Y8, Y6, Y6 VEXTRACTF128 $1, Y6, X8 VADDPS X8, X6, X6 VPSHUFD $0x4E, X6, X8 VADDPS X8, X6, X6 VPSHUFD $0xB1, X6, X8 VADDPS X8, X6, X6 MOVSS X6, 24(DX) // Z7 -> out[7] VEXTRACTF32X8 $1, Z7, Y8 VADDPS Y8, Y7, Y7 VEXTRACTF128 $1, Y7, X8 VADDPS X8, X7, X7 VPSHUFD $0x4E, X7, X8 VADDPS X8, X7, X7 VPSHUFD $0xB1, X7, X8 VADDPS X8, X7, X7 MOVSS X7, 28(DX) VZEROUPPER JMP epilogue zero: VXORPS X0, X0, X0 MOVSS X0, 0(DX) MOVSS X0, 4(DX) MOVSS X0, 8(DX) MOVSS X0, 12(DX) MOVSS X0, 16(DX) MOVSS X0, 20(DX) MOVSS X0, 24(DX) MOVSS X0, 28(DX) VZEROUPPER JMP epilogue epilogue: MOVQ 0(SP), AX MOVQ 8(SP), BX MOVQ 16(SP), CX MOVQ 24(SP), DX MOVQ 32(SP), DI MOVQ 40(SP), SI MOVQ 48(SP), R8 MOVQ 56(SP), R9 MOVQ 64(SP), R10 MOVQ 72(SP), R11 MOVQ 80(SP), R12 MOVQ 88(SP), R13 RET