|
|
@@ -77,6 +77,11 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t *
|
|
|
}
|
|
|
return 1/iscale;
|
|
|
}
|
|
|
+ bool return_early = false;
|
|
|
+ if (rmse_type < 0) {
|
|
|
+ rmse_type = -rmse_type;
|
|
|
+ return_early = true;
|
|
|
+ }
|
|
|
int weight_type = rmse_type%2;
|
|
|
float sumlx = 0;
|
|
|
float suml2 = 0;
|
|
|
@@ -89,56 +94,9 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t *
|
|
|
suml2 += w*l*l;
|
|
|
}
|
|
|
float scale = sumlx/suml2;
|
|
|
+ if (return_early) return suml2 > 0 ? 0.5f*(scale + 1/iscale) : 1/iscale;
|
|
|
float best = scale * sumlx;
|
|
|
- for (int itry = 0; itry < 3; ++itry) {
|
|
|
- iscale = 1/scale;
|
|
|
- float slx = 0;
|
|
|
- float sl2 = 0;
|
|
|
- bool changed = false;
|
|
|
- for (int i = 0; i < n; ++i) {
|
|
|
- int l = nearest_int(iscale * x[i]);
|
|
|
- l = MAX(-nmax, MIN(nmax-1, l));
|
|
|
- if (l + nmax != L[i]) { changed = true; }
|
|
|
- float w = weight_type == 1 ? x[i] * x[i] : 1.f;
|
|
|
- slx += w*x[i]*l;
|
|
|
- sl2 += w*l*l;
|
|
|
- }
|
|
|
- if (!changed || sl2 == 0 || slx*slx <= best*sl2) { break; }
|
|
|
- for (int i = 0; i < n; ++i) {
|
|
|
- int l = nearest_int(iscale * x[i]);
|
|
|
- L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
|
|
|
- }
|
|
|
- sumlx = slx; suml2 = sl2;
|
|
|
- scale = sumlx/suml2;
|
|
|
- best = scale * sumlx;
|
|
|
- }
|
|
|
- for (int itry = 0; itry < 5; ++itry) {
|
|
|
- int n_changed = 0;
|
|
|
- for (int i = 0; i < n; ++i) {
|
|
|
- float w = weight_type == 1 ? x[i]*x[i] : 1;
|
|
|
- int l = L[i] - nmax;
|
|
|
- float slx = sumlx - w*x[i]*l;
|
|
|
- if (slx > 0) {
|
|
|
- float sl2 = suml2 - w*l*l;
|
|
|
- int new_l = nearest_int(x[i] * sl2 / slx);
|
|
|
- new_l = MAX(-nmax, MIN(nmax-1, new_l));
|
|
|
- if (new_l != l) {
|
|
|
- slx += w*x[i]*new_l;
|
|
|
- sl2 += w*new_l*new_l;
|
|
|
- if (sl2 > 0 && slx*slx*suml2 > sumlx*sumlx*sl2) {
|
|
|
- L[i] = nmax + new_l; sumlx = slx; suml2 = sl2;
|
|
|
- scale = sumlx / suml2; best = scale * sumlx;
|
|
|
- ++n_changed;
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- if (!n_changed) { break; }
|
|
|
- }
|
|
|
- if (rmse_type < 3) {
|
|
|
- return scale;
|
|
|
- }
|
|
|
- for (int is = -4; is <= 4; ++is) {
|
|
|
+ for (int is = -9; is <= 9; ++is) {
|
|
|
if (is == 0) {
|
|
|
continue;
|
|
|
}
|
|
|
@@ -221,12 +179,17 @@ static float make_q3_quants(int n, int nmax, const float * restrict x, int8_t *
|
|
|
return 1/iscale;
|
|
|
}
|
|
|
|
|
|
-static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, float * restrict the_min, int ntry) {
|
|
|
+static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, float * restrict the_min,
|
|
|
+ int ntry, float alpha) {
|
|
|
float min = x[0];
|
|
|
float max = x[0];
|
|
|
+ float sum_x = 0;
|
|
|
+ float sum_x2 = 0;
|
|
|
for (int i = 1; i < n; ++i) {
|
|
|
if (x[i] < min) min = x[i];
|
|
|
if (x[i] > max) max = x[i];
|
|
|
+ sum_x += x[i];
|
|
|
+ sum_x2 += x[i]*x[i];
|
|
|
}
|
|
|
if (max == min) {
|
|
|
for (int i = 0; i < n; ++i) L[i] = 0;
|
|
|
@@ -254,7 +217,7 @@ static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t
|
|
|
for (int i = 0; i < n; ++i) {
|
|
|
sum += x[i] - scale*L[i];
|
|
|
}
|
|
|
- min = sum/n;
|
|
|
+ min = alpha*min + (1 - alpha)*sum/n;
|
|
|
if (min > 0) min = 0;
|
|
|
iscale = 1/scale;
|
|
|
if (!did_change) break;
|
|
|
@@ -263,6 +226,82 @@ static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t
|
|
|
return scale;
|
|
|
}
|
|
|
|
|
|
+static float make_qkx2_quants(int n, int nmax, const float * restrict x, const float * restrict weights,
|
|
|
+ uint8_t * restrict L, float * restrict the_min, uint8_t * restrict Laux,
|
|
|
+ float rmin, float rdelta, int nstep, bool use_mad) {
|
|
|
+ float min = x[0];
|
|
|
+ float max = x[0];
|
|
|
+ float sum_w = weights[0];
|
|
|
+ float sum_x = sum_w * x[0];
|
|
|
+ for (int i = 1; i < n; ++i) {
|
|
|
+ if (x[i] < min) min = x[i];
|
|
|
+ if (x[i] > max) max = x[i];
|
|
|
+ float w = weights[i];
|
|
|
+ sum_w += w;
|
|
|
+ sum_x += w * x[i];
|
|
|
+ }
|
|
|
+ if (min > 0) min = 0;
|
|
|
+ if (max == min) {
|
|
|
+ for (int i = 0; i < n; ++i) L[i] = 0;
|
|
|
+ *the_min = -min;
|
|
|
+ return 0.f;
|
|
|
+ }
|
|
|
+ float iscale = nmax/(max - min);
|
|
|
+ float scale = 1/iscale;
|
|
|
+ float best_mad = 0;
|
|
|
+ for (int i = 0; i < n; ++i) {
|
|
|
+ int l = nearest_int(iscale*(x[i] - min));
|
|
|
+ L[i] = MAX(0, MIN(nmax, l));
|
|
|
+ float diff = scale * L[i] + min - x[i];
|
|
|
+ diff = use_mad ? fabsf(diff) : diff * diff;
|
|
|
+ float w = weights[i];
|
|
|
+ best_mad += w * diff;
|
|
|
+ }
|
|
|
+ if (nstep < 1) {
|
|
|
+ *the_min = -min;
|
|
|
+ return scale;
|
|
|
+ }
|
|
|
+ for (int is = 0; is <= nstep; ++is) {
|
|
|
+ iscale = (rmin + rdelta*is + nmax)/(max - min);
|
|
|
+ float sum_l = 0, sum_l2 = 0, sum_xl = 0;
|
|
|
+ for (int i = 0; i < n; ++i) {
|
|
|
+ int l = nearest_int(iscale*(x[i] - min));
|
|
|
+ l = MAX(0, MIN(nmax, l));
|
|
|
+ Laux[i] = l;
|
|
|
+ float w = weights[i];
|
|
|
+ sum_l += w*l;
|
|
|
+ sum_l2 += w*l*l;
|
|
|
+ sum_xl += w*l*x[i];
|
|
|
+ }
|
|
|
+ float D = sum_w * sum_l2 - sum_l * sum_l;
|
|
|
+ if (D > 0) {
|
|
|
+ float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D;
|
|
|
+ float this_min = (sum_l2 * sum_x - sum_l * sum_xl)/D;
|
|
|
+ if (this_min > 0) {
|
|
|
+ this_min = 0;
|
|
|
+ this_scale = sum_xl / sum_l2;
|
|
|
+ }
|
|
|
+ float mad = 0;
|
|
|
+ for (int i = 0; i < n; ++i) {
|
|
|
+ float diff = this_scale * Laux[i] + this_min - x[i];
|
|
|
+ diff = use_mad ? fabsf(diff) : diff * diff;
|
|
|
+ float w = weights[i];
|
|
|
+ mad += w * diff;
|
|
|
+ }
|
|
|
+ if (mad < best_mad) {
|
|
|
+ for (int i = 0; i < n; ++i) {
|
|
|
+ L[i] = Laux[i];
|
|
|
+ }
|
|
|
+ best_mad = mad;
|
|
|
+ scale = this_scale;
|
|
|
+ min = this_min;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ *the_min = -min;
|
|
|
+ return scale;
|
|
|
+}
|
|
|
+
|
|
|
#if QK_K == 256
|
|
|
static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
|
|
|
if (j < 4) {
|
|
|
@@ -281,6 +320,8 @@ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict
|
|
|
const int nb = k / QK_K;
|
|
|
|
|
|
uint8_t L[QK_K];
|
|
|
+ uint8_t Laux[16];
|
|
|
+ float weights[16];
|
|
|
float mins[QK_K/16];
|
|
|
float scales[QK_K/16];
|
|
|
|
|
|
@@ -291,7 +332,8 @@ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict
|
|
|
float max_scale = 0; // as we are deducting the min, scales are always positive
|
|
|
float max_min = 0;
|
|
|
for (int j = 0; j < QK_K/16; ++j) {
|
|
|
- scales[j] = make_qkx1_quants(16, 3, x + 16*j, L + 16*j, &mins[j], 5);
|
|
|
+ for (int l = 0; l < 16; ++l) weights[l] = fabsf(x[16*j + l]);
|
|
|
+ scales[j] = make_qkx2_quants(16, 3, x + 16*j, weights, L + 16*j, &mins[j], Laux, -0.5f, 0.1f, 15, true);
|
|
|
float scale = scales[j];
|
|
|
if (scale > max_scale) {
|
|
|
max_scale = scale;
|
|
|
@@ -637,6 +679,8 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
|
|
|
const int nb = k / QK_K;
|
|
|
|
|
|
uint8_t L[QK_K];
|
|
|
+ uint8_t Laux[32];
|
|
|
+ float weights[32];
|
|
|
float mins[QK_K/32];
|
|
|
float scales[QK_K/32];
|
|
|
|
|
|
@@ -645,7 +689,12 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
|
|
|
float max_scale = 0; // as we are deducting the min, scales are always positive
|
|
|
float max_min = 0;
|
|
|
for (int j = 0; j < QK_K/32; ++j) {
|
|
|
- scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 5);
|
|
|
+ //scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
|
|
|
+ float sum_x2 = 0;
|
|
|
+ for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l];
|
|
|
+ float av_x = sqrtf(sum_x2/32);
|
|
|
+ for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
|
|
|
+ scales[j] = make_qkx2_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -1.f, 0.1f, 20, false);
|
|
|
float scale = scales[j];
|
|
|
if (scale > max_scale) {
|
|
|
max_scale = scale;
|
|
|
@@ -798,6 +847,8 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
|
|
|
uint8_t L[QK_K];
|
|
|
float mins[QK_K/32];
|
|
|
float scales[QK_K/32];
|
|
|
+ float weights[32];
|
|
|
+ uint8_t Laux[32];
|
|
|
#else
|
|
|
int8_t L[QK_K];
|
|
|
float scales[QK_K/16];
|
|
|
@@ -810,7 +861,12 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
|
|
|
float max_scale = 0; // as we are deducting the min, scales are always positive
|
|
|
float max_min = 0;
|
|
|
for (int j = 0; j < QK_K/32; ++j) {
|
|
|
- scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 5);
|
|
|
+ //scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
|
|
|
+ float sum_x2 = 0;
|
|
|
+ for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l];
|
|
|
+ float av_x = sqrtf(sum_x2/32);
|
|
|
+ for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
|
|
|
+ scales[j] = make_qkx2_quants(32, 31, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.5f, 0.1f, 15, false);
|
|
|
float scale = scales[j];
|
|
|
if (scale > max_scale) {
|
|
|
max_scale = scale;
|