//go:build amd64 // +build amd64 #include "textflag.h" // func gemvF32Tile8AVX2(a *float32, b *float32, out *float32, K int) // Computes 8 independent dot products: // out[t] = sum_{i=0..K-1} a[i] * b[t*K+i], for t=0..7 // Vectorizes over K with AVX2/FMA and reuses each A vector across 8 outputs. TEXT ·gemvF32Tile8AVX2(SB), NOSPLIT, $96-32 // Preserve general-purpose registers (Go ABI + ABI wrappers). MOVQ AX, 0(SP) MOVQ BX, 8(SP) MOVQ CX, 16(SP) MOVQ DX, 24(SP) MOVQ DI, 32(SP) MOVQ SI, 40(SP) MOVQ R8, 48(SP) MOVQ R9, 56(SP) MOVQ R10, 64(SP) MOVQ R11, 72(SP) MOVQ R12, 80(SP) MOVQ R13, 88(SP) MOVQ a+0(FP), DI MOVQ b+8(FP), SI MOVQ out+16(FP), DX MOVQ K+24(FP), CX // strideBytes = K * 4 MOVQ CX, BX SHLQ $2, BX // kMain = K &^ 7 (multiple of 8 floats) ANDQ $-8, CX JLE zero // b1..b7 pointers MOVQ SI, R8 ADDQ BX, R8 MOVQ R8, R9 ADDQ BX, R9 MOVQ R9, R10 ADDQ BX, R10 MOVQ R10, R11 ADDQ BX, R11 MOVQ R11, R12 ADDQ BX, R12 MOVQ R12, R13 ADDQ BX, R13 MOVQ R13, AX ADDQ BX, AX // zero accumulators Y0..Y7 VXORPS Y0, Y0, Y0 VXORPS Y1, Y1, Y1 VXORPS Y2, Y2, Y2 VXORPS Y3, Y3, Y3 VXORPS Y4, Y4, Y4 VXORPS Y5, Y5, Y5 VXORPS Y6, Y6, Y6 VXORPS Y7, Y7, Y7 loop: // load 8 floats from a VMOVUPS (DI), Y8 // out0..out7 accumulate with shared A vector VMOVUPS (SI), Y9 VFMADD231PS Y8, Y9, Y0 VMOVUPS (R8), Y9 VFMADD231PS Y8, Y9, Y1 VMOVUPS (R9), Y9 VFMADD231PS Y8, Y9, Y2 VMOVUPS (R10), Y9 VFMADD231PS Y8, Y9, Y3 VMOVUPS (R11), Y9 VFMADD231PS Y8, Y9, Y4 VMOVUPS (R12), Y9 VFMADD231PS Y8, Y9, Y5 VMOVUPS (R13), Y9 VFMADD231PS Y8, Y9, Y6 VMOVUPS (AX), Y9 VFMADD231PS Y8, Y9, Y7 // advance pointers ADDQ $32, DI ADDQ $32, SI ADDQ $32, R8 ADDQ $32, R9 ADDQ $32, R10 ADDQ $32, R11 ADDQ $32, R12 ADDQ $32, R13 ADDQ $32, AX SUBQ $8, CX JNZ loop // Reduce each accumulator to scalar and store. // Y0 -> out[0] VEXTRACTF128 $1, Y0, X8 VADDPS X8, X0, X0 VMOVHLPS X0, X0, X8 VADDPS X8, X0, X0 VPSHUFD $0xB1, X0, X8 VADDPS X8, X0, X0 MOVSS X0, 0(DX) // Y1 -> out[1] VEXTRACTF128 $1, Y1, X8 VADDPS X8, X1, X1 VMOVHLPS X1, X1, X8 VADDPS X8, X1, X1 VPSHUFD $0xB1, X1, X8 VADDPS X8, X1, X1 MOVSS X1, 4(DX) // Y2 -> out[2] VEXTRACTF128 $1, Y2, X8 VADDPS X8, X2, X2 VMOVHLPS X2, X2, X8 VADDPS X8, X2, X2 VPSHUFD $0xB1, X2, X8 VADDPS X8, X2, X2 MOVSS X2, 8(DX) // Y3 -> out[3] VEXTRACTF128 $1, Y3, X8 VADDPS X8, X3, X3 VMOVHLPS X3, X3, X8 VADDPS X8, X3, X3 VPSHUFD $0xB1, X3, X8 VADDPS X8, X3, X3 MOVSS X3, 12(DX) // Y4 -> out[4] VEXTRACTF128 $1, Y4, X8 VADDPS X8, X4, X4 VMOVHLPS X4, X4, X8 VADDPS X8, X4, X4 VPSHUFD $0xB1, X4, X8 VADDPS X8, X4, X4 MOVSS X4, 16(DX) // Y5 -> out[5] VEXTRACTF128 $1, Y5, X8 VADDPS X8, X5, X5 VMOVHLPS X5, X5, X8 VADDPS X8, X5, X5 VPSHUFD $0xB1, X5, X8 VADDPS X8, X5, X5 MOVSS X5, 20(DX) // Y6 -> out[6] VEXTRACTF128 $1, Y6, X8 VADDPS X8, X6, X6 VMOVHLPS X6, X6, X8 VADDPS X8, X6, X6 VPSHUFD $0xB1, X6, X8 VADDPS X8, X6, X6 MOVSS X6, 24(DX) // Y7 -> out[7] VEXTRACTF128 $1, Y7, X8 VADDPS X8, X7, X7 VMOVHLPS X7, X7, X8 VADDPS X8, X7, X7 VPSHUFD $0xB1, X7, X8 VADDPS X8, X7, X7 MOVSS X7, 28(DX) VZEROUPPER JMP epilogue zero: VXORPS X0, X0, X0 MOVSS X0, 0(DX) MOVSS X0, 4(DX) MOVSS X0, 8(DX) MOVSS X0, 12(DX) MOVSS X0, 16(DX) MOVSS X0, 20(DX) MOVSS X0, 24(DX) MOVSS X0, 28(DX) VZEROUPPER JMP epilogue epilogue: MOVQ 0(SP), AX MOVQ 8(SP), BX MOVQ 16(SP), CX MOVQ 24(SP), DX MOVQ 32(SP), DI MOVQ 40(SP), SI MOVQ 48(SP), R8 MOVQ 56(SP), R9 MOVQ 64(SP), R10 MOVQ 72(SP), R11 MOVQ 80(SP), R12 MOVQ 88(SP), R13 RET