1
0

token_embedding_nocuda.go 576 B

123456789101112131415161718192021222324
  1. //go:build !cuda
  2. package compute
  3. import (
  4. "makarna/pkg/backend/cpu"
  5. "makarna/pkg/backend/cpu/nn"
  6. "makarna/pkg/backend/device"
  7. "makarna/pkg/tensor"
  8. )
  9. func HybridTokenEmbedding(ids []int, tokenEmb tensor.Tensor, hiddenSize int, dispatcher *device.DeviceDispatcher) (*Activation, error) {
  10. _ = dispatcher
  11. embCPU, err := ToCPU(tokenEmb)
  12. if err != nil {
  13. return nil, err
  14. }
  15. hiddenCPU := cpu.NewTensor(tensor.Shape{len(ids), hiddenSize}, nil)
  16. if err := nn.Embedding(ids, embCPU, hiddenCPU); err != nil {
  17. return nil, err
  18. }
  19. return NewActivationFrom(hiddenCPU), nil
  20. }