Răsfoiți Sursa

ggml : quantization refactoring (#3833)

* ggml : factor all quantization code in ggml-quants

ggml-ci

* ggml-quants : fix Zig and Swift builds + quantize tool

ggml-ci

* quantize : --pure option for disabling k-quant mixtures

---------

Co-authored-by: cebtenzzre <cebtenzzre@gmail.com>
Georgi Gerganov 2 ani în urmă
părinte
comite
d69d777c02
11 a modificat fișierele cu 4073 adăugiri și 4084 ștergeri
  1. 4 8
      CMakeLists.txt
  2. 6 12
      Makefile
  3. 1 2
      Package.swift
  4. 8 13
      build.zig
  5. 5 4
      examples/quantize/quantize.cpp
  6. 3269 1041
      ggml-quants.c
  7. 86 17
      ggml-quants.h
  8. 672 2967
      ggml.c
  9. 7 0
      ggml.h
  10. 14 20
      llama.cpp
  11. 1 0
      llama.h

+ 4 - 8
CMakeLists.txt

@@ -94,7 +94,6 @@ option(LLAMA_CLBLAST                         "llama: use CLBlast"
 option(LLAMA_METAL                           "llama: use Metal"                                 ${LLAMA_METAL_DEFAULT})
 option(LLAMA_METAL_NDEBUG                    "llama: disable Metal debugging"                   OFF)
 option(LLAMA_MPI                             "llama: use MPI"                                   OFF)
-option(LLAMA_K_QUANTS                        "llama: use k-quants"                              ON)
 option(LLAMA_QKK_64                          "llama: use super-block size of 64 for k-quants"   OFF)
 
 option(LLAMA_BUILD_TESTS                "llama: build tests"    ${LLAMA_STANDALONE})
@@ -278,13 +277,8 @@ if (LLAMA_BLAS)
     endif()
 endif()
 
-if (LLAMA_K_QUANTS)
-    set(GGML_HEADERS_EXTRA k_quants.h)
-    set(GGML_SOURCES_EXTRA k_quants.c)
-    add_compile_definitions(GGML_USE_K_QUANTS)
-    if (LLAMA_QKK_64)
-        add_compile_definitions(GGML_QKK_64)
-    endif()
+if (LLAMA_QKK_64)
+    add_compile_definitions(GGML_QKK_64)
 endif()
 
 if (LLAMA_CUBLAS)
@@ -673,6 +667,8 @@ add_library(ggml OBJECT
             ggml-alloc.h
             ggml-backend.c
             ggml-backend.h
+            ggml-quants.c
+            ggml-quants.h
             ${GGML_SOURCES_CUDA} ${GGML_HEADERS_CUDA}
             ${GGML_SOURCES_OPENCL} ${GGML_HEADERS_OPENCL}
             ${GGML_SOURCES_METAL} ${GGML_HEADERS_METAL}

+ 6 - 12
Makefile

@@ -342,13 +342,9 @@ else
 	MK_CXXFLAGS += -march=rv64gcv -mabi=lp64d
 endif
 
-ifndef LLAMA_NO_K_QUANTS
-	MK_CPPFLAGS += -DGGML_USE_K_QUANTS
-	OBJS     += k_quants.o
 ifdef LLAMA_QKK_64
 	MK_CPPFLAGS += -DGGML_QKK_64
 endif
-endif
 
 ifndef LLAMA_NO_ACCELERATE
 	# Mac OS - include Accelerate framework.
@@ -365,7 +361,7 @@ ifdef LLAMA_MPI
 	MK_CPPFLAGS += -DGGML_USE_MPI
 	MK_CFLAGS   += -Wno-cast-qual
 	MK_CXXFLAGS += -Wno-cast-qual
-	OBJS     += ggml-mpi.o
+	OBJS        += ggml-mpi.o
 endif # LLAMA_MPI
 
 ifdef LLAMA_OPENBLAS
@@ -382,7 +378,7 @@ endif # LLAMA_BLIS
 ifdef LLAMA_CUBLAS
 	MK_CPPFLAGS  += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
 	MK_LDFLAGS   += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib
-	OBJS      += ggml-cuda.o
+	OBJS         += ggml-cuda.o
 	NVCCFLAGS = --forward-unknown-to-host-compiler -use_fast_math
 ifdef LLAMA_CUDA_NVCC
 	NVCC = $(LLAMA_CUDA_NVCC)
@@ -497,11 +493,6 @@ ggml-mpi.o: ggml-mpi.c ggml-mpi.h
 	$(CC) $(CFLAGS) -c $< -o $@
 endif # LLAMA_MPI
 
-ifndef LLAMA_NO_K_QUANTS
-k_quants.o: k_quants.c k_quants.h
-	$(CC) $(CFLAGS) -c $< -o $@
-endif # LLAMA_NO_K_QUANTS
-
 # combine build flags with cmdline overrides
 override CFLAGS        := $(MK_CPPFLAGS) $(CPPFLAGS) $(MK_CFLAGS) $(CFLAGS)
 override CXXFLAGS      := $(MK_CPPFLAGS) $(CPPFLAGS) $(MK_CXXFLAGS) $(CXXFLAGS)
@@ -542,7 +533,10 @@ ggml-alloc.o: ggml-alloc.c ggml.h ggml-alloc.h
 ggml-backend.o: ggml-backend.c ggml.h ggml-backend.h
 	$(CC)  $(CFLAGS)   -c $< -o $@
 
-OBJS += ggml-alloc.o ggml-backend.o
+ggml-quants.o: ggml-quants.c ggml.h ggml-quants.h
+	$(CC) $(CFLAGS)    -c $< -o $@
+
+OBJS += ggml-alloc.o ggml-backend.o ggml-quants.o
 
 llama.o: llama.cpp ggml.h ggml-alloc.h ggml-backend.h ggml-cuda.h ggml-metal.h llama.h
 	$(CXX) $(CXXFLAGS) -c $< -o $@

+ 1 - 2
Package.swift

@@ -42,13 +42,12 @@ let package = Package(
                 "llama.cpp",
                 "ggml-alloc.c",
                 "ggml-backend.c",
-                "k_quants.c",
+                "ggml-quants.c",
             ] + additionalSources,
             resources: resources,
             publicHeadersPath: "spm-headers",
             cSettings: [
                 .unsafeFlags(["-Wno-shorten-64-to-32", "-O3", "-DNDEBUG"]),
-                .define("GGML_USE_K_QUANTS"),
                 .define("GGML_USE_ACCELERATE")
                 // NOTE: NEW_LAPACK will required iOS version 16.4+
                 // We should consider add this in the future when we drop support for iOS 14

+ 8 - 13
build.zig

@@ -116,15 +116,10 @@ pub fn build(b: *std.build.Builder) !void {
     var make = try Maker.init(b);
     make.enable_lto = b.option(bool, "lto", "Enable LTO optimization, (default: false)") orelse false;
 
-    if (b.option(bool, "k-quants", "Enable K-quants, (default: true)") orelse true) {
-        try make.addFlag("-DGGML_USE_K_QUANTS");
-        const k_quants = make.obj("k_quants", "k_quants.c");
-        try make.objs.append(k_quants);
-    }
-
     const ggml = make.obj("ggml", "ggml.c");
     const ggml_alloc = make.obj("ggml-alloc", "ggml-alloc.c");
     const ggml_backend = make.obj("ggml-backend", "ggml-backend.c");
+    const ggml_quants = make.obj("ggml-quants", "ggml-quants.c");
     const llama = make.obj("llama", "llama.cpp");
     const common = make.obj("common", "common/common.cpp");
     const console = make.obj("console", "common/console.cpp");
@@ -133,14 +128,14 @@ pub fn build(b: *std.build.Builder) !void {
     const train = make.obj("train", "common/train.cpp");
     const clip = make.obj("clip", "examples/llava/clip.cpp");
 
-    _ = make.exe("main", "examples/main/main.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common, sampling, console, grammar_parser });
-    _ = make.exe("quantize", "examples/quantize/quantize.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common });
-    _ = make.exe("perplexity", "examples/perplexity/perplexity.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common });
-    _ = make.exe("embedding", "examples/embedding/embedding.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common });
-    _ = make.exe("finetune", "examples/finetune/finetune.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common, train });
-    _ = make.exe("train-text-from-scratch", "examples/train-text-from-scratch/train-text-from-scratch.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common, train });
+    _ = make.exe("main", "examples/main/main.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, common, sampling, console, grammar_parser });
+    _ = make.exe("quantize", "examples/quantize/quantize.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, common });
+    _ = make.exe("perplexity", "examples/perplexity/perplexity.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, common });
+    _ = make.exe("embedding", "examples/embedding/embedding.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, common });
+    _ = make.exe("finetune", "examples/finetune/finetune.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, common, train });
+    _ = make.exe("train-text-from-scratch", "examples/train-text-from-scratch/train-text-from-scratch.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, common, train });
 
-    const server = make.exe("server", "examples/server/server.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common, sampling, grammar_parser, clip });
+    const server = make.exe("server", "examples/server/server.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, common, sampling, grammar_parser, clip });
     if (server.target.isWindows()) {
         server.linkSystemLibrary("ws2_32");
     }

+ 5 - 4
examples/quantize/quantize.cpp

@@ -18,7 +18,6 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
     { "Q4_1",   LLAMA_FTYPE_MOSTLY_Q4_1,   " 3.90G, +0.1585 ppl @ LLaMA-v1-7B", },
     { "Q5_0",   LLAMA_FTYPE_MOSTLY_Q5_0,   " 4.33G, +0.0683 ppl @ LLaMA-v1-7B", },
     { "Q5_1",   LLAMA_FTYPE_MOSTLY_Q5_1,   " 4.70G, +0.0349 ppl @ LLaMA-v1-7B", },
-#ifdef GGML_USE_K_QUANTS
     { "Q2_K",   LLAMA_FTYPE_MOSTLY_Q2_K,   " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", },
     { "Q3_K",   LLAMA_FTYPE_MOSTLY_Q3_K_M, "alias for Q3_K_M" },
     { "Q3_K_S", LLAMA_FTYPE_MOSTLY_Q3_K_S, " 2.75G, +0.5551 ppl @ LLaMA-v1-7B", },
@@ -31,7 +30,6 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
     { "Q5_K_S", LLAMA_FTYPE_MOSTLY_Q5_K_S, " 4.33G, +0.0400 ppl @ LLaMA-v1-7B", },
     { "Q5_K_M", LLAMA_FTYPE_MOSTLY_Q5_K_M, " 4.45G, +0.0122 ppl @ LLaMA-v1-7B", },
     { "Q6_K",   LLAMA_FTYPE_MOSTLY_Q6_K,   " 5.15G, -0.0008 ppl @ LLaMA-v1-7B", },
-#endif
     { "Q8_0",   LLAMA_FTYPE_MOSTLY_Q8_0,   " 6.70G, +0.0004 ppl @ LLaMA-v1-7B", },
     { "F16",    LLAMA_FTYPE_MOSTLY_F16,    "13.00G              @ 7B", },
     { "F32",    LLAMA_FTYPE_ALL_F32,       "26.00G              @ 7B", },
@@ -70,13 +68,14 @@ static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftyp
 }
 
 // usage:
-//  ./quantize [--allow-requantize] [--leave-output-tensor] models/llama/ggml-model.gguf [models/llama/ggml-model-quant.gguf] type [nthreads]
+//  ./quantize [--allow-requantize] [--leave-output-tensor] [--pure] models/llama/ggml-model.gguf [models/llama/ggml-model-quant.gguf] type [nthreads]
 //
 [[noreturn]]
 static void usage(const char * executable) {
-    printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] model-f32.gguf [model-quant.gguf] type [nthreads]\n\n", executable);
+    printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] model-f32.gguf [model-quant.gguf] type [nthreads]\n\n", executable);
     printf("  --allow-requantize: Allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit\n");
     printf("  --leave-output-tensor: Will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing\n");
+    printf("  --pure: Disable k-quant mixtures and quantize all tensors to the same type\n");
     printf("\nAllowed quantization types:\n");
     for (auto & it : QUANT_OPTIONS) {
         if (it.name != "COPY") {
@@ -103,6 +102,8 @@ int main(int argc, char ** argv) {
             params.quantize_output_tensor = false;
         } else if (strcmp(argv[arg_idx], "--allow-requantize") == 0) {
             params.allow_requantize = true;
+        } else if (strcmp(argv[arg_idx], "--pure") == 0) {
+            params.pure = true;
         } else {
             usage(argv[0]);
         }

+ 3269 - 1041
k_quants.c → ggml-quants.c

@@ -1,9 +1,10 @@
-#include "k_quants.h"
+#include "ggml-quants.h"
 #include "ggml.h"
 
 #include <math.h>
 #include <string.h>
 #include <assert.h>
+#include <float.h>
 
 #ifdef __ARM_NEON
 
@@ -65,1251 +66,3478 @@ inline static int32_t vaddvq_s32(int32x4_t v) {
 
 #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
 
-//
-// 2-6 bit quantization in super-blocks
-//
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
+// multiply int8_t, add results pairwise twice
+static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
+    // Get absolute values of x vectors
+    const __m128i ax = _mm_sign_epi8(x, x);
+    // Sign the values of the y vectors
+    const __m128i sy = _mm_sign_epi8(y, x);
+    // Perform multiplication and create 16-bit values
+    const __m128i dot = _mm_maddubs_epi16(ax, sy);
+    const __m128i ones = _mm_set1_epi16(1);
+    return _mm_madd_epi16(ones, dot);
+}
 
-//
-// ===================== Helper functions
-//
-static inline int nearest_int(float fval) {
-    assert(fval <= 4194303.f);
-    float val = fval + 12582912.f;
-    int i; memcpy(&i, &val, sizeof(int));
-    return (i & 0x007fffff) - 0x00400000;
+#if __AVX__ || __AVX2__ || __AVX512F__
+// horizontally add 8 floats
+static inline float hsum_float_8(const __m256 x) {
+    __m128 res = _mm256_extractf128_ps(x, 1);
+    res = _mm_add_ps(res, _mm256_castps256_ps128(x));
+    res = _mm_add_ps(res, _mm_movehl_ps(res, res));
+    res = _mm_add_ss(res, _mm_movehdup_ps(res));
+    return _mm_cvtss_f32(res);
 }
 
-static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, int rmse_type) {
-    float max = 0;
-    float amax = 0;
-    for (int i = 0; i < n; ++i) {
-        float ax = fabsf(x[i]);
-        if (ax > amax) { amax = ax; max = x[i]; }
-    }
-    if (amax < 1e-30f) { // all zero
-        for (int i = 0; i < n; ++i) {
-            L[i] = 0;
+// horizontally add 8 int32_t
+static inline int hsum_i32_8(const __m256i a) {
+    const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
+    const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
+    const __m128i sum64 = _mm_add_epi32(hi64, sum128);
+    const __m128i hi32  = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
+    return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
+}
+
+// horizontally add 4 int32_t
+static inline int hsum_i32_4(const __m128i a) {
+    const __m128i hi64 = _mm_unpackhi_epi64(a, a);
+    const __m128i sum64 = _mm_add_epi32(hi64, a);
+    const __m128i hi32  = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
+    return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
+}
+
+#if defined(__AVX2__) || defined(__AVX512F__)
+// spread 32 bits to 32 bytes { 0x00, 0xFF }
+static inline __m256i bytes_from_bits_32(const uint8_t * x) {
+    uint32_t x32;
+    memcpy(&x32, x, sizeof(uint32_t));
+    const __m256i shuf_mask = _mm256_set_epi64x(
+            0x0303030303030303, 0x0202020202020202,
+            0x0101010101010101, 0x0000000000000000);
+    __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask);
+    const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);
+    bytes = _mm256_or_si256(bytes, bit_mask);
+    return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1));
+}
+
+// Unpack 32 4-bit fields into 32 bytes
+// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
+static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
+{
+    const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi);
+    const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp);
+    const __m256i lowMask = _mm256_set1_epi8( 0xF );
+    return _mm256_and_si256(lowMask, bytes);
+}
+
+// add int16_t pairwise and return as float vector
+static inline __m256 sum_i16_pairs_float(const __m256i x) {
+    const __m256i ones = _mm256_set1_epi16(1);
+    const __m256i summed_pairs = _mm256_madd_epi16(ones, x);
+    return _mm256_cvtepi32_ps(summed_pairs);
+}
+
+static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
+#if __AVXVNNI__
+    const __m256i zero = _mm256_setzero_si256();
+    const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
+    return _mm256_cvtepi32_ps(summed_pairs);
+#else
+    // Perform multiplication and create 16-bit values
+    const __m256i dot = _mm256_maddubs_epi16(ax, sy);
+    return sum_i16_pairs_float(dot);
+#endif
+}
+
+// multiply int8_t, add results pairwise twice and return as float vector
+static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
+#if __AVXVNNIINT8__
+    const __m256i zero = _mm256_setzero_si256();
+    const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y);
+    return _mm256_cvtepi32_ps(summed_pairs);
+#else
+    // Get absolute values of x vectors
+    const __m256i ax = _mm256_sign_epi8(x, x);
+    // Sign the values of the y vectors
+    const __m256i sy = _mm256_sign_epi8(y, x);
+    return mul_sum_us8_pairs_float(ax, sy);
+#endif
+}
+
+static inline __m128i packNibbles( __m256i bytes )
+{
+    // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
+#if __AVX512F__
+    const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4);   // 0000_0000_abcd_0000
+    bytes = _mm256_or_si256(bytes, bytes_srli_4);               // 0000_abcd_abcd_efgh
+    return _mm256_cvtepi16_epi8(bytes);                         // abcd_efgh
+#else
+    const __m256i lowByte = _mm256_set1_epi16( 0xFF );
+    __m256i high = _mm256_andnot_si256( lowByte, bytes );
+    __m256i low = _mm256_and_si256( lowByte, bytes );
+    high = _mm256_srli_epi16( high, 4 );
+    bytes = _mm256_or_si256( low, high );
+
+    // Compress uint16_t lanes into bytes
+    __m128i r0 = _mm256_castsi256_si128( bytes );
+    __m128i r1 = _mm256_extracti128_si256( bytes, 1 );
+    return _mm_packus_epi16( r0, r1 );
+#endif
+}
+#elif defined(__AVX__)
+// spread 32 bits to 32 bytes { 0x00, 0xFF }
+static inline __m256i bytes_from_bits_32(const uint8_t * x) {
+    uint32_t x32;
+    memcpy(&x32, x, sizeof(uint32_t));
+    const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
+    const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303, 0x0202020202020202);
+    __m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl);
+    __m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh);
+    const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfe);
+    bytesl = _mm_or_si128(bytesl, bit_mask);
+    bytesh = _mm_or_si128(bytesh, bit_mask);
+    bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1));
+    bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1));
+    return MM256_SET_M128I(bytesh, bytesl);
+}
+
+// Unpack 32 4-bit fields into 32 bytes
+// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
+static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
+{
+    // Load 16 bytes from memory
+    __m128i tmpl = _mm_loadu_si128((const __m128i *)rsi);
+    __m128i tmph = _mm_srli_epi16(tmpl, 4);
+    const __m128i lowMask = _mm_set1_epi8(0xF);
+    tmpl = _mm_and_si128(lowMask, tmpl);
+    tmph = _mm_and_si128(lowMask, tmph);
+    return MM256_SET_M128I(tmph, tmpl);
+}
+
+// add int16_t pairwise and return as float vector
+static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
+    const __m128i ones = _mm_set1_epi16(1);
+    const __m128i summed_pairsl = _mm_madd_epi16(ones, xl);
+    const __m128i summed_pairsh = _mm_madd_epi16(ones, xh);
+    const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl);
+    return _mm256_cvtepi32_ps(summed_pairs);
+}
+
+static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
+    const __m128i axl = _mm256_castsi256_si128(ax);
+    const __m128i axh = _mm256_extractf128_si256(ax, 1);
+    const __m128i syl = _mm256_castsi256_si128(sy);
+    const __m128i syh = _mm256_extractf128_si256(sy, 1);
+    // Perform multiplication and create 16-bit values
+    const __m128i dotl = _mm_maddubs_epi16(axl, syl);
+    const __m128i doth = _mm_maddubs_epi16(axh, syh);
+    return sum_i16_pairs_float(doth, dotl);
+}
+
+// multiply int8_t, add results pairwise twice and return as float vector
+static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
+    const __m128i xl = _mm256_castsi256_si128(x);
+    const __m128i xh = _mm256_extractf128_si256(x, 1);
+    const __m128i yl = _mm256_castsi256_si128(y);
+    const __m128i yh = _mm256_extractf128_si256(y, 1);
+    // Get absolute values of x vectors
+    const __m128i axl = _mm_sign_epi8(xl, xl);
+    const __m128i axh = _mm_sign_epi8(xh, xh);
+    // Sign the values of the y vectors
+    const __m128i syl = _mm_sign_epi8(yl, xl);
+    const __m128i syh = _mm_sign_epi8(yh, xh);
+    // Perform multiplication and create 16-bit values
+    const __m128i dotl = _mm_maddubs_epi16(axl, syl);
+    const __m128i doth = _mm_maddubs_epi16(axh, syh);
+    return sum_i16_pairs_float(doth, dotl);
+}
+
+static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
+{
+    // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
+    const __m128i lowByte = _mm_set1_epi16( 0xFF );
+    __m128i high = _mm_andnot_si128( lowByte, bytes1 );
+    __m128i low = _mm_and_si128( lowByte, bytes1 );
+    high = _mm_srli_epi16( high, 4 );
+    bytes1 = _mm_or_si128( low, high );
+    high = _mm_andnot_si128( lowByte, bytes2 );
+    low = _mm_and_si128( lowByte, bytes2 );
+    high = _mm_srli_epi16( high, 4 );
+    bytes2 = _mm_or_si128( low, high );
+
+    return _mm_packus_epi16( bytes1, bytes2);
+}
+#endif
+#elif defined(__SSSE3__)
+// horizontally add 4x4 floats
+static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) {
+    __m128 res_0 =_mm_hadd_ps(a, b);
+    __m128 res_1 =_mm_hadd_ps(c, d);
+    __m128 res =_mm_hadd_ps(res_0, res_1);
+    res =_mm_hadd_ps(res, res);
+    res =_mm_hadd_ps(res, res);
+
+    return _mm_cvtss_f32(res);
+}
+#endif // __AVX__ || __AVX2__ || __AVX512F__
+#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
+
+#if defined(__ARM_NEON)
+
+#if !defined(__aarch64__)
+
+inline static int32_t vaddvq_s32(int32x4_t v) {
+    return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
+}
+
+inline static float vaddvq_f32(float32x4_t v) {
+    return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
+}
+
+inline static float vmaxvq_f32(float32x4_t v) {
+    return
+        MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
+            MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
+}
+
+inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
+    int32x4_t res;
+
+    res[0] = roundf(vgetq_lane_f32(v, 0));
+    res[1] = roundf(vgetq_lane_f32(v, 1));
+    res[2] = roundf(vgetq_lane_f32(v, 2));
+    res[3] = roundf(vgetq_lane_f32(v, 3));
+
+    return res;
+}
+
+#endif
+#endif
+
+#if defined(__ARM_NEON) || defined(__wasm_simd128__)
+#define B1(c,s,n)  0x ## n ## c ,  0x ## n ## s
+#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
+#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)
+#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)
+#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)
+#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)
+#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)
+#define B8(c,s  ) B7(c,s,     c), B7(c,s,     s)
+
+// precomputed tables for expanding 8bits to 8 bytes:
+static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4
+static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
+#endif
+
+// reference implementation for deterministic creation of model files
+void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
+    static const int qk = QK4_0;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        float amax = 0.0f; // absolute max
+        float max  = 0.0f;
+
+        for (int j = 0; j < qk; j++) {
+            const float v = x[i*qk + j];
+            if (amax < fabsf(v)) {
+                amax = fabsf(v);
+                max  = v;
+            }
         }
-        return 0.f;
-    }
-    float iscale = -nmax / max;
-    if (rmse_type == 0) {
-        for (int i = 0; i < n; ++i) {
-            int l = nearest_int(iscale * x[i]);
-            L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
+
+        const float d  = max / -8;
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = ggml_fp32_to_fp16(d);
+
+        for (int j = 0; j < qk/2; ++j) {
+            const float x0 = x[i*qk + 0    + j]*id;
+            const float x1 = x[i*qk + qk/2 + j]*id;
+
+            const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
+            const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
+
+            y[i].qs[j]  = xi0;
+            y[i].qs[j] |= xi1 << 4;
         }
-        return 1/iscale;
-    }
-    bool return_early = false;
-    if (rmse_type < 0) {
-        rmse_type = -rmse_type;
-        return_early = true;
-    }
-    int weight_type = rmse_type%2;
-    float sumlx = 0;
-    float suml2 = 0;
-    for (int i = 0; i < n; ++i) {
-        int l = nearest_int(iscale * x[i]);
-        l = MAX(-nmax, MIN(nmax-1, l));
-        L[i] = l + nmax;
-        float w = weight_type == 1 ? x[i] * x[i] : 1;
-        sumlx += w*x[i]*l;
-        suml2 += w*l*l;
     }
-    float scale = sumlx/suml2;
-    if (return_early) return suml2 > 0 ? 0.5f*(scale + 1/iscale) : 1/iscale;
-    float best = scale * sumlx;
-    for (int is = -9; is <= 9; ++is) {
-        if (is == 0) {
-            continue;
+}
+
+void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
+    quantize_row_q4_0_reference(x, y, k);
+}
+
+void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k) {
+    const int qk = QK4_1;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        float min = FLT_MAX;
+        float max = -FLT_MAX;
+
+        for (int j = 0; j < qk; j++) {
+            const float v = x[i*qk + j];
+
+            if (v < min) min = v;
+            if (v > max) max = v;
         }
-        iscale = -(nmax + 0.1f*is) / max;
-        sumlx = suml2 = 0;
-        for (int i = 0; i < n; ++i) {
-            int l = nearest_int(iscale * x[i]);
-            l = MAX(-nmax, MIN(nmax-1, l));
-            float w = weight_type == 1 ? x[i] * x[i] : 1;
-            sumlx += w*x[i]*l;
-            suml2 += w*l*l;
+
+        const float d  = (max - min) / ((1 << 4) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = ggml_fp32_to_fp16(d);
+        y[i].m = ggml_fp32_to_fp16(min);
+
+        for (int j = 0; j < qk/2; ++j) {
+            const float x0 = (x[i*qk + 0    + j] - min)*id;
+            const float x1 = (x[i*qk + qk/2 + j] - min)*id;
+
+            const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
+            const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
+
+            y[i].qs[j]  = xi0;
+            y[i].qs[j] |= xi1 << 4;
         }
-        if (suml2 > 0 && sumlx*sumlx > best*suml2) {
-            for (int i = 0; i < n; ++i) {
-                int l = nearest_int(iscale * x[i]);
-                L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
+    }
+}
+
+void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
+    quantize_row_q4_1_reference(x, y, k);
+}
+
+void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) {
+    static const int qk = QK5_0;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        float amax = 0.0f; // absolute max
+        float max  = 0.0f;
+
+        for (int j = 0; j < qk; j++) {
+            const float v = x[i*qk + j];
+            if (amax < fabsf(v)) {
+                amax = fabsf(v);
+                max  = v;
             }
-            scale = sumlx/suml2; best = scale*sumlx;
         }
+
+        const float d  = max / -16;
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = ggml_fp32_to_fp16(d);
+
+        uint32_t qh = 0;
+
+        for (int j = 0; j < qk/2; ++j) {
+            const float x0 = x[i*qk + 0    + j]*id;
+            const float x1 = x[i*qk + qk/2 + j]*id;
+
+            const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
+            const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
+
+            y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
+
+            // get the 5-th bit and store it in qh at the right position
+            qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+            qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2);
+        }
+
+        memcpy(&y[i].qh, &qh, sizeof(qh));
     }
-    return scale;
 }
 
-static float make_q3_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, bool do_rmse) {
-    float max = 0;
-    float amax = 0;
-    for (int i = 0; i < n; ++i) {
-        float ax = fabsf(x[i]);
-        if (ax > amax) { amax = ax; max = x[i]; }
+void quantize_row_q5_0(const float * restrict x, void * restrict y, int k) {
+    quantize_row_q5_0_reference(x, y, k);
+}
+
+void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) {
+    const int qk = QK5_1;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        float min = FLT_MAX;
+        float max = -FLT_MAX;
+
+        for (int j = 0; j < qk; j++) {
+            const float v = x[i*qk + j];
+
+            if (v < min) min = v;
+            if (v > max) max = v;
+        }
+
+        const float d  = (max - min) / ((1 << 5) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = ggml_fp32_to_fp16(d);
+        y[i].m = ggml_fp32_to_fp16(min);
+
+        uint32_t qh = 0;
+
+        for (int j = 0; j < qk/2; ++j) {
+            const float x0 = (x[i*qk + 0    + j] - min)*id;
+            const float x1 = (x[i*qk + qk/2 + j] - min)*id;
+
+            const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
+            const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
+
+            y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
+
+            // get the 5-th bit and store it in qh at the right position
+            qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+            qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2);
+        }
+
+        memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
     }
-    if (!amax) { // all zero
-        for (int i = 0; i < n; ++i) { L[i] = 0; }
-        return 0.f;
+}
+
+void quantize_row_q5_1(const float * restrict x, void * restrict y, int k) {
+    quantize_row_q5_1_reference(x, y, k);
+}
+
+// reference implementation for deterministic creation of model files
+void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) {
+    assert(k % QK8_0 == 0);
+    const int nb = k / QK8_0;
+
+    for (int i = 0; i < nb; i++) {
+        float amax = 0.0f; // absolute max
+
+        for (int j = 0; j < QK8_0; j++) {
+            const float v = x[i*QK8_0 + j];
+            amax = MAX(amax, fabsf(v));
+        }
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = ggml_fp32_to_fp16(d);
+
+        for (int j = 0; j < QK8_0; ++j) {
+            const float x0 = x[i*QK8_0 + j]*id;
+
+            y[i].qs[j] = roundf(x0);
+        }
     }
-    float iscale = -nmax / max;
-    if (do_rmse) {
-        float sumlx = 0;
-        float suml2 = 0;
-        for (int i = 0; i < n; ++i) {
-            int l = nearest_int(iscale * x[i]);
-            l = MAX(-nmax, MIN(nmax-1, l));
-            L[i] = l;
-            float w = x[i]*x[i];
-            sumlx += w*x[i]*l;
-            suml2 += w*l*l;
+}
+
+void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
+    assert(QK8_0 == 32);
+    assert(k % QK8_0 == 0);
+    const int nb = k / QK8_0;
+
+    block_q8_0 * restrict y = vy;
+
+#if defined(__ARM_NEON)
+    for (int i = 0; i < nb; i++) {
+        float32x4_t srcv [8];
+        float32x4_t asrcv[8];
+        float32x4_t amaxv[8];
+
+        for (int j = 0; j < 8; j++) srcv[j]  = vld1q_f32(x + i*32 + 4*j);
+        for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
+
+        for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
+        for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
+        for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
+
+        const float amax = vmaxvq_f32(amaxv[0]);
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = ggml_fp32_to_fp16(d);
+
+        for (int j = 0; j < 8; j++) {
+            const float32x4_t v  = vmulq_n_f32(srcv[j], id);
+            const int32x4_t   vi = vcvtnq_s32_f32(v);
+
+            y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
+            y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
+            y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
+            y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
         }
-        for (int itry = 0; itry < 5; ++itry) {
-            int n_changed = 0;
-            for (int i = 0; i < n; ++i) {
-                float w = x[i]*x[i];
-                float slx = sumlx - w*x[i]*L[i];
-                if (slx > 0) {
-                    float sl2 = suml2 - w*L[i]*L[i];
-                    int new_l = nearest_int(x[i] * sl2 / slx);
-                    new_l = MAX(-nmax, MIN(nmax-1, new_l));
-                    if (new_l != L[i]) {
-                        slx += w*x[i]*new_l;
-                        sl2 += w*new_l*new_l;
-                        if (sl2 > 0 && slx*slx*suml2 > sumlx*sumlx*sl2) {
-                            L[i] = new_l; sumlx = slx; suml2 = sl2;
-                            ++n_changed;
-                        }
-                    }
-                }
+    }
+#elif defined(__wasm_simd128__)
+    for (int i = 0; i < nb; i++) {
+        v128_t srcv [8];
+        v128_t asrcv[8];
+        v128_t amaxv[8];
+
+        for (int j = 0; j < 8; j++) srcv[j]  = wasm_v128_load(x + i*32 + 4*j);
+        for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]);
+
+        for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]);
+        for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]);
+        for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]);
+
+        const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0),
+                                   wasm_f32x4_extract_lane(amaxv[0], 1)),
+                               MAX(wasm_f32x4_extract_lane(amaxv[0], 2),
+                                   wasm_f32x4_extract_lane(amaxv[0], 3)));
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = ggml_fp32_to_fp16(d);
+
+        for (int j = 0; j < 8; j++) {
+            const v128_t v  = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id));
+            const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v);
+
+            y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0);
+            y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1);
+            y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2);
+            y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3);
+        }
+    }
+#elif defined(__AVX2__) || defined(__AVX__)
+    for (int i = 0; i < nb; i++) {
+        // Load elements into 4 AVX vectors
+        __m256 v0 = _mm256_loadu_ps( x );
+        __m256 v1 = _mm256_loadu_ps( x + 8 );
+        __m256 v2 = _mm256_loadu_ps( x + 16 );
+        __m256 v3 = _mm256_loadu_ps( x + 24 );
+        x += 32;
+
+        // Compute max(abs(e)) for the block
+        const __m256 signBit = _mm256_set1_ps( -0.0f );
+        __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
+
+        __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
+        max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
+        max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
+        const float maxScalar = _mm_cvtss_f32( max4 );
+
+        // Quantize these floats
+        const float d = maxScalar / 127.f;
+        y[i].d = ggml_fp32_to_fp16(d);
+        const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
+        const __m256 mul = _mm256_set1_ps( id );
+
+        // Apply the multiplier
+        v0 = _mm256_mul_ps( v0, mul );
+        v1 = _mm256_mul_ps( v1, mul );
+        v2 = _mm256_mul_ps( v2, mul );
+        v3 = _mm256_mul_ps( v3, mul );
+
+        // Round to nearest integer
+        v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
+        v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
+        v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
+        v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
+
+        // Convert floats to integers
+        __m256i i0 = _mm256_cvtps_epi32( v0 );
+        __m256i i1 = _mm256_cvtps_epi32( v1 );
+        __m256i i2 = _mm256_cvtps_epi32( v2 );
+        __m256i i3 = _mm256_cvtps_epi32( v3 );
+
+#if defined(__AVX2__)
+        // Convert int32 to int16
+        i0 = _mm256_packs_epi32( i0, i1 );	// 0, 1, 2, 3,  8, 9, 10, 11,  4, 5, 6, 7, 12, 13, 14, 15
+        i2 = _mm256_packs_epi32( i2, i3 );	// 16, 17, 18, 19,  24, 25, 26, 27,  20, 21, 22, 23, 28, 29, 30, 31
+                                            // Convert int16 to int8
+        i0 = _mm256_packs_epi16( i0, i2 );	// 0, 1, 2, 3,  8, 9, 10, 11,  16, 17, 18, 19,  24, 25, 26, 27,  4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
+
+        // We got our precious signed bytes, but the order is now wrong
+        // These AVX2 pack instructions process 16-byte pieces independently
+        // The following instruction is fixing the order
+        const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
+        i0 = _mm256_permutevar8x32_epi32( i0, perm );
+
+        _mm256_storeu_si256((__m256i *)y[i].qs, i0);
+#else
+        // Since we don't have in AVX some necessary functions,
+        // we split the registers in half and call AVX2 analogs from SSE
+        __m128i ni0 = _mm256_castsi256_si128( i0 );
+        __m128i ni1 = _mm256_extractf128_si256( i0, 1);
+        __m128i ni2 = _mm256_castsi256_si128( i1 );
+        __m128i ni3 = _mm256_extractf128_si256( i1, 1);
+        __m128i ni4 = _mm256_castsi256_si128( i2 );
+        __m128i ni5 = _mm256_extractf128_si256( i2, 1);
+        __m128i ni6 = _mm256_castsi256_si128( i3 );
+        __m128i ni7 = _mm256_extractf128_si256( i3, 1);
+
+        // Convert int32 to int16
+        ni0 = _mm_packs_epi32( ni0, ni1 );
+        ni2 = _mm_packs_epi32( ni2, ni3 );
+        ni4 = _mm_packs_epi32( ni4, ni5 );
+        ni6 = _mm_packs_epi32( ni6, ni7 );
+        // Convert int16 to int8
+        ni0 = _mm_packs_epi16( ni0, ni2 );
+        ni4 = _mm_packs_epi16( ni4, ni6 );
+
+        _mm_storeu_si128((__m128i *)(y[i].qs +  0), ni0);
+        _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
+#endif
+    }
+#elif defined(__riscv_v_intrinsic)
+
+    size_t vl = __riscv_vsetvl_e32m4(QK8_0);
+
+    for (int i = 0; i < nb; i++) {
+        // load elements
+        vfloat32m4_t v_x   = __riscv_vle32_v_f32m4(x+i*QK8_0, vl);
+
+        vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
+        vfloat32m1_t tmp   = __riscv_vfmv_v_f_f32m1(0.0f, vl);
+        vfloat32m1_t vmax  = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
+        float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = ggml_fp32_to_fp16(d);
+
+        vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
+
+        // convert to integer
+        vint16m2_t   vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
+        vint8m1_t    vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
+
+        // store result
+        __riscv_vse8_v_i8m1(y[i].qs , vs, vl);
+    }
+#else
+    // scalar
+    quantize_row_q8_0_reference(x, y, k);
+#endif
+}
+
+// reference implementation for deterministic creation of model files
+void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k) {
+    assert(QK8_1 == 32);
+    assert(k % QK8_1 == 0);
+    const int nb = k / QK8_1;
+
+    for (int i = 0; i < nb; i++) {
+        float amax = 0.0f; // absolute max
+
+        for (int j = 0; j < QK8_1; j++) {
+            const float v = x[i*QK8_1 + j];
+            amax = MAX(amax, fabsf(v));
+        }
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = d;
+
+        int sum = 0;
+
+        for (int j = 0; j < QK8_1/2; ++j) {
+            const float v0 = x[i*QK8_1           + j]*id;
+            const float v1 = x[i*QK8_1 + QK8_1/2 + j]*id;
+
+            y[i].qs[          j] = roundf(v0);
+            y[i].qs[QK8_1/2 + j] = roundf(v1);
+
+            sum += y[i].qs[          j];
+            sum += y[i].qs[QK8_1/2 + j];
+        }
+
+        y[i].s = sum*d;
+    }
+}
+
+void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) {
+    assert(k % QK8_1 == 0);
+    const int nb = k / QK8_1;
+
+    block_q8_1 * restrict y = vy;
+
+#if defined(__ARM_NEON)
+    for (int i = 0; i < nb; i++) {
+        float32x4_t srcv [8];
+        float32x4_t asrcv[8];
+        float32x4_t amaxv[8];
+
+        for (int j = 0; j < 8; j++) srcv[j]  = vld1q_f32(x + i*32 + 4*j);
+        for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
+
+        for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
+        for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
+        for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
+
+        const float amax = vmaxvq_f32(amaxv[0]);
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = d;
+
+        int32x4_t accv = vdupq_n_s32(0);
+
+        for (int j = 0; j < 8; j++) {
+            const float32x4_t v  = vmulq_n_f32(srcv[j], id);
+            const int32x4_t   vi = vcvtnq_s32_f32(v);
+
+            y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
+            y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
+            y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
+            y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
+
+            accv = vaddq_s32(accv, vi);
+        }
+
+        y[i].s = d * vaddvq_s32(accv);
+    }
+#elif defined(__wasm_simd128__)
+    for (int i = 0; i < nb; i++) {
+        v128_t srcv [8];
+        v128_t asrcv[8];
+        v128_t amaxv[8];
+
+        for (int j = 0; j < 8; j++) srcv[j]  = wasm_v128_load(x + i*32 + 4*j);
+        for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]);
+
+        for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]);
+        for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]);
+        for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]);
+
+        const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0),
+                                   wasm_f32x4_extract_lane(amaxv[0], 1)),
+                               MAX(wasm_f32x4_extract_lane(amaxv[0], 2),
+                                   wasm_f32x4_extract_lane(amaxv[0], 3)));
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = d;
+
+        v128_t accv = wasm_i32x4_splat(0);
+
+        for (int j = 0; j < 8; j++) {
+            const v128_t v  = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id));
+            const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v);
+
+            y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0);
+            y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1);
+            y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2);
+            y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3);
+
+            accv = wasm_i32x4_add(accv, vi);
+        }
+
+        y[i].s = d * (wasm_i32x4_extract_lane(accv, 0) +
+                      wasm_i32x4_extract_lane(accv, 1) +
+                      wasm_i32x4_extract_lane(accv, 2) +
+                      wasm_i32x4_extract_lane(accv, 3));
+    }
+#elif defined(__AVX2__) || defined(__AVX__)
+    for (int i = 0; i < nb; i++) {
+        // Load elements into 4 AVX vectors
+        __m256 v0 = _mm256_loadu_ps( x );
+        __m256 v1 = _mm256_loadu_ps( x + 8 );
+        __m256 v2 = _mm256_loadu_ps( x + 16 );
+        __m256 v3 = _mm256_loadu_ps( x + 24 );
+        x += 32;
+
+        // Compute max(abs(e)) for the block
+        const __m256 signBit = _mm256_set1_ps( -0.0f );
+        __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
+
+        __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
+        max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
+        max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
+        const float maxScalar = _mm_cvtss_f32( max4 );
+
+        // Quantize these floats
+        const float d = maxScalar / 127.f;
+        y[i].d = d;
+        const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
+        const __m256 mul = _mm256_set1_ps( id );
+
+        // Apply the multiplier
+        v0 = _mm256_mul_ps( v0, mul );
+        v1 = _mm256_mul_ps( v1, mul );
+        v2 = _mm256_mul_ps( v2, mul );
+        v3 = _mm256_mul_ps( v3, mul );
+
+        // Round to nearest integer
+        v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
+        v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
+        v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
+        v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
+
+        // Convert floats to integers
+        __m256i i0 = _mm256_cvtps_epi32( v0 );
+        __m256i i1 = _mm256_cvtps_epi32( v1 );
+        __m256i i2 = _mm256_cvtps_epi32( v2 );
+        __m256i i3 = _mm256_cvtps_epi32( v3 );
+
+#if defined(__AVX2__)
+        // Compute the sum of the quants and set y[i].s
+        y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
+
+        // Convert int32 to int16
+        i0 = _mm256_packs_epi32( i0, i1 );	// 0, 1, 2, 3,  8, 9, 10, 11,  4, 5, 6, 7, 12, 13, 14, 15
+        i2 = _mm256_packs_epi32( i2, i3 );	// 16, 17, 18, 19,  24, 25, 26, 27,  20, 21, 22, 23, 28, 29, 30, 31
+                                            // Convert int16 to int8
+        i0 = _mm256_packs_epi16( i0, i2 );	// 0, 1, 2, 3,  8, 9, 10, 11,  16, 17, 18, 19,  24, 25, 26, 27,  4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
+
+        // We got our precious signed bytes, but the order is now wrong
+        // These AVX2 pack instructions process 16-byte pieces independently
+        // The following instruction is fixing the order
+        const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
+        i0 = _mm256_permutevar8x32_epi32( i0, perm );
+
+        _mm256_storeu_si256((__m256i *)y[i].qs, i0);
+#else
+        // Since we don't have in AVX some necessary functions,
+        // we split the registers in half and call AVX2 analogs from SSE
+        __m128i ni0 = _mm256_castsi256_si128( i0 );
+        __m128i ni1 = _mm256_extractf128_si256( i0, 1);
+        __m128i ni2 = _mm256_castsi256_si128( i1 );
+        __m128i ni3 = _mm256_extractf128_si256( i1, 1);
+        __m128i ni4 = _mm256_castsi256_si128( i2 );
+        __m128i ni5 = _mm256_extractf128_si256( i2, 1);
+        __m128i ni6 = _mm256_castsi256_si128( i3 );
+        __m128i ni7 = _mm256_extractf128_si256( i3, 1);
+
+        // Compute the sum of the quants and set y[i].s
+        const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3));
+        const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7));
+        y[i].s = d * hsum_i32_4(_mm_add_epi32(s0, s1));
+
+        // Convert int32 to int16
+        ni0 = _mm_packs_epi32( ni0, ni1 );
+        ni2 = _mm_packs_epi32( ni2, ni3 );
+        ni4 = _mm_packs_epi32( ni4, ni5 );
+        ni6 = _mm_packs_epi32( ni6, ni7 );
+        // Convert int16 to int8
+        ni0 = _mm_packs_epi16( ni0, ni2 );
+        ni4 = _mm_packs_epi16( ni4, ni6 );
+
+        _mm_storeu_si128((__m128i *)(y[i].qs +  0), ni0);
+        _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
+#endif
+    }
+#elif defined(__riscv_v_intrinsic)
+
+    size_t vl = __riscv_vsetvl_e32m4(QK8_1);
+
+    for (int i = 0; i < nb; i++) {
+        // load elements
+        vfloat32m4_t v_x   = __riscv_vle32_v_f32m4(x+i*QK8_1, vl);
+
+        vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
+        vfloat32m1_t tmp   = __riscv_vfmv_v_f_f32m1(0.0, vl);
+        vfloat32m1_t vmax  = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
+        float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
+
+        const float d  = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = d;
+
+        vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
+
+        // convert to integer
+        vint16m2_t   vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
+        vint8m1_t    vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
+
+        // store result
+        __riscv_vse8_v_i8m1(y[i].qs , vs, vl);
+
+        // compute sum for y[i].s
+        vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl);
+        vint16m1_t vwrs = __riscv_vwredsum_vs_i8m1_i16m1(vs, tmp2, vl);
+
+        // set y[i].s
+        int sum = __riscv_vmv_x_s_i16m1_i16(vwrs);
+        y[i].s = sum*d;
+    }
+#else
+    // scalar
+    quantize_row_q8_1_reference(x, y, k);
+#endif
+}
+
+void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k) {
+    static const int qk = QK4_0;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        const float d = ggml_fp16_to_fp32(x[i].d);
+
+        for (int j = 0; j < qk/2; ++j) {
+            const int x0 = (x[i].qs[j] & 0x0F) - 8;
+            const int x1 = (x[i].qs[j] >>   4) - 8;
+
+            y[i*qk + j + 0   ] = x0*d;
+            y[i*qk + j + qk/2] = x1*d;
+        }
+    }
+}
+
+void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int k) {
+    static const int qk = QK4_1;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        const float d = ggml_fp16_to_fp32(x[i].d);
+        const float m = ggml_fp16_to_fp32(x[i].m);
+
+        for (int j = 0; j < qk/2; ++j) {
+            const int x0 = (x[i].qs[j] & 0x0F);
+            const int x1 = (x[i].qs[j] >>   4);
+
+            y[i*qk + j + 0   ] = x0*d + m;
+            y[i*qk + j + qk/2] = x1*d + m;
+        }
+    }
+}
+
+void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int k) {
+    static const int qk = QK5_0;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        const float d = ggml_fp16_to_fp32(x[i].d);
+
+        uint32_t qh;
+        memcpy(&qh, x[i].qh, sizeof(qh));
+
+        for (int j = 0; j < qk/2; ++j) {
+            const uint8_t xh_0 = ((qh >> (j +  0)) << 4) & 0x10;
+            const uint8_t xh_1 = ((qh >> (j + 12))     ) & 0x10;
+
+            const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16;
+            const int32_t x1 = ((x[i].qs[j] >>   4) | xh_1) - 16;
+
+            y[i*qk + j + 0   ] = x0*d;
+            y[i*qk + j + qk/2] = x1*d;
+        }
+    }
+}
+
+void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int k) {
+    static const int qk = QK5_1;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        const float d = ggml_fp16_to_fp32(x[i].d);
+        const float m = ggml_fp16_to_fp32(x[i].m);
+
+        uint32_t qh;
+        memcpy(&qh, x[i].qh, sizeof(qh));
+
+        for (int j = 0; j < qk/2; ++j) {
+            const uint8_t xh_0 = ((qh >> (j +  0)) << 4) & 0x10;
+            const uint8_t xh_1 = ((qh >> (j + 12))     ) & 0x10;
+
+            const int x0 = (x[i].qs[j] & 0x0F) | xh_0;
+            const int x1 = (x[i].qs[j] >>   4) | xh_1;
+
+            y[i*qk + j + 0   ] = x0*d + m;
+            y[i*qk + j + qk/2] = x1*d + m;
+        }
+    }
+}
+
+void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int k) {
+    static const int qk = QK8_0;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        const float d = ggml_fp16_to_fp32(x[i].d);
+
+        for (int j = 0; j < qk; ++j) {
+            y[i*qk + j] = x[i].qs[j]*d;
+        }
+    }
+}
+
+//
+// 2-6 bit quantization in super-blocks
+//
+
+//
+// ===================== Helper functions
+//
+static inline int nearest_int(float fval) {
+    assert(fval <= 4194303.f);
+    float val = fval + 12582912.f;
+    int i; memcpy(&i, &val, sizeof(int));
+    return (i & 0x007fffff) - 0x00400000;
+}
+
+static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, int rmse_type) {
+    float max = 0;
+    float amax = 0;
+    for (int i = 0; i < n; ++i) {
+        float ax = fabsf(x[i]);
+        if (ax > amax) { amax = ax; max = x[i]; }
+    }
+    if (amax < 1e-30f) { // all zero
+        for (int i = 0; i < n; ++i) {
+            L[i] = 0;
+        }
+        return 0.f;
+    }
+    float iscale = -nmax / max;
+    if (rmse_type == 0) {
+        for (int i = 0; i < n; ++i) {
+            int l = nearest_int(iscale * x[i]);
+            L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
+        }
+        return 1/iscale;
+    }
+    bool return_early = false;
+    if (rmse_type < 0) {
+        rmse_type = -rmse_type;
+        return_early = true;
+    }
+    int weight_type = rmse_type%2;
+    float sumlx = 0;
+    float suml2 = 0;
+    for (int i = 0; i < n; ++i) {
+        int l = nearest_int(iscale * x[i]);
+        l = MAX(-nmax, MIN(nmax-1, l));
+        L[i] = l + nmax;
+        float w = weight_type == 1 ? x[i] * x[i] : 1;
+        sumlx += w*x[i]*l;
+        suml2 += w*l*l;
+    }
+    float scale = sumlx/suml2;
+    if (return_early) return suml2 > 0 ? 0.5f*(scale + 1/iscale) : 1/iscale;
+    float best = scale * sumlx;
+    for (int is = -9; is <= 9; ++is) {
+        if (is == 0) {
+            continue;
+        }
+        iscale = -(nmax + 0.1f*is) / max;
+        sumlx = suml2 = 0;
+        for (int i = 0; i < n; ++i) {
+            int l = nearest_int(iscale * x[i]);
+            l = MAX(-nmax, MIN(nmax-1, l));
+            float w = weight_type == 1 ? x[i] * x[i] : 1;
+            sumlx += w*x[i]*l;
+            suml2 += w*l*l;
+        }
+        if (suml2 > 0 && sumlx*sumlx > best*suml2) {
+            for (int i = 0; i < n; ++i) {
+                int l = nearest_int(iscale * x[i]);
+                L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
+            }
+            scale = sumlx/suml2; best = scale*sumlx;
+        }
+    }
+    return scale;
+}
+
+static float make_q3_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, bool do_rmse) {
+    float max = 0;
+    float amax = 0;
+    for (int i = 0; i < n; ++i) {
+        float ax = fabsf(x[i]);
+        if (ax > amax) { amax = ax; max = x[i]; }
+    }
+    if (!amax) { // all zero
+        for (int i = 0; i < n; ++i) { L[i] = 0; }
+        return 0.f;
+    }
+    float iscale = -nmax / max;
+    if (do_rmse) {
+        float sumlx = 0;
+        float suml2 = 0;
+        for (int i = 0; i < n; ++i) {
+            int l = nearest_int(iscale * x[i]);
+            l = MAX(-nmax, MIN(nmax-1, l));
+            L[i] = l;
+            float w = x[i]*x[i];
+            sumlx += w*x[i]*l;
+            suml2 += w*l*l;
+        }
+        for (int itry = 0; itry < 5; ++itry) {
+            int n_changed = 0;
+            for (int i = 0; i < n; ++i) {
+                float w = x[i]*x[i];
+                float slx = sumlx - w*x[i]*L[i];
+                if (slx > 0) {
+                    float sl2 = suml2 - w*L[i]*L[i];
+                    int new_l = nearest_int(x[i] * sl2 / slx);
+                    new_l = MAX(-nmax, MIN(nmax-1, new_l));
+                    if (new_l != L[i]) {
+                        slx += w*x[i]*new_l;
+                        sl2 += w*new_l*new_l;
+                        if (sl2 > 0 && slx*slx*suml2 > sumlx*sumlx*sl2) {
+                            L[i] = new_l; sumlx = slx; suml2 = sl2;
+                            ++n_changed;
+                        }
+                    }
+                }
+            }
+            if (!n_changed) {
+                break;
+            }
+        }
+        for (int i = 0; i < n; ++i) {
+            L[i] += nmax;
+        }
+        return sumlx / suml2;
+    }
+    for (int i = 0; i < n; ++i) {
+        int l = nearest_int(iscale * x[i]);
+        l = MAX(-nmax, MIN(nmax-1, l));
+        L[i] = l + nmax;
+    }
+    return 1/iscale;
+}
+
+static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, float * restrict the_min,
+        int ntry, float alpha) {
+    float min = x[0];
+    float max = x[0];
+    for (int i = 1; i < n; ++i) {
+        if (x[i] < min) min = x[i];
+        if (x[i] > max) max = x[i];
+    }
+    if (max == min) {
+        for (int i = 0; i < n; ++i) L[i] = 0;
+        *the_min = 0;
+        return 0.f;
+    }
+    if (min > 0) min = 0;
+    float iscale = nmax/(max - min);
+    float scale = 1/iscale;
+    for (int itry = 0; itry < ntry; ++itry) {
+        float sumlx = 0; int suml2 = 0;
+        bool did_change = false;
+        for (int i = 0; i < n; ++i) {
+            int l = nearest_int(iscale*(x[i] - min));
+            l = MAX(0, MIN(nmax, l));
+            if (l != L[i]) {
+                L[i] = l;
+                did_change = true;
+            }
+            sumlx += (x[i] - min)*l;
+            suml2 += l*l;
+        }
+        scale = sumlx/suml2;
+        float sum = 0;
+        for (int i = 0; i < n; ++i) {
+            sum += x[i] - scale*L[i];
+        }
+        min = alpha*min + (1 - alpha)*sum/n;
+        if (min > 0) min = 0;
+        iscale = 1/scale;
+        if (!did_change) break;
+    }
+    *the_min = -min;
+    return scale;
+}
+
+static float make_qkx2_quants(int n, int nmax, const float * restrict x, const float * restrict weights,
+        uint8_t * restrict L, float * restrict the_min, uint8_t * restrict Laux,
+        float rmin, float rdelta, int nstep, bool use_mad) {
+    float min = x[0];
+    float max = x[0];
+    float sum_w = weights[0];
+    float sum_x = sum_w * x[0];
+    for (int i = 1; i < n; ++i) {
+        if (x[i] < min) min = x[i];
+        if (x[i] > max) max = x[i];
+        float w = weights[i];
+        sum_w += w;
+        sum_x += w * x[i];
+    }
+    if (min > 0) min = 0;
+    if (max == min) {
+        for (int i = 0; i < n; ++i) L[i] = 0;
+        *the_min = -min;
+        return 0.f;
+    }
+    float iscale = nmax/(max - min);
+    float scale = 1/iscale;
+    float best_mad = 0;
+    for (int i = 0; i < n; ++i) {
+        int l = nearest_int(iscale*(x[i] - min));
+        L[i] = MAX(0, MIN(nmax, l));
+        float diff = scale * L[i] + min - x[i];
+        diff = use_mad ? fabsf(diff) : diff * diff;
+        float w = weights[i];
+        best_mad += w * diff;
+    }
+    if (nstep < 1) {
+        *the_min = -min;
+        return scale;
+    }
+    for (int is = 0; is <= nstep; ++is) {
+        iscale = (rmin + rdelta*is + nmax)/(max - min);
+        float sum_l = 0, sum_l2 = 0, sum_xl = 0;
+        for (int i = 0; i < n; ++i) {
+            int l = nearest_int(iscale*(x[i] - min));
+            l = MAX(0, MIN(nmax, l));
+            Laux[i] = l;
+            float w = weights[i];
+            sum_l += w*l;
+            sum_l2 += w*l*l;
+            sum_xl += w*l*x[i];
+        }
+        float D = sum_w * sum_l2 - sum_l * sum_l;
+        if (D > 0) {
+            float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D;
+            float this_min   = (sum_l2 * sum_x - sum_l * sum_xl)/D;
+            if (this_min > 0) {
+                this_min = 0;
+                this_scale = sum_xl / sum_l2;
+            }
+            float mad = 0;
+            for (int i = 0; i < n; ++i) {
+                float diff = this_scale * Laux[i] + this_min - x[i];
+                diff = use_mad ? fabsf(diff) : diff * diff;
+                float w = weights[i];
+                mad += w * diff;
+            }
+            if (mad < best_mad) {
+                for (int i = 0; i < n; ++i) {
+                    L[i] = Laux[i];
+                }
+                best_mad = mad;
+                scale = this_scale;
+                min = this_min;
+            }
+        }
+    }
+    *the_min = -min;
+    return scale;
+}
+
+#if QK_K == 256
+static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
+    if (j < 4) {
+        *d = q[j] & 63; *m = q[j + 4] & 63;
+    } else {
+        *d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
+        *m = (q[j+4] >>  4) | ((q[j-0] >> 6) << 4);
+    }
+}
+#endif
+
+//========================- 2-bit (de)-quantization
+
+void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    uint8_t L[QK_K];
+    uint8_t Laux[16];
+    float   weights[16];
+    float mins[QK_K/16];
+    float scales[QK_K/16];
+
+    const float q4scale = 15.f;
+
+    for (int i = 0; i < nb; i++) {
+        float max_scale = 0; // as we are deducting the min, scales are always positive
+        float max_min = 0;
+        for (int j = 0; j < QK_K/16; ++j) {
+            for (int l = 0; l < 16; ++l) weights[l] = fabsf(x[16*j + l]);
+            scales[j] = make_qkx2_quants(16, 3, x + 16*j, weights, L + 16*j, &mins[j], Laux, -0.5f, 0.1f, 15, true);
+            float scale = scales[j];
+            if (scale > max_scale) {
+                max_scale = scale;
+            }
+            float min = mins[j];
+            if (min > max_min) {
+                max_min = min;
+            }
+        }
+
+        if (max_scale > 0) {
+            float iscale = q4scale/max_scale;
+            for (int j = 0; j < QK_K/16; ++j) {
+                int l = nearest_int(iscale*scales[j]);
+                y[i].scales[j] = l;
+            }
+            y[i].d = ggml_fp32_to_fp16(max_scale/q4scale);
+        } else {
+            for (int j = 0; j < QK_K/16; ++j) y[i].scales[j] = 0;
+            y[i].d = ggml_fp32_to_fp16(0.f);
+        }
+        if (max_min > 0) {
+            float iscale = q4scale/max_min;
+            for (int j = 0; j < QK_K/16; ++j) {
+                int l = nearest_int(iscale*mins[j]);
+                y[i].scales[j] |= (l << 4);
+            }
+            y[i].dmin = ggml_fp32_to_fp16(max_min/q4scale);
+        } else {
+            y[i].dmin = ggml_fp32_to_fp16(0.f);
+        }
+        for (int j = 0; j < QK_K/16; ++j) {
+            const float d = ggml_fp16_to_fp32(y[i].d) * (y[i].scales[j] & 0xF);
+            if (!d) continue;
+            const float dm = ggml_fp16_to_fp32(y[i].dmin) * (y[i].scales[j] >> 4);
+            for (int ii = 0; ii < 16; ++ii) {
+                int l = nearest_int((x[16*j + ii] + dm)/d);
+                l = MAX(0, MIN(3, l));
+                L[16*j + ii] = l;
+            }
+        }
+
+#if QK_K == 256
+        for (int j = 0; j < QK_K; j += 128) {
+            for (int l = 0; l < 32; ++l) {
+                y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
+            }
+        }
+#else
+        for (int l = 0; l < 16; ++l) {
+            y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6);
+        }
+#endif
+
+        x += QK_K;
+
+    }
+}
+
+void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    for (int i = 0; i < nb; i++) {
+
+        const float d = ggml_fp16_to_fp32(x[i].d);
+        const float min = ggml_fp16_to_fp32(x[i].dmin);
+
+        const uint8_t * q = x[i].qs;
+
+#if QK_K == 256
+        int is = 0;
+        float dl, ml;
+        for (int n = 0; n < QK_K; n += 128) {
+            int shift = 0;
+            for (int j = 0; j < 4; ++j) {
+
+                uint8_t sc = x[i].scales[is++];
+                dl = d * (sc & 0xF); ml = min * (sc >> 4);
+                for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml;
+
+                sc = x[i].scales[is++];
+                dl = d * (sc & 0xF); ml = min * (sc >> 4);
+                for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml;
+
+                shift += 2;
+            }
+            q += 32;
+        }
+#else
+        float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4);
+        float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4);
+        float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4);
+        float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4);
+        for (int l = 0; l < 16; ++l) {
+            y[l+ 0] = dl1 * ((int8_t)((q[l] >> 0) & 3)) - ml1;
+            y[l+16] = dl2 * ((int8_t)((q[l] >> 2) & 3)) - ml2;
+            y[l+32] = dl3 * ((int8_t)((q[l] >> 4) & 3)) - ml3;
+            y[l+48] = dl4 * ((int8_t)((q[l] >> 6) & 3)) - ml4;
+        }
+        y += QK_K;
+#endif
+    }
+}
+
+void quantize_row_q2_K(const float * restrict x, void * restrict vy, int k) {
+    quantize_row_q2_K_reference(x, vy, k);
+}
+
+size_t ggml_quantize_q2_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
+    (void)hist; // TODO: collect histograms
+
+    for (int j = 0; j < n; j += k) {
+        block_q2_K * restrict y = (block_q2_K *)dst + j/QK_K;
+        quantize_row_q2_K_reference(src + j, y, k);
+    }
+    return (n/QK_K*sizeof(block_q2_K));
+}
+
+//========================= 3-bit (de)-quantization
+
+void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    int8_t L[QK_K];
+    float scales[QK_K / 16];
+
+    for (int i = 0; i < nb; i++) {
+
+        float max_scale = 0;
+        float amax = 0;
+        for (int j = 0; j < QK_K/16; ++j) {
+            scales[j] = make_q3_quants(16, 4, x + 16*j, L + 16*j, true);
+            float scale = fabsf(scales[j]);
+            if (scale > amax) {
+                amax = scale; max_scale = scales[j];
+            }
+        }
+
+#if QK_K == 256
+        memset(y[i].scales, 0, 12);
+        if (max_scale) {
+            float iscale = -32.f/max_scale;
+            for (int j = 0; j < QK_K/16; ++j) {
+                int8_t l = nearest_int(iscale*scales[j]);
+                l = MAX(-32, MIN(31, l)) + 32;
+                if (j < 8) {
+                    y[i].scales[j] = l & 0xF;
+                } else {
+                    y[i].scales[j-8] |= ((l & 0xF) << 4);
+                }
+                l >>= 4;
+                y[i].scales[j%4 + 8] |= (l << (2*(j/4)));
+            }
+            y[i].d = ggml_fp32_to_fp16(1/iscale);
+        } else {
+            y[i].d = ggml_fp32_to_fp16(0.f);
+        }
+
+        int8_t sc;
+        for (int j = 0; j < QK_K/16; ++j) {
+            sc = j < 8 ? y[i].scales[j] & 0xF : y[i].scales[j-8] >> 4;
+            sc = (sc | (((y[i].scales[8 + j%4] >> (2*(j/4))) & 3) << 4)) - 32;
+            float d = ggml_fp16_to_fp32(y[i].d) * sc;
+            if (!d) {
+                continue;
+            }
+            for (int ii = 0; ii < 16; ++ii) {
+                int l = nearest_int(x[16*j + ii]/d);
+                l = MAX(-4, MIN(3, l));
+                L[16*j + ii] = l + 4;
+            }
+        }
+#else
+        if (max_scale) {
+            float iscale = -8.f/max_scale;
+            for (int j = 0; j < QK_K/16; j+=2) {
+                int l1 = nearest_int(iscale*scales[j]);
+                l1 = 8 + MAX(-8, MIN(7, l1));
+                int l2 = nearest_int(iscale*scales[j+1]);
+                l2 = 8 + MAX(-8, MIN(7, l2));
+                y[i].scales[j/2] = l1 | (l2 << 4);
+            }
+            y[i].d = ggml_fp32_to_fp16(1/iscale);
+        } else {
+            for (int j = 0; j < QK_K/16; j+=2) {
+                y[i].scales[j/2] = 0;
+            }
+            y[i].d = ggml_fp32_to_fp16(0.f);
+        }
+        for (int j = 0; j < QK_K/16; ++j) {
+            int s = j%2 == 0 ? y[i].scales[j/2] & 0xF : y[i].scales[j/2] >> 4;
+            float d = ggml_fp16_to_fp32(y[i].d) * (s - 8);
+            if (!d) {
+                continue;
+            }
+            for (int ii = 0; ii < 16; ++ii) {
+                int l = nearest_int(x[16*j + ii]/d);
+                l = MAX(-4, MIN(3, l));
+                L[16*j + ii] = l + 4;
+            }
+        }
+#endif
+
+        memset(y[i].hmask, 0, QK_K/8);
+        // We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc.
+        int m = 0;
+        uint8_t hm = 1;
+        for (int j = 0; j < QK_K; ++j) {
+            if (L[j] > 3) {
+                y[i].hmask[m] |= hm;
+                L[j] -= 4;
+            }
+            if (++m == QK_K/8) {
+                m = 0; hm <<= 1;
+            }
+        }
+#if QK_K == 256
+        for (int j = 0; j < QK_K; j += 128) {
+            for (int l = 0; l < 32; ++l) {
+                y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
+            }
+        }
+#else
+        for (int l = 0; l < 16; ++l) {
+            y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6);
+        }
+#endif
+
+        x += QK_K;
+    }
+}
+
+#if QK_K == 256
+void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    const uint32_t kmask1 = 0x03030303;
+    const uint32_t kmask2 = 0x0f0f0f0f;
+
+    uint32_t aux[4];
+    const int8_t * scales = (const int8_t*)aux;
+
+    for (int i = 0; i < nb; i++) {
+
+        const float d_all = ggml_fp16_to_fp32(x[i].d);
+
+        const uint8_t * restrict q = x[i].qs;
+        const uint8_t * restrict hm = x[i].hmask;
+        uint8_t m = 1;
+
+        memcpy(aux, x[i].scales, 12);
+        uint32_t tmp = aux[2];
+        aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
+        aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
+        aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
+        aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
+
+        int is = 0;
+        float dl;
+        for (int n = 0; n < QK_K; n += 128) {
+            int shift = 0;
+            for (int j = 0; j < 4; ++j) {
+
+                dl = d_all * (scales[is++] - 32);
+                for (int l = 0; l < 16; ++l) {
+                    *y++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4));
+                }
+
+                dl = d_all * (scales[is++] - 32);
+                for (int l = 0; l < 16; ++l) {
+                    *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4));
+                }
+
+                shift += 2;
+                m <<= 1;
+            }
+            q += 32;
+        }
+
+    }
+}
+#else
+void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) {
+    assert(k % QK_K == 0);
+    assert(QK_K == 64);
+    const int nb = k / QK_K;
+
+    for (int i = 0; i < nb; i++) {
+
+        const float d_all = ggml_fp16_to_fp32(x[i].d);
+
+        const uint8_t * restrict q = x[i].qs;
+        const uint8_t * restrict hm = x[i].hmask;
+
+        const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
+        const float d2 = d_all * ((x[i].scales[0] >>  4) - 8);
+        const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
+        const float d4 = d_all * ((x[i].scales[1] >>  4) - 8);
+
+        for (int l=0; l<8; ++l) {
+            uint8_t h = hm[l];
+            y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4));
+            y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4));
+            y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4));
+            y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4));
+            y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4));
+            y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4));
+            y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4));
+            y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4));
+        }
+        y += QK_K;
+    }
+}
+#endif
+
+void quantize_row_q3_K(const float * restrict x, void * restrict vy, int k) {
+    quantize_row_q3_K_reference(x, vy, k);
+}
+
+size_t ggml_quantize_q3_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
+    (void)hist; // TODO: collect histograms
+
+    for (int j = 0; j < n; j += k) {
+        block_q3_K * restrict y = (block_q3_K *)dst + j/QK_K;
+        quantize_row_q3_K_reference(src + j, y, k);
+    }
+    return (n/QK_K*sizeof(block_q3_K));
+}
+
+// ====================== 4-bit (de)-quantization
+
+void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    uint8_t L[QK_K];
+    uint8_t Laux[32];
+    float   weights[32];
+    float mins[QK_K/32];
+    float scales[QK_K/32];
+
+    for (int i = 0; i < nb; i++) {
+
+        float max_scale = 0; // as we are deducting the min, scales are always positive
+        float max_min = 0;
+        for (int j = 0; j < QK_K/32; ++j) {
+            //scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
+            float sum_x2 = 0;
+            for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l];
+            float av_x = sqrtf(sum_x2/32);
+            for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
+            scales[j] = make_qkx2_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -1.f, 0.1f, 20, false);
+            float scale = scales[j];
+            if (scale > max_scale) {
+                max_scale = scale;
+            }
+            float min = mins[j];
+            if (min > max_min) {
+                max_min = min;
+            }
+        }
+
+#if QK_K == 256
+        float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
+        float inv_min   = max_min   > 0 ? 63.f/max_min   : 0.f;
+        for (int j = 0; j < QK_K/32; ++j) {
+            uint8_t ls = nearest_int(inv_scale*scales[j]);
+            uint8_t lm = nearest_int(inv_min*mins[j]);
+            ls = MIN(63, ls);
+            lm = MIN(63, lm);
+            if (j < 4) {
+                y[i].scales[j] = ls;
+                y[i].scales[j+4] = lm;
+            } else {
+                y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
+                y[i].scales[j-4] |= ((ls >> 4) << 6);
+                y[i].scales[j-0] |= ((lm >> 4) << 6);
+            }
+        }
+        y[i].d = ggml_fp32_to_fp16(max_scale/63.f);
+        y[i].dmin = ggml_fp32_to_fp16(max_min/63.f);
+
+        uint8_t sc, m;
+        for (int j = 0; j < QK_K/32; ++j) {
+            get_scale_min_k4(j, y[i].scales, &sc, &m);
+            const float d = ggml_fp16_to_fp32(y[i].d) * sc;
+            if (!d) continue;
+            const float dm = ggml_fp16_to_fp32(y[i].dmin) * m;
+            for (int ii = 0; ii < 32; ++ii) {
+                int l = nearest_int((x[32*j + ii] + dm)/d);
+                l = MAX(0, MIN(15, l));
+                L[32*j + ii] = l;
+            }
+        }
+#else
+        const float s_factor = 15.f;
+        float inv_scale = max_scale > 0 ? s_factor/max_scale : 0.f;
+        float inv_min   = max_min   > 0 ? s_factor/max_min   : 0.f;
+        int d1 = nearest_int(inv_scale*scales[0]);
+        int m1 = nearest_int(inv_min*mins[0]);
+        int d2 = nearest_int(inv_scale*scales[1]);
+        int m2 = nearest_int(inv_min*mins[1]);
+        y[i].scales[0] = d1 | (m1 << 4);
+        y[i].scales[1] = d2 | (m2 << 4);
+        y[i].d[0] = ggml_fp32_to_fp16(max_scale/s_factor);
+        y[i].d[1] = ggml_fp32_to_fp16(max_min/s_factor);
+
+        float sumlx = 0;
+        int   suml2 = 0;
+        for (int j = 0; j < QK_K/32; ++j) {
+            const uint8_t sd = y[i].scales[j] & 0xF;
+            const uint8_t sm = y[i].scales[j] >>  4;
+            const float d = ggml_fp16_to_fp32(y[i].d[0]) * sd;
+            if (!d) continue;
+            const float m = ggml_fp16_to_fp32(y[i].d[1]) * sm;
+            for (int ii = 0; ii < 32; ++ii) {
+                int l = nearest_int((x[32*j + ii] + m)/d);
+                l = MAX(0, MIN(15, l));
+                L[32*j + ii] = l;
+                sumlx += (x[32*j + ii] + m)*l*sd;
+                suml2 += l*l*sd*sd;
+            }
+        }
+        if (suml2) {
+            y[i].d[0] = ggml_fp32_to_fp16(sumlx/suml2);
+        }
+#endif
+        uint8_t * q = y[i].qs;
+        for (int j = 0; j < QK_K; j += 64) {
+            for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4);
+            q += 32;
+        }
+
+        x += QK_K;
+
+    }
+}
+
+void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    for (int i = 0; i < nb; i++) {
+
+        const uint8_t * q = x[i].qs;
+
+#if QK_K == 256
+
+        const float d   = ggml_fp16_to_fp32(x[i].d);
+        const float min = ggml_fp16_to_fp32(x[i].dmin);
+
+        int is = 0;
+        uint8_t sc, m;
+        for (int j = 0; j < QK_K; j += 64) {
+            get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
+            const float d1 = d * sc; const float m1 = min * m;
+            get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
+            const float d2 = d * sc; const float m2 = min * m;
+            for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
+            for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l]  >> 4) - m2;
+            q += 32; is += 2;
+        }
+#else
+        const float dall = ggml_fp16_to_fp32(x[i].d[0]);
+        const float mall = ggml_fp16_to_fp32(x[i].d[1]);
+        const float d1 = dall * (x[i].scales[0] & 0xF), m1 = mall * (x[i].scales[0] >> 4);
+        const float d2 = dall * (x[i].scales[1] & 0xF), m2 = mall * (x[i].scales[1] >> 4);
+        for (int l = 0; l < 32; ++l) {
+            y[l+ 0] = d1 * (q[l] & 0xF) - m1;
+            y[l+32] = d2 * (q[l] >>  4) - m2;
+        }
+        y += QK_K;
+#endif
+
+    }
+}
+
+void quantize_row_q4_K(const float * restrict x, void * restrict vy, int k) {
+    assert(k % QK_K == 0);
+    block_q4_K * restrict y = vy;
+    quantize_row_q4_K_reference(x, y, k);
+}
+
+size_t ggml_quantize_q4_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
+    assert(k % QK_K == 0);
+    (void)hist; // TODO: collect histograms
+
+    for (int j = 0; j < n; j += k) {
+        block_q4_K * restrict y = (block_q4_K *)dst + j/QK_K;
+        quantize_row_q4_K_reference(src + j, y, k);
+    }
+    return (n/QK_K*sizeof(block_q4_K));
+}
+
+// ====================== 5-bit (de)-quantization
+
+void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+#if QK_K == 256
+    uint8_t L[QK_K];
+    float mins[QK_K/32];
+    float scales[QK_K/32];
+    float weights[32];
+    uint8_t Laux[32];
+#else
+    int8_t L[QK_K];
+    float scales[QK_K/16];
+#endif
+
+    for (int i = 0; i < nb; i++) {
+
+#if QK_K == 256
+
+        float max_scale = 0; // as we are deducting the min, scales are always positive
+        float max_min = 0;
+        for (int j = 0; j < QK_K/32; ++j) {
+            //scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
+            float sum_x2 = 0;
+            for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l];
+            float av_x = sqrtf(sum_x2/32);
+            for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
+            scales[j] = make_qkx2_quants(32, 31, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.5f, 0.1f, 15, false);
+            float scale = scales[j];
+            if (scale > max_scale) {
+                max_scale = scale;
+            }
+            float min = mins[j];
+            if (min > max_min) {
+                max_min = min;
+            }
+        }
+
+        float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
+        float inv_min   = max_min   > 0 ? 63.f/max_min   : 0.f;
+        for (int j = 0; j < QK_K/32; ++j) {
+            uint8_t ls = nearest_int(inv_scale*scales[j]);
+            uint8_t lm = nearest_int(inv_min*mins[j]);
+            ls = MIN(63, ls);
+            lm = MIN(63, lm);
+            if (j < 4) {
+                y[i].scales[j] = ls;
+                y[i].scales[j+4] = lm;
+            } else {
+                y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
+                y[i].scales[j-4] |= ((ls >> 4) << 6);
+                y[i].scales[j-0] |= ((lm >> 4) << 6);
             }
-            if (!n_changed) {
-                break;
+        }
+        y[i].d = ggml_fp32_to_fp16(max_scale/63.f);
+        y[i].dmin = ggml_fp32_to_fp16(max_min/63.f);
+
+        uint8_t sc, m;
+        for (int j = 0; j < QK_K/32; ++j) {
+            get_scale_min_k4(j, y[i].scales, &sc, &m);
+            const float d = ggml_fp16_to_fp32(y[i].d) * sc;
+            if (!d) continue;
+            const float dm = ggml_fp16_to_fp32(y[i].dmin) * m;
+            for (int ii = 0; ii < 32; ++ii) {
+                int l = nearest_int((x[32*j + ii] + dm)/d);
+                l = MAX(0, MIN(31, l));
+                L[32*j + ii] = l;
+            }
+        }
+
+        uint8_t * restrict qh = y[i].qh;
+        uint8_t * restrict ql = y[i].qs;
+        memset(qh, 0, QK_K/8);
+
+        uint8_t m1 = 1, m2 = 2;
+        for (int n = 0; n < QK_K; n += 64) {
+            for (int j = 0; j < 32; ++j) {
+                int l1 = L[n + j];
+                if (l1 > 15) {
+                    l1 -= 16; qh[j] |= m1;
+                }
+                int l2 = L[n + j + 32];
+                if (l2 > 15) {
+                    l2 -= 16; qh[j] |= m2;
+                }
+                ql[j] = l1 | (l2 << 4);
+            }
+            m1 <<= 2; m2 <<= 2;
+            ql += 32;
+        }
+#else
+        float max_scale = 0, amax = 0;
+        for (int j = 0; j < QK_K/16; ++j) {
+            scales[j] = make_qx_quants(16, 16, x + 16*j, L + 16*j, 1);
+            float abs_scale = fabsf(scales[j]);
+            if (abs_scale > amax) {
+                amax = abs_scale;
+                max_scale = scales[j];
+            }
+        }
+
+        float iscale = -128.f/max_scale;
+        for (int j = 0; j < QK_K/16; ++j) {
+            int l = nearest_int(iscale*scales[j]);
+            y[i].scales[j] = MAX(-128, MIN(127, l));
+        }
+        y[i].d = ggml_fp32_to_fp16(1/iscale);
+
+        for (int j = 0; j < QK_K/16; ++j) {
+            const float d = ggml_fp16_to_fp32(y[i].d) * y[i].scales[j];
+            if (!d) continue;
+            for (int ii = 0; ii < 16; ++ii) {
+                int l = nearest_int(x[16*j + ii]/d);
+                l = MAX(-16, MIN(15, l));
+                L[16*j + ii] = l + 16;
+            }
+        }
+
+        uint8_t * restrict qh = y[i].qh;
+        uint8_t * restrict ql = y[i].qs;
+        memset(qh, 0, QK_K/8);
+
+        for (int j = 0; j < 32; ++j) {
+            int jm = j%8;
+            int is = j/8;
+            int l1 = L[j];
+            if (l1 > 15) {
+                l1 -= 16; qh[jm] |= (1 << is);
+            }
+            int l2 = L[j + 32];
+            if (l2 > 15) {
+                l2 -= 16; qh[jm] |= (1 << (4 + is));
+            }
+            ql[j] = l1 | (l2 << 4);
+        }
+#endif
+
+        x += QK_K;
+
+    }
+}
+
+void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    for (int i = 0; i < nb; i++) {
+
+        const uint8_t * ql = x[i].qs;
+        const uint8_t * qh = x[i].qh;
+
+#if QK_K == 256
+
+        const float d = ggml_fp16_to_fp32(x[i].d);
+        const float min = ggml_fp16_to_fp32(x[i].dmin);
+
+        int is = 0;
+        uint8_t sc, m;
+        uint8_t u1 = 1, u2 = 2;
+        for (int j = 0; j < QK_K; j += 64) {
+            get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
+            const float d1 = d * sc; const float m1 = min * m;
+            get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
+            const float d2 = d * sc; const float m2 = min * m;
+            for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1;
+            for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l]  >> 4) + (qh[l] & u2 ? 16 : 0)) - m2;
+            ql += 32; is += 2;
+            u1 <<= 2; u2 <<= 2;
+        }
+#else
+        float d = ggml_fp16_to_fp32(x[i].d);
+        const int8_t * restrict s = x[i].scales;
+        for (int l = 0; l < 8; ++l) {
+            y[l+ 0] = d * s[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16));
+            y[l+ 8] = d * s[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16));
+            y[l+16] = d * s[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16));
+            y[l+24] = d * s[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16));
+            y[l+32] = d * s[2] * ((ql[l+ 0] >>  4) - (qh[l] & 0x10 ? 0 : 16));
+            y[l+40] = d * s[2] * ((ql[l+ 8] >>  4) - (qh[l] & 0x20 ? 0 : 16));
+            y[l+48] = d * s[3] * ((ql[l+16] >>  4) - (qh[l] & 0x40 ? 0 : 16));
+            y[l+56] = d * s[3] * ((ql[l+24] >>  4) - (qh[l] & 0x80 ? 0 : 16));
+        }
+        y += QK_K;
+#endif
+    }
+}
+
+void quantize_row_q5_K(const float * restrict x, void * restrict vy, int k) {
+    assert(k % QK_K == 0);
+    block_q5_K * restrict y = vy;
+    quantize_row_q5_K_reference(x, y, k);
+}
+
+size_t ggml_quantize_q5_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
+    assert(k % QK_K == 0);
+    (void)hist; // TODO: collect histograms
+
+    for (int j = 0; j < n; j += k) {
+        block_q5_K * restrict y = (block_q5_K *)dst + j/QK_K;
+        quantize_row_q5_K_reference(src + j, y, k);
+    }
+    return (n/QK_K*sizeof(block_q5_K));
+}
+
+// ====================== 6-bit (de)-quantization
+
+void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    int8_t L[QK_K];
+    float   scales[QK_K/16];
+
+    for (int i = 0; i < nb; i++) {
+
+        float max_scale = 0;
+        float max_abs_scale = 0;
+
+        for (int ib = 0; ib < QK_K/16; ++ib) {
+
+            const float scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1);
+            scales[ib] = scale;
+
+            const float abs_scale = fabsf(scale);
+            if (abs_scale > max_abs_scale) {
+                max_abs_scale = abs_scale;
+                max_scale = scale;
+            }
+
+        }
+
+        if (!max_abs_scale) {
+            memset(&y[i], 0, sizeof(block_q6_K));
+            y[i].d = ggml_fp32_to_fp16(0.f);
+            x += QK_K;
+            continue;
+        }
+
+        float iscale = -128.f/max_scale;
+        y[i].d = ggml_fp32_to_fp16(1/iscale);
+        for (int ib = 0; ib < QK_K/16; ++ib) {
+            y[i].scales[ib] = MIN(127, nearest_int(iscale*scales[ib]));
+        }
+
+        for (int j = 0; j < QK_K/16; ++j) {
+            float d = ggml_fp16_to_fp32(y[i].d) * y[i].scales[j];
+            if (!d) {
+                continue;
+            }
+            for (int ii = 0; ii < 16; ++ii) {
+                int l = nearest_int(x[16*j + ii]/d);
+                l = MAX(-32, MIN(31, l));
+                L[16*j + ii] = l + 32;
+            }
+        }
+
+        uint8_t * restrict ql = y[i].ql;
+        uint8_t * restrict qh = y[i].qh;
+#if QK_K == 256
+        for (int j = 0; j < QK_K; j += 128) {
+            for (int l = 0; l < 32; ++l) {
+                const uint8_t q1 = L[j + l +  0] & 0xF;
+                const uint8_t q2 = L[j + l + 32] & 0xF;
+                const uint8_t q3 = L[j + l + 64] & 0xF;
+                const uint8_t q4 = L[j + l + 96] & 0xF;
+                ql[l+ 0] = q1 | (q3 << 4);
+                ql[l+32] = q2 | (q4 << 4);
+                qh[l] = (L[j + l] >> 4) | ((L[j + l + 32] >> 4) << 2) | ((L[j + l + 64] >> 4) << 4) | ((L[j + l + 96] >> 4) << 6);
+            }
+            ql += 64;
+            qh += 32;
+        }
+#else
+        for (int l = 0; l < 32; ++l) {
+            const uint8_t q1 = L[l +  0] & 0xF;
+            const uint8_t q2 = L[l + 32] & 0xF;
+            ql[l] = q1 | (q2 << 4);
+        }
+        for (int l = 0; l < 16; ++l) {
+            qh[l] = (L[l] >> 4) | ((L[l + 16] >> 4) << 2) | ((L[l + 32] >> 4) << 4) | ((L[l + 48] >> 4) << 6);
+        }
+#endif
+
+        x += QK_K;
+
+    }
+}
+
+void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    for (int i = 0; i < nb; i++) {
+
+        const float d = ggml_fp16_to_fp32(x[i].d);
+
+        const uint8_t * restrict ql = x[i].ql;
+        const uint8_t * restrict qh = x[i].qh;
+        const int8_t  * restrict sc = x[i].scales;
+
+#if QK_K == 256
+        for (int n = 0; n < QK_K; n += 128) {
+            for (int l = 0; l < 32; ++l) {
+                int is = l/16;
+                const int8_t q1 = (int8_t)((ql[l +  0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
+                const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
+                const int8_t q3 = (int8_t)((ql[l +  0]  >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
+                const int8_t q4 = (int8_t)((ql[l + 32]  >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
+                y[l +  0] = d * sc[is + 0] * q1;
+                y[l + 32] = d * sc[is + 2] * q2;
+                y[l + 64] = d * sc[is + 4] * q3;
+                y[l + 96] = d * sc[is + 6] * q4;
             }
+            y  += 128;
+            ql += 64;
+            qh += 32;
+            sc += 8;
         }
-        for (int i = 0; i < n; ++i) {
-            L[i] += nmax;
+#else
+        for (int l = 0; l < 16; ++l) {
+            const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
+            const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
+            const int8_t q3 = (int8_t)((ql[l+ 0]  >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
+            const int8_t q4 = (int8_t)((ql[l+16]  >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
+            y[l+ 0] = d * sc[0] * q1;
+            y[l+16] = d * sc[1] * q2;
+            y[l+32] = d * sc[2] * q3;
+            y[l+48] = d * sc[3] * q4;
         }
-        return sumlx / suml2;
-    }
-    for (int i = 0; i < n; ++i) {
-        int l = nearest_int(iscale * x[i]);
-        l = MAX(-nmax, MIN(nmax-1, l));
-        L[i] = l + nmax;
+        y  += 64;
+#endif
+
     }
-    return 1/iscale;
 }
 
-static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, float * restrict the_min,
-        int ntry, float alpha) {
-    float min = x[0];
-    float max = x[0];
-    for (int i = 1; i < n; ++i) {
-        if (x[i] < min) min = x[i];
-        if (x[i] > max) max = x[i];
-    }
-    if (max == min) {
-        for (int i = 0; i < n; ++i) L[i] = 0;
-        *the_min = 0;
-        return 0.f;
+void quantize_row_q6_K(const float * restrict x, void * restrict vy, int k) {
+    assert(k % QK_K == 0);
+    block_q6_K * restrict y = vy;
+    quantize_row_q6_K_reference(x, y, k);
+}
+
+size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist) {
+    assert(k % QK_K == 0);
+    (void)hist; // TODO: collect histograms
+
+    for (int j = 0; j < n; j += k) {
+        block_q6_K * restrict y = (block_q6_K *)dst + j/QK_K;
+        quantize_row_q6_K_reference(src + j, y, k);
     }
-    if (min > 0) min = 0;
-    float iscale = nmax/(max - min);
-    float scale = 1/iscale;
-    for (int itry = 0; itry < ntry; ++itry) {
-        float sumlx = 0; int suml2 = 0;
-        bool did_change = false;
-        for (int i = 0; i < n; ++i) {
-            int l = nearest_int(iscale*(x[i] - min));
-            l = MAX(0, MIN(nmax, l));
-            if (l != L[i]) {
-                L[i] = l;
-                did_change = true;
+    return (n/QK_K*sizeof(block_q6_K));
+}
+
+//===================================== Q8_K ==============================================
+
+void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    for (int i = 0; i < nb; i++) {
+
+        float max = 0;
+        float amax = 0;
+        for (int j = 0; j < QK_K; ++j) {
+            float ax = fabsf(x[j]);
+            if (ax > amax) {
+                amax = ax; max = x[j];
             }
-            sumlx += (x[i] - min)*l;
-            suml2 += l*l;
         }
-        scale = sumlx/suml2;
-        float sum = 0;
-        for (int i = 0; i < n; ++i) {
-            sum += x[i] - scale*L[i];
+        if (!amax) {
+            y[i].d = 0;
+            memset(y[i].qs, 0, QK_K);
+            x += QK_K;
+            continue;
         }
-        min = alpha*min + (1 - alpha)*sum/n;
-        if (min > 0) min = 0;
-        iscale = 1/scale;
-        if (!did_change) break;
-    }
-    *the_min = -min;
-    return scale;
-}
-
-static float make_qkx2_quants(int n, int nmax, const float * restrict x, const float * restrict weights,
-        uint8_t * restrict L, float * restrict the_min, uint8_t * restrict Laux,
-        float rmin, float rdelta, int nstep, bool use_mad) {
-    float min = x[0];
-    float max = x[0];
-    float sum_w = weights[0];
-    float sum_x = sum_w * x[0];
-    for (int i = 1; i < n; ++i) {
-        if (x[i] < min) min = x[i];
-        if (x[i] > max) max = x[i];
-        float w = weights[i];
-        sum_w += w;
-        sum_x += w * x[i];
-    }
-    if (min > 0) min = 0;
-    if (max == min) {
-        for (int i = 0; i < n; ++i) L[i] = 0;
-        *the_min = -min;
-        return 0.f;
-    }
-    float iscale = nmax/(max - min);
-    float scale = 1/iscale;
-    float best_mad = 0;
-    for (int i = 0; i < n; ++i) {
-        int l = nearest_int(iscale*(x[i] - min));
-        L[i] = MAX(0, MIN(nmax, l));
-        float diff = scale * L[i] + min - x[i];
-        diff = use_mad ? fabsf(diff) : diff * diff;
-        float w = weights[i];
-        best_mad += w * diff;
-    }
-    if (nstep < 1) {
-        *the_min = -min;
-        return scale;
-    }
-    for (int is = 0; is <= nstep; ++is) {
-        iscale = (rmin + rdelta*is + nmax)/(max - min);
-        float sum_l = 0, sum_l2 = 0, sum_xl = 0;
-        for (int i = 0; i < n; ++i) {
-            int l = nearest_int(iscale*(x[i] - min));
-            l = MAX(0, MIN(nmax, l));
-            Laux[i] = l;
-            float w = weights[i];
-            sum_l += w*l;
-            sum_l2 += w*l*l;
-            sum_xl += w*l*x[i];
+        const float iscale = -128.f/max;
+        for (int j = 0; j < QK_K; ++j) {
+            int v = nearest_int(iscale*x[j]);
+            y[i].qs[j] = MIN(127, v);
         }
-        float D = sum_w * sum_l2 - sum_l * sum_l;
-        if (D > 0) {
-            float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D;
-            float this_min   = (sum_l2 * sum_x - sum_l * sum_xl)/D;
-            if (this_min > 0) {
-                this_min = 0;
-                this_scale = sum_xl / sum_l2;
-            }
-            float mad = 0;
-            for (int i = 0; i < n; ++i) {
-                float diff = this_scale * Laux[i] + this_min - x[i];
-                diff = use_mad ? fabsf(diff) : diff * diff;
-                float w = weights[i];
-                mad += w * diff;
-            }
-            if (mad < best_mad) {
-                for (int i = 0; i < n; ++i) {
-                    L[i] = Laux[i];
-                }
-                best_mad = mad;
-                scale = this_scale;
-                min = this_min;
+        for (int j = 0; j < QK_K/16; ++j) {
+            int sum = 0;
+            for (int ii = 0; ii < 16; ++ii) {
+                sum += y[i].qs[j*16 + ii];
             }
+            y[i].bsums[j] = sum;
         }
+        y[i].d = 1/iscale;
+        x += QK_K;
     }
-    *the_min = -min;
-    return scale;
 }
 
-#if QK_K == 256
-static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
-    if (j < 4) {
-        *d = q[j] & 63; *m = q[j + 4] & 63;
-    } else {
-        *d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
-        *m = (q[j+4] >>  4) | ((q[j-0] >> 6) << 4);
+void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    for (int i = 0; i < nb; i++) {
+        for (int j = 0; j < QK_K; ++j) {
+            *y++ = x[i].d * x[i].qs[j];
+        }
     }
 }
+
+void quantize_row_q8_K(const float * restrict x, void * restrict y, int k) {
+    quantize_row_q8_K_reference(x, y, k);
+}
+
+//===================================== Dot ptoducts =================================
+
+//
+// Helper functions
+//
+#if __AVX__ || __AVX2__ || __AVX512F__
+
+// shuffles to pick the required scales in dot products
+static inline __m256i get_scale_shuffle_q3k(int i) {
+    static const uint8_t k_shuffle[128] = {
+         0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,     2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
+         4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,     6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
+         8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,    10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
+        12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,    14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
+    };
+    return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
+}
+static inline __m256i get_scale_shuffle_k4(int i) {
+    static const uint8_t k_shuffle[256] = {
+         0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
+         2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
+         4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,
+         6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
+         8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,
+        10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
+        12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,
+        14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15
+    };
+    return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
+}
+static inline __m128i get_scale_shuffle(int i) {
+    static const uint8_t k_shuffle[128] = {
+         0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
+         2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
+         4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5,
+         6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7,
+         8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9,
+        10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11,
+        12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13,
+        14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15
+    };
+    return _mm_loadu_si128((const __m128i*)k_shuffle + i);
+}
 #endif
 
-//========================- 2-bit (de)-quantization
+void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+
+    assert(n % qk == 0);
+
+    const block_q4_0 * restrict x = vx;
+    const block_q8_0 * restrict y = vy;
 
-void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int k) {
-    assert(k % QK_K == 0);
-    const int nb = k / QK_K;
+#if defined(__ARM_NEON)
+    float32x4_t sumv0 = vdupq_n_f32(0.0f);
+    float32x4_t sumv1 = vdupq_n_f32(0.0f);
 
-    uint8_t L[QK_K];
-    uint8_t Laux[16];
-    float   weights[16];
-    float mins[QK_K/16];
-    float scales[QK_K/16];
+    assert(nb % 2 == 0); // TODO: handle odd nb
 
-    const float q4scale = 15.f;
+    for (int i = 0; i < nb; i += 2) {
+        const block_q4_0 * restrict x0 = &x[i + 0];
+        const block_q4_0 * restrict x1 = &x[i + 1];
+        const block_q8_0 * restrict y0 = &y[i + 0];
+        const block_q8_0 * restrict y1 = &y[i + 1];
 
-    for (int i = 0; i < nb; i++) {
-        float max_scale = 0; // as we are deducting the min, scales are always positive
-        float max_min = 0;
-        for (int j = 0; j < QK_K/16; ++j) {
-            for (int l = 0; l < 16; ++l) weights[l] = fabsf(x[16*j + l]);
-            scales[j] = make_qkx2_quants(16, 3, x + 16*j, weights, L + 16*j, &mins[j], Laux, -0.5f, 0.1f, 15, true);
-            float scale = scales[j];
-            if (scale > max_scale) {
-                max_scale = scale;
-            }
-            float min = mins[j];
-            if (min > max_min) {
-                max_min = min;
-            }
-        }
+        const uint8x16_t m4b = vdupq_n_u8(0x0F);
+        const int8x16_t  s8b = vdupq_n_s8(0x8);
 
-        if (max_scale > 0) {
-            float iscale = q4scale/max_scale;
-            for (int j = 0; j < QK_K/16; ++j) {
-                int l = nearest_int(iscale*scales[j]);
-                y[i].scales[j] = l;
-            }
-            y[i].d = ggml_fp32_to_fp16(max_scale/q4scale);
-        } else {
-            for (int j = 0; j < QK_K/16; ++j) y[i].scales[j] = 0;
-            y[i].d = ggml_fp32_to_fp16(0.f);
-        }
-        if (max_min > 0) {
-            float iscale = q4scale/max_min;
-            for (int j = 0; j < QK_K/16; ++j) {
-                int l = nearest_int(iscale*mins[j]);
-                y[i].scales[j] |= (l << 4);
-            }
-            y[i].dmin = ggml_fp32_to_fp16(max_min/q4scale);
-        } else {
-            y[i].dmin = ggml_fp32_to_fp16(0.f);
-        }
-        for (int j = 0; j < QK_K/16; ++j) {
-            const float d = ggml_fp16_to_fp32(y[i].d) * (y[i].scales[j] & 0xF);
-            if (!d) continue;
-            const float dm = ggml_fp16_to_fp32(y[i].dmin) * (y[i].scales[j] >> 4);
-            for (int ii = 0; ii < 16; ++ii) {
-                int l = nearest_int((x[16*j + ii] + dm)/d);
-                l = MAX(0, MIN(3, l));
-                L[16*j + ii] = l;
-            }
-        }
+        const uint8x16_t v0_0 = vld1q_u8(x0->qs);
+        const uint8x16_t v0_1 = vld1q_u8(x1->qs);
 
-#if QK_K == 256
-        for (int j = 0; j < QK_K; j += 128) {
-            for (int l = 0; l < 32; ++l) {
-                y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
-            }
-        }
+        // 4-bit -> 8-bit
+        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
+        const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+        const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
+        const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+        // sub 8
+        const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
+        const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
+        const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
+        const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
+
+        // load y
+        const int8x16_t v1_0l = vld1q_s8(y0->qs);
+        const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
+        const int8x16_t v1_1l = vld1q_s8(y1->qs);
+        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+        // dot product into int32x4_t
+        const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
+        const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
+
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), ggml_fp16_to_fp32(x0->d)*ggml_fp16_to_fp32(y0->d));
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), ggml_fp16_to_fp32(x1->d)*ggml_fp16_to_fp32(y1->d));
 #else
-        for (int l = 0; l < 16; ++l) {
-            y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6);
-        }
+        const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l));
+        const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
+        const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0h));
+        const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
+
+        const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1l));
+        const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l));
+        const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1h));
+        const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h));
+
+        const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
+        const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
+        const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
+        const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
+
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), ggml_fp16_to_fp32(x0->d)*ggml_fp16_to_fp32(y0->d));
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), ggml_fp16_to_fp32(x1->d)*ggml_fp16_to_fp32(y1->d));
 #endif
+    }
 
-        x += QK_K;
+    *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
+#elif defined(__AVX2__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+
+    // Main loop
+    for (int i = 0; i < nb; ++i) {
+        /* Compute combined scale for the block */
+        const __m256 d = _mm256_set1_ps( ggml_fp16_to_fp32(x[i].d) * ggml_fp16_to_fp32(y[i].d) );
+
+        __m256i bx = bytes_from_nibbles_32(x[i].qs);
 
+        // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
+        const __m256i off = _mm256_set1_epi8( 8 );
+        bx = _mm256_sub_epi8( bx, off );
+
+        __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
+
+        const __m256 q = mul_sum_i8_pairs_float(bx, by);
+
+        /* Multiply q with scale and accumulate */
+        acc = _mm256_fmadd_ps( d, q, acc );
     }
-}
 
-void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int k) {
-    assert(k % QK_K == 0);
-    const int nb = k / QK_K;
+    *s = hsum_float_8(acc);
+#elif defined(__AVX__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
 
-    for (int i = 0; i < nb; i++) {
+    // Main loop
+    for (int i = 0; i < nb; ++i) {
+        // Compute combined scale for the block
+        const __m256 d = _mm256_set1_ps( ggml_fp16_to_fp32(x[i].d) * ggml_fp16_to_fp32(y[i].d) );
 
-        const float d = ggml_fp16_to_fp32(x[i].d);
-        const float min = ggml_fp16_to_fp32(x[i].dmin);
+        const __m128i lowMask = _mm_set1_epi8(0xF);
+        const __m128i off = _mm_set1_epi8(8);
 
-        const uint8_t * q = x[i].qs;
+        const __m128i tmp = _mm_loadu_si128((const __m128i *)x[i].qs);
 
-#if QK_K == 256
-        int is = 0;
-        float dl, ml;
-        for (int n = 0; n < QK_K; n += 128) {
-            int shift = 0;
-            for (int j = 0; j < 4; ++j) {
+        __m128i bx = _mm_and_si128(lowMask, tmp);
+        __m128i by = _mm_loadu_si128((const __m128i *)y[i].qs);
+        bx = _mm_sub_epi8(bx, off);
+        const __m128i i32_0 = mul_sum_i8_pairs(bx, by);
 
-                uint8_t sc = x[i].scales[is++];
-                dl = d * (sc & 0xF); ml = min * (sc >> 4);
-                for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml;
+        bx = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4));
+        by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
+        bx = _mm_sub_epi8(bx, off);
+        const __m128i i32_1 = mul_sum_i8_pairs(bx, by);
 
-                sc = x[i].scales[is++];
-                dl = d * (sc & 0xF); ml = min * (sc >> 4);
-                for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml;
+        // Convert int32_t to float
+        __m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1));
 
-                shift += 2;
-            }
-            q += 32;
-        }
-#else
-        float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4);
-        float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4);
-        float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4);
-        float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4);
-        for (int l = 0; l < 16; ++l) {
-            y[l+ 0] = dl1 * ((int8_t)((q[l] >> 0) & 3)) - ml1;
-            y[l+16] = dl2 * ((int8_t)((q[l] >> 2) & 3)) - ml2;
-            y[l+32] = dl3 * ((int8_t)((q[l] >> 4) & 3)) - ml3;
-            y[l+48] = dl4 * ((int8_t)((q[l] >> 6) & 3)) - ml4;
-        }
-        y += QK_K;
-#endif
+        // Apply the scale, and accumulate
+        acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
     }
-}
 
-void quantize_row_q2_K(const float * restrict x, void * restrict vy, int k) {
-    quantize_row_q2_K_reference(x, vy, k);
-}
+    *s = hsum_float_8(acc);
+#elif defined(__SSSE3__)
+    // set constants
+    const __m128i lowMask = _mm_set1_epi8(0xF);
+    const __m128i off = _mm_set1_epi8(8);
+
+    // Initialize accumulator with zeros
+    __m128 acc_0 = _mm_setzero_ps();
+    __m128 acc_1 = _mm_setzero_ps();
+    __m128 acc_2 = _mm_setzero_ps();
+    __m128 acc_3 = _mm_setzero_ps();
+
+    // First round without accumulation
+    {
+        _mm_prefetch(&x[0] + sizeof(block_q4_0), _MM_HINT_T0);
+        _mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0);
+
+        // Compute combined scale for the block 0 and 1
+        const __m128 d_0_1 = _mm_set1_ps( ggml_fp16_to_fp32(x[0].d) * ggml_fp16_to_fp32(y[0].d) );
+
+        const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[0].qs);
+
+        __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
+        __m128i by_0 = _mm_loadu_si128((const __m128i *)y[0].qs);
+        bx_0 = _mm_sub_epi8(bx_0, off);
+        const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
+
+        __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
+        __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[0].qs + 16));
+        bx_1 = _mm_sub_epi8(bx_1, off);
+        const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
+
+        _mm_prefetch(&x[1] + sizeof(block_q4_0), _MM_HINT_T0);
+        _mm_prefetch(&y[1] + sizeof(block_q8_0), _MM_HINT_T0);
+
+        // Compute combined scale for the block 2 and 3
+        const __m128 d_2_3 = _mm_set1_ps( ggml_fp16_to_fp32(x[1].d) * ggml_fp16_to_fp32(y[1].d) );
+
+        const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[1].qs);
+
+        __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
+        __m128i by_2 = _mm_loadu_si128((const __m128i *)y[1].qs);
+        bx_2 = _mm_sub_epi8(bx_2, off);
+        const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
+
+        __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
+        __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[1].qs + 16));
+        bx_3 = _mm_sub_epi8(bx_3, off);
+        const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
+
+        // Convert int32_t to float
+        __m128 p0 = _mm_cvtepi32_ps(i32_0);
+        __m128 p1 = _mm_cvtepi32_ps(i32_1);
+        __m128 p2 = _mm_cvtepi32_ps(i32_2);
+        __m128 p3 = _mm_cvtepi32_ps(i32_3);
+
+        // Apply the scale
+        acc_0 = _mm_mul_ps( d_0_1, p0 );
+        acc_1 = _mm_mul_ps( d_0_1, p1 );
+        acc_2 = _mm_mul_ps( d_2_3, p2 );
+        acc_3 = _mm_mul_ps( d_2_3, p3 );
+    }
 
-size_t ggml_quantize_q2_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
-    (void)hist; // TODO: collect histograms
+    assert(nb % 2 == 0); // TODO: handle odd nb
 
-    for (int j = 0; j < n; j += k) {
-        block_q2_K * restrict y = (block_q2_K *)dst + j/QK_K;
-        quantize_row_q2_K_reference(src + j, y, k);
-    }
-    return (n/QK_K*sizeof(block_q2_K));
-}
+    // Main loop
+    for (int i = 2; i < nb; i+=2) {
+        _mm_prefetch(&x[i] + sizeof(block_q4_0), _MM_HINT_T0);
+        _mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0);
 
-//========================= 3-bit (de)-quantization
+        // Compute combined scale for the block 0 and 1
+        const __m128 d_0_1 = _mm_set1_ps( ggml_fp16_to_fp32(x[i].d) * ggml_fp16_to_fp32(y[i].d) );
 
-void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k) {
-    assert(k % QK_K == 0);
-    const int nb = k / QK_K;
+        const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[i].qs);
 
-    int8_t L[QK_K];
-    float scales[QK_K / 16];
+        __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
+        __m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs);
+        bx_0 = _mm_sub_epi8(bx_0, off);
+        const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
+
+        __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
+        __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
+        bx_1 = _mm_sub_epi8(bx_1, off);
+        const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
+
+        _mm_prefetch(&x[i] + 2 * sizeof(block_q4_0), _MM_HINT_T0);
+        _mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
+
+        // Compute combined scale for the block 2 and 3
+        const __m128 d_2_3 = _mm_set1_ps( ggml_fp16_to_fp32(x[i + 1].d) * ggml_fp16_to_fp32(y[i + 1].d) );
+
+        const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[i + 1].qs);
+
+        __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
+        __m128i by_2 = _mm_loadu_si128((const __m128i *)y[i + 1].qs);
+        bx_2 = _mm_sub_epi8(bx_2, off);
+        const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
+
+        __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
+        __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[i + 1].qs + 16));
+        bx_3 = _mm_sub_epi8(bx_3, off);
+        const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
+
+        // Convert int32_t to float
+        __m128 p0 = _mm_cvtepi32_ps(i32_0);
+        __m128 p1 = _mm_cvtepi32_ps(i32_1);
+        __m128 p2 = _mm_cvtepi32_ps(i32_2);
+        __m128 p3 = _mm_cvtepi32_ps(i32_3);
+
+        // Apply the scale
+        __m128 p0_d = _mm_mul_ps( d_0_1, p0 );
+        __m128 p1_d = _mm_mul_ps( d_0_1, p1 );
+        __m128 p2_d = _mm_mul_ps( d_2_3, p2 );
+        __m128 p3_d = _mm_mul_ps( d_2_3, p3 );
+
+        // Acummulate
+        acc_0 = _mm_add_ps(p0_d, acc_0);
+        acc_1 = _mm_add_ps(p1_d, acc_1);
+        acc_2 = _mm_add_ps(p2_d, acc_2);
+        acc_3 = _mm_add_ps(p3_d, acc_3);
+    }
+
+    *s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
+#elif defined(__riscv_v_intrinsic)
+    float sumf = 0.0;
+
+    size_t vl = __riscv_vsetvl_e8m1(qk/2);
 
     for (int i = 0; i < nb; i++) {
+        // load elements
+        vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
 
-        float max_scale = 0;
-        float amax = 0;
-        for (int j = 0; j < QK_K/16; ++j) {
-            scales[j] = make_q3_quants(16, 4, x + 16*j, L + 16*j, true);
-            float scale = fabsf(scales[j]);
-            if (scale > amax) {
-                amax = scale; max_scale = scales[j];
-            }
-        }
+        vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
+        vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
 
-#if QK_K == 256
-        memset(y[i].scales, 0, 12);
-        if (max_scale) {
-            float iscale = -32.f/max_scale;
-            for (int j = 0; j < QK_K/16; ++j) {
-                int8_t l = nearest_int(iscale*scales[j]);
-                l = MAX(-32, MIN(31, l)) + 32;
-                if (j < 8) {
-                    y[i].scales[j] = l & 0xF;
-                } else {
-                    y[i].scales[j-8] |= ((l & 0xF) << 4);
-                }
-                l >>= 4;
-                y[i].scales[j%4 + 8] |= (l << (2*(j/4)));
-            }
-            y[i].d = ggml_fp32_to_fp16(1/iscale);
-        } else {
-            y[i].d = ggml_fp32_to_fp16(0.f);
-        }
+        // mask and store lower part of x, and then upper part
+        vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
+        vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
 
-        int8_t sc;
-        for (int j = 0; j < QK_K/16; ++j) {
-            sc = j < 8 ? y[i].scales[j] & 0xF : y[i].scales[j-8] >> 4;
-            sc = (sc | (((y[i].scales[8 + j%4] >> (2*(j/4))) & 3) << 4)) - 32;
-            float d = ggml_fp16_to_fp32(y[i].d) * sc;
-            if (!d) {
-                continue;
-            }
-            for (int ii = 0; ii < 16; ++ii) {
-                int l = nearest_int(x[16*j + ii]/d);
-                l = MAX(-4, MIN(3, l));
-                L[16*j + ii] = l + 4;
-            }
-        }
-#else
-        if (max_scale) {
-            float iscale = -8.f/max_scale;
-            for (int j = 0; j < QK_K/16; j+=2) {
-                int l1 = nearest_int(iscale*scales[j]);
-                l1 = 8 + MAX(-8, MIN(7, l1));
-                int l2 = nearest_int(iscale*scales[j+1]);
-                l2 = 8 + MAX(-8, MIN(7, l2));
-                y[i].scales[j/2] = l1 | (l2 << 4);
-            }
-            y[i].d = ggml_fp32_to_fp16(1/iscale);
-        } else {
-            for (int j = 0; j < QK_K/16; j+=2) {
-                y[i].scales[j/2] = 0;
-            }
-            y[i].d = ggml_fp32_to_fp16(0.f);
-        }
-        for (int j = 0; j < QK_K/16; ++j) {
-            int s = j%2 == 0 ? y[i].scales[j/2] & 0xF : y[i].scales[j/2] >> 4;
-            float d = ggml_fp16_to_fp32(y[i].d) * (s - 8);
-            if (!d) {
-                continue;
-            }
-            for (int ii = 0; ii < 16; ++ii) {
-                int l = nearest_int(x[16*j + ii]/d);
-                l = MAX(-4, MIN(3, l));
-                L[16*j + ii] = l + 4;
-            }
-        }
-#endif
+        vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
+        vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
 
-        memset(y[i].hmask, 0, QK_K/8);
-        // We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc.
-        int m = 0;
-        uint8_t hm = 1;
-        for (int j = 0; j < QK_K; ++j) {
-            if (L[j] > 3) {
-                y[i].hmask[m] |= hm;
-                L[j] -= 4;
-            }
-            if (++m == QK_K/8) {
-                m = 0; hm <<= 1;
-            }
-        }
-#if QK_K == 256
-        for (int j = 0; j < QK_K; j += 128) {
-            for (int l = 0; l < 32; ++l) {
-                y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
-            }
-        }
+        // subtract offset
+        vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 8, vl);
+        vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 8, vl);
+
+        vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
+        vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
+
+        vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
+
+        vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
+        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
+
+        int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
+
+        sumf += sumi*ggml_fp16_to_fp32(x[i].d)*ggml_fp16_to_fp32(y[i].d);
+    }
+
+    *s = sumf;
 #else
-        for (int l = 0; l < 16; ++l) {
-            y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6);
+    // scalar
+    float sumf = 0.0;
+
+    for (int i = 0; i < nb; i++) {
+        int sumi = 0;
+
+        for (int j = 0; j < qk/2; ++j) {
+            const int v0 = (x[i].qs[j] & 0x0F) - 8;
+            const int v1 = (x[i].qs[j] >>   4) - 8;
+
+            sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
         }
-#endif
 
-        x += QK_K;
+        sumf += sumi*ggml_fp16_to_fp32(x[i].d)*ggml_fp16_to_fp32(y[i].d);
     }
+
+    *s = sumf;
+#endif
 }
 
-#if QK_K == 256
-void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) {
-    assert(k % QK_K == 0);
-    const int nb = k / QK_K;
+void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    const int qk = QK8_1;
+    const int nb = n / qk;
 
-    const uint32_t kmask1 = 0x03030303;
-    const uint32_t kmask2 = 0x0f0f0f0f;
+    assert(n % qk == 0);
 
-    uint32_t aux[4];
-    const int8_t * scales = (const int8_t*)aux;
+    const block_q4_1 * restrict x = vx;
+    const block_q8_1 * restrict y = vy;
 
-    for (int i = 0; i < nb; i++) {
+    // TODO: add WASM SIMD
+#if defined(__ARM_NEON)
+    float32x4_t sumv0 = vdupq_n_f32(0.0f);
+    float32x4_t sumv1 = vdupq_n_f32(0.0f);
 
-        const float d_all = ggml_fp16_to_fp32(x[i].d);
+    float summs = 0;
 
-        const uint8_t * restrict q = x[i].qs;
-        const uint8_t * restrict hm = x[i].hmask;
-        uint8_t m = 1;
+    assert(nb % 2 == 0); // TODO: handle odd nb
 
-        memcpy(aux, x[i].scales, 12);
-        uint32_t tmp = aux[2];
-        aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
-        aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
-        aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
-        aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
+    for (int i = 0; i < nb; i += 2) {
+        const block_q4_1 * restrict x0 = &x[i + 0];
+        const block_q4_1 * restrict x1 = &x[i + 1];
+        const block_q8_1 * restrict y0 = &y[i + 0];
+        const block_q8_1 * restrict y1 = &y[i + 1];
 
-        int is = 0;
-        float dl;
-        for (int n = 0; n < QK_K; n += 128) {
-            int shift = 0;
-            for (int j = 0; j < 4; ++j) {
+        summs += ggml_fp16_to_fp32(x0->m) * y0->s + ggml_fp16_to_fp32(x1->m) * y1->s;
 
-                dl = d_all * (scales[is++] - 32);
-                for (int l = 0; l < 16; ++l) {
-                    *y++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4));
-                }
+        const uint8x16_t m4b = vdupq_n_u8(0x0F);
 
-                dl = d_all * (scales[is++] - 32);
-                for (int l = 0; l < 16; ++l) {
-                    *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4));
-                }
+        const uint8x16_t v0_0 = vld1q_u8(x0->qs);
+        const uint8x16_t v0_1 = vld1q_u8(x1->qs);
 
-                shift += 2;
-                m <<= 1;
-            }
-            q += 32;
-        }
+        // 4-bit -> 8-bit
+        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
+        const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+        const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
+        const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
 
+        // load y
+        const int8x16_t v1_0l = vld1q_s8(y0->qs);
+        const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
+        const int8x16_t v1_1l = vld1q_s8(y1->qs);
+        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+        // dot product into int32x4_t
+        const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);
+        const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);
+
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), ggml_fp16_to_fp32(x0->d)*y0->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), ggml_fp16_to_fp32(x1->d)*y1->d);
+#else
+        const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0l));
+        const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0l));
+        const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0h));
+        const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0h));
+
+        const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1l));
+        const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1l));
+        const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1h));
+        const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1h));
+
+        const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
+        const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
+        const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
+        const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
+
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), ggml_fp16_to_fp32(x0->d)*y0->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), ggml_fp16_to_fp32(x1->d)*y1->d);
+#endif
     }
-}
+
+    *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
+#elif defined(__AVX2__) || defined(__AVX__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+
+    float summs = 0;
+
+    // Main loop
+    for (int i = 0; i < nb; ++i) {
+        const float d0 = ggml_fp16_to_fp32(x[i].d);
+        const float d1 = y[i].d;
+
+        summs += ggml_fp16_to_fp32(x[i].m) * y[i].s;
+
+        const __m256 d0v = _mm256_set1_ps( d0 );
+        const __m256 d1v = _mm256_set1_ps( d1 );
+
+        // Compute combined scales
+        const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
+
+        // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
+        const __m256i bx = bytes_from_nibbles_32(x[i].qs);
+        const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
+
+        const __m256 xy = mul_sum_us8_pairs_float(bx, by);
+
+        // Accumulate d0*d1*x*y
+#if defined(__AVX2__)
+        acc = _mm256_fmadd_ps( d0d1, xy, acc );
 #else
-void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) {
-    assert(k % QK_K == 0);
-    assert(QK_K == 64);
-    const int nb = k / QK_K;
+        acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc );
+#endif
+    }
+
+    *s = hsum_float_8(acc) + summs;
+#elif defined(__riscv_v_intrinsic)
+    float sumf = 0.0;
+
+    size_t vl = __riscv_vsetvl_e8m1(qk/2);
 
     for (int i = 0; i < nb; i++) {
+        // load elements
+        vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
 
-        const float d_all = ggml_fp16_to_fp32(x[i].d);
+        vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
+        vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
 
-        const uint8_t * restrict q = x[i].qs;
-        const uint8_t * restrict hm = x[i].hmask;
+        // mask and store lower part of x, and then upper part
+        vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
+        vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
 
-        const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
-        const float d2 = d_all * ((x[i].scales[0] >>  4) - 8);
-        const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
-        const float d4 = d_all * ((x[i].scales[1] >>  4) - 8);
+        vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
+        vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
 
-        for (int l=0; l<8; ++l) {
-            uint8_t h = hm[l];
-            y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4));
-            y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4));
-            y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4));
-            y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4));
-            y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4));
-            y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4));
-            y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4));
-            y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4));
+        vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
+        vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
+
+        vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
+
+        vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
+        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
+
+        int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
+
+        sumf += (ggml_fp16_to_fp32(x[i].d)*y[i].d)*sumi + ggml_fp16_to_fp32(x[i].m)*y[i].s;
+    }
+
+    *s = sumf;
+#else
+    // scalar
+    float sumf = 0.0;
+
+    for (int i = 0; i < nb; i++) {
+        int sumi = 0;
+
+        for (int j = 0; j < qk/2; ++j) {
+            const int v0 = (x[i].qs[j] & 0x0F);
+            const int v1 = (x[i].qs[j] >>   4);
+
+            sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
         }
-        y += QK_K;
+
+        sumf += (ggml_fp16_to_fp32(x[i].d)*y[i].d)*sumi + ggml_fp16_to_fp32(x[i].m)*y[i].s;
     }
+
+    *s = sumf;
+#endif
 }
+
+void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+
+    assert(n % qk == 0);
+    assert(qk == QK5_0);
+
+    const block_q5_0 * restrict x = vx;
+    const block_q8_0 * restrict y = vy;
+
+#if defined(__ARM_NEON)
+    float32x4_t sumv0 = vdupq_n_f32(0.0f);
+    float32x4_t sumv1 = vdupq_n_f32(0.0f);
+
+    uint32_t qh0;
+    uint32_t qh1;
+
+    uint64_t tmp0[4];
+    uint64_t tmp1[4];
+
+    assert(nb % 2 == 0); // TODO: handle odd nb
+
+    for (int i = 0; i < nb; i += 2) {
+        const block_q5_0 * restrict x0 = &x[i];
+        const block_q5_0 * restrict x1 = &x[i + 1];
+        const block_q8_0 * restrict y0 = &y[i];
+        const block_q8_0 * restrict y1 = &y[i + 1];
+
+        const uint8x16_t m4b = vdupq_n_u8(0x0F);
+
+        // extract the 5th bit via lookup table ((!b) << 4)
+        memcpy(&qh0, x0->qh, sizeof(qh0));
+        memcpy(&qh1, x1->qh, sizeof(qh1));
+
+        tmp0[0] = table_b2b_1[(qh0 >>  0) & 0xFF];
+        tmp0[1] = table_b2b_1[(qh0 >>  8) & 0xFF];
+        tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF];
+        tmp0[3] = table_b2b_1[(qh0 >> 24)       ];
+
+        tmp1[0] = table_b2b_1[(qh1 >>  0) & 0xFF];
+        tmp1[1] = table_b2b_1[(qh1 >>  8) & 0xFF];
+        tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF];
+        tmp1[3] = table_b2b_1[(qh1 >> 24)       ];
+
+        const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0));
+        const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2));
+        const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0));
+        const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2));
+
+        const uint8x16_t v0_0 = vld1q_u8(x0->qs);
+        const uint8x16_t v0_1 = vld1q_u8(x1->qs);
+
+        // 4-bit -> 8-bit
+        int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
+        int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+        int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
+        int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+        // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero)
+        const int8x16_t v0_0lf = vsubq_s8(v0_0l, qhl0);
+        const int8x16_t v0_0hf = vsubq_s8(v0_0h, qhh0);
+        const int8x16_t v0_1lf = vsubq_s8(v0_1l, qhl1);
+        const int8x16_t v0_1hf = vsubq_s8(v0_1h, qhh1);
+
+        // load y
+        const int8x16_t v1_0l = vld1q_s8(y0->qs);
+        const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
+        const int8x16_t v1_1l = vld1q_s8(y1->qs);
+        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
+                        vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
+                        vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), ggml_fp16_to_fp32(x0->d)*ggml_fp16_to_fp32(y0->d));
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
+                        vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
+                        vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), ggml_fp16_to_fp32(x1->d)*ggml_fp16_to_fp32(y1->d));
+#else
+        const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
+        const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
+        const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h));
+        const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h));
+
+        const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l));
+        const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l));
+        const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h));
+        const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h));
+
+        const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
+        const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
+        const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
+        const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
+
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), ggml_fp16_to_fp32(x0->d)*ggml_fp16_to_fp32(y0->d));
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), ggml_fp16_to_fp32(x1->d)*ggml_fp16_to_fp32(y1->d));
 #endif
+    }
 
-void quantize_row_q3_K(const float * restrict x, void * restrict vy, int k) {
-    quantize_row_q3_K_reference(x, vy, k);
-}
+    *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
+#elif defined(__wasm_simd128__)
+    v128_t sumv = wasm_f32x4_splat(0.0f);
 
-size_t ggml_quantize_q3_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
-    (void)hist; // TODO: collect histograms
+    uint32_t qh;
+    uint64_t tmp[4];
 
-    for (int j = 0; j < n; j += k) {
-        block_q3_K * restrict y = (block_q3_K *)dst + j/QK_K;
-        quantize_row_q3_K_reference(src + j, y, k);
+    // TODO: check if unrolling this is better
+    for (int i = 0; i < nb; ++i) {
+        const block_q5_0 * restrict x0 = &x[i];
+        const block_q8_0 * restrict y0 = &y[i];
+
+        const v128_t m4b  = wasm_i8x16_splat(0x0F);
+
+        // extract the 5th bit
+        memcpy(&qh, x0->qh, sizeof(qh));
+
+        tmp[0] = table_b2b_1[(qh >>  0) & 0xFF];
+        tmp[1] = table_b2b_1[(qh >>  8) & 0xFF];
+        tmp[2] = table_b2b_1[(qh >> 16) & 0xFF];
+        tmp[3] = table_b2b_1[(qh >> 24)       ];
+
+        const v128_t qhl = wasm_v128_load(tmp + 0);
+        const v128_t qhh = wasm_v128_load(tmp + 2);
+
+        const v128_t v0 = wasm_v128_load(x0->qs);
+
+        // 4-bit -> 8-bit
+        const v128_t v0l = wasm_v128_and (v0, m4b);
+        const v128_t v0h = wasm_u8x16_shr(v0, 4);
+
+        // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero)
+        const v128_t v0lf = wasm_i8x16_sub(v0l, qhl);
+        const v128_t v0hf = wasm_i8x16_sub(v0h, qhh);
+
+        // load y
+        const v128_t v1l = wasm_v128_load(y0->qs);
+        const v128_t v1h = wasm_v128_load(y0->qs + 16);
+
+        // int8x16 -> int16x8
+        const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
+        const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
+        const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
+        const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
+
+        const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
+        const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
+        const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
+        const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
+
+        // dot product
+        sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(
+                        wasm_i32x4_add(
+                            wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
+                                           wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
+                            wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
+                                           wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
+                    wasm_f32x4_splat(ggml_fp16_to_fp32(x0->d) * ggml_fp16_to_fp32(y0->d))));
     }
-    return (n/QK_K*sizeof(block_q3_K));
-}
 
-// ====================== 4-bit (de)-quantization
+    *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
+         wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
+#elif defined(__AVX2__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
 
-void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k) {
-    assert(k % QK_K == 0);
-    const int nb = k / QK_K;
+    // Main loop
+    for (int i = 0; i < nb; i++) {
+        /* Compute combined scale for the block */
+        const __m256 d = _mm256_set1_ps(ggml_fp16_to_fp32(x[i].d) * ggml_fp16_to_fp32(y[i].d));
 
-    uint8_t L[QK_K];
-    uint8_t Laux[32];
-    float   weights[32];
-    float mins[QK_K/32];
-    float scales[QK_K/32];
+        __m256i bx = bytes_from_nibbles_32(x[i].qs);
+        __m256i bxhi = bytes_from_bits_32(x[i].qh);
+        bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0));
+        bx = _mm256_or_si256(bx, bxhi);
+
+        __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
+
+        const __m256 q = mul_sum_i8_pairs_float(bx, by);
+
+        /* Multiply q with scale and accumulate */
+        acc = _mm256_fmadd_ps(d, q, acc);
+    }
+
+    *s = hsum_float_8(acc);
+#elif defined(__AVX__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+    __m128i mask = _mm_set1_epi8((char)0xF0);
 
+    // Main loop
     for (int i = 0; i < nb; i++) {
+        /* Compute combined scale for the block */
+        const __m256 d = _mm256_set1_ps(ggml_fp16_to_fp32(x[i].d) * ggml_fp16_to_fp32(y[i].d));
+
+        __m256i bx = bytes_from_nibbles_32(x[i].qs);
+        const __m256i bxhi = bytes_from_bits_32(x[i].qh);
+        __m128i bxhil = _mm256_castsi256_si128(bxhi);
+        __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
+        bxhil = _mm_andnot_si128(bxhil, mask);
+        bxhih = _mm_andnot_si128(bxhih, mask);
+        __m128i bxl = _mm256_castsi256_si128(bx);
+        __m128i bxh = _mm256_extractf128_si256(bx, 1);
+        bxl = _mm_or_si128(bxl, bxhil);
+        bxh = _mm_or_si128(bxh, bxhih);
+        bx = MM256_SET_M128I(bxh, bxl);
+
+        const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
+
+        const __m256 q = mul_sum_i8_pairs_float(bx, by);
+
+        /* Multiply q with scale and accumulate */
+        acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc);
+    }
 
-        float max_scale = 0; // as we are deducting the min, scales are always positive
-        float max_min = 0;
-        for (int j = 0; j < QK_K/32; ++j) {
-            //scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
-            float sum_x2 = 0;
-            for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l];
-            float av_x = sqrtf(sum_x2/32);
-            for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
-            scales[j] = make_qkx2_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -1.f, 0.1f, 20, false);
-            float scale = scales[j];
-            if (scale > max_scale) {
-                max_scale = scale;
-            }
-            float min = mins[j];
-            if (min > max_min) {
-                max_min = min;
-            }
-        }
+    *s = hsum_float_8(acc);
+#elif defined(__riscv_v_intrinsic)
+    float sumf = 0.0;
 
-#if QK_K == 256
-        float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
-        float inv_min   = max_min   > 0 ? 63.f/max_min   : 0.f;
-        for (int j = 0; j < QK_K/32; ++j) {
-            uint8_t ls = nearest_int(inv_scale*scales[j]);
-            uint8_t lm = nearest_int(inv_min*mins[j]);
-            ls = MIN(63, ls);
-            lm = MIN(63, lm);
-            if (j < 4) {
-                y[i].scales[j] = ls;
-                y[i].scales[j+4] = lm;
-            } else {
-                y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
-                y[i].scales[j-4] |= ((ls >> 4) << 6);
-                y[i].scales[j-0] |= ((lm >> 4) << 6);
-            }
-        }
-        y[i].d = ggml_fp32_to_fp16(max_scale/63.f);
-        y[i].dmin = ggml_fp32_to_fp16(max_min/63.f);
+    uint32_t qh;
 
-        uint8_t sc, m;
-        for (int j = 0; j < QK_K/32; ++j) {
-            get_scale_min_k4(j, y[i].scales, &sc, &m);
-            const float d = ggml_fp16_to_fp32(y[i].d) * sc;
-            if (!d) continue;
-            const float dm = ggml_fp16_to_fp32(y[i].dmin) * m;
-            for (int ii = 0; ii < 32; ++ii) {
-                int l = nearest_int((x[32*j + ii] + dm)/d);
-                l = MAX(0, MIN(15, l));
-                L[32*j + ii] = l;
-            }
-        }
-#else
-        const float s_factor = 15.f;
-        float inv_scale = max_scale > 0 ? s_factor/max_scale : 0.f;
-        float inv_min   = max_min   > 0 ? s_factor/max_min   : 0.f;
-        int d1 = nearest_int(inv_scale*scales[0]);
-        int m1 = nearest_int(inv_min*mins[0]);
-        int d2 = nearest_int(inv_scale*scales[1]);
-        int m2 = nearest_int(inv_min*mins[1]);
-        y[i].scales[0] = d1 | (m1 << 4);
-        y[i].scales[1] = d2 | (m2 << 4);
-        y[i].d[0] = ggml_fp32_to_fp16(max_scale/s_factor);
-        y[i].d[1] = ggml_fp32_to_fp16(max_min/s_factor);
+    size_t vl = __riscv_vsetvl_e8m1(qk/2);
 
-        float sumlx = 0;
-        int   suml2 = 0;
-        for (int j = 0; j < QK_K/32; ++j) {
-            const uint8_t sd = y[i].scales[j] & 0xF;
-            const uint8_t sm = y[i].scales[j] >>  4;
-            const float d = ggml_fp16_to_fp32(y[i].d[0]) * sd;
-            if (!d) continue;
-            const float m = ggml_fp16_to_fp32(y[i].d[1]) * sm;
-            for (int ii = 0; ii < 32; ++ii) {
-                int l = nearest_int((x[32*j + ii] + m)/d);
-                l = MAX(0, MIN(15, l));
-                L[32*j + ii] = l;
-                sumlx += (x[32*j + ii] + m)*l*sd;
-                suml2 += l*l*sd*sd;
-            }
-        }
-        if (suml2) {
-            y[i].d[0] = ggml_fp32_to_fp16(sumlx/suml2);
-        }
-#endif
-        uint8_t * q = y[i].qs;
-        for (int j = 0; j < QK_K; j += 64) {
-            for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4);
-            q += 32;
-        }
+    // These tempory registers are for masking and shift operations
+    vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
+    vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl);
+
+    vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl);
+    vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
+
+    for (int i = 0; i < nb; i++) {
+        memcpy(&qh, x[i].qh, sizeof(uint32_t));
+
+        // ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
+        vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(vt_2, qh, vl);
+        vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(xha_0, vt_1, vl);
+        vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
+
+        // ((qh & (1u << (j + 16))) >> (j + 12));
+        vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(vt_3, qh, vl);
+        vuint32m2_t xhl_1 = __riscv_vsrl_vv_u32m2(xha_1, vt_4, vl);
+
+        // narrowing
+        vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xhl_0, vl);
+        vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
+
+        vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xhl_1, vl);
+        vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
+
+        // load
+        vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
+
+        vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
+        vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
+
+        vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
+        vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
+
+        vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
+        vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
+
+        vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
+        vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
 
-        x += QK_K;
+        vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 16, vl);
+        vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 16, vl);
+
+        vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
+        vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
+
+        vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
+
+        vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
+        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
 
+        int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
+
+        sumf += (ggml_fp16_to_fp32(x[i].d)*ggml_fp16_to_fp32(y[i].d)) * sumi;
     }
-}
 
-void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int k) {
-    assert(k % QK_K == 0);
-    const int nb = k / QK_K;
+    *s = sumf;
+#else
+    // scalar
+    float sumf = 0.0;
 
     for (int i = 0; i < nb; i++) {
+        uint32_t qh;
+        memcpy(&qh, x[i].qh, sizeof(qh));
 
-        const uint8_t * q = x[i].qs;
+        int sumi = 0;
 
-#if QK_K == 256
+        for (int j = 0; j < qk/2; ++j) {
+            const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
+            const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
 
-        const float d   = ggml_fp16_to_fp32(x[i].d);
-        const float min = ggml_fp16_to_fp32(x[i].dmin);
+            const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16;
+            const int32_t x1 = ((x[i].qs[j] >>   4) | xh_1) - 16;
 
-        int is = 0;
-        uint8_t sc, m;
-        for (int j = 0; j < QK_K; j += 64) {
-            get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
-            const float d1 = d * sc; const float m1 = min * m;
-            get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
-            const float d2 = d * sc; const float m2 = min * m;
-            for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
-            for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l]  >> 4) - m2;
-            q += 32; is += 2;
-        }
-#else
-        const float dall = ggml_fp16_to_fp32(x[i].d[0]);
-        const float mall = ggml_fp16_to_fp32(x[i].d[1]);
-        const float d1 = dall * (x[i].scales[0] & 0xF), m1 = mall * (x[i].scales[0] >> 4);
-        const float d2 = dall * (x[i].scales[1] & 0xF), m2 = mall * (x[i].scales[1] >> 4);
-        for (int l = 0; l < 32; ++l) {
-            y[l+ 0] = d1 * (q[l] & 0xF) - m1;
-            y[l+32] = d2 * (q[l] >>  4) - m2;
+            sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]);
         }
-        y += QK_K;
-#endif
 
+        sumf += (ggml_fp16_to_fp32(x[i].d)*ggml_fp16_to_fp32(y[i].d)) * sumi;
     }
-}
 
-void quantize_row_q4_K(const float * restrict x, void * restrict vy, int k) {
-    assert(k % QK_K == 0);
-    block_q4_K * restrict y = vy;
-    quantize_row_q4_K_reference(x, y, k);
+    *s = sumf;
+#endif
 }
 
-size_t ggml_quantize_q4_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
-    assert(k % QK_K == 0);
-    (void)hist; // TODO: collect histograms
+void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    const int qk = QK8_1;
+    const int nb = n / qk;
 
-    for (int j = 0; j < n; j += k) {
-        block_q4_K * restrict y = (block_q4_K *)dst + j/QK_K;
-        quantize_row_q4_K_reference(src + j, y, k);
-    }
-    return (n/QK_K*sizeof(block_q4_K));
-}
+    assert(n % qk == 0);
+    assert(qk == QK5_1);
 
-// ====================== 5-bit (de)-quantization
+    const block_q5_1 * restrict x = vx;
+    const block_q8_1 * restrict y = vy;
 
-void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k) {
-    assert(k % QK_K == 0);
-    const int nb = k / QK_K;
+#if defined(__ARM_NEON)
+    float32x4_t sumv0 = vdupq_n_f32(0.0f);
+    float32x4_t sumv1 = vdupq_n_f32(0.0f);
 
-#if QK_K == 256
-    uint8_t L[QK_K];
-    float mins[QK_K/32];
-    float scales[QK_K/32];
-    float weights[32];
-    uint8_t Laux[32];
+    float summs0 = 0.0f;
+    float summs1 = 0.0f;
+
+    uint32_t qh0;
+    uint32_t qh1;
+
+    uint64_t tmp0[4];
+    uint64_t tmp1[4];
+
+    assert(nb % 2 == 0); // TODO: handle odd nb
+
+    for (int i = 0; i < nb; i += 2) {
+        const block_q5_1 * restrict x0 = &x[i];
+        const block_q5_1 * restrict x1 = &x[i + 1];
+        const block_q8_1 * restrict y0 = &y[i];
+        const block_q8_1 * restrict y1 = &y[i + 1];
+
+        const uint8x16_t m4b = vdupq_n_u8(0x0F);
+
+        summs0 += ggml_fp16_to_fp32(x0->m) * y0->s;
+        summs1 += ggml_fp16_to_fp32(x1->m) * y1->s;
+
+        // extract the 5th bit via lookup table ((b) << 4)
+        memcpy(&qh0, x0->qh, sizeof(qh0));
+        memcpy(&qh1, x1->qh, sizeof(qh1));
+
+        tmp0[0] = table_b2b_0[(qh0 >>  0) & 0xFF];
+        tmp0[1] = table_b2b_0[(qh0 >>  8) & 0xFF];
+        tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF];
+        tmp0[3] = table_b2b_0[(qh0 >> 24)       ];
+
+        tmp1[0] = table_b2b_0[(qh1 >>  0) & 0xFF];
+        tmp1[1] = table_b2b_0[(qh1 >>  8) & 0xFF];
+        tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF];
+        tmp1[3] = table_b2b_0[(qh1 >> 24)       ];
+
+        const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0));
+        const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2));
+        const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0));
+        const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2));
+
+        const uint8x16_t v0_0 = vld1q_u8(x0->qs);
+        const uint8x16_t v0_1 = vld1q_u8(x1->qs);
+
+        // 4-bit -> 8-bit
+        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
+        const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+        const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
+        const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+        // add high bit
+        const int8x16_t v0_0lf = vorrq_s8(v0_0l, qhl0);
+        const int8x16_t v0_0hf = vorrq_s8(v0_0h, qhh0);
+        const int8x16_t v0_1lf = vorrq_s8(v0_1l, qhl1);
+        const int8x16_t v0_1hf = vorrq_s8(v0_1h, qhh1);
+
+        // load y
+        const int8x16_t v1_0l = vld1q_s8(y0->qs);
+        const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
+        const int8x16_t v1_1l = vld1q_s8(y1->qs);
+        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
+                        vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
+                        vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), ggml_fp16_to_fp32(x0->d)*y0->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
+                        vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
+                        vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), ggml_fp16_to_fp32(x1->d)*y1->d);
 #else
-    int8_t L[QK_K];
-    float scales[QK_K/16];
+        const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
+        const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
+        const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h));
+        const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h));
+
+        const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l));
+        const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l));
+        const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h));
+        const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h));
+
+        const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
+        const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
+        const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
+        const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
+
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), ggml_fp16_to_fp32(x0->d)*y0->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), ggml_fp16_to_fp32(x1->d)*y1->d);
 #endif
+    }
 
-    for (int i = 0; i < nb; i++) {
+    *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
+#elif defined(__wasm_simd128__)
+    v128_t sumv = wasm_f32x4_splat(0.0f);
 
-#if QK_K == 256
+    float summs = 0.0f;
 
-        float max_scale = 0; // as we are deducting the min, scales are always positive
-        float max_min = 0;
-        for (int j = 0; j < QK_K/32; ++j) {
-            //scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
-            float sum_x2 = 0;
-            for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l];
-            float av_x = sqrtf(sum_x2/32);
-            for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
-            scales[j] = make_qkx2_quants(32, 31, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.5f, 0.1f, 15, false);
-            float scale = scales[j];
-            if (scale > max_scale) {
-                max_scale = scale;
-            }
-            float min = mins[j];
-            if (min > max_min) {
-                max_min = min;
-            }
-        }
+    uint32_t qh;
+    uint64_t tmp[4];
 
-        float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
-        float inv_min   = max_min   > 0 ? 63.f/max_min   : 0.f;
-        for (int j = 0; j < QK_K/32; ++j) {
-            uint8_t ls = nearest_int(inv_scale*scales[j]);
-            uint8_t lm = nearest_int(inv_min*mins[j]);
-            ls = MIN(63, ls);
-            lm = MIN(63, lm);
-            if (j < 4) {
-                y[i].scales[j] = ls;
-                y[i].scales[j+4] = lm;
-            } else {
-                y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
-                y[i].scales[j-4] |= ((ls >> 4) << 6);
-                y[i].scales[j-0] |= ((lm >> 4) << 6);
-            }
-        }
-        y[i].d = ggml_fp32_to_fp16(max_scale/63.f);
-        y[i].dmin = ggml_fp32_to_fp16(max_min/63.f);
+    // TODO: check if unrolling this is better
+    for (int i = 0; i < nb; ++i) {
+        const block_q5_1 * restrict x0 = &x[i];
+        const block_q8_1 * restrict y0 = &y[i];
 
-        uint8_t sc, m;
-        for (int j = 0; j < QK_K/32; ++j) {
-            get_scale_min_k4(j, y[i].scales, &sc, &m);
-            const float d = ggml_fp16_to_fp32(y[i].d) * sc;
-            if (!d) continue;
-            const float dm = ggml_fp16_to_fp32(y[i].dmin) * m;
-            for (int ii = 0; ii < 32; ++ii) {
-                int l = nearest_int((x[32*j + ii] + dm)/d);
-                l = MAX(0, MIN(31, l));
-                L[32*j + ii] = l;
-            }
-        }
+        summs += ggml_fp16_to_fp32(x0->m) * y0->s;
 
-        uint8_t * restrict qh = y[i].qh;
-        uint8_t * restrict ql = y[i].qs;
-        memset(qh, 0, QK_K/8);
+        const v128_t m4b = wasm_i8x16_splat(0x0F);
 
-        uint8_t m1 = 1, m2 = 2;
-        for (int n = 0; n < QK_K; n += 64) {
-            for (int j = 0; j < 32; ++j) {
-                int l1 = L[n + j];
-                if (l1 > 15) {
-                    l1 -= 16; qh[j] |= m1;
-                }
-                int l2 = L[n + j + 32];
-                if (l2 > 15) {
-                    l2 -= 16; qh[j] |= m2;
-                }
-                ql[j] = l1 | (l2 << 4);
-            }
-            m1 <<= 2; m2 <<= 2;
-            ql += 32;
-        }
-#else
-        float max_scale = 0, amax = 0;
-        for (int j = 0; j < QK_K/16; ++j) {
-            scales[j] = make_qx_quants(16, 16, x + 16*j, L + 16*j, 1);
-            float abs_scale = fabsf(scales[j]);
-            if (abs_scale > amax) {
-                amax = abs_scale;
-                max_scale = scales[j];
-            }
-        }
+        // extract the 5th bit
+        memcpy(&qh, x0->qh, sizeof(qh));
 
-        float iscale = -128.f/max_scale;
-        for (int j = 0; j < QK_K/16; ++j) {
-            int l = nearest_int(iscale*scales[j]);
-            y[i].scales[j] = MAX(-128, MIN(127, l));
-        }
-        y[i].d = ggml_fp32_to_fp16(1/iscale);
+        tmp[0] = table_b2b_0[(qh >>  0) & 0xFF];
+        tmp[1] = table_b2b_0[(qh >>  8) & 0xFF];
+        tmp[2] = table_b2b_0[(qh >> 16) & 0xFF];
+        tmp[3] = table_b2b_0[(qh >> 24)       ];
 
-        for (int j = 0; j < QK_K/16; ++j) {
-            const float d = ggml_fp16_to_fp32(y[i].d) * y[i].scales[j];
-            if (!d) continue;
-            for (int ii = 0; ii < 16; ++ii) {
-                int l = nearest_int(x[16*j + ii]/d);
-                l = MAX(-16, MIN(15, l));
-                L[16*j + ii] = l + 16;
-            }
-        }
+        const v128_t qhl = wasm_v128_load(tmp + 0);
+        const v128_t qhh = wasm_v128_load(tmp + 2);
 
-        uint8_t * restrict qh = y[i].qh;
-        uint8_t * restrict ql = y[i].qs;
-        memset(qh, 0, QK_K/8);
+        const v128_t v0 = wasm_v128_load(x0->qs);
 
-        for (int j = 0; j < 32; ++j) {
-            int jm = j%8;
-            int is = j/8;
-            int l1 = L[j];
-            if (l1 > 15) {
-                l1 -= 16; qh[jm] |= (1 << is);
-            }
-            int l2 = L[j + 32];
-            if (l2 > 15) {
-                l2 -= 16; qh[jm] |= (1 << (4 + is));
-            }
-            ql[j] = l1 | (l2 << 4);
-        }
-#endif
+        // 4-bit -> 8-bit
+        const v128_t v0l = wasm_v128_and (v0, m4b);
+        const v128_t v0h = wasm_u8x16_shr(v0, 4);
 
-        x += QK_K;
+        // add high bit
+        const v128_t v0lf = wasm_v128_or(v0l, qhl);
+        const v128_t v0hf = wasm_v128_or(v0h, qhh);
+
+        // load y
+        const v128_t v1l = wasm_v128_load(y0->qs);
+        const v128_t v1h = wasm_v128_load(y0->qs + 16);
 
+        // int8x16 -> int16x8
+        const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
+        const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
+        const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
+        const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
+
+        const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
+        const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
+        const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
+        const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
+
+        // dot product
+        sumv = wasm_f32x4_add(sumv,
+                wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add(
+                            wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
+                                           wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
+                            wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
+                                           wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
+                    wasm_f32x4_splat(ggml_fp16_to_fp32(x0->d) * y0->d)));
     }
-}
 
-void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k) {
-    assert(k % QK_K == 0);
-    const int nb = k / QK_K;
+    *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
+         wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs;
+#elif defined(__AVX2__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+
+    float summs = 0.0f;
 
+    // Main loop
     for (int i = 0; i < nb; i++) {
+        const __m256 dx = _mm256_set1_ps(ggml_fp16_to_fp32(x[i].d));
 
-        const uint8_t * ql = x[i].qs;
-        const uint8_t * qh = x[i].qh;
+        summs += ggml_fp16_to_fp32(x[i].m) * y[i].s;
 
-#if QK_K == 256
+        __m256i bx = bytes_from_nibbles_32(x[i].qs);
+        __m256i bxhi = bytes_from_bits_32(x[i].qh);
+        bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10));
+        bx = _mm256_or_si256(bx, bxhi);
 
-        const float d = ggml_fp16_to_fp32(x[i].d);
-        const float min = ggml_fp16_to_fp32(x[i].dmin);
+        const __m256 dy = _mm256_set1_ps(y[i].d);
+        const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
 
-        int is = 0;
-        uint8_t sc, m;
-        uint8_t u1 = 1, u2 = 2;
-        for (int j = 0; j < QK_K; j += 64) {
-            get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
-            const float d1 = d * sc; const float m1 = min * m;
-            get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
-            const float d2 = d * sc; const float m2 = min * m;
-            for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1;
-            for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l]  >> 4) + (qh[l] & u2 ? 16 : 0)) - m2;
-            ql += 32; is += 2;
-            u1 <<= 2; u2 <<= 2;
-        }
-#else
-        float d = ggml_fp16_to_fp32(x[i].d);
-        const int8_t * restrict s = x[i].scales;
-        for (int l = 0; l < 8; ++l) {
-            y[l+ 0] = d * s[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16));
-            y[l+ 8] = d * s[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16));
-            y[l+16] = d * s[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16));
-            y[l+24] = d * s[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16));
-            y[l+32] = d * s[2] * ((ql[l+ 0] >>  4) - (qh[l] & 0x10 ? 0 : 16));
-            y[l+40] = d * s[2] * ((ql[l+ 8] >>  4) - (qh[l] & 0x20 ? 0 : 16));
-            y[l+48] = d * s[3] * ((ql[l+16] >>  4) - (qh[l] & 0x40 ? 0 : 16));
-            y[l+56] = d * s[3] * ((ql[l+24] >>  4) - (qh[l] & 0x80 ? 0 : 16));
-        }
-        y += QK_K;
-#endif
+        const __m256 q = mul_sum_us8_pairs_float(bx, by);
+
+        acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
     }
-}
 
-void quantize_row_q5_K(const float * restrict x, void * restrict vy, int k) {
-    assert(k % QK_K == 0);
-    block_q5_K * restrict y = vy;
-    quantize_row_q5_K_reference(x, y, k);
-}
+    *s = hsum_float_8(acc) + summs;
+#elif defined(__AVX__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+    __m128i mask = _mm_set1_epi8(0x10);
 
-size_t ggml_quantize_q5_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
-    assert(k % QK_K == 0);
-    (void)hist; // TODO: collect histograms
+    float summs = 0.0f;
 
-    for (int j = 0; j < n; j += k) {
-        block_q5_K * restrict y = (block_q5_K *)dst + j/QK_K;
-        quantize_row_q5_K_reference(src + j, y, k);
+    // Main loop
+    for (int i = 0; i < nb; i++) {
+        const __m256 dx = _mm256_set1_ps(ggml_fp16_to_fp32(x[i].d));
+
+        summs += ggml_fp16_to_fp32(x[i].m) * y[i].s;
+
+        __m256i bx = bytes_from_nibbles_32(x[i].qs);
+        const __m256i bxhi = bytes_from_bits_32(x[i].qh);
+        __m128i bxhil = _mm256_castsi256_si128(bxhi);
+        __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
+        bxhil = _mm_and_si128(bxhil, mask);
+        bxhih = _mm_and_si128(bxhih, mask);
+        __m128i bxl = _mm256_castsi256_si128(bx);
+        __m128i bxh = _mm256_extractf128_si256(bx, 1);
+        bxl = _mm_or_si128(bxl, bxhil);
+        bxh = _mm_or_si128(bxh, bxhih);
+        bx = MM256_SET_M128I(bxh, bxl);
+
+        const __m256 dy = _mm256_set1_ps(y[i].d);
+        const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
+
+        const __m256 q = mul_sum_us8_pairs_float(bx, by);
+
+        acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc);
     }
-    return (n/QK_K*sizeof(block_q5_K));
-}
 
-// ====================== 6-bit (de)-quantization
+    *s = hsum_float_8(acc) + summs;
+#elif defined(__riscv_v_intrinsic)
+    float sumf = 0.0;
 
-void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k) {
-    assert(k % QK_K == 0);
-    const int nb = k / QK_K;
+    uint32_t qh;
 
-    int8_t L[QK_K];
-    float   scales[QK_K/16];
+    size_t vl = __riscv_vsetvl_e8m1(qk/2);
+
+    // temporary registers for shift operations
+    vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
+    vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
 
     for (int i = 0; i < nb; i++) {
+        memcpy(&qh, x[i].qh, sizeof(uint32_t));
 
-        float max_scale = 0;
-        float max_abs_scale = 0;
+        // load qh
+        vuint32m2_t vqh = __riscv_vmv_v_x_u32m2(qh, vl);
 
-        for (int ib = 0; ib < QK_K/16; ++ib) {
+        // ((qh >> (j +  0)) << 4) & 0x10;
+        vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(vqh, vt_1, vl);
+        vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
+        vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(xhl_0, 0x10, vl);
 
-            const float scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1);
-            scales[ib] = scale;
+        // ((qh >> (j + 12))     ) & 0x10;
+        vuint32m2_t xhr_1 = __riscv_vsrl_vv_u32m2(vqh, vt_2, vl);
+        vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(xhr_1, 0x10, vl);
 
-            const float abs_scale = fabsf(scale);
-            if (abs_scale > max_abs_scale) {
-                max_abs_scale = abs_scale;
-                max_scale = scale;
-            }
+        // narrowing
+        vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xha_0, vl);
+        vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
 
-        }
+        vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xha_1, vl);
+        vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
 
-        if (!max_abs_scale) {
-            memset(&y[i], 0, sizeof(block_q6_K));
-            y[i].d = ggml_fp32_to_fp16(0.f);
-            x += QK_K;
-            continue;
-        }
+        // load
+        vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
 
-        float iscale = -128.f/max_scale;
-        y[i].d = ggml_fp32_to_fp16(1/iscale);
-        for (int ib = 0; ib < QK_K/16; ++ib) {
-            y[i].scales[ib] = MIN(127, nearest_int(iscale*scales[ib]));
-        }
+        vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
+        vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
 
-        for (int j = 0; j < QK_K/16; ++j) {
-            float d = ggml_fp16_to_fp32(y[i].d) * y[i].scales[j];
-            if (!d) {
-                continue;
-            }
-            for (int ii = 0; ii < 16; ++ii) {
-                int l = nearest_int(x[16*j + ii]/d);
-                l = MAX(-32, MIN(31, l));
-                L[16*j + ii] = l + 32;
-            }
-        }
+        vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
+        vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
 
-        uint8_t * restrict ql = y[i].ql;
-        uint8_t * restrict qh = y[i].qh;
-#if QK_K == 256
-        for (int j = 0; j < QK_K; j += 128) {
-            for (int l = 0; l < 32; ++l) {
-                const uint8_t q1 = L[j + l +  0] & 0xF;
-                const uint8_t q2 = L[j + l + 32] & 0xF;
-                const uint8_t q3 = L[j + l + 64] & 0xF;
-                const uint8_t q4 = L[j + l + 96] & 0xF;
-                ql[l+ 0] = q1 | (q3 << 4);
-                ql[l+32] = q2 | (q4 << 4);
-                qh[l] = (L[j + l] >> 4) | ((L[j + l + 32] >> 4) << 2) | ((L[j + l + 64] >> 4) << 4) | ((L[j + l + 96] >> 4) << 6);
-            }
-            ql += 64;
-            qh += 32;
-        }
-#else
-        for (int l = 0; l < 32; ++l) {
-            const uint8_t q1 = L[l +  0] & 0xF;
-            const uint8_t q2 = L[l + 32] & 0xF;
-            ql[l] = q1 | (q2 << 4);
-        }
-        for (int l = 0; l < 16; ++l) {
-            qh[l] = (L[l] >> 4) | ((L[l + 16] >> 4) << 2) | ((L[l + 32] >> 4) << 4) | ((L[l + 48] >> 4) << 6);
-        }
-#endif
+        vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
+        vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
 
-        x += QK_K;
+        vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
+        vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
+
+        vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
+        vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
+
+        vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
+
+        vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
+        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
 
+        int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
+
+        sumf += (ggml_fp16_to_fp32(x[i].d)*y[i].d)*sumi + ggml_fp16_to_fp32(x[i].m)*y[i].s;
     }
-}
 
-void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k) {
-    assert(k % QK_K == 0);
-    const int nb = k / QK_K;
+    *s = sumf;
+#else
+    // scalar
+    float sumf = 0.0;
 
     for (int i = 0; i < nb; i++) {
+        uint32_t qh;
+        memcpy(&qh, x[i].qh, sizeof(qh));
 
-        const float d = ggml_fp16_to_fp32(x[i].d);
+        int sumi = 0;
 
-        const uint8_t * restrict ql = x[i].ql;
-        const uint8_t * restrict qh = x[i].qh;
-        const int8_t  * restrict sc = x[i].scales;
+        for (int j = 0; j < qk/2; ++j) {
+            const uint8_t xh_0 = ((qh >> (j +  0)) << 4) & 0x10;
+            const uint8_t xh_1 = ((qh >> (j + 12))     ) & 0x10;
 
-#if QK_K == 256
-        for (int n = 0; n < QK_K; n += 128) {
-            for (int l = 0; l < 32; ++l) {
-                int is = l/16;
-                const int8_t q1 = (int8_t)((ql[l +  0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
-                const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
-                const int8_t q3 = (int8_t)((ql[l +  0]  >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
-                const int8_t q4 = (int8_t)((ql[l + 32]  >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
-                y[l +  0] = d * sc[is + 0] * q1;
-                y[l + 32] = d * sc[is + 2] * q2;
-                y[l + 64] = d * sc[is + 4] * q3;
-                y[l + 96] = d * sc[is + 6] * q4;
-            }
-            y  += 128;
-            ql += 64;
-            qh += 32;
-            sc += 8;
-        }
-#else
-        for (int l = 0; l < 16; ++l) {
-            const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
-            const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
-            const int8_t q3 = (int8_t)((ql[l+ 0]  >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
-            const int8_t q4 = (int8_t)((ql[l+16]  >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
-            y[l+ 0] = d * sc[0] * q1;
-            y[l+16] = d * sc[1] * q2;
-            y[l+32] = d * sc[2] * q3;
-            y[l+48] = d * sc[3] * q4;
+            const int32_t x0 = (x[i].qs[j] & 0xF) | xh_0;
+            const int32_t x1 = (x[i].qs[j] >>  4) | xh_1;
+
+            sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]);
         }
-        y  += 64;
-#endif
 
+        sumf += (ggml_fp16_to_fp32(x[i].d)*y[i].d)*sumi + ggml_fp16_to_fp32(x[i].m)*y[i].s;
     }
-}
 
-void quantize_row_q6_K(const float * restrict x, void * restrict vy, int k) {
-    assert(k % QK_K == 0);
-    block_q6_K * restrict y = vy;
-    quantize_row_q6_K_reference(x, y, k);
+    *s = sumf;
+#endif
 }
 
-size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist) {
-    assert(k % QK_K == 0);
-    (void)hist; // TODO: collect histograms
+void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
 
-    for (int j = 0; j < n; j += k) {
-        block_q6_K * restrict y = (block_q6_K *)dst + j/QK_K;
-        quantize_row_q6_K_reference(src + j, y, k);
+    assert(n % qk == 0);
+
+    const block_q8_0 * restrict x = vx;
+    const block_q8_0 * restrict y = vy;
+
+#if defined(__ARM_NEON)
+    float32x4_t sumv0 = vdupq_n_f32(0.0f);
+    float32x4_t sumv1 = vdupq_n_f32(0.0f);
+
+    assert(nb % 2 == 0); // TODO: handle odd nb
+
+    for (int i = 0; i < nb; i += 2) {
+        const block_q8_0 * restrict x0 = &x[i + 0];
+        const block_q8_0 * restrict x1 = &x[i + 1];
+        const block_q8_0 * restrict y0 = &y[i + 0];
+        const block_q8_0 * restrict y1 = &y[i + 1];
+
+        const int8x16_t x0_0 = vld1q_s8(x0->qs);
+        const int8x16_t x0_1 = vld1q_s8(x0->qs + 16);
+        const int8x16_t x1_0 = vld1q_s8(x1->qs);
+        const int8x16_t x1_1 = vld1q_s8(x1->qs + 16);
+
+        // load y
+        const int8x16_t y0_0 = vld1q_s8(y0->qs);
+        const int8x16_t y0_1 = vld1q_s8(y0->qs + 16);
+        const int8x16_t y1_0 = vld1q_s8(y1->qs);
+        const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
+                        vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
+                        vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), ggml_fp16_to_fp32(x0->d)*ggml_fp16_to_fp32(y0->d));
+
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
+                        vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
+                        vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), ggml_fp16_to_fp32(x1->d)*ggml_fp16_to_fp32(y1->d));
+
+#else
+        const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0));
+        const int16x8_t p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
+        const int16x8_t p0_2 = vmull_s8(vget_low_s8 (x0_1), vget_low_s8 (y0_1));
+        const int16x8_t p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
+
+        const int16x8_t p1_0 = vmull_s8(vget_low_s8 (x1_0), vget_low_s8 (y1_0));
+        const int16x8_t p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
+        const int16x8_t p1_2 = vmull_s8(vget_low_s8 (x1_1), vget_low_s8 (y1_1));
+        const int16x8_t p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
+
+        const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
+        const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
+        const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
+        const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
+
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), ggml_fp16_to_fp32(x0->d)*ggml_fp16_to_fp32(y0->d));
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), ggml_fp16_to_fp32(x1->d)*ggml_fp16_to_fp32(y1->d));
+#endif
     }
-    return (n/QK_K*sizeof(block_q6_K));
-}
 
-//===================================== Q8_K ==============================================
+    *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
+#elif defined(__AVX2__) || defined(__AVX__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
 
-void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
-    assert(k % QK_K == 0);
-    const int nb = k / QK_K;
+    // Main loop
+    for (int i = 0; i < nb; ++i) {
+        // Compute combined scale for the block
+        const __m256 d = _mm256_set1_ps(ggml_fp16_to_fp32(x[i].d) * ggml_fp16_to_fp32(y[i].d));
+        __m256i bx = _mm256_loadu_si256((const __m256i *)x[i].qs);
+        __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
 
-    for (int i = 0; i < nb; i++) {
+        const __m256 q = mul_sum_i8_pairs_float(bx, by);
 
-        float max = 0;
-        float amax = 0;
-        for (int j = 0; j < QK_K; ++j) {
-            float ax = fabsf(x[j]);
-            if (ax > amax) {
-                amax = ax; max = x[j];
-            }
-        }
-        if (!amax) {
-            y[i].d = 0;
-            memset(y[i].qs, 0, QK_K);
-            x += QK_K;
-            continue;
-        }
-        const float iscale = -128.f/max;
-        for (int j = 0; j < QK_K; ++j) {
-            int v = nearest_int(iscale*x[j]);
-            y[i].qs[j] = MIN(127, v);
-        }
-        for (int j = 0; j < QK_K/16; ++j) {
-            int sum = 0;
-            for (int ii = 0; ii < 16; ++ii) {
-                sum += y[i].qs[j*16 + ii];
-            }
-            y[i].bsums[j] = sum;
-        }
-        y[i].d = 1/iscale;
-        x += QK_K;
+        // Multiply q with scale and accumulate
+#if defined(__AVX2__)
+        acc = _mm256_fmadd_ps( d, q, acc );
+#else
+        acc = _mm256_add_ps( _mm256_mul_ps( d, q ), acc );
+#endif
     }
-}
 
-void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k) {
-    assert(k % QK_K == 0);
-    const int nb = k / QK_K;
+    *s = hsum_float_8(acc);
+#elif defined(__riscv_v_intrinsic)
+    float sumf = 0.0;
+    size_t vl = __riscv_vsetvl_e8m1(qk);
 
     for (int i = 0; i < nb; i++) {
-        for (int j = 0; j < QK_K; ++j) {
-            *y++ = x[i].d * x[i].qs[j];
-        }
+        // load elements
+        vint8m1_t bx = __riscv_vle8_v_i8m1(x[i].qs, vl);
+        vint8m1_t by = __riscv_vle8_v_i8m1(y[i].qs, vl);
+
+        vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx, by, vl);
+
+        vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl);
+        vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl);
+
+        int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum);
+
+        sumf += sumi*(ggml_fp16_to_fp32(x[i].d)*ggml_fp16_to_fp32(y[i].d));
     }
-}
 
-void quantize_row_q8_K(const float * restrict x, void * restrict y, int k) {
-    quantize_row_q8_K_reference(x, y, k);
-}
+    *s = sumf;
+#else
+    // scalar
+    float sumf = 0.0;
 
-//===================================== Dot ptoducts =================================
+    for (int i = 0; i < nb; i++) {
+        int sumi = 0;
 
-//
-// Helper functions
-//
-#if __AVX__ || __AVX2__ || __AVX512F__
+        for (int j = 0; j < qk; j++) {
+            sumi += x[i].qs[j]*y[i].qs[j];
+        }
 
-// horizontally add 8 floats
-static inline float hsum_float_8(const __m256 x) {
-    __m128 res = _mm256_extractf128_ps(x, 1);
-    res = _mm_add_ps(res, _mm256_castps256_ps128(x));
-    res = _mm_add_ps(res, _mm_movehl_ps(res, res));
-    res = _mm_add_ss(res, _mm_movehdup_ps(res));
-    return _mm_cvtss_f32(res);
-}
+        sumf += sumi*(ggml_fp16_to_fp32(x[i].d)*ggml_fp16_to_fp32(y[i].d));
+    }
 
-// shuffles to pick the required scales in dot products
-static inline __m256i get_scale_shuffle_q3k(int i) {
-    static const uint8_t k_shuffle[128] = {
-         0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,     2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
-         4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,     6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
-         8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,    10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
-        12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,    14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
-    };
-    return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
-}
-static inline __m256i get_scale_shuffle_k4(int i) {
-    static const uint8_t k_shuffle[256] = {
-         0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
-         2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
-         4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,
-         6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
-         8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,
-        10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
-        12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,
-        14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15
-    };
-    return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
-}
-static inline __m128i get_scale_shuffle(int i) {
-    static const uint8_t k_shuffle[128] = {
-         0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
-         2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
-         4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5,
-         6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7,
-         8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9,
-        10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11,
-        12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13,
-        14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15
-    };
-    return _mm_loadu_si128((const __m128i*)k_shuffle + i);
-}
+    *s = sumf;
 #endif
+}
 
 #if QK_K == 256
 void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {

+ 86 - 17
k_quants.h → ggml-quants.h

@@ -1,20 +1,14 @@
 #pragma once
 
+// This is a private API for quantization and dequantization
+// Should not be used directly, use ggml.h instead
+
 #include "ggml.h"
 
 #include <stdint.h>
 #include <assert.h>
 #include <stddef.h>
 
-// Super-block size
-#ifdef GGML_QKK_64
-#define QK_K 64
-#define K_SCALE_SIZE 4
-#else
-#define QK_K 256
-#define K_SCALE_SIZE 12
-#endif
-
 #ifndef static_assert
 #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L)
 #define static_assert(cond, msg) _Static_assert(cond, msg)
@@ -23,10 +17,66 @@
 #endif
 #endif
 
+#define QK4_0 32
+typedef struct {
+    ggml_fp16_t d;          // delta
+    uint8_t qs[QK4_0 / 2];  // nibbles / quants
+} block_q4_0;
+static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
+
+#define QK4_1 32
+typedef struct {
+    ggml_fp16_t d;          // delta
+    ggml_fp16_t m;          // min
+    uint8_t qs[QK4_1 / 2];  // nibbles / quants
+} block_q4_1;
+static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_fp16_t) + QK4_1 / 2, "wrong q4_1 block size/padding");
+
+#define QK5_0 32
+typedef struct {
+    ggml_fp16_t d;         // delta
+    uint8_t qh[4];         // 5-th bit of quants
+    uint8_t qs[QK5_0 / 2]; // nibbles / quants
+} block_q5_0;
+static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
+
+#define QK5_1 32
+typedef struct {
+    ggml_fp16_t d;         // delta
+    ggml_fp16_t m;         // min
+    uint8_t qh[4];         // 5-th bit of quants
+    uint8_t qs[QK5_1 / 2]; // nibbles / quants
+} block_q5_1;
+static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
+
+#define QK8_0 32
+typedef struct {
+    ggml_fp16_t d;         // delta
+    int8_t  qs[QK8_0];     // quants
+} block_q8_0;
+static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
+
+#define QK8_1 32
+typedef struct {
+    float d;               // delta
+    float s;               // d * sum(qs[i])
+    int8_t  qs[QK8_1];     // quants
+} block_q8_1;
+static_assert(sizeof(block_q8_1) == 2*sizeof(float) + QK8_1, "wrong q8_1 block size/padding");
+
 //
 // Super-block quantization structures
 //
 
+// Super-block size
+#ifdef GGML_QKK_64
+#define QK_K 64
+#define K_SCALE_SIZE 4
+#else
+#define QK_K 256
+#define K_SCALE_SIZE 12
+#endif
+
 // 2-bit quantization
 // weight is represented as x = a * q + b
 // 16 blocks of 16 elements each
@@ -127,6 +177,13 @@ static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_
 
 
 // Quantization
+void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k);
+void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k);
+void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k);
+void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k);
+void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k);
+void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k);
+
 void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int k);
 void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k);
 void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k);
@@ -134,6 +191,13 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
 void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k);
 void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k);
 
+void quantize_row_q4_0(const float * restrict x, void * restrict y, int k);
+void quantize_row_q4_1(const float * restrict x, void * restrict y, int k);
+void quantize_row_q5_0(const float * restrict x, void * restrict y, int k);
+void quantize_row_q5_1(const float * restrict x, void * restrict y, int k);
+void quantize_row_q8_0(const float * restrict x, void * restrict y, int k);
+void quantize_row_q8_1(const float * restrict x, void * restrict y, int k);
+
 void quantize_row_q2_K(const float * restrict x, void * restrict y, int k);
 void quantize_row_q3_K(const float * restrict x, void * restrict y, int k);
 void quantize_row_q4_K(const float * restrict x, void * restrict y, int k);
@@ -142,6 +206,13 @@ void quantize_row_q6_K(const float * restrict x, void * restrict y, int k);
 void quantize_row_q8_K(const float * restrict x, void * restrict y, int k);
 
 // Dequantization
+void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k);
+void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int k);
+void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int k);
+void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int k);
+void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int k);
+//void dequantize_row_q8_1(const block_q8_1 * restrict x, float * restrict y, int k);
+
 void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int k);
 void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k);
 void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int k);
@@ -150,16 +221,14 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int
 void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k);
 
 // Dot product
+void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+
 void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
-
-// Quantization with histogram collection
-size_t ggml_quantize_q2_K(const float * src, void * dst, int n, int k, int64_t * hist);
-size_t ggml_quantize_q3_K(const float * src, void * dst, int n, int k, int64_t * hist);
-size_t ggml_quantize_q4_K(const float * src, void * dst, int n, int k, int64_t * hist);
-size_t ggml_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist);
-size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist);
-

Fișier diff suprimat deoarece este prea mare
+ 672 - 2967
ggml.c


+ 7 - 0
ggml.h

@@ -1930,12 +1930,19 @@ extern "C" {
     // quantization
     //
 
+    // TODO: these would probably get removed in favor of the more general ggml_quantize_chunk
     GGML_API size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
     GGML_API size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
     GGML_API size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist);
     GGML_API size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist);
     GGML_API size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist);
 
+    GGML_API size_t ggml_quantize_q2_K(const float * src, void * dst, int n, int k, int64_t * hist);
+    GGML_API size_t ggml_quantize_q3_K(const float * src, void * dst, int n, int k, int64_t * hist);
+    GGML_API size_t ggml_quantize_q4_K(const float * src, void * dst, int n, int k, int64_t * hist);
+    GGML_API size_t ggml_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist);
+    GGML_API size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist);
+
     GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
 
     //

+ 14 - 20
llama.cpp

@@ -19,13 +19,11 @@
 #ifdef GGML_USE_MPI
 #  include "ggml-mpi.h"
 #endif
-#ifdef GGML_USE_K_QUANTS
-#  ifndef QK_K
-#    ifdef GGML_QKK_64
-#      define QK_K 64
-#    else
-#      define QK_K 256
-#    endif
+#ifndef QK_K
+#  ifdef GGML_QKK_64
+#    define QK_K 64
+#  else
+#    define QK_K 256
 #  endif
 #endif
 
@@ -8052,7 +8050,7 @@ struct no_init {
 struct quantize_state_internal {
     const llama_model                 & model;
     const llama_model_quantize_params * params;
-#ifdef GGML_USE_K_QUANTS
+
     int n_attention_wv    = 0;
     int n_feed_forward_w2 = 0;
     int i_attention_wv    = 0;
@@ -8060,7 +8058,7 @@ struct quantize_state_internal {
 
     int n_k_quantized     = 0;
     int n_fallback        = 0;
-#endif
+
     quantize_state_internal(const llama_model & model, const llama_model_quantize_params * params)
         : model(model)
         , params(params)
@@ -8125,7 +8123,6 @@ static void llama_convert_tensor_internal(
     workers.clear();
 }
 
-#ifdef GGML_USE_K_QUANTS
 static ggml_type get_k_quant_type(
     quantize_state_internal & qs,
     ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype
@@ -8237,7 +8234,6 @@ static ggml_type get_k_quant_type(
 
     return new_type;
 }
-#endif
 
 static void llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) {
     ggml_type quantized_type;
@@ -8252,7 +8248,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
         case LLAMA_FTYPE_MOSTLY_F16:  quantized_type = GGML_TYPE_F16;  break;
         case LLAMA_FTYPE_ALL_F32:     quantized_type = GGML_TYPE_F32;  break;
 
-#ifdef GGML_USE_K_QUANTS
         // K-quants
         case LLAMA_FTYPE_MOSTLY_Q2_K:   quantized_type = GGML_TYPE_Q2_K; break;
         case LLAMA_FTYPE_MOSTLY_Q3_K_S:
@@ -8263,7 +8258,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
         case LLAMA_FTYPE_MOSTLY_Q5_K_S:
         case LLAMA_FTYPE_MOSTLY_Q5_K_M: quantized_type = GGML_TYPE_Q5_K; break;
         case LLAMA_FTYPE_MOSTLY_Q6_K:   quantized_type = GGML_TYPE_Q6_K; break;
-#endif
+
         default: throw std::runtime_error(format("invalid output file type %d\n", ftype));
     }
 
@@ -8304,7 +8299,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
     gguf_set_val_u32(ctx_out, "general.quantization_version", GGML_QNT_VERSION);
     gguf_set_val_u32(ctx_out, "general.file_type", ftype);
 
-#ifdef GGML_USE_K_QUANTS
     for (int i = 0; i < ml.n_tensors; ++i) {
         struct ggml_tensor * meta = ml.get_tensor_meta(i);
 
@@ -8322,7 +8316,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
         LLAMA_LOG_WARN("%s ============ Strange model: n_attention_wv = %d, n_feed_forward_w2 = %d, hparams.n_layer = %d\n",
                 __func__, qs.n_attention_wv, qs.n_feed_forward_w2, model.hparams.n_layer);
     }
-#endif
 
     size_t total_size_org = 0;
     size_t total_size_new = 0;
@@ -8387,9 +8380,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
 
         if (quantize) {
             new_type = quantized_type;
-#ifdef GGML_USE_K_QUANTS
-            new_type = get_k_quant_type(qs, new_type, tensor, ftype);
-#endif
+            if (!params->pure) {
+                new_type = get_k_quant_type(qs, new_type, tensor, ftype);
+            }
+
             // If we've decided to quantize to the same type the tensor is already
             // in then there's nothing to do.
             quantize = tensor->type != new_type;
@@ -8514,12 +8508,11 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
             LLAMA_LOG_INFO("\n");
         }
     }
-#ifdef GGML_USE_K_QUANTS
+
     if (qs.n_fallback > 0) {
         LLAMA_LOG_WARN("%s: WARNING: %d of %d tensor(s) incompatible with k-quants and required fallback quantization\n",
                 __func__, qs.n_fallback, qs.n_k_quantized + qs.n_fallback);
     }
-#endif
 }
 
 static int llama_apply_lora_from_file_internal(
@@ -8844,6 +8837,7 @@ struct llama_model_quantize_params llama_model_quantize_default_params() {
         /*.allow_requantize            =*/ false,
         /*.quantize_output_tensor      =*/ true,
         /*.only_copy                   =*/ false,
+        /*.pure                        =*/ false,
     };
 
     return result;

+ 1 - 0
llama.h

@@ -191,6 +191,7 @@ extern "C" {
         bool allow_requantize;       // allow quantizing non-f32/f16 tensors
         bool quantize_output_tensor; // quantize output.weight
         bool only_copy;              // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
+        bool pure;                   // disable k-quant mixtures and quantize all tensors to the same type
     } llama_model_quantize_params;
 
     // grammar types

Unele fișiere nu au fost afișate deoarece prea multe fișiere au fost modificate în acest diff