|
|
@@ -1374,7 +1374,10 @@ struct ggml_compute_state {
|
|
|
|
|
|
inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
|
|
|
inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
|
|
|
-inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
|
|
|
+
|
|
|
+inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
|
|
|
+inline static void ggml_vec_cpy_i32(const int n, int32_t * y, const int32_t * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; }
|
|
|
+
|
|
|
inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
|
|
|
inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
|
|
|
inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
|
|
|
@@ -8248,6 +8251,77 @@ static void ggml_compute_forward_set_f32(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+static void ggml_compute_forward_set_i32(
|
|
|
+ const struct ggml_compute_params * params,
|
|
|
+ struct ggml_tensor * dst) {
|
|
|
+
|
|
|
+ const struct ggml_tensor * src0 = dst->src[0];
|
|
|
+ const struct ggml_tensor * src1 = dst->src[1];
|
|
|
+
|
|
|
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
|
|
|
+
|
|
|
+ // view src0 and dst with these strides and data offset inbytes during set
|
|
|
+ // nb0 is implicitly element_size because src0 and dst are contiguous
|
|
|
+ size_t nb1 = ((int32_t *) dst->op_params)[0];
|
|
|
+ size_t nb2 = ((int32_t *) dst->op_params)[1];
|
|
|
+ size_t nb3 = ((int32_t *) dst->op_params)[2];
|
|
|
+ size_t offset = ((int32_t *) dst->op_params)[3];
|
|
|
+ bool inplace = (bool) ((int32_t *) dst->op_params)[4];
|
|
|
+
|
|
|
+ if (!inplace) {
|
|
|
+ if (params->ith == 0) {
|
|
|
+ // memcpy needs to be synchronized across threads to avoid race conditions.
|
|
|
+ // => do it in INIT phase
|
|
|
+ memcpy(
|
|
|
+ ((char *) dst->data),
|
|
|
+ ((char *) src0->data),
|
|
|
+ ggml_nbytes(dst));
|
|
|
+ }
|
|
|
+ ggml_barrier(params->threadpool);
|
|
|
+ }
|
|
|
+
|
|
|
+ const int ith = params->ith;
|
|
|
+ const int nth = params->nth;
|
|
|
+
|
|
|
+ const int nr = ggml_nrows(src1);
|
|
|
+ const int nc = src1->ne[0];
|
|
|
+
|
|
|
+ GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
|
|
|
+ GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
|
|
|
+
|
|
|
+ // src0 and dst as viewed during set
|
|
|
+ const size_t nb0 = ggml_element_size(src0);
|
|
|
+
|
|
|
+ const int im0 = (ne10 == 0 ? 0 : ne10-1);
|
|
|
+ const int im1 = (ne11 == 0 ? 0 : ne11-1);
|
|
|
+ const int im2 = (ne12 == 0 ? 0 : ne12-1);
|
|
|
+ const int im3 = (ne13 == 0 ? 0 : ne13-1);
|
|
|
+
|
|
|
+ GGML_ASSERT(offset + im0*nb0 + im1*nb1 + im2*nb2 + im3*nb3 <= ggml_nbytes(dst));
|
|
|
+
|
|
|
+ GGML_ASSERT(nb10 == sizeof(int32_t));
|
|
|
+
|
|
|
+ // rows per thread
|
|
|
+ const int dr = (nr + nth - 1)/nth;
|
|
|
+
|
|
|
+ // row range for this thread
|
|
|
+ const int ir0 = dr*ith;
|
|
|
+ const int ir1 = MIN(ir0 + dr, nr);
|
|
|
+
|
|
|
+ for (int ir = ir0; ir < ir1; ++ir) {
|
|
|
+ // src0 and dst are viewed with shape of src1 and offset
|
|
|
+ // => same indices
|
|
|
+ const int i3 = ir/(ne12*ne11);
|
|
|
+ const int i2 = (ir - i3*ne12*ne11)/ne11;
|
|
|
+ const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
|
|
|
+
|
|
|
+ ggml_vec_cpy_i32(nc,
|
|
|
+ (int32_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset),
|
|
|
+ (int32_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
static void ggml_compute_forward_set(
|
|
|
const struct ggml_compute_params * params,
|
|
|
struct ggml_tensor * dst) {
|
|
|
@@ -8259,6 +8333,10 @@ static void ggml_compute_forward_set(
|
|
|
{
|
|
|
ggml_compute_forward_set_f32(params, dst);
|
|
|
} break;
|
|
|
+ case GGML_TYPE_I32:
|
|
|
+ {
|
|
|
+ ggml_compute_forward_set_i32(params, dst);
|
|
|
+ } break;
|
|
|
case GGML_TYPE_F16:
|
|
|
case GGML_TYPE_BF16:
|
|
|
case GGML_TYPE_Q4_0:
|