|
|
@@ -9611,10 +9611,12 @@ static void ggml_compute_forward_out_prod_f32(
|
|
|
const int ith = params->ith;
|
|
|
const int nth = params->nth;
|
|
|
|
|
|
+ GGML_ASSERT(ne0 == ne00);
|
|
|
+ GGML_ASSERT(ne1 == ne10);
|
|
|
+ GGML_ASSERT(ne2 == ne02);
|
|
|
GGML_ASSERT(ne02 == ne12);
|
|
|
- GGML_ASSERT(ne03 == ne13);
|
|
|
- GGML_ASSERT(ne2 == ne12);
|
|
|
GGML_ASSERT(ne3 == ne13);
|
|
|
+ GGML_ASSERT(ne03 == ne13);
|
|
|
|
|
|
// we don't support permuted src0 or src1
|
|
|
GGML_ASSERT(nb00 == sizeof(float));
|
|
|
@@ -9625,18 +9627,25 @@ static void ggml_compute_forward_out_prod_f32(
|
|
|
// GGML_ASSERT(nb1 <= nb2);
|
|
|
// GGML_ASSERT(nb2 <= nb3);
|
|
|
|
|
|
- GGML_ASSERT(ne0 == ne00);
|
|
|
- GGML_ASSERT(ne1 == ne10);
|
|
|
- GGML_ASSERT(ne2 == ne02);
|
|
|
- GGML_ASSERT(ne3 == ne03);
|
|
|
-
|
|
|
// nb01 >= nb00 - src0 is not transposed
|
|
|
// compute by src0 rows
|
|
|
|
|
|
// TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod
|
|
|
- // TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
|
|
|
+ // TODO: #if defined(GGML_USE_CLBLAST)
|
|
|
+
|
|
|
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
|
|
+ bool use_blas = ggml_is_matrix(src0) &&
|
|
|
+ ggml_is_matrix(src1) &&
|
|
|
+ ggml_is_contiguous(src0) &&
|
|
|
+ (ggml_is_contiguous(src1) || ggml_is_transposed(src1));
|
|
|
+#endif
|
|
|
|
|
|
if (params->type == GGML_TASK_INIT) {
|
|
|
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) // gemm beta will zero dst
|
|
|
+ if (use_blas) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+#endif
|
|
|
ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
|
|
|
return;
|
|
|
}
|
|
|
@@ -9645,6 +9654,50 @@ static void ggml_compute_forward_out_prod_f32(
|
|
|
return;
|
|
|
}
|
|
|
|
|
|
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
|
|
+ if (use_blas) {
|
|
|
+ if (params->ith != 0) { // All threads other than the first do no work.
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ // Arguments to ggml_compute_forward_out_prod (expressed as major,minor)
|
|
|
+ // src0: (k,n)
|
|
|
+ // src1: (k,m)
|
|
|
+ // dst: (m,n)
|
|
|
+ //
|
|
|
+ // Arguments to sgemm (see https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/sgemm.f)
|
|
|
+ // Also expressed as (major,minor)
|
|
|
+ // a: (m,k): so src1 transposed
|
|
|
+ // b: (k,n): so src0
|
|
|
+ // c: (m,n)
|
|
|
+ //
|
|
|
+ // However, if ggml_is_transposed(src1) is true, then
|
|
|
+ // src1->data already contains a transposed version, so sgemm mustn't
|
|
|
+ // transpose it further.
|
|
|
+
|
|
|
+ int n = src0->ne[0];
|
|
|
+ int k = src0->ne[1];
|
|
|
+ int m = src1->ne[0];
|
|
|
+
|
|
|
+ int transposeA, lda;
|
|
|
+
|
|
|
+ if (!ggml_is_transposed(src1)) {
|
|
|
+ transposeA = CblasTrans;
|
|
|
+ lda = m;
|
|
|
+ } else {
|
|
|
+ transposeA = CblasNoTrans;
|
|
|
+ lda = k;
|
|
|
+ }
|
|
|
+
|
|
|
+ float * a = (float *) ((char *) src1->data);
|
|
|
+ float * b = (float *) ((char *) src0->data);
|
|
|
+ float * c = (float *) ((char *) dst->data);
|
|
|
+
|
|
|
+ cblas_sgemm(CblasRowMajor, transposeA, CblasNoTrans, m, n, k, 1.0, a, lda, b, n, 0.0, c, n);
|
|
|
+
|
|
|
+ return;
|
|
|
+ }
|
|
|
+#endif
|
|
|
+
|
|
|
// dst[:,:,:,:] = 0
|
|
|
// for i2,i3:
|
|
|
// for i1:
|