| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897 |
- //go:build amd64
- // +build amd64
- #include "textflag.h"
- // ============================================================================
- // Q8_K Dequantization - AVX2
- // BlockQ8_K layout:
- // offset 0: D (float32)
- // offset 4: QS[256] (int8)
- // ============================================================================
- // func dequantQ8KAVX2(b *BlockQ8_K, out *float32)
- TEXT ·dequantQ8KAVX2(SB), NOSPLIT, $0-16
- MOVQ b+0(FP), DI
- MOVQ out+8(FP), SI
- // Broadcast scale d to Y0
- VBROADCASTSS (DI), Y0
- // QS pointer = b + 4
- LEAQ 4(DI), R8
- MOVQ SI, R9
- // Process 256 elements, 32 at a time (unrolled)
- MOVQ $0, CX
- loop_q8:
- CMPQ CX, $256
- JGE done_q8
- // Load 8 int8, sign-extend to 8 int32, convert to float, multiply by scale
- VPMOVSXBD (R8), Y1
- VCVTDQ2PS Y1, Y1
- VMULPS Y0, Y1, Y1
- VMOVUPS Y1, (R9)
- VPMOVSXBD 8(R8), Y2
- VCVTDQ2PS Y2, Y2
- VMULPS Y0, Y2, Y2
- VMOVUPS Y2, 32(R9)
- VPMOVSXBD 16(R8), Y3
- VCVTDQ2PS Y3, Y3
- VMULPS Y0, Y3, Y3
- VMOVUPS Y3, 64(R9)
- VPMOVSXBD 24(R8), Y4
- VCVTDQ2PS Y4, Y4
- VMULPS Y0, Y4, Y4
- VMOVUPS Y4, 96(R9)
- ADDQ $32, R8
- ADDQ $128, R9
- ADDQ $32, CX
- JMP loop_q8
- done_q8:
- VZEROUPPER
- RET
- // =========================================================================
- // Q3_K Fused Dot Inner Loop - AVX2
- // Computes: sum_i x[i] * (dl * qv_i) for 16 elements
- // where qv_i is 2-bit value with sign via hm/m.
- // func dotQ3KInnerAVX2Fused(q *byte, hm *byte, x *float32, dl float32, m uint8, shift uint) float32
- TEXT ·dotQ3KInnerAVX2Fused(SB), NOSPLIT, $0-48
- MOVQ q+0(FP), DI
- MOVQ hm+8(FP), SI
- MOVQ x+16(FP), DX
- // Load dl (float32) -> Y0
- VBROADCASTSS dl+24(FP), Y0
- // Load shift -> X2
- MOVQ shift+32(FP), AX
- MOVD AX, X2
- // Load 16 bytes from q -> 16 words in Y3
- VPMOVZXBW (DI), Y3
- // Shift right by variable 'shift'
- VPSRLW X2, Y3, Y3
- // Mask with 3 (0x0003)
- MOVL $3, BX
- MOVD BX, X4
- VPBROADCASTW X4, Y4
- VPAND Y4, Y3, Y3 // Y3 = (q >> shift) & 3
- // Handle HMask
- MOVBLZX m+28(FP), BX
- MOVD BX, X5
- VPBROADCASTB X5, X5
- VMOVDQU (SI), X6
- VPAND X5, X6, X6
- VPXOR X7, X7, X7
- VPCMPEQB X7, X6, X6
- VPMOVSXBW X6, Y6
- MOVL $-4, BX
- MOVD BX, X7
- VPBROADCASTW X7, Y7
- VPAND Y7, Y6, Y6
- VPADDW Y6, Y3, Y3
- // Split into low/high 8 words -> int32
- // Low half of Y3 is already in X3.
- VPMOVSXWD X3, Y8
- VEXTRACTI128 $1, Y3, X4
- VPMOVSXWD X4, Y9
- // Convert to float and scale by dl
- VCVTDQ2PS Y8, Y8
- VCVTDQ2PS Y9, Y9
- VMULPS Y0, Y8, Y8
- VMULPS Y0, Y9, Y9
- // Load x and accumulate dot
- VMOVUPS (DX), Y10
- VMOVUPS 32(DX), Y11
- VMULPS Y10, Y8, Y8
- VMULPS Y11, Y9, Y9
- VADDPS Y9, Y8, Y8
- // Horizontal sum Y8 -> X0
- VEXTRACTF128 $1, Y8, X1
- VADDPS X1, X8, X8
- VHADDPS X8, X8, X8
- VHADDPS X8, X8, X8
- VMOVSS X8, ret+40(FP)
- VZEROUPPER
- RET
- // ==========================================================================
- // Q2_K Fused Dot - AVX2
- // Computes: sum_i x[i] * (dl*val_i - ml), for i in [0..15]
- // where val_i = (q[i] >> shift) & 3
- // func dotQ2KInnerAVX2(q *byte, x *float32, dl, ml float32, shift uint) float32
- TEXT ·dotQ2KInnerAVX2Fused(SB), NOSPLIT, $0-40
- MOVQ q+0(FP), DI
- MOVQ x+8(FP), SI
- VBROADCASTSS dl+16(FP), Y0
- VBROADCASTSS ml+20(FP), Y1
- MOVQ shift+24(FP), CX
- // Mask for 2 bits
- MOVL $0x03030303, AX
- MOVD AX, X7
- VPBROADCASTD X7, Y7
- // Shift amount
- MOVD CX, X6
- // Accumulator
- VXORPS Y15, Y15, Y15
- // Low 8 bytes
- VPMOVZXBD (DI), Y2
- VPSRLD X6, Y2, Y2
- VPAND Y7, Y2, Y2
- VCVTDQ2PS Y2, Y2
- VMULPS Y0, Y2, Y2
- VSUBPS Y1, Y2, Y2
- VMOVUPS (SI), Y4
- VFMADD231PS Y4, Y2, Y15
- // High 8 bytes
- VPMOVZXBD 8(DI), Y3
- VPSRLD X6, Y3, Y3
- VPAND Y7, Y3, Y3
- VCVTDQ2PS Y3, Y3
- VMULPS Y0, Y3, Y3
- VSUBPS Y1, Y3, Y3
- VMOVUPS 32(SI), Y5
- VFMADD231PS Y5, Y3, Y15
- // Reduce Y15 -> scalar
- VEXTRACTF128 $1, Y15, X1
- VADDPS X1, X15, X0
- VHADDPS X0, X0, X0
- VHADDPS X0, X0, X0
- MOVSS X0, ret+32(FP)
- VZEROUPPER
- RET
- // func dotQ5KInnerAVX2(qs *byte, qh *byte, x *float32, d1, m1, d2, m2 float32, u1, u2 uint) float32
- TEXT ·dotQ5KInnerAVX2(SB), NOSPLIT, $0-64
- MOVQ qs+0(FP), DI
- MOVQ qh+8(FP), SI
- MOVQ x+16(FP), DX
- VBROADCASTSS d1+24(FP), Y0
- VBROADCASTSS m1+28(FP), Y1
- VBROADCASTSS d2+32(FP), Y2
- VBROADCASTSS m2+36(FP), Y3
- // low-nibble mask 0x0F
- MOVL $0x0F0F0F0F, AX
- MOVD AX, X6
- VPBROADCASTD X6, Y6
- // int32(16) constant
- MOVL $16, AX
- MOVD AX, X7
- VPBROADCASTD X7, Y7
- // mask1 = 1 << u1
- MOVQ u1+40(FP), CX
- MOVQ $1, AX
- SHLQ CL, AX
- MOVL AX, BX
- MOVD BX, X4
- VPBROADCASTD X4, Y4
- // mask2 = 1 << u2
- MOVQ u2+48(FP), CX
- MOVQ $1, AX
- SHLQ CL, AX
- MOVL AX, BX
- MOVD BX, X5
- VPBROADCASTD X5, Y5
- // zero and ones (int dwords)
- VXORPS Y8, Y8, Y8
- VPCMPEQD Y9, Y9, Y9
- // Accumulator
- VXORPS Y15, Y15, Y15
- MOVQ $0, CX
- dot_q5k_loop:
- CMPQ CX, $32
- JGE dot_q5k_reduce
- // Load 8 qs bytes
- VPMOVZXBD (DI), Y10
- // low nibble -> int32
- VPAND Y6, Y10, Y11
- // high nibble -> int32
- VPSRLD $4, Y10, Y10
- // Load 8 qh bytes and copy
- VPMOVZXBD (SI), Y12
- VMOVAPS Y12, Y13
- // flag1 int32: (qh & mask1) ? 16 : 0
- VPAND Y4, Y12, Y12
- VPCMPEQD Y8, Y12, Y12
- VPXOR Y9, Y12, Y12
- VPAND Y7, Y12, Y12
-
- // low dequant: (float(lowNib + flag1) * d1) - m1
- VCVTDQ2PS Y11, Y11
- VCVTDQ2PS Y12, Y12
- VADDPS Y12, Y11, Y11
- VMULPS Y0, Y11, Y11
- VSUBPS Y1, Y11, Y11
- VMOVUPS (DX), Y14
- VFMADD231PS Y14, Y11, Y15
- // flag2 int32: (qh & mask2) ? 16 : 0
- VPAND Y5, Y13, Y13
- VPCMPEQD Y8, Y13, Y13
- VPXOR Y9, Y13, Y13
- VPAND Y7, Y13, Y13
- // high dequant: (float(highNib + flag2) * d2) - m2
- VCVTDQ2PS Y10, Y10
- VCVTDQ2PS Y13, Y13
- VADDPS Y13, Y10, Y10
- VMULPS Y2, Y10, Y10
- VSUBPS Y3, Y10, Y10
- VMOVUPS 128(DX), Y14
- VFMADD231PS Y14, Y10, Y15
- ADDQ $8, DI
- ADDQ $8, SI
- ADDQ $32, DX
- ADDQ $8, CX
- JMP dot_q5k_loop
- dot_q5k_reduce:
- VEXTRACTF128 $1, Y15, X1
- VADDPS X1, X15, X0
- VHADDPS X0, X0, X0
- VHADDPS X0, X0, X0
- MOVSS X0, ret+56(FP)
- VZEROUPPER
- RET
- // ==========================================================================
- // Q8_K Fused Dot - AVX2
- // Computes: sum_i x[i] * (d * float32(qs[i]))
- // func dotQ8KAVX2(b *BlockQ8_K, x *float32) float32
- TEXT ·dotQ8KAVX2(SB), NOSPLIT, $0-24
- MOVQ b+0(FP), DI
- MOVQ x+8(FP), SI
- // Load scale d
- MOVSS (DI), X0
- VBROADCASTSS X0, Y0
- // QS pointer = b + 4
- LEAQ 4(DI), R8
- // Accumulator Y1
- VXORPS Y1, Y1, Y1
- MOVQ $0, CX
- dot_q8_256_loop:
- CMPQ CX, $256
- JGE dot_q8_256_reduce
- // 8x int8 -> 8x int32 -> 8x float
- VPMOVSXBD (R8), Y2
- VCVTDQ2PS Y2, Y2
- VMULPS Y0, Y2, Y2
- // load 8 floats from x
- VMOVUPS (SI), Y3
- VFMADD231PS Y3, Y2, Y1
- ADDQ $8, R8
- ADDQ $32, SI
- ADDQ $8, CX
- JMP dot_q8_256_loop
- dot_q8_256_reduce:
- // horizontal add ymm1 -> scalar x1
- VEXTRACTF128 $1, Y1, X2
- VADDPS X2, X1, X1
- VHADDPS X1, X1, X1
- VHADDPS X1, X1, X1
- MOVSS X1, ret+16(FP)
- VZEROUPPER
- RET
- // ============================================================================
- // Q4_K Inner Loop - AVX2 (Vectorized nibble extraction)
- // Processes 32 4-bit quants with pre-computed scales.
- // ============================================================================
- // func dequantQ4KInnerAVX2(qs *byte, out *float32, d1, m1, d2, m2 float32)
- TEXT ·dequantQ4KInnerAVX2(SB), NOSPLIT, $0-40
- MOVQ qs+0(FP), DI
- MOVQ out+8(FP), SI
-
- // Broadcast d1, m1, d2, m2
- VBROADCASTSS d1+16(FP), Y0 // d1
- VBROADCASTSS m1+20(FP), Y1 // m1
- VBROADCASTSS d2+24(FP), Y2 // d2
- VBROADCASTSS m2+28(FP), Y3 // m2
- // Mask for low nibble (0x0F repeated)
- MOVL $0x0F0F0F0F, AX
- MOVD AX, X7
- VPBROADCASTD X7, Y7
- // Process 32 quants, 8 at a time
- MOVQ $0, CX
- loop_q4k:
- CMPQ CX, $32
- JGE done_q4k
- // Load 8 bytes from QS as unsigned
- VPMOVZXBD (DI), Y4 // 8 bytes -> 8 uint32
- // Extract low nibbles: v1 = val & 0xF
- VPAND Y7, Y4, Y5
- VCVTDQ2PS Y5, Y5
- VFMSUB132PS Y0, Y1, Y5 // out[i] = v1*d1 - m1
- VMOVUPS Y5, (SI)
- // Extract high nibbles: v2 = val >> 4
- VPSRLD $4, Y4, Y4
- VCVTDQ2PS Y4, Y4
- VFMSUB132PS Y2, Y3, Y4 // out[i+32] = v2*d2 - m2
- VMOVUPS Y4, 128(SI) // 32 * 4 bytes offset
- ADDQ $8, DI
- ADDQ $32, SI
- ADDQ $8, CX
- JMP loop_q4k
- done_q4k:
- VZEROUPPER
- RET
- // func dotQ4KInnerAVX2(qs *byte, x *float32, d1, m1, d2, m2 float32) float32
- TEXT ·dotQ4KInnerAVX2(SB), NOSPLIT, $0-40
- MOVQ qs+0(FP), DI
- MOVQ x+8(FP), SI
- // Broadcast d1, m1, d2, m2
- VBROADCASTSS d1+16(FP), Y0
- VBROADCASTSS m1+20(FP), Y1
- VBROADCASTSS d2+24(FP), Y2
- VBROADCASTSS m2+28(FP), Y3
- // Mask for low nibble (0x0F repeated)
- MOVL $0x0F0F0F0F, AX
- MOVD AX, X7
- VPBROADCASTD X7, Y7
- // Accumulators:
- // Y12 = sum(x_low * v1)
- // Y13 = sum(x_low)
- // Y14 = sum(x_high * v2)
- // Y15 = sum(x_high)
- VXORPS Y12, Y12, Y12
- VXORPS Y13, Y13, Y13
- VXORPS Y14, Y14, Y14
- VXORPS Y15, Y15, Y15
- // Process 32 bytes as 4x8-byte chunks
- MOVQ $0, CX
- dot_q4k_loop:
- CMPQ CX, $32
- JGE dot_q4k_reduce
- // Load 8 bytes from QS as unsigned dwords
- VPMOVZXBD (DI), Y8
- // Low nibble values -> float
- VPAND Y7, Y8, Y9
- VCVTDQ2PS Y9, Y9
- // x low: 8 floats
- VMOVUPS (SI), Y10
- VADDPS Y10, Y13, Y13
- VFMADD231PS Y10, Y9, Y12
- // High nibble values -> float
- VPSRLD $4, Y8, Y8
- VCVTDQ2PS Y8, Y8
- // x high: offset by 32 floats (128 bytes)
- VMOVUPS 128(SI), Y11
- VADDPS Y11, Y15, Y15
- VFMADD231PS Y11, Y8, Y14
- ADDQ $8, DI
- ADDQ $32, SI
- ADDQ $8, CX
- JMP dot_q4k_loop
- dot_q4k_reduce:
- // result = d1*sum(x1*v1) - m1*sum(x1) + d2*sum(x2*v2) - m2*sum(x2)
- VMULPS Y0, Y12, Y12
- VMULPS Y1, Y13, Y13
- VSUBPS Y13, Y12, Y12
- VMULPS Y2, Y14, Y14
- VMULPS Y3, Y15, Y15
- VSUBPS Y15, Y14, Y14
- VADDPS Y14, Y12, Y12
- // Horizontal add ymm12 -> scalar in X0
- VEXTRACTF128 $1, Y12, X1
- VADDPS X1, X12, X0
- VHADDPS X0, X0, X0
- VHADDPS X0, X0, X0
- MOVSS X0, ret+32(FP)
- VZEROUPPER
- RET
- // ============================================================================
- // Q2_K Inner Loop - AVX2
- // Processes 16 2-bit values with scale and min applied
- // ============================================================================
- // func dequantQ2KInnerAVX2(q *byte, out *float32, dl, ml float32, shift uint)
- TEXT ·dequantQ2KInnerAVX2(SB), NOSPLIT, $0-32
- MOVQ q+0(FP), DI
- MOVQ out+8(FP), SI
- VBROADCASTSS dl+16(FP), Y0
- VBROADCASTSS ml+20(FP), Y1
- MOVQ shift+24(FP), CX
- // Mask for 2 bits
- MOVL $0x03030303, AX
- MOVD AX, X7
- VPBROADCASTD X7, Y7
- // Load 16 bytes, extract 2-bit values
- VPMOVZXBD (DI), Y2 // First 8 bytes -> 8 int32
- VPMOVZXBD 8(DI), Y3 // Next 8 bytes -> 8 int32
- // Shift right by 'shift' and mask
- MOVD CX, X6
- VPSRLD X6, Y2, Y2
- VPAND Y7, Y2, Y2
- VPSRLD X6, Y3, Y3
- VPAND Y7, Y3, Y3
- // Convert to float and compute: dl*val - ml
- VCVTDQ2PS Y2, Y2
- VFMSUB132PS Y0, Y1, Y2
- VMOVUPS Y2, (SI)
- VCVTDQ2PS Y3, Y3
- VFMSUB132PS Y0, Y1, Y3
- VMOVUPS Y3, 32(SI)
- VZEROUPPER
- RET
- // ============================================================================
- // Q3_K Inner Loop - AVX2
- // Processes 16 output elements (consuming 16 bytes from q)
- // ============================================================================
- // func dequantQ3KInnerAVX2(q *byte, hm *byte, out *float32, dl float32, m uint8, shift uint)
- TEXT ·dequantQ3KInnerAVX2(SB), NOSPLIT, $0-40
- MOVQ q+0(FP), DI
- MOVQ hm+8(FP), SI
- MOVQ out+16(FP), DX
-
- // Load dl (float32) -> Y0
- VBROADCASTSS dl+24(FP), Y0
- // Load shift -> X2
- MOVQ shift+32(FP), AX
- MOVD AX, X2
- // Load 16 bytes from q -> 16 words in Y3
- VPMOVZXBW (DI), Y3
- // Shift right by variable 'shift'
- VPSRLW X2, Y3, Y3
- // Mask with 3 (0x0003)
- MOVL $3, BX
- MOVD BX, X4
- VPBROADCASTW X4, Y4
- VPAND Y4, Y3, Y3 // Y3 = (q >> shift) & 3
- // Handle HMask
- // Load `m` (byte)
- MOVBLZX m+28(FP), BX
- MOVD BX, X5
- VPBROADCASTB X5, X5 // X5 = m repeated
-
- // Load 16 bytes hm
- VMOVDQU (SI), X6
-
- // Check (hm & m) == 0
- VPAND X5, X6, X6 // X = hm & m
- VPXOR X7, X7, X7 // Zero
- VPCMPEQB X7, X6, X6 // X6 = (hm&m == 0) ? FF : 00
- // Expand byte mask to word mask (-1 or 0)
- VPMOVSXBW X6, Y6 // Y6 = -1 or 0
- // We want to subtract 4 if mask is -1.
- // Add (mask & -4).
- MOVL $-4, BX // 0xFFFFFFFC
- MOVD BX, X7
- VPBROADCASTW X7, Y7 // Y7 = -4 repeated
-
- VPAND Y7, Y6, Y6 // Y6 = -4 or 0
- VPADDW Y6, Y3, Y3 // Y3 = val - 4 (if needed)
- // Convert to float (Y3 has 16 int16)
- // Split into low 8 (Y8) and high 8 (Y9) as int32
- VPMOVSXWD X3, Y8 // Low 8 words -> 8 int32
- VEXTRACTI128 $1, Y3, X3
- VPMOVSXWD X3, Y9 // High 8 words -> 8 int32
- VCVTDQ2PS Y8, Y8
- VCVTDQ2PS Y9, Y9
- VMULPS Y0, Y8, Y8
- VMULPS Y0, Y9, Y9
- // Store 16 floats
- VMOVUPS Y8, (DX)
- VMOVUPS Y9, 32(DX)
- VZEROUPPER
- RET
- // ============================================================================
- // Q6_K Inner Loop - AVX2
- //func dequantQ6KInnerAVX2(ql *byte, qh *byte, scales *int8, out *float32, d float32)
- // Processes 128 elements (all 4 sub-blocks)
- // ============================================================================
- TEXT ·dequantQ6KInnerAVX2(SB), NOSPLIT, $0-40
- MOVQ ql+0(FP), DI
- MOVQ qh+8(FP), SI
- MOVQ scales+16(FP), DX
- MOVQ out+24(FP), R8
-
- // Broadcast d (float32) -> Y0
- VBROADCASTSS d+32(FP), Y0
- // Y15 = 0x0F (mask)
- MOVL $0x0F, AX
- MOVD AX, X15
- VPBROADCASTB X15, Y15
- // Y14 = 0x03 (mask)
- MOVL $0x03, AX
- MOVD AX, X14
- VPBROADCASTB X14, Y14
- // Y13 = 32.0 (float)
- MOVL $0x42000000, AX // float 32.0
- MOVD AX, X13
- VBROADCASTSS X13, Y13
- // Registers:
- // R9: Loop counter (0, 16)
- MOVQ $0, R9
- loop_q6k:
- CMPQ R9, $32
- JGE done_q6k
- // Load qh chunk (16 bytes) -> X1
- VMOVDQU (SI)(R9*1), X1
-
- // Load ql chunk 1 (16 bytes) -> X2
- VMOVDQU (DI)(R9*1), X2
-
- // Load ql chunk 2 (16 bytes) -> X3
- VMOVDQU 32(DI)(R9*1), X3
- // Mask for bit shifting logic
- MOVL $0xF0, AX
- MOVD AX, X6
- VPBROADCASTB X6, X6
- // --- Q1 ---
- // (ql_c1 & 0xF) | ((qh & 3) << 4)
- VPAND X15, X2, X4
- VPAND X14, X1, X5
- VPSLLW $4, X5, X5
- VPAND X6, X5, X5
- VPOR X5, X4, X4 // X4 = Q1 values (16 bytes)
- // Scale s[0] or s[1] -> offset R9>>4
- MOVQ R9, R11
- SHRQ $4, R11 // 0 or 1
- MOVBQSX (DX)(R11*1), BX
-
- // Convert X4 -> Y4/Y5 (floats), scale by BX, d, sub 32, store
- // (Inline expansion)
- VCVTSI2SSQ BX, X10, X10
- VBROADCASTSS X10, Y10 // Scale
-
- VMOVDQA X4, X7 // Copy X4 to X7
- VPMOVZXBD X7, Y4 // Low 8 bytes from X7 -> Y4
- VPSRLDQ $8, X4, X8 // Shift X4 -> X8
- VPMOVZXBD X8, Y5 // Low 8 bytes from X8 -> Y5
-
- VCVTDQ2PS Y4, Y4
- VCVTDQ2PS Y5, Y5
- VSUBPS Y13, Y4, Y4
- VSUBPS Y13, Y5, Y5
- VMULPS Y10, Y4, Y4
- VMULPS Y0, Y4, Y4
- VMULPS Y10, Y5, Y5
- VMULPS Y0, Y5, Y5
-
- // Store to out + R9*4
- LEAQ (R8)(R9*4), R12
- VMOVUPS Y4, (R12)
- VMOVUPS Y5, 32(R12)
- // --- Q2 ---
- // (ql_c2 & 0xF) | (((qh >> 2) & 3) << 4)
- VPAND X15, X3, X4
- VPSRLW $2, X1, X5
- VPAND X14, X5, X5
- VPSLLW $4, X5, X5
- VPAND X6, X5, X5
- VPOR X5, X4, X4
- // Scale s[2] or s[3] -> offset R9>>4 + 2
- MOVBQSX 2(DX)(R11*1), BX
- VCVTSI2SSQ BX, X10, X10
- VBROADCASTSS X10, Y10
-
- VMOVDQA X4, X7
- VPMOVZXBD X7, Y4
- VPSRLDQ $8, X4, X8
- VPMOVZXBD X8, Y5
- VCVTDQ2PS Y4, Y4
- VCVTDQ2PS Y5, Y5
- VSUBPS Y13, Y4, Y4
- VSUBPS Y13, Y5, Y5
- VMULPS Y10, Y4, Y4
- VMULPS Y0, Y4, Y4
- VMULPS Y10, Y5, Y5
- VMULPS Y0, Y5, Y5
-
- // Store to out + 128 + R9*4
- LEAQ 128(R8)(R9*4), R12
- VMOVUPS Y4, (R12)
- VMOVUPS Y5, 32(R12)
- // --- Q3 ---
- // (ql_c1 >> 4) | (((qh >> 4) & 3) << 4)
- VPSRLW $4, X2, X4
- VPAND X15, X4, X4
- VPSRLW $4, X1, X5
- VPAND X14, X5, X5
- VPSLLW $4, X5, X5
- VPAND X6, X5, X5
- VPOR X5, X4, X4
-
- // Scale s[4] or s[5]
- MOVBQSX 4(DX)(R11*1), BX
- VCVTSI2SSQ BX, X10, X10
- VBROADCASTSS X10, Y10
-
- VMOVDQA X4, X7
- VPMOVZXBD X7, Y4
- VPSRLDQ $8, X4, X8
- VPMOVZXBD X8, Y5
- VCVTDQ2PS Y4, Y4
- VCVTDQ2PS Y5, Y5
- VSUBPS Y13, Y4, Y4
- VSUBPS Y13, Y5, Y5
- VMULPS Y10, Y4, Y4
- VMULPS Y0, Y4, Y4
- VMULPS Y10, Y5, Y5
- VMULPS Y0, Y5, Y5
-
- LEAQ 256(R8)(R9*4), R12
- VMOVUPS Y4, (R12)
- VMOVUPS Y5, 32(R12)
- // --- Q4 ---
- // (ql_c2 >> 4) | (((qh >> 6) & 3) << 4)
- VPSRLW $4, X3, X4
- VPAND X15, X4, X4
- VPSRLW $6, X1, X5
- VPAND X14, X5, X5
- VPSLLW $4, X5, X5
- VPAND X6, X5, X5
- VPOR X5, X4, X4
-
- // Scale s[6] or s[7]
- MOVBQSX 6(DX)(R11*1), BX
- VCVTSI2SSQ BX, X10, X10
- VBROADCASTSS X10, Y10
-
- VMOVDQA X4, X7
- VPMOVZXBD X7, Y4
- VPSRLDQ $8, X4, X8
- VPMOVZXBD X8, Y5
- VCVTDQ2PS Y4, Y4
- VCVTDQ2PS Y5, Y5
- VSUBPS Y13, Y4, Y4
- VSUBPS Y13, Y5, Y5
- VMULPS Y10, Y4, Y4
- VMULPS Y0, Y4, Y4
- VMULPS Y10, Y5, Y5
- VMULPS Y0, Y5, Y5
-
- LEAQ 384(R8)(R9*4), R12
- VMOVUPS Y4, (R12)
- VMOVUPS Y5, 32(R12)
- ADDQ $16, R9
- JMP loop_q6k
- done_q6k:
- VZEROUPPER
- RET
- // ============================================================================
- // Q6_K Fused Dot Inner Loop - AVX2
- // func dotQ6KInnerAVX2(ql *byte, qh *byte, scales *float32, x *float32) float32
- // Processes 128 elements (one half of block), returns partial dot sum
- // scales: 8 precomputed float32 values (d*scale[0..7])
- // ============================================================================
- TEXT ·dotQ6KInnerAVX2(SB), NOSPLIT, $0-40
- MOVQ ql+0(FP), DI // QL pointer (64 bytes)
- MOVQ qh+8(FP), SI // QH pointer (32 bytes)
- MOVQ scales+16(FP), DX // Precomputed scales (8 floats)
- MOVQ x+24(FP), R8 // X pointer (128 floats)
- // Y11 = 0x0F as dwords (for masking after VPMOVZXBD)
- MOVL $0x0F, AX
- MOVD AX, X11
- VPBROADCASTD X11, Y11
- // Y10 = 0x03 as dwords
- MOVL $0x03, AX
- MOVD AX, X10
- VPBROADCASTD X10, Y10
- // Y9 = 32.0 (float bias)
- MOVL $0x42000000, AX
- MOVD AX, X9
- VBROADCASTSS X9, Y9
- // Y8 = accumulator for dot product
- VXORPS Y8, Y8, Y8
- // Process 8 elements at a time (4 iterations for 32 elements)
- // Each iteration: load 8 QL bytes, 8 QH bytes, compute Q1-Q4 for 8 elements
- MOVQ $0, R9
- dotq6k_loop:
- CMPQ R9, $32
- JGE dotq6k_done
- // R11 = R9 >> 4 (0 or 1, for scale indexing)
- MOVQ R9, R11
- SHRQ $4, R11
- // Load 8 bytes of QL (for Q1/Q3) and 8 bytes of QH
- VPMOVZXBD (DI)(R9*1), Y0 // QL[R9..R9+7] -> 8 dwords
- VPMOVZXBD 32(DI)(R9*1), Y1 // QL[32+R9..32+R9+7] -> 8 dwords (for Q2/Q4)
- VPMOVZXBD (SI)(R9*1), Y2 // QH[R9..R9+7] -> 8 dwords
- // --- Q1: (ql & 0xF) | ((qh & 3) << 4) ---
- VPAND Y11, Y0, Y3 // ql & 0x0F
- VPAND Y10, Y2, Y4 // qh & 0x03
- VPSLLD $4, Y4, Y4 // << 4
- VPOR Y4, Y3, Y3 // combine
- VCVTDQ2PS Y3, Y3
- VSUBPS Y9, Y3, Y3 // q - 32
- VBROADCASTSS (DX)(R11*4), Y4 // scale s[0] or s[1]
- VMULPS Y4, Y3, Y3 // * scale
- LEAQ (R8)(R9*4), R12
- VMOVUPS (R12), Y5 // x[R9..R9+7]
- VFMADD231PS Y3, Y5, Y8 // acc += q * x
- // --- Q2: (ql32 & 0xF) | (((qh >> 2) & 3) << 4) ---
- VPAND Y11, Y1, Y3 // ql32 & 0x0F
- VPSRLD $2, Y2, Y4 // qh >> 2
- VPAND Y10, Y4, Y4 // & 0x03
- VPSLLD $4, Y4, Y4 // << 4
- VPOR Y4, Y3, Y3 // combine
- VCVTDQ2PS Y3, Y3
- VSUBPS Y9, Y3, Y3 // q - 32
- VBROADCASTSS 8(DX)(R11*4), Y4 // scale s[2] or s[3]
- VMULPS Y4, Y3, Y3 // * scale
- VMOVUPS 128(R12), Y5 // x[32+R9..32+R9+7]
- VFMADD231PS Y3, Y5, Y8 // acc += q * x
- // --- Q3: (ql >> 4) | (((qh >> 4) & 3) << 4) ---
- VPSRLD $4, Y0, Y3 // ql >> 4
- VPAND Y11, Y3, Y3 // & 0x0F
- VPSRLD $4, Y2, Y4 // qh >> 4
- VPAND Y10, Y4, Y4 // & 0x03
- VPSLLD $4, Y4, Y4 // << 4
- VPOR Y4, Y3, Y3 // combine
- VCVTDQ2PS Y3, Y3
- VSUBPS Y9, Y3, Y3 // q - 32
- VBROADCASTSS 16(DX)(R11*4), Y4 // scale s[4] or s[5]
- VMULPS Y4, Y3, Y3 // * scale
- VMOVUPS 256(R12), Y5 // x[64+R9..64+R9+7]
- VFMADD231PS Y3, Y5, Y8 // acc += q * x
- // --- Q4: (ql32 >> 4) | (((qh >> 6) & 3) << 4) ---
- VPSRLD $4, Y1, Y3 // ql32 >> 4
- VPAND Y11, Y3, Y3 // & 0x0F
- VPSRLD $6, Y2, Y4 // qh >> 6
- VPAND Y10, Y4, Y4 // & 0x03
- VPSLLD $4, Y4, Y4 // << 4
- VPOR Y4, Y3, Y3 // combine
- VCVTDQ2PS Y3, Y3
- VSUBPS Y9, Y3, Y3 // q - 32
- VBROADCASTSS 24(DX)(R11*4), Y4 // scale s[6] or s[7]
- VMULPS Y4, Y3, Y3 // * scale
- VMOVUPS 384(R12), Y5 // x[96+R9..96+R9+7]
- VFMADD231PS Y3, Y5, Y8 // acc += q * x
- ADDQ $8, R9
- JMP dotq6k_loop
- dotq6k_done:
- // Horizontal sum of Y8
- VEXTRACTF128 $1, Y8, X0
- VADDPS X8, X0, X0
- VHADDPS X0, X0, X0
- VHADDPS X0, X0, X0
- VMOVSS X0, ret+32(FP)
- VZEROUPPER
- RET
|