| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202 |
- //go:build amd64
- // +build amd64
- #include "textflag.h"
- // func gemvF32Tile8AVX2(a *float32, b *float32, out *float32, K int)
- // Computes 8 independent dot products:
- // out[t] = sum_{i=0..K-1} a[i] * b[t*K+i], for t=0..7
- // Vectorizes over K with AVX2/FMA and reuses each A vector across 8 outputs.
- TEXT ·gemvF32Tile8AVX2(SB), NOSPLIT, $96-32
- // Preserve general-purpose registers (Go ABI + ABI wrappers).
- MOVQ AX, 0(SP)
- MOVQ BX, 8(SP)
- MOVQ CX, 16(SP)
- MOVQ DX, 24(SP)
- MOVQ DI, 32(SP)
- MOVQ SI, 40(SP)
- MOVQ R8, 48(SP)
- MOVQ R9, 56(SP)
- MOVQ R10, 64(SP)
- MOVQ R11, 72(SP)
- MOVQ R12, 80(SP)
- MOVQ R13, 88(SP)
- MOVQ a+0(FP), DI
- MOVQ b+8(FP), SI
- MOVQ out+16(FP), DX
- MOVQ K+24(FP), CX
- // strideBytes = K * 4
- MOVQ CX, BX
- SHLQ $2, BX
- // kMain = K &^ 7 (multiple of 8 floats)
- ANDQ $-8, CX
- JLE zero
- // b1..b7 pointers
- MOVQ SI, R8
- ADDQ BX, R8
- MOVQ R8, R9
- ADDQ BX, R9
- MOVQ R9, R10
- ADDQ BX, R10
- MOVQ R10, R11
- ADDQ BX, R11
- MOVQ R11, R12
- ADDQ BX, R12
- MOVQ R12, R13
- ADDQ BX, R13
- MOVQ R13, AX
- ADDQ BX, AX
- // zero accumulators Y0..Y7
- VXORPS Y0, Y0, Y0
- VXORPS Y1, Y1, Y1
- VXORPS Y2, Y2, Y2
- VXORPS Y3, Y3, Y3
- VXORPS Y4, Y4, Y4
- VXORPS Y5, Y5, Y5
- VXORPS Y6, Y6, Y6
- VXORPS Y7, Y7, Y7
- loop:
- // load 8 floats from a
- VMOVUPS (DI), Y8
- // out0..out7 accumulate with shared A vector
- VMOVUPS (SI), Y9
- VFMADD231PS Y8, Y9, Y0
- VMOVUPS (R8), Y9
- VFMADD231PS Y8, Y9, Y1
- VMOVUPS (R9), Y9
- VFMADD231PS Y8, Y9, Y2
- VMOVUPS (R10), Y9
- VFMADD231PS Y8, Y9, Y3
- VMOVUPS (R11), Y9
- VFMADD231PS Y8, Y9, Y4
- VMOVUPS (R12), Y9
- VFMADD231PS Y8, Y9, Y5
- VMOVUPS (R13), Y9
- VFMADD231PS Y8, Y9, Y6
- VMOVUPS (AX), Y9
- VFMADD231PS Y8, Y9, Y7
- // advance pointers
- ADDQ $32, DI
- ADDQ $32, SI
- ADDQ $32, R8
- ADDQ $32, R9
- ADDQ $32, R10
- ADDQ $32, R11
- ADDQ $32, R12
- ADDQ $32, R13
- ADDQ $32, AX
- SUBQ $8, CX
- JNZ loop
- // Reduce each accumulator to scalar and store.
- // Y0 -> out[0]
- VEXTRACTF128 $1, Y0, X8
- VADDPS X8, X0, X0
- VMOVHLPS X0, X0, X8
- VADDPS X8, X0, X0
- VPSHUFD $0xB1, X0, X8
- VADDPS X8, X0, X0
- MOVSS X0, 0(DX)
- // Y1 -> out[1]
- VEXTRACTF128 $1, Y1, X8
- VADDPS X8, X1, X1
- VMOVHLPS X1, X1, X8
- VADDPS X8, X1, X1
- VPSHUFD $0xB1, X1, X8
- VADDPS X8, X1, X1
- MOVSS X1, 4(DX)
- // Y2 -> out[2]
- VEXTRACTF128 $1, Y2, X8
- VADDPS X8, X2, X2
- VMOVHLPS X2, X2, X8
- VADDPS X8, X2, X2
- VPSHUFD $0xB1, X2, X8
- VADDPS X8, X2, X2
- MOVSS X2, 8(DX)
- // Y3 -> out[3]
- VEXTRACTF128 $1, Y3, X8
- VADDPS X8, X3, X3
- VMOVHLPS X3, X3, X8
- VADDPS X8, X3, X3
- VPSHUFD $0xB1, X3, X8
- VADDPS X8, X3, X3
- MOVSS X3, 12(DX)
- // Y4 -> out[4]
- VEXTRACTF128 $1, Y4, X8
- VADDPS X8, X4, X4
- VMOVHLPS X4, X4, X8
- VADDPS X8, X4, X4
- VPSHUFD $0xB1, X4, X8
- VADDPS X8, X4, X4
- MOVSS X4, 16(DX)
- // Y5 -> out[5]
- VEXTRACTF128 $1, Y5, X8
- VADDPS X8, X5, X5
- VMOVHLPS X5, X5, X8
- VADDPS X8, X5, X5
- VPSHUFD $0xB1, X5, X8
- VADDPS X8, X5, X5
- MOVSS X5, 20(DX)
- // Y6 -> out[6]
- VEXTRACTF128 $1, Y6, X8
- VADDPS X8, X6, X6
- VMOVHLPS X6, X6, X8
- VADDPS X8, X6, X6
- VPSHUFD $0xB1, X6, X8
- VADDPS X8, X6, X6
- MOVSS X6, 24(DX)
- // Y7 -> out[7]
- VEXTRACTF128 $1, Y7, X8
- VADDPS X8, X7, X7
- VMOVHLPS X7, X7, X8
- VADDPS X8, X7, X7
- VPSHUFD $0xB1, X7, X8
- VADDPS X8, X7, X7
- MOVSS X7, 28(DX)
- VZEROUPPER
- JMP epilogue
- zero:
- VXORPS X0, X0, X0
- MOVSS X0, 0(DX)
- MOVSS X0, 4(DX)
- MOVSS X0, 8(DX)
- MOVSS X0, 12(DX)
- MOVSS X0, 16(DX)
- MOVSS X0, 20(DX)
- MOVSS X0, 24(DX)
- MOVSS X0, 28(DX)
- VZEROUPPER
- JMP epilogue
- epilogue:
- MOVQ 0(SP), AX
- MOVQ 8(SP), BX
- MOVQ 16(SP), CX
- MOVQ 24(SP), DX
- MOVQ 32(SP), DI
- MOVQ 40(SP), SI
- MOVQ 48(SP), R8
- MOVQ 56(SP), R9
- MOVQ 64(SP), R10
- MOVQ 72(SP), R11
- MOVQ 80(SP), R12
- MOVQ 88(SP), R13
- RET
|