ggml-delta.h 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. #pragma once
  2. #include "ggml-backend.h"
  3. #include "ggml.h"
  4. #ifdef __cplusplus
  5. extern "C" {
  6. #endif
  7. // Delta-Net linear layer activation
  8. // Implements the complete Delta-Net gated linear attention mechanism
  9. // This includes causal convolution preprocessing and gated delta rule computation
  10. // k, v, q, g: [S, H, n_tokens, n_seqs] - key, value, query, gate tensors
  11. // conv_weight: [conv_dim, 1, conv_kernel_size] - convolution kernel weights
  12. // conv_bias: [conv_dim] - convolution bias (optional, can be NULL)
  13. // beta: [H, n_tokens, n_seqs] - beta parameter for delta rule
  14. // state: [S, S, H, n_seqs] - recurrent state tensor
  15. // chunk_size: chunk size for chunked computation (0 for recurrent mode)
  16. // use_qk_l2norm: whether to apply L2 normalization to query and key
  17. // scale: attention scaling factor
  18. GGML_API struct ggml_tensor * ggml_delta_net(struct ggml_context * ctx,
  19. struct ggml_tensor * k,
  20. struct ggml_tensor * v,
  21. struct ggml_tensor * q,
  22. struct ggml_tensor * g,
  23. struct ggml_tensor * beta,
  24. struct ggml_tensor * state,
  25. bool use_qk_l2norm,
  26. float scale);
  27. GGML_API struct ggml_tensor * ggml_delta_net_op(struct ggml_context * ctx,
  28. struct ggml_tensor * q,
  29. struct ggml_tensor * k,
  30. struct ggml_tensor * v,
  31. struct ggml_tensor * g,
  32. struct ggml_tensor * beta,
  33. struct ggml_tensor * state,
  34. bool use_qk_l2norm,
  35. float scale);
  36. #ifdef __cplusplus
  37. }
  38. #endif