| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495 |
- package model
- import (
- "context"
- "makarna/pkg/tensor"
- )
- // Config represents model configuration
- type Config struct {
- Architecture string `json:"architecture"` // e.g., "qwen3", "llama"
- VocabSize int `json:"vocab_size"`
- HiddenSize int `json:"hidden_size"`
- NumLayers int `json:"num_layers"`
- NumHeads int `json:"num_heads"`
- NumKVHeads int `json:"num_kv_heads"`
- HeadDim int `json:"head_dim"`
- Intermediate int `json:"intermediate_size"`
- RopeTheta float64 `json:"rope_theta"`
- RMSNormEps float64 `json:"rms_norm_eps"`
- Params map[string]any `json:"params,omitempty"`
- }
- // Model defines the interface that all model architectures must implement
- type Model interface {
- // Forward performs a forward pass
- // input: [batch_size, seq_len] tokens
- // returns: logits [batch_size, seq_len, vocab_size]
- Forward(ctx context.Context, input tensor.Tensor, positions tensor.Tensor, kvCache KVCache) (tensor.Tensor, error)
-
- // Config returns the model configuration
- Config() *Config
-
- // Close frees resources (e.g. CUDA memory)
- Close() error
- // SetTensor assigns a named tensor to the model parameters
- SetTensor(name string, t tensor.Tensor) error
- }
- // KVCache provides the minimal contract models rely on. Concrete cache
- // implementations expose richer APIs via type assertions inside model code.
- type KVCache interface {
- SeqLen() int
- Commit(newTokens int)
- }
- // BatchForwarder is an optional interface implemented by models that support
- // fused multi-sequence forward passes.
- type BatchForwarder interface {
- ForwardBatch(ctx context.Context, input tensor.Tensor, positions tensor.Tensor, kvCaches []KVCache) (tensor.Tensor, error)
- }
- // CacheType indicates what kind of KV cache a model requires.
- type CacheType int
- const (
- // CacheTypePaged indicates standard paged KV cache (for transformer attention).
- CacheTypePaged CacheType = iota
- // CacheTypeRecurrent indicates recurrent state cache (for Mamba, KDA, linear attention).
- CacheTypeRecurrent
- )
- // CacheFactory is an optional interface implemented by models that need
- // custom cache implementations (e.g., recurrent state models like Mamba/KDA).
- // Models that don't implement this use the default PagedKVCache.
- type CacheFactory interface {
- // CacheType returns the type of cache this model requires.
- CacheType() CacheType
- // CreateCache creates a new cache instance for a single request.
- // The returned cache must implement KVCache.
- CreateCache() (KVCache, error)
- }
- // Registry manages available model architectures
- var registry = make(map[string]func(*Config) (Model, error))
- func Register(name string, factory func(*Config) (Model, error)) {
- registry[name] = factory
- }
- func New(name string, cfg *Config) (Model, error) {
- factory, ok := registry[name]
- if !ok {
- return nil, &UnknownArchitectureError{Name: name}
- }
- return factory(cfg)
- }
- type UnknownArchitectureError struct {
- Name string
- }
- func (e *UnknownArchitectureError) Error() string {
- return "unknown architecture: " + e.Name
- }
|