|
|
@@ -10538,7 +10538,7 @@ static void delta_apply_triangular_updates_chunk_f32(float * attn,
|
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
for (int i = 0; i < num_chunks; i++) {
|
|
|
for (int64_t head = 0; head < H_v; head++) {
|
|
|
- float * attn_ptr = attn + seq * (chunk_size * chunk_size * H_v) + (head * num_chunks + i) * (chunk_size * chunk_size);
|
|
|
+ float * attn_ptr = attn + seq * (chunk_size * chunk_size * num_chunks * H_v) + (head * num_chunks + i) * (chunk_size * chunk_size);
|
|
|
|
|
|
// Apply triangular updates following the Python reference exactly:
|
|
|
// for i in range(1, chunk_size):
|
|
|
@@ -10591,7 +10591,7 @@ static void delta_add_identity_matrix_chunk_f32(float * matrix,
|
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
for (int i = 0; i < num_chunks; i++) {
|
|
|
for (int64_t head = 0; head < H_v; head++) {
|
|
|
- float * matrix_ptr = matrix + seq * (chunk_size * chunk_size * H_v) +
|
|
|
+ float * matrix_ptr = matrix + seq * (chunk_size * chunk_size * num_chunks * H_v) +
|
|
|
(head * num_chunks + i) * (chunk_size * chunk_size);
|
|
|
// Add identity matrix directly
|
|
|
for (int64_t i = 0; i < chunk_size; i++) {
|