//go:build amd64 // +build amd64 #include "textflag.h" // ============================================================================ // Q8_K Dequantization - AVX2 // BlockQ8_K layout: // offset 0: D (float32) // offset 4: QS[256] (int8) // ============================================================================ // func dequantQ8KAVX2(b *BlockQ8_K, out *float32) TEXT ·dequantQ8KAVX2(SB), NOSPLIT, $0-16 MOVQ b+0(FP), DI MOVQ out+8(FP), SI // Broadcast scale d to Y0 VBROADCASTSS (DI), Y0 // QS pointer = b + 4 LEAQ 4(DI), R8 MOVQ SI, R9 // Process 256 elements, 32 at a time (unrolled) MOVQ $0, CX loop_q8: CMPQ CX, $256 JGE done_q8 // Load 8 int8, sign-extend to 8 int32, convert to float, multiply by scale VPMOVSXBD (R8), Y1 VCVTDQ2PS Y1, Y1 VMULPS Y0, Y1, Y1 VMOVUPS Y1, (R9) VPMOVSXBD 8(R8), Y2 VCVTDQ2PS Y2, Y2 VMULPS Y0, Y2, Y2 VMOVUPS Y2, 32(R9) VPMOVSXBD 16(R8), Y3 VCVTDQ2PS Y3, Y3 VMULPS Y0, Y3, Y3 VMOVUPS Y3, 64(R9) VPMOVSXBD 24(R8), Y4 VCVTDQ2PS Y4, Y4 VMULPS Y0, Y4, Y4 VMOVUPS Y4, 96(R9) ADDQ $32, R8 ADDQ $128, R9 ADDQ $32, CX JMP loop_q8 done_q8: VZEROUPPER RET // ========================================================================= // Q3_K Fused Dot Inner Loop - AVX2 // Computes: sum_i x[i] * (dl * qv_i) for 16 elements // where qv_i is 2-bit value with sign via hm/m. // func dotQ3KInnerAVX2Fused(q *byte, hm *byte, x *float32, dl float32, m uint8, shift uint) float32 TEXT ·dotQ3KInnerAVX2Fused(SB), NOSPLIT, $0-48 MOVQ q+0(FP), DI MOVQ hm+8(FP), SI MOVQ x+16(FP), DX // Load dl (float32) -> Y0 VBROADCASTSS dl+24(FP), Y0 // Load shift -> X2 MOVQ shift+32(FP), AX MOVD AX, X2 // Load 16 bytes from q -> 16 words in Y3 VPMOVZXBW (DI), Y3 // Shift right by variable 'shift' VPSRLW X2, Y3, Y3 // Mask with 3 (0x0003) MOVL $3, BX MOVD BX, X4 VPBROADCASTW X4, Y4 VPAND Y4, Y3, Y3 // Y3 = (q >> shift) & 3 // Handle HMask MOVBLZX m+28(FP), BX MOVD BX, X5 VPBROADCASTB X5, X5 VMOVDQU (SI), X6 VPAND X5, X6, X6 VPXOR X7, X7, X7 VPCMPEQB X7, X6, X6 VPMOVSXBW X6, Y6 MOVL $-4, BX MOVD BX, X7 VPBROADCASTW X7, Y7 VPAND Y7, Y6, Y6 VPADDW Y6, Y3, Y3 // Split into low/high 8 words -> int32 // Low half of Y3 is already in X3. VPMOVSXWD X3, Y8 VEXTRACTI128 $1, Y3, X4 VPMOVSXWD X4, Y9 // Convert to float and scale by dl VCVTDQ2PS Y8, Y8 VCVTDQ2PS Y9, Y9 VMULPS Y0, Y8, Y8 VMULPS Y0, Y9, Y9 // Load x and accumulate dot VMOVUPS (DX), Y10 VMOVUPS 32(DX), Y11 VMULPS Y10, Y8, Y8 VMULPS Y11, Y9, Y9 VADDPS Y9, Y8, Y8 // Horizontal sum Y8 -> X0 VEXTRACTF128 $1, Y8, X1 VADDPS X1, X8, X8 VHADDPS X8, X8, X8 VHADDPS X8, X8, X8 VMOVSS X8, ret+40(FP) VZEROUPPER RET // ========================================================================== // Q2_K Fused Dot - AVX2 // Computes: sum_i x[i] * (dl*val_i - ml), for i in [0..15] // where val_i = (q[i] >> shift) & 3 // func dotQ2KInnerAVX2(q *byte, x *float32, dl, ml float32, shift uint) float32 TEXT ·dotQ2KInnerAVX2Fused(SB), NOSPLIT, $0-40 MOVQ q+0(FP), DI MOVQ x+8(FP), SI VBROADCASTSS dl+16(FP), Y0 VBROADCASTSS ml+20(FP), Y1 MOVQ shift+24(FP), CX // Mask for 2 bits MOVL $0x03030303, AX MOVD AX, X7 VPBROADCASTD X7, Y7 // Shift amount MOVD CX, X6 // Accumulator VXORPS Y15, Y15, Y15 // Low 8 bytes VPMOVZXBD (DI), Y2 VPSRLD X6, Y2, Y2 VPAND Y7, Y2, Y2 VCVTDQ2PS Y2, Y2 VMULPS Y0, Y2, Y2 VSUBPS Y1, Y2, Y2 VMOVUPS (SI), Y4 VFMADD231PS Y4, Y2, Y15 // High 8 bytes VPMOVZXBD 8(DI), Y3 VPSRLD X6, Y3, Y3 VPAND Y7, Y3, Y3 VCVTDQ2PS Y3, Y3 VMULPS Y0, Y3, Y3 VSUBPS Y1, Y3, Y3 VMOVUPS 32(SI), Y5 VFMADD231PS Y5, Y3, Y15 // Reduce Y15 -> scalar VEXTRACTF128 $1, Y15, X1 VADDPS X1, X15, X0 VHADDPS X0, X0, X0 VHADDPS X0, X0, X0 MOVSS X0, ret+32(FP) VZEROUPPER RET // func dotQ5KInnerAVX2(qs *byte, qh *byte, x *float32, d1, m1, d2, m2 float32, u1, u2 uint) float32 TEXT ·dotQ5KInnerAVX2(SB), NOSPLIT, $0-64 MOVQ qs+0(FP), DI MOVQ qh+8(FP), SI MOVQ x+16(FP), DX VBROADCASTSS d1+24(FP), Y0 VBROADCASTSS m1+28(FP), Y1 VBROADCASTSS d2+32(FP), Y2 VBROADCASTSS m2+36(FP), Y3 // low-nibble mask 0x0F MOVL $0x0F0F0F0F, AX MOVD AX, X6 VPBROADCASTD X6, Y6 // int32(16) constant MOVL $16, AX MOVD AX, X7 VPBROADCASTD X7, Y7 // mask1 = 1 << u1 MOVQ u1+40(FP), CX MOVQ $1, AX SHLQ CL, AX MOVL AX, BX MOVD BX, X4 VPBROADCASTD X4, Y4 // mask2 = 1 << u2 MOVQ u2+48(FP), CX MOVQ $1, AX SHLQ CL, AX MOVL AX, BX MOVD BX, X5 VPBROADCASTD X5, Y5 // zero and ones (int dwords) VXORPS Y8, Y8, Y8 VPCMPEQD Y9, Y9, Y9 // Accumulator VXORPS Y15, Y15, Y15 MOVQ $0, CX dot_q5k_loop: CMPQ CX, $32 JGE dot_q5k_reduce // Load 8 qs bytes VPMOVZXBD (DI), Y10 // low nibble -> int32 VPAND Y6, Y10, Y11 // high nibble -> int32 VPSRLD $4, Y10, Y10 // Load 8 qh bytes and copy VPMOVZXBD (SI), Y12 VMOVAPS Y12, Y13 // flag1 int32: (qh & mask1) ? 16 : 0 VPAND Y4, Y12, Y12 VPCMPEQD Y8, Y12, Y12 VPXOR Y9, Y12, Y12 VPAND Y7, Y12, Y12 // low dequant: (float(lowNib + flag1) * d1) - m1 VCVTDQ2PS Y11, Y11 VCVTDQ2PS Y12, Y12 VADDPS Y12, Y11, Y11 VMULPS Y0, Y11, Y11 VSUBPS Y1, Y11, Y11 VMOVUPS (DX), Y14 VFMADD231PS Y14, Y11, Y15 // flag2 int32: (qh & mask2) ? 16 : 0 VPAND Y5, Y13, Y13 VPCMPEQD Y8, Y13, Y13 VPXOR Y9, Y13, Y13 VPAND Y7, Y13, Y13 // high dequant: (float(highNib + flag2) * d2) - m2 VCVTDQ2PS Y10, Y10 VCVTDQ2PS Y13, Y13 VADDPS Y13, Y10, Y10 VMULPS Y2, Y10, Y10 VSUBPS Y3, Y10, Y10 VMOVUPS 128(DX), Y14 VFMADD231PS Y14, Y10, Y15 ADDQ $8, DI ADDQ $8, SI ADDQ $32, DX ADDQ $8, CX JMP dot_q5k_loop dot_q5k_reduce: VEXTRACTF128 $1, Y15, X1 VADDPS X1, X15, X0 VHADDPS X0, X0, X0 VHADDPS X0, X0, X0 MOVSS X0, ret+56(FP) VZEROUPPER RET // ========================================================================== // Q8_K Fused Dot - AVX2 // Computes: sum_i x[i] * (d * float32(qs[i])) // func dotQ8KAVX2(b *BlockQ8_K, x *float32) float32 TEXT ·dotQ8KAVX2(SB), NOSPLIT, $0-24 MOVQ b+0(FP), DI MOVQ x+8(FP), SI // Load scale d MOVSS (DI), X0 VBROADCASTSS X0, Y0 // QS pointer = b + 4 LEAQ 4(DI), R8 // Accumulator Y1 VXORPS Y1, Y1, Y1 MOVQ $0, CX dot_q8_256_loop: CMPQ CX, $256 JGE dot_q8_256_reduce // 8x int8 -> 8x int32 -> 8x float VPMOVSXBD (R8), Y2 VCVTDQ2PS Y2, Y2 VMULPS Y0, Y2, Y2 // load 8 floats from x VMOVUPS (SI), Y3 VFMADD231PS Y3, Y2, Y1 ADDQ $8, R8 ADDQ $32, SI ADDQ $8, CX JMP dot_q8_256_loop dot_q8_256_reduce: // horizontal add ymm1 -> scalar x1 VEXTRACTF128 $1, Y1, X2 VADDPS X2, X1, X1 VHADDPS X1, X1, X1 VHADDPS X1, X1, X1 MOVSS X1, ret+16(FP) VZEROUPPER RET // ============================================================================ // Q4_K Inner Loop - AVX2 (Vectorized nibble extraction) // Processes 32 4-bit quants with pre-computed scales. // ============================================================================ // func dequantQ4KInnerAVX2(qs *byte, out *float32, d1, m1, d2, m2 float32) TEXT ·dequantQ4KInnerAVX2(SB), NOSPLIT, $0-40 MOVQ qs+0(FP), DI MOVQ out+8(FP), SI // Broadcast d1, m1, d2, m2 VBROADCASTSS d1+16(FP), Y0 // d1 VBROADCASTSS m1+20(FP), Y1 // m1 VBROADCASTSS d2+24(FP), Y2 // d2 VBROADCASTSS m2+28(FP), Y3 // m2 // Mask for low nibble (0x0F repeated) MOVL $0x0F0F0F0F, AX MOVD AX, X7 VPBROADCASTD X7, Y7 // Process 32 quants, 8 at a time MOVQ $0, CX loop_q4k: CMPQ CX, $32 JGE done_q4k // Load 8 bytes from QS as unsigned VPMOVZXBD (DI), Y4 // 8 bytes -> 8 uint32 // Extract low nibbles: v1 = val & 0xF VPAND Y7, Y4, Y5 VCVTDQ2PS Y5, Y5 VFMSUB132PS Y0, Y1, Y5 // out[i] = v1*d1 - m1 VMOVUPS Y5, (SI) // Extract high nibbles: v2 = val >> 4 VPSRLD $4, Y4, Y4 VCVTDQ2PS Y4, Y4 VFMSUB132PS Y2, Y3, Y4 // out[i+32] = v2*d2 - m2 VMOVUPS Y4, 128(SI) // 32 * 4 bytes offset ADDQ $8, DI ADDQ $32, SI ADDQ $8, CX JMP loop_q4k done_q4k: VZEROUPPER RET // func dotQ4KInnerAVX2(qs *byte, x *float32, d1, m1, d2, m2 float32) float32 TEXT ·dotQ4KInnerAVX2(SB), NOSPLIT, $0-40 MOVQ qs+0(FP), DI MOVQ x+8(FP), SI // Broadcast d1, m1, d2, m2 VBROADCASTSS d1+16(FP), Y0 VBROADCASTSS m1+20(FP), Y1 VBROADCASTSS d2+24(FP), Y2 VBROADCASTSS m2+28(FP), Y3 // Mask for low nibble (0x0F repeated) MOVL $0x0F0F0F0F, AX MOVD AX, X7 VPBROADCASTD X7, Y7 // Accumulators: // Y12 = sum(x_low * v1) // Y13 = sum(x_low) // Y14 = sum(x_high * v2) // Y15 = sum(x_high) VXORPS Y12, Y12, Y12 VXORPS Y13, Y13, Y13 VXORPS Y14, Y14, Y14 VXORPS Y15, Y15, Y15 // Process 32 bytes as 4x8-byte chunks MOVQ $0, CX dot_q4k_loop: CMPQ CX, $32 JGE dot_q4k_reduce // Load 8 bytes from QS as unsigned dwords VPMOVZXBD (DI), Y8 // Low nibble values -> float VPAND Y7, Y8, Y9 VCVTDQ2PS Y9, Y9 // x low: 8 floats VMOVUPS (SI), Y10 VADDPS Y10, Y13, Y13 VFMADD231PS Y10, Y9, Y12 // High nibble values -> float VPSRLD $4, Y8, Y8 VCVTDQ2PS Y8, Y8 // x high: offset by 32 floats (128 bytes) VMOVUPS 128(SI), Y11 VADDPS Y11, Y15, Y15 VFMADD231PS Y11, Y8, Y14 ADDQ $8, DI ADDQ $32, SI ADDQ $8, CX JMP dot_q4k_loop dot_q4k_reduce: // result = d1*sum(x1*v1) - m1*sum(x1) + d2*sum(x2*v2) - m2*sum(x2) VMULPS Y0, Y12, Y12 VMULPS Y1, Y13, Y13 VSUBPS Y13, Y12, Y12 VMULPS Y2, Y14, Y14 VMULPS Y3, Y15, Y15 VSUBPS Y15, Y14, Y14 VADDPS Y14, Y12, Y12 // Horizontal add ymm12 -> scalar in X0 VEXTRACTF128 $1, Y12, X1 VADDPS X1, X12, X0 VHADDPS X0, X0, X0 VHADDPS X0, X0, X0 MOVSS X0, ret+32(FP) VZEROUPPER RET // ============================================================================ // Q2_K Inner Loop - AVX2 // Processes 16 2-bit values with scale and min applied // ============================================================================ // func dequantQ2KInnerAVX2(q *byte, out *float32, dl, ml float32, shift uint) TEXT ·dequantQ2KInnerAVX2(SB), NOSPLIT, $0-32 MOVQ q+0(FP), DI MOVQ out+8(FP), SI VBROADCASTSS dl+16(FP), Y0 VBROADCASTSS ml+20(FP), Y1 MOVQ shift+24(FP), CX // Mask for 2 bits MOVL $0x03030303, AX MOVD AX, X7 VPBROADCASTD X7, Y7 // Load 16 bytes, extract 2-bit values VPMOVZXBD (DI), Y2 // First 8 bytes -> 8 int32 VPMOVZXBD 8(DI), Y3 // Next 8 bytes -> 8 int32 // Shift right by 'shift' and mask MOVD CX, X6 VPSRLD X6, Y2, Y2 VPAND Y7, Y2, Y2 VPSRLD X6, Y3, Y3 VPAND Y7, Y3, Y3 // Convert to float and compute: dl*val - ml VCVTDQ2PS Y2, Y2 VFMSUB132PS Y0, Y1, Y2 VMOVUPS Y2, (SI) VCVTDQ2PS Y3, Y3 VFMSUB132PS Y0, Y1, Y3 VMOVUPS Y3, 32(SI) VZEROUPPER RET // ============================================================================ // Q3_K Inner Loop - AVX2 // Processes 16 output elements (consuming 16 bytes from q) // ============================================================================ // func dequantQ3KInnerAVX2(q *byte, hm *byte, out *float32, dl float32, m uint8, shift uint) TEXT ·dequantQ3KInnerAVX2(SB), NOSPLIT, $0-40 MOVQ q+0(FP), DI MOVQ hm+8(FP), SI MOVQ out+16(FP), DX // Load dl (float32) -> Y0 VBROADCASTSS dl+24(FP), Y0 // Load shift -> X2 MOVQ shift+32(FP), AX MOVD AX, X2 // Load 16 bytes from q -> 16 words in Y3 VPMOVZXBW (DI), Y3 // Shift right by variable 'shift' VPSRLW X2, Y3, Y3 // Mask with 3 (0x0003) MOVL $3, BX MOVD BX, X4 VPBROADCASTW X4, Y4 VPAND Y4, Y3, Y3 // Y3 = (q >> shift) & 3 // Handle HMask // Load `m` (byte) MOVBLZX m+28(FP), BX MOVD BX, X5 VPBROADCASTB X5, X5 // X5 = m repeated // Load 16 bytes hm VMOVDQU (SI), X6 // Check (hm & m) == 0 VPAND X5, X6, X6 // X = hm & m VPXOR X7, X7, X7 // Zero VPCMPEQB X7, X6, X6 // X6 = (hm&m == 0) ? FF : 00 // Expand byte mask to word mask (-1 or 0) VPMOVSXBW X6, Y6 // Y6 = -1 or 0 // We want to subtract 4 if mask is -1. // Add (mask & -4). MOVL $-4, BX // 0xFFFFFFFC MOVD BX, X7 VPBROADCASTW X7, Y7 // Y7 = -4 repeated VPAND Y7, Y6, Y6 // Y6 = -4 or 0 VPADDW Y6, Y3, Y3 // Y3 = val - 4 (if needed) // Convert to float (Y3 has 16 int16) // Split into low 8 (Y8) and high 8 (Y9) as int32 VPMOVSXWD X3, Y8 // Low 8 words -> 8 int32 VEXTRACTI128 $1, Y3, X3 VPMOVSXWD X3, Y9 // High 8 words -> 8 int32 VCVTDQ2PS Y8, Y8 VCVTDQ2PS Y9, Y9 VMULPS Y0, Y8, Y8 VMULPS Y0, Y9, Y9 // Store 16 floats VMOVUPS Y8, (DX) VMOVUPS Y9, 32(DX) VZEROUPPER RET // ============================================================================ // Q6_K Inner Loop - AVX2 //func dequantQ6KInnerAVX2(ql *byte, qh *byte, scales *int8, out *float32, d float32) // Processes 128 elements (all 4 sub-blocks) // ============================================================================ TEXT ·dequantQ6KInnerAVX2(SB), NOSPLIT, $0-40 MOVQ ql+0(FP), DI MOVQ qh+8(FP), SI MOVQ scales+16(FP), DX MOVQ out+24(FP), R8 // Broadcast d (float32) -> Y0 VBROADCASTSS d+32(FP), Y0 // Y15 = 0x0F (mask) MOVL $0x0F, AX MOVD AX, X15 VPBROADCASTB X15, Y15 // Y14 = 0x03 (mask) MOVL $0x03, AX MOVD AX, X14 VPBROADCASTB X14, Y14 // Y13 = 32.0 (float) MOVL $0x42000000, AX // float 32.0 MOVD AX, X13 VBROADCASTSS X13, Y13 // Registers: // R9: Loop counter (0, 16) MOVQ $0, R9 loop_q6k: CMPQ R9, $32 JGE done_q6k // Load qh chunk (16 bytes) -> X1 VMOVDQU (SI)(R9*1), X1 // Load ql chunk 1 (16 bytes) -> X2 VMOVDQU (DI)(R9*1), X2 // Load ql chunk 2 (16 bytes) -> X3 VMOVDQU 32(DI)(R9*1), X3 // Mask for bit shifting logic MOVL $0xF0, AX MOVD AX, X6 VPBROADCASTB X6, X6 // --- Q1 --- // (ql_c1 & 0xF) | ((qh & 3) << 4) VPAND X15, X2, X4 VPAND X14, X1, X5 VPSLLW $4, X5, X5 VPAND X6, X5, X5 VPOR X5, X4, X4 // X4 = Q1 values (16 bytes) // Scale s[0] or s[1] -> offset R9>>4 MOVQ R9, R11 SHRQ $4, R11 // 0 or 1 MOVBQSX (DX)(R11*1), BX // Convert X4 -> Y4/Y5 (floats), scale by BX, d, sub 32, store // (Inline expansion) VCVTSI2SSQ BX, X10, X10 VBROADCASTSS X10, Y10 // Scale VMOVDQA X4, X7 // Copy X4 to X7 VPMOVZXBD X7, Y4 // Low 8 bytes from X7 -> Y4 VPSRLDQ $8, X4, X8 // Shift X4 -> X8 VPMOVZXBD X8, Y5 // Low 8 bytes from X8 -> Y5 VCVTDQ2PS Y4, Y4 VCVTDQ2PS Y5, Y5 VSUBPS Y13, Y4, Y4 VSUBPS Y13, Y5, Y5 VMULPS Y10, Y4, Y4 VMULPS Y0, Y4, Y4 VMULPS Y10, Y5, Y5 VMULPS Y0, Y5, Y5 // Store to out + R9*4 LEAQ (R8)(R9*4), R12 VMOVUPS Y4, (R12) VMOVUPS Y5, 32(R12) // --- Q2 --- // (ql_c2 & 0xF) | (((qh >> 2) & 3) << 4) VPAND X15, X3, X4 VPSRLW $2, X1, X5 VPAND X14, X5, X5 VPSLLW $4, X5, X5 VPAND X6, X5, X5 VPOR X5, X4, X4 // Scale s[2] or s[3] -> offset R9>>4 + 2 MOVBQSX 2(DX)(R11*1), BX VCVTSI2SSQ BX, X10, X10 VBROADCASTSS X10, Y10 VMOVDQA X4, X7 VPMOVZXBD X7, Y4 VPSRLDQ $8, X4, X8 VPMOVZXBD X8, Y5 VCVTDQ2PS Y4, Y4 VCVTDQ2PS Y5, Y5 VSUBPS Y13, Y4, Y4 VSUBPS Y13, Y5, Y5 VMULPS Y10, Y4, Y4 VMULPS Y0, Y4, Y4 VMULPS Y10, Y5, Y5 VMULPS Y0, Y5, Y5 // Store to out + 128 + R9*4 LEAQ 128(R8)(R9*4), R12 VMOVUPS Y4, (R12) VMOVUPS Y5, 32(R12) // --- Q3 --- // (ql_c1 >> 4) | (((qh >> 4) & 3) << 4) VPSRLW $4, X2, X4 VPAND X15, X4, X4 VPSRLW $4, X1, X5 VPAND X14, X5, X5 VPSLLW $4, X5, X5 VPAND X6, X5, X5 VPOR X5, X4, X4 // Scale s[4] or s[5] MOVBQSX 4(DX)(R11*1), BX VCVTSI2SSQ BX, X10, X10 VBROADCASTSS X10, Y10 VMOVDQA X4, X7 VPMOVZXBD X7, Y4 VPSRLDQ $8, X4, X8 VPMOVZXBD X8, Y5 VCVTDQ2PS Y4, Y4 VCVTDQ2PS Y5, Y5 VSUBPS Y13, Y4, Y4 VSUBPS Y13, Y5, Y5 VMULPS Y10, Y4, Y4 VMULPS Y0, Y4, Y4 VMULPS Y10, Y5, Y5 VMULPS Y0, Y5, Y5 LEAQ 256(R8)(R9*4), R12 VMOVUPS Y4, (R12) VMOVUPS Y5, 32(R12) // --- Q4 --- // (ql_c2 >> 4) | (((qh >> 6) & 3) << 4) VPSRLW $4, X3, X4 VPAND X15, X4, X4 VPSRLW $6, X1, X5 VPAND X14, X5, X5 VPSLLW $4, X5, X5 VPAND X6, X5, X5 VPOR X5, X4, X4 // Scale s[6] or s[7] MOVBQSX 6(DX)(R11*1), BX VCVTSI2SSQ BX, X10, X10 VBROADCASTSS X10, Y10 VMOVDQA X4, X7 VPMOVZXBD X7, Y4 VPSRLDQ $8, X4, X8 VPMOVZXBD X8, Y5 VCVTDQ2PS Y4, Y4 VCVTDQ2PS Y5, Y5 VSUBPS Y13, Y4, Y4 VSUBPS Y13, Y5, Y5 VMULPS Y10, Y4, Y4 VMULPS Y0, Y4, Y4 VMULPS Y10, Y5, Y5 VMULPS Y0, Y5, Y5 LEAQ 384(R8)(R9*4), R12 VMOVUPS Y4, (R12) VMOVUPS Y5, 32(R12) ADDQ $16, R9 JMP loop_q6k done_q6k: VZEROUPPER RET // ============================================================================ // Q6_K Fused Dot Inner Loop - AVX2 // func dotQ6KInnerAVX2(ql *byte, qh *byte, scales *float32, x *float32) float32 // Processes 128 elements (one half of block), returns partial dot sum // scales: 8 precomputed float32 values (d*scale[0..7]) // ============================================================================ TEXT ·dotQ6KInnerAVX2(SB), NOSPLIT, $0-40 MOVQ ql+0(FP), DI // QL pointer (64 bytes) MOVQ qh+8(FP), SI // QH pointer (32 bytes) MOVQ scales+16(FP), DX // Precomputed scales (8 floats) MOVQ x+24(FP), R8 // X pointer (128 floats) // Y11 = 0x0F as dwords (for masking after VPMOVZXBD) MOVL $0x0F, AX MOVD AX, X11 VPBROADCASTD X11, Y11 // Y10 = 0x03 as dwords MOVL $0x03, AX MOVD AX, X10 VPBROADCASTD X10, Y10 // Y9 = 32.0 (float bias) MOVL $0x42000000, AX MOVD AX, X9 VBROADCASTSS X9, Y9 // Y8 = accumulator for dot product VXORPS Y8, Y8, Y8 // Process 8 elements at a time (4 iterations for 32 elements) // Each iteration: load 8 QL bytes, 8 QH bytes, compute Q1-Q4 for 8 elements MOVQ $0, R9 dotq6k_loop: CMPQ R9, $32 JGE dotq6k_done // R11 = R9 >> 4 (0 or 1, for scale indexing) MOVQ R9, R11 SHRQ $4, R11 // Load 8 bytes of QL (for Q1/Q3) and 8 bytes of QH VPMOVZXBD (DI)(R9*1), Y0 // QL[R9..R9+7] -> 8 dwords VPMOVZXBD 32(DI)(R9*1), Y1 // QL[32+R9..32+R9+7] -> 8 dwords (for Q2/Q4) VPMOVZXBD (SI)(R9*1), Y2 // QH[R9..R9+7] -> 8 dwords // --- Q1: (ql & 0xF) | ((qh & 3) << 4) --- VPAND Y11, Y0, Y3 // ql & 0x0F VPAND Y10, Y2, Y4 // qh & 0x03 VPSLLD $4, Y4, Y4 // << 4 VPOR Y4, Y3, Y3 // combine VCVTDQ2PS Y3, Y3 VSUBPS Y9, Y3, Y3 // q - 32 VBROADCASTSS (DX)(R11*4), Y4 // scale s[0] or s[1] VMULPS Y4, Y3, Y3 // * scale LEAQ (R8)(R9*4), R12 VMOVUPS (R12), Y5 // x[R9..R9+7] VFMADD231PS Y3, Y5, Y8 // acc += q * x // --- Q2: (ql32 & 0xF) | (((qh >> 2) & 3) << 4) --- VPAND Y11, Y1, Y3 // ql32 & 0x0F VPSRLD $2, Y2, Y4 // qh >> 2 VPAND Y10, Y4, Y4 // & 0x03 VPSLLD $4, Y4, Y4 // << 4 VPOR Y4, Y3, Y3 // combine VCVTDQ2PS Y3, Y3 VSUBPS Y9, Y3, Y3 // q - 32 VBROADCASTSS 8(DX)(R11*4), Y4 // scale s[2] or s[3] VMULPS Y4, Y3, Y3 // * scale VMOVUPS 128(R12), Y5 // x[32+R9..32+R9+7] VFMADD231PS Y3, Y5, Y8 // acc += q * x // --- Q3: (ql >> 4) | (((qh >> 4) & 3) << 4) --- VPSRLD $4, Y0, Y3 // ql >> 4 VPAND Y11, Y3, Y3 // & 0x0F VPSRLD $4, Y2, Y4 // qh >> 4 VPAND Y10, Y4, Y4 // & 0x03 VPSLLD $4, Y4, Y4 // << 4 VPOR Y4, Y3, Y3 // combine VCVTDQ2PS Y3, Y3 VSUBPS Y9, Y3, Y3 // q - 32 VBROADCASTSS 16(DX)(R11*4), Y4 // scale s[4] or s[5] VMULPS Y4, Y3, Y3 // * scale VMOVUPS 256(R12), Y5 // x[64+R9..64+R9+7] VFMADD231PS Y3, Y5, Y8 // acc += q * x // --- Q4: (ql32 >> 4) | (((qh >> 6) & 3) << 4) --- VPSRLD $4, Y1, Y3 // ql32 >> 4 VPAND Y11, Y3, Y3 // & 0x0F VPSRLD $6, Y2, Y4 // qh >> 6 VPAND Y10, Y4, Y4 // & 0x03 VPSLLD $4, Y4, Y4 // << 4 VPOR Y4, Y3, Y3 // combine VCVTDQ2PS Y3, Y3 VSUBPS Y9, Y3, Y3 // q - 32 VBROADCASTSS 24(DX)(R11*4), Y4 // scale s[6] or s[7] VMULPS Y4, Y3, Y3 // * scale VMOVUPS 384(R12), Y5 // x[96+R9..96+R9+7] VFMADD231PS Y3, Y5, Y8 // acc += q * x ADDQ $8, R9 JMP dotq6k_loop dotq6k_done: // Horizontal sum of Y8 VEXTRACTF128 $1, Y8, X0 VADDPS X8, X0, X0 VHADDPS X0, X0, X0 VHADDPS X0, X0, X0 VMOVSS X0, ret+32(FP) VZEROUPPER RET