Просмотр исходного кода

ggml: backward pass for split swiglu (#14483)

Johannes Gäßler 6 месяцев назад
Родитель
Сommit
c8c4495b8d
2 измененных файлов с 21 добавлено и 2 удалено
  1. 17 2
      ggml/src/ggml.c
  2. 4 0
      tests/test-backend-ops.cpp

+ 17 - 2
ggml/src/ggml.c

@@ -6050,13 +6050,28 @@ static void ggml_compute_backward(
             }
             GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
         } break;
+        case GGML_OP_GLU: {
+            switch (ggml_get_glu_op(tensor)) {
+                case GGML_GLU_OP_SWIGLU: {
+                    if (src0_needs_grads) {
+                        GGML_ASSERT(src1 && "backward pass only implemented for split swiglu");
+                        ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, ggml_mul(ctx, grad, src1), src0));
+                    }
+                    if (src1_needs_grads) {
+                        ggml_add_or_set(ctx, cgraph, isrc1, ggml_mul(ctx, ggml_silu(ctx, src0), grad));
+                    }
+                } break;
+                default: {
+                    GGML_ABORT("unsupported glu op for backward pass: %s", ggml_glu_op_name(ggml_get_glu_op(tensor)));
+                } //break;
+            }
+        } break;
         case GGML_OP_NONE: {
             // noop
         } break;
         case GGML_OP_COUNT:
         default: {
-            fprintf(stderr, "%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op));
-            GGML_ABORT("fatal error");
+            GGML_ABORT("%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op));
         } //break;
     }
 

+ 4 - 0
tests/test-backend-ops.cpp

@@ -1175,21 +1175,25 @@ struct test_glu_split : public test_case {
         if (v & 1) {
             auto ne = ne_a; ne[0] *= 3;
             a = ggml_new_tensor(ctx, type, 4, ne.data());
+            ggml_set_param(a);
             ggml_set_name(a, "a");
 
             a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
             ggml_set_name(a, "view_of_a");
 
             b = ggml_new_tensor(ctx, type, 4, ne.data());
+            ggml_set_param(b);
             ggml_set_name(b, "b");
 
             b = ggml_view_4d(ctx, b, ne_a[0], ne_a[1], ne_a[2], ne_a[3], b->nb[1], b->nb[2], b->nb[3], 0);
             ggml_set_name(a, "view_of_b");
         } else {
             a = ggml_new_tensor(ctx, type, 4, ne_a.data());
+            ggml_set_param(a);
             ggml_set_name(a, "a");
 
             b = ggml_new_tensor(ctx, type, 4, ne_a.data());
+            ggml_set_param(b);
             ggml_set_name(b, "b");
         }