package model import ( "context" "makarna/pkg/tensor" ) // Config represents model configuration type Config struct { Architecture string `json:"architecture"` // e.g., "qwen3", "llama" VocabSize int `json:"vocab_size"` HiddenSize int `json:"hidden_size"` NumLayers int `json:"num_layers"` NumHeads int `json:"num_heads"` NumKVHeads int `json:"num_kv_heads"` HeadDim int `json:"head_dim"` Intermediate int `json:"intermediate_size"` RopeTheta float64 `json:"rope_theta"` RMSNormEps float64 `json:"rms_norm_eps"` Params map[string]any `json:"params,omitempty"` } // Model defines the interface that all model architectures must implement type Model interface { // Forward performs a forward pass // input: [batch_size, seq_len] tokens // returns: logits [batch_size, seq_len, vocab_size] Forward(ctx context.Context, input tensor.Tensor, positions tensor.Tensor, kvCache KVCache) (tensor.Tensor, error) // Config returns the model configuration Config() *Config // Close frees resources (e.g. CUDA memory) Close() error // SetTensor assigns a named tensor to the model parameters SetTensor(name string, t tensor.Tensor) error } // KVCache provides the minimal contract models rely on. Concrete cache // implementations expose richer APIs via type assertions inside model code. type KVCache interface { SeqLen() int Commit(newTokens int) } // BatchForwarder is an optional interface implemented by models that support // fused multi-sequence forward passes. type BatchForwarder interface { ForwardBatch(ctx context.Context, input tensor.Tensor, positions tensor.Tensor, kvCaches []KVCache) (tensor.Tensor, error) } // CacheType indicates what kind of KV cache a model requires. type CacheType int const ( // CacheTypePaged indicates standard paged KV cache (for transformer attention). CacheTypePaged CacheType = iota // CacheTypeRecurrent indicates recurrent state cache (for Mamba, KDA, linear attention). CacheTypeRecurrent ) // CacheFactory is an optional interface implemented by models that need // custom cache implementations (e.g., recurrent state models like Mamba/KDA). // Models that don't implement this use the default PagedKVCache. type CacheFactory interface { // CacheType returns the type of cache this model requires. CacheType() CacheType // CreateCache creates a new cache instance for a single request. // The returned cache must implement KVCache. CreateCache() (KVCache, error) } // Registry manages available model architectures var registry = make(map[string]func(*Config) (Model, error)) func Register(name string, factory func(*Config) (Model, error)) { registry[name] = factory } func New(name string, cfg *Config) (Model, error) { factory, ok := registry[name] if !ok { return nil, &UnknownArchitectureError{Name: name} } return factory(cfg) } type UnknownArchitectureError struct { Name string } func (e *UnknownArchitectureError) Error() string { return "unknown architecture: " + e.Name }