|
@@ -10728,7 +10728,7 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
GGML_ASSERT(ggml_is_contiguous(src8));
|
|
GGML_ASSERT(ggml_is_contiguous(src8));
|
|
|
|
|
|
|
|
int64_t total_params = n_seqs * H_v * num_chunks;
|
|
int64_t total_params = n_seqs * H_v * num_chunks;
|
|
|
- int64_t per_thread = total_params / nth;
|
|
|
|
|
|
|
+ int64_t per_thread = (total_params % nth == 0) ? total_params / nth : (total_params / nth) + 1;
|
|
|
|
|
|
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
for (int64_t head = 0; head < H_v; head++) {
|
|
for (int64_t head = 0; head < H_v; head++) {
|