Răsfoiți Sursa

finetune : add -ngl parameter (#3762)

* Add '-ngl' support to finetune.cpp

* Add fprintf in ggml_cuda_op_add

When I tried CUDA offloading during finetuning following the readme, I got an assert here.
This probably isn't an important case because inference later gives a warning saying you should use f16 or f32 instead when using lora

* Add 'finetune.sh', which currently fails when using GPU

"error: operator (): Finetuning on tensors with type 'f16' is not yet supported"

* tweak finetune.sh

* Suppress some warnings in ggml.c

* Add f16 implementation to ggml_compute_forward_add_f16_f32

* Add an f16 case to ggml_add_cast_impl and llama_build_lora_finetune_graphs

* finetune.sh: Edit comments

* Add "add_f16_f32_f32_cuda"

* Tweak an error message

* finetune.sh: Add an optional LLAMA_MODEL_DIR variable

* finetune.sh: Add an optional LLAMA_TRAINING_DIR variable

* train : minor

* tabs to spaces

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: cebtenzzre <cebtenzzre@gmail.com>
Andrew Godfrey 2 ani în urmă
părinte
comite
73bdcb395e
8 a modificat fișierele cu 108 adăugiri și 17 ștergeri
  1. 2 0
      common/train.cpp
  2. 1 0
      common/train.h
  3. 13 1
      examples/finetune/finetune.cpp
  4. 34 0
      examples/finetune/finetune.sh
  5. 17 0
      ggml-cuda.cu
  6. 2 0
      ggml-quants.c
  7. 38 15
      ggml.c
  8. 1 1
      llama.cpp

+ 2 - 0
common/train.cpp

@@ -1045,6 +1045,7 @@ struct train_params_common get_default_train_params_common() {
     params.n_batch    =    8;
     params.n_gradient_accumulation = 1;
     params.n_epochs   = -1;
+    params.n_gpu_layers = 0;
 
     params.custom_n_ctx = false;
 
@@ -1080,6 +1081,7 @@ struct train_params_common get_default_train_params_common() {
     params.adam_beta2          = 0.999f;
     params.adam_gclip          = 1.0f;
     params.adam_eps_f          = 0.0f;
+
     return params;
 }
 

+ 1 - 0
common/train.h

@@ -44,6 +44,7 @@ struct train_params_common {
     int n_batch;
     int n_gradient_accumulation;
     int n_epochs;
+    int n_gpu_layers;
 
     bool custom_n_ctx;
 

+ 13 - 1
examples/finetune/finetune.cpp

@@ -652,7 +652,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
     GGML_ASSERT(tokens_input->type == GGML_TYPE_I32);
 
     auto add_to_f32 = [] (struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) {
-        if (ggml_is_quantized(a->type)) {
+        if (ggml_is_quantized(a->type) || a->type == GGML_TYPE_F16) {
             return ggml_add_cast(ctx, a, b, GGML_TYPE_F32);
         } else if (a->type == GGML_TYPE_F32) {
             return ggml_add(ctx, a, b);
@@ -1459,6 +1459,17 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
             }
             params->n_rank_w3 = std::stoi(argv[i]);
             params->custom_n_rank_w3 = true;
+        } else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
+            params->common.n_gpu_layers = std::stoi(argv[i]);
+#else
+            fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
+            fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
+#endif
         } else {
             fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
             train_print_usage(argc, argv, &default_params);
@@ -1545,6 +1556,7 @@ int main(int argc, char ** argv) {
     srand(params.common.seed);
 
     struct llama_model_params llama_mparams = llama_model_default_params();
+    llama_mparams.n_gpu_layers = params.common.n_gpu_layers;
     llama_mparams.vocab_only = false;
 
     printf("%s: model base = '%s'\n", __func__, params.fn_model_base);

+ 34 - 0
examples/finetune/finetune.sh

@@ -0,0 +1,34 @@
+#!/bin/bash
+cd `dirname $0`
+cd ../..
+
+EXE="./finetune"
+
+if [[ ! $LLAMA_MODEL_DIR ]]; then LLAMA_MODEL_DIR="./models"; fi
+if [[ ! $LLAMA_TRAINING_DIR ]]; then LLAMA_TRAINING_DIR="."; fi
+
+# MODEL="$LLAMA_MODEL_DIR/openllama-3b-v2-q8_0.gguf" # This is the model the readme uses.
+MODEL="$LLAMA_MODEL_DIR/openllama-3b-v2.gguf" # An f16 model. Note in this case with "-g", you get an f32-format .BIN file that isn't yet supported if you use it with "main --lora" with GPU inferencing.
+
+while getopts "dg" opt; do
+  case $opt in
+    d)
+      DEBUGGER="gdb --args"
+      ;;
+    g)
+      EXE="./build/bin/Release/finetune"
+      GPUARG="--gpu-layers 25"
+      ;;
+  esac
+done
+
+$DEBUGGER $EXE \
+        --model-base $MODEL \
+        $GPUARG \
+        --checkpoint-in  chk-ol3b-shakespeare-LATEST.gguf \
+        --checkpoint-out chk-ol3b-shakespeare-ITERATION.gguf \
+        --lora-out lora-ol3b-shakespeare-ITERATION.bin \
+        --train-data "$LLAMA_TRAINING_DIR\shakespeare.txt" \
+        --save-every 10 \
+        --threads 10 --adam-iter 30 --batch 4 --ctx 64 \
+        --use-checkpointing

+ 17 - 0
ggml-cuda.cu

@@ -513,6 +513,15 @@ static __global__ void add_f16_f32_f16(const half * x, const float * y, half * d
     dst[i] = __hadd(x[i], __float2half(y[i]));
 }
 
+static __global__ void add_f16_f32_f32(const half * x, const float * y, float * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = __half2float(x[i]) + y[i];
+}
+
 static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
@@ -4693,6 +4702,11 @@ static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, co
     add_f16_f32_f16<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
 }
 
+static void add_f16_f32_f32_cuda(const half * x, const float * y, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
+    add_f16_f32_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
+}
+
 static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
     const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE;
     mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
@@ -5996,7 +6010,10 @@ inline void ggml_cuda_op_add(
         add_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(src0), ne10*ne11, main_stream);
     } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
         add_f16_f32_f16_cuda((const half *) src0_dd, src1_dd, (half *) dst_dd, ggml_nelements(src0), main_stream);
+    } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
+        add_f16_f32_f32_cuda((const half *) src0_dd, src1_dd, dst_dd, ggml_nelements(src0), main_stream);
     } else {
+        fprintf(stderr, "src0->type: %d  dst->type: %d\n", src0->type, dst->type);
         GGML_ASSERT(false);
     }
 

+ 2 - 0
ggml-quants.c

@@ -716,6 +716,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
         __riscv_vse8_v_i8m1(y[i].qs , vs, vl);
     }
 #else
+    UNUSED(nb);
     // scalar
     quantize_row_q8_0_reference(x, y, k);
 #endif
@@ -969,6 +970,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) {
         y[i].s = sum*d;
     }
 #else
+    UNUSED(nb);
     // scalar
     quantize_row_q8_1_reference(x, y, k);
 #endif

+ 38 - 15
ggml.c

@@ -3153,7 +3153,7 @@ static struct ggml_tensor * ggml_add_cast_impl(
     // TODO: support less-strict constraint
     //       GGML_ASSERT(ggml_can_repeat(b, a));
     GGML_ASSERT(ggml_can_repeat_rows(b, a));
-    GGML_ASSERT(ggml_is_quantized(a->type)); // currently only supported for quantized input
+    GGML_ASSERT(ggml_is_quantized(a->type) || a->type == GGML_TYPE_F16); // currently only supported for quantized input and f16
 
     bool is_node = false;
 
@@ -6927,9 +6927,15 @@ static void ggml_compute_forward_add_f16_f32(
 
     GGML_ASSERT(src0->type == GGML_TYPE_F16);
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->type  == GGML_TYPE_F16);
 
-    GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
+    if (dst->type == GGML_TYPE_F32) {
+        GGML_ASSERT( nb0 == sizeof(float));
+    }
+    else {
+        GGML_ASSERT(dst->type  == GGML_TYPE_F16);
+        GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
+    }
+
     GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
 
     // rows per thread
@@ -6940,18 +6946,35 @@ static void ggml_compute_forward_add_f16_f32(
     const int ir1 = MIN(ir0 + dr, nr);
 
     if (nb10 == sizeof(float)) {
-        for (int ir = ir0; ir < ir1; ++ir) {
-            // src0, src1 and dst are same shape => same indices
-            const int i3 = ir/(ne2*ne1);
-            const int i2 = (ir - i3*ne2*ne1)/ne1;
-            const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
-            ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1);
-            ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
-            float *       src1_ptr = (float *)       ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
-
-            for (int i = 0; i < ne0; i++) {
-                dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]);
+        if (dst->type == GGML_TYPE_F16) {
+            for (int ir = ir0; ir < ir1; ++ir) {
+                // src0, src1 and dst are same shape => same indices
+                const int i3 = ir/(ne2*ne1);
+                const int i2 = (ir - i3*ne2*ne1)/ne1;
+                const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+                ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1);
+                ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+                float *       src1_ptr = (float *)       ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
+
+                for (int i = 0; i < ne0; i++) {
+                    dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]);
+                }
+            }
+        } else {
+            for (int ir = ir0; ir < ir1; ++ir) {
+                // src0, src1 and dst are same shape => same indices
+                const int i3 = ir/(ne2*ne1);
+                const int i2 = (ir - i3*ne2*ne1)/ne1;
+                const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+                float *       dst_ptr  = (float *)       ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1);
+                ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+                float *       src1_ptr = (float *)       ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
+
+                for (int i = 0; i < ne0; i++) {
+                    dst_ptr[i] = GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i];
+                }
             }
         }
     }

+ 1 - 1
llama.cpp

@@ -8003,7 +8003,7 @@ static int llama_apply_lora_from_file_internal(
             if (dest_t->backend == GGML_BACKEND_GPU || dest_t->backend == GGML_BACKEND_GPU_SPLIT) {
                 if (dest_t->type != GGML_TYPE_F16) {
                     throw std::runtime_error(format(
-                        "%s: error: the simultaneous use of LoRAs and GPU acceleration is only supported for f16 models", __func__));
+                        "%s: error: the simultaneous use of LoRAs and GPU acceleration is only supported for f16 models. dest_t->type: %d", __func__, dest_t->type));
                 }
                 offload_func = ggml_cuda_assign_buffers;
                 offload_func_force_inplace = ggml_cuda_assign_buffers_force_inplace;