|
|
@@ -293,7 +293,8 @@ static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restri
|
|
|
#if QK_K == 256
|
|
|
static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
|
|
|
if (j < 4) {
|
|
|
- d = q[j] & 63; m = q[j + 4] & 63;
|
|
|
+ d = q[j] & 63;
|
|
|
+ m = q[j + 4] & 63;
|
|
|
} else {
|
|
|
d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
|
|
|
m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
|
|
|
@@ -303,7 +304,7 @@ static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8
|
|
|
|
|
|
template<typename dst_t>
|
|
|
static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
|
|
|
- const sycl::nd_item<3> &item_ct1) {
|
|
|
+ uint8_t* scales_local, const sycl::nd_item<3> &item_ct1) {
|
|
|
const block_q4_K * x = (const block_q4_K *) vx;
|
|
|
|
|
|
const int i = item_ct1.get_group(2);
|
|
|
@@ -318,19 +319,26 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
|
|
|
|
|
|
dst_t * y = yy + i*QK_K + 64*il + n*ir;
|
|
|
|
|
|
- const float dall = x[i].dm[0];
|
|
|
- const float dmin = x[i].dm[1];
|
|
|
+ const sycl::half2 dm = x[i].dm;
|
|
|
+ const float dall = dm[0];
|
|
|
+ const float dmin = dm[1];
|
|
|
|
|
|
- const uint8_t * q = x[i].qs + 32*il + n*ir;
|
|
|
+ if (tid < 12)
|
|
|
+ scales_local[tid] = x[i].scales[tid];
|
|
|
+ item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
|
|
|
|
uint8_t sc, m;
|
|
|
- get_scale_min_k4(is + 0, x[i].scales, sc, m);
|
|
|
- const float d1 = dall * sc; const float m1 = dmin * m;
|
|
|
- get_scale_min_k4(is + 1, x[i].scales, sc, m);
|
|
|
- const float d2 = dall * sc; const float m2 = dmin * m;
|
|
|
+ get_scale_min_k4(is + 0, scales_local, sc, m);
|
|
|
+ const float d1 = dall * sc;
|
|
|
+ const float m1 = dmin * m;
|
|
|
+ get_scale_min_k4(is + 1, scales_local, sc, m);
|
|
|
+ const float d2 = dall * sc;
|
|
|
+ const float m2 = dmin * m;
|
|
|
+
|
|
|
+ sycl::vec<uint8_t, n> q_vec = vec_aligned_load<uint8_t, n>(x[i].qs + 32*il + n*ir);
|
|
|
for (int l = 0; l < n; ++l) {
|
|
|
- y[l + 0] = d1 * (q[l] & 0xF) - m1;
|
|
|
- y[l +32] = d2 * (q[l] >> 4) - m2;
|
|
|
+ y[l + 0] = d1 * (q_vec[l] & 0xF) - m1;
|
|
|
+ y[l +32] = d2 * (q_vec[l] >> 4) - m2;
|
|
|
}
|
|
|
#else
|
|
|
const int tid = item_ct1.get_local_id(2);
|