model.go 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. package model
  2. import (
  3. "context"
  4. "makarna/pkg/tensor"
  5. )
  6. // Config represents model configuration
  7. type Config struct {
  8. Architecture string `json:"architecture"` // e.g., "qwen3", "llama"
  9. VocabSize int `json:"vocab_size"`
  10. HiddenSize int `json:"hidden_size"`
  11. NumLayers int `json:"num_layers"`
  12. NumHeads int `json:"num_heads"`
  13. NumKVHeads int `json:"num_kv_heads"`
  14. HeadDim int `json:"head_dim"`
  15. Intermediate int `json:"intermediate_size"`
  16. RopeTheta float64 `json:"rope_theta"`
  17. RMSNormEps float64 `json:"rms_norm_eps"`
  18. Params map[string]any `json:"params,omitempty"`
  19. }
  20. // Model defines the interface that all model architectures must implement
  21. type Model interface {
  22. // Forward performs a forward pass
  23. // input: [batch_size, seq_len] tokens
  24. // returns: logits [batch_size, seq_len, vocab_size]
  25. Forward(ctx context.Context, input tensor.Tensor, positions tensor.Tensor, kvCache KVCache) (tensor.Tensor, error)
  26. // Config returns the model configuration
  27. Config() *Config
  28. // Close frees resources (e.g. CUDA memory)
  29. Close() error
  30. // SetTensor assigns a named tensor to the model parameters
  31. SetTensor(name string, t tensor.Tensor) error
  32. }
  33. // KVCache provides the minimal contract models rely on. Concrete cache
  34. // implementations expose richer APIs via type assertions inside model code.
  35. type KVCache interface {
  36. SeqLen() int
  37. Commit(newTokens int)
  38. }
  39. // BatchForwarder is an optional interface implemented by models that support
  40. // fused multi-sequence forward passes.
  41. type BatchForwarder interface {
  42. ForwardBatch(ctx context.Context, input tensor.Tensor, positions tensor.Tensor, kvCaches []KVCache) (tensor.Tensor, error)
  43. }
  44. // CacheType indicates what kind of KV cache a model requires.
  45. type CacheType int
  46. const (
  47. // CacheTypePaged indicates standard paged KV cache (for transformer attention).
  48. CacheTypePaged CacheType = iota
  49. // CacheTypeRecurrent indicates recurrent state cache (for Mamba, KDA, linear attention).
  50. CacheTypeRecurrent
  51. )
  52. // CacheFactory is an optional interface implemented by models that need
  53. // custom cache implementations (e.g., recurrent state models like Mamba/KDA).
  54. // Models that don't implement this use the default PagedKVCache.
  55. type CacheFactory interface {
  56. // CacheType returns the type of cache this model requires.
  57. CacheType() CacheType
  58. // CreateCache creates a new cache instance for a single request.
  59. // The returned cache must implement KVCache.
  60. CreateCache() (KVCache, error)
  61. }
  62. // Registry manages available model architectures
  63. var registry = make(map[string]func(*Config) (Model, error))
  64. func Register(name string, factory func(*Config) (Model, error)) {
  65. registry[name] = factory
  66. }
  67. func New(name string, cfg *Config) (Model, error) {
  68. factory, ok := registry[name]
  69. if !ok {
  70. return nil, &UnknownArchitectureError{Name: name}
  71. }
  72. return factory(cfg)
  73. }
  74. type UnknownArchitectureError struct {
  75. Name string
  76. }
  77. func (e *UnknownArchitectureError) Error() string {
  78. return "unknown architecture: " + e.Name
  79. }