//go:build cuda package cuda /* #cgo CFLAGS: -I${SRCDIR} #cgo LDFLAGS: -L${SRCDIR}/../../..//build/cuda -Wl,-Bstatic -lmakarna_cuda -Wl,-Bdynamic #cgo LDFLAGS: -L/usr/local/cuda/lib64 -lcudart -lstdc++ -lm #cgo LDFLAGS: -Wl,-rpath,${SRCDIR}/../../..//build/cuda -Wl,-rpath,/usr/local/cuda/lib64 #include "kernels.h" */ import "C" import ( "errors" "fmt" "runtime" "time" "unsafe" "makarna/pkg/profile" "makarna/pkg/tensor" ) func syncIfProfiling(gpu int) error { if !profile.Enabled() { return nil } return Synchronize(gpu) } // Ensure Interface Compliance var _ tensor.Tensor = (*Tensor)(nil) // Storage holds the underlying GPU memory with reference counting. // Multiple Tensors can share the same Storage (e.g., views, reshapes). // Memory is freed only when all references are gone. type Storage struct { ptr unsafe.Pointer gpu int // Note: We rely on Go's GC and SetFinalizer for ref counting. // Each Tensor that shares this storage keeps a reference to it. // When the last Tensor is GC'd, the Storage becomes unreachable, // and its finalizer frees the GPU memory. } // newStorage creates a new Storage and sets up its finalizer func newStorage(ptr unsafe.Pointer, gpu int) *Storage { s := &Storage{ptr: ptr, gpu: gpu} runtime.SetFinalizer(s, func(st *Storage) { _ = C.cuda_set_device(C.int(st.gpu)) C.cuda_free(st.ptr) }) return s } type Tensor struct { shape tensor.Shape dtype tensor.DType storage *Storage // Shared storage with ref counting ptr unsafe.Pointer // Pointer into storage (may be offset for slices) gpu int // ownsStorage indicates whether this Tensor is responsible for explicitly // freeing the underlying CUDA allocation. // Views/reshapes must not free shared storage because they may outlive the base // tensor (e.g. scratch-buffer views). ownsStorage bool } // NewTensor allocates memory on the GPU func NewTensor(shape tensor.Shape, dtype tensor.DType, gpu int) (*Tensor, error) { if dtype != tensor.Float32 && dtype != tensor.Float16 && dtype != tensor.BFloat16 { return nil, errors.New("unsupported dtype on CUDA") } if gpu < 0 { gpu = 0 } if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return nil, errors.New("failed to set cuda device") } size := shape.NumElements() * dtype.Size() ptr := C.cuda_malloc(C.size_t(size)) if ptr == nil { return nil, errors.New("cuda malloc failed") } storage := newStorage(ptr, gpu) t := &Tensor{ shape: shape, dtype: dtype, storage: storage, ptr: ptr, gpu: gpu, ownsStorage: true, } return t, nil } func (t *Tensor) Shape() tensor.Shape { return t.shape } func (t *Tensor) DType() tensor.DType { return t.dtype } func (t *Tensor) Device() tensor.DeviceType { return tensor.CUDA } // GPU returns the device ordinal. func (t *Tensor) GPU() int { return t.gpu } func (t *Tensor) Placement() tensor.DevicePlacement { return tensor.DevicePlacement{Type: tensor.CUDA, GPU: t.gpu} } func (t *Tensor) Data() interface{} { return t.ptr } // Free explicitly frees the GPU memory associated with the tensor. // Use this for temporary tensors to avoid OOM due to delayed GC. func (t *Tensor) Free() { if t == nil { return } // Only the allocating tensor should explicitly free the CUDA allocation. // Views/reshapes share storage and must not free it. if t.storage != nil && t.ownsStorage { // Clear finalizer so it doesn't run later runtime.SetFinalizer(t.storage, nil) _ = C.cuda_set_device(C.int(t.gpu)) C.cuda_free(t.storage.ptr) } t.storage = nil t.ptr = nil } func (t *Tensor) Add(other tensor.Tensor) error { o, ok := other.(*Tensor) if !ok { return errors.New("other must be CUDA tensor") } if t.dtype != tensor.Float32 || o.dtype != tensor.Float32 { return errors.New("Add only supports Float32") } if t.shape.NumElements() != o.shape.NumElements() { return errors.New("shape mismatch") } // Calls in-place add: t += o if ret := C.cuda_set_device(C.int(t.gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_add_f32((*C.float)(t.ptr), (*C.float)(o.ptr), C.size_t(t.shape.NumElements())) if ret != 0 { return errors.New("cuda add failed") } return nil } func PagedAttentionBatch(Q, kBlocksFlatDev, vBlocksFlatDev, blockOffsetsDev, kvLensDev, queryPosDev, out unsafe.Pointer, numTokens, numHeads, numKVHeads, headDim, blockSize int, scale float32, maxKvLen int, gpu int) error { if err := syncIfProfiling(gpu); err != nil { return err } if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_paged_attention_batch_f32( (*C.float)(Q), (**C.float)(kBlocksFlatDev), (**C.float)(vBlocksFlatDev), (*C.int)(blockOffsetsDev), (*C.int)(kvLensDev), (*C.int)(queryPosDev), (*C.float)(out), C.int(numTokens), C.int(numHeads), C.int(numKVHeads), C.int(headDim), C.int(blockSize), C.float(scale), C.int(maxKvLen), ) if ret != 0 { return errors.New("cuda paged attention batch failed") } if err := syncIfProfiling(gpu); err != nil { return err } return nil } func (t *Tensor) Mul(other tensor.Tensor) error { o, ok := other.(*Tensor) if !ok { return errors.New("other must be CUDA tensor") } if t.dtype != tensor.Float32 || o.dtype != tensor.Float32 { return errors.New("Mul only supports Float32") } if t.shape.NumElements() != o.shape.NumElements() { return errors.New("shape mismatch") } if ret := C.cuda_set_device(C.int(t.gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_mul_f32((*C.float)(t.ptr), (*C.float)(o.ptr), C.size_t(t.shape.NumElements())) if ret != 0 { return errors.New("cuda mul failed") } return nil } func (t *Tensor) MatMul(other tensor.Tensor, out tensor.Tensor) error { B, ok := other.(*Tensor) if !ok { return errors.New("other must be CUDA tensor") } C_out, ok := out.(*Tensor) if !ok { return errors.New("out must be CUDA tensor") } if t.dtype != tensor.Float32 || B.dtype != tensor.Float32 || C_out.dtype != tensor.Float32 { return errors.New("MatMul only supports Float32") } if len(t.shape) != 2 || len(B.shape) != 2 || len(C_out.shape) != 2 { return errors.New("only 2D matmul") } M := t.shape[0] K := t.shape[1] // We use NT matmul (A @ B^T), so B is expected to be [N, K] N := B.shape[0] K2 := B.shape[1] if K != K2 { return fmt.Errorf("k dimension mismatch: A[%d,%d] vs B[%d,%d]", M, K, N, K2) } if C_out.shape[0] != M || C_out.shape[1] != N { return fmt.Errorf("out shape mismatch: expected [%d,%d], got [%d,%d]", M, N, C_out.shape[0], C_out.shape[1]) } if ret := C.cuda_set_device(C.int(t.gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_matmul_f32_nt( (*C.float)(t.ptr), (*C.float)(B.ptr), (*C.float)(C_out.ptr), C.int(M), C.int(K), C.int(N), ) if ret != 0 { return errors.New("cuda matmul failed") } return nil } // Reshape creates a view (shared storage) with new shape. // The new tensor shares the same underlying Storage, so memory // is only freed when all tensors sharing this storage are GC'd. func (t *Tensor) Reshape(shape tensor.Shape) (tensor.Tensor, error) { if shape.NumElements() != t.shape.NumElements() { return nil, errors.New("num elements mismatch") } // Share the same storage - Go's GC handles ref counting for us return &Tensor{ shape: shape, dtype: t.dtype, storage: t.storage, // Shared reference ptr: t.ptr, gpu: t.gpu, ownsStorage: false, }, nil } // ViewAt returns a view into the tensor starting at the given byte offset. // The returned tensor shares storage and does not allocate. func (t *Tensor) ViewAt(shape tensor.Shape, offsetBytes uintptr) (*Tensor, error) { if t == nil { return nil, errors.New("nil tensor") } if offsetBytes%uintptr(t.dtype.Size()) != 0 { return nil, fmt.Errorf("offset %d not aligned to dtype size %d", offsetBytes, t.dtype.Size()) } newPtr := unsafe.Pointer(uintptr(t.ptr) + offsetBytes) return &Tensor{ shape: shape, dtype: t.dtype, storage: t.storage, ptr: newPtr, gpu: t.gpu, ownsStorage: false, }, nil } func (t *Tensor) View(shape tensor.Shape) (tensor.Tensor, error) { return t.Reshape(shape) } func (t *Tensor) ToDevice(device tensor.DeviceType) (tensor.Tensor, error) { if device == tensor.CUDA { return t, nil } // TODO: support CUDA -> CPU if device == tensor.CPU { // We need to copy data back // 1. Create CPU tensor // 2. Memcpy D2H // 3. Return CPU tensor // This requires importing "makarna/pkg/backend/cpu". Circular dependency risk? // No, `cpu` imports `tensor`, `cuda` imports `tensor`. // But `cuda` cannot import `cpu` easily if `cpu` is intended to be the default. // Actually it's fine if `cuda` imports `cpu`. return nil, errors.New("ToDevice(CPU) not implemented here yet, use helper") } return nil, errors.New("unknown device") } func (t *Tensor) CopyFrom(data interface{}) error { if t.dtype != tensor.Float32 { return errors.New("CopyFrom only supports Float32") } // Assuming data is []float32 on Host src, ok := data.([]float32) if !ok { return errors.New("data must be []float32") } size := len(src) * 4 if size != t.shape.NumElements()*t.dtype.Size() { return errors.New("size mismatch") } if ret := C.cuda_set_device(C.int(t.gpu)); ret != 0 { return errors.New("failed to set cuda device") } start := time.Now() ret := C.cuda_memcpy_h2d(t.ptr, unsafe.Pointer(&src[0]), C.size_t(size)) if ret != 0 { runtime.KeepAlive(src) runtime.KeepAlive(t) return errors.New("cuda memcpy failed") } profile.RecordTransfer("CopyFrom/H2D", profile.EventH2D, int64(size), time.Since(start), t.gpu) runtime.KeepAlive(src) runtime.KeepAlive(t) return nil } // Helper to copy back to host func (t *Tensor) CopyToHost(dst []float32) error { if t.dtype != tensor.Float32 { return errors.New("CopyToHost only supports Float32") } size := len(dst) * 4 if size != t.shape.NumElements()*4 { return errors.New("size mismatch") } if ret := C.cuda_set_device(C.int(t.gpu)); ret != 0 { return errors.New("failed to set cuda device") } start := time.Now() ret := C.cuda_memcpy_d2h(unsafe.Pointer(&dst[0]), t.ptr, C.size_t(size)) if ret != 0 { runtime.KeepAlive(dst) runtime.KeepAlive(t) return errors.New("cuda memcpy d2h failed") } profile.RecordTransfer("CopyToHost/D2H", profile.EventD2H, int64(size), time.Since(start), t.gpu) runtime.KeepAlive(dst) runtime.KeepAlive(t) return nil } func (t *Tensor) CopyToInt32(dst []int32) error { if t.dtype != tensor.Int32 { return errors.New("CopyToInt32 only supports Int32") } size := len(dst) * 4 if size != t.shape.NumElements()*4 { return errors.New("size mismatch") } if ret := C.cuda_set_device(C.int(t.gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_memcpy_d2h(unsafe.Pointer(&dst[0]), t.ptr, C.size_t(size)) if ret != 0 { return errors.New("cuda memcpy d2h failed") } runtime.KeepAlive(dst) runtime.KeepAlive(t) return nil } // CopyPartialFrom copies a portion of host data to the tensor at a given offset. // dstOffset: offset in float32 elements from the start of the tensor // src: source data to copy from host func (t *Tensor) CopyPartialFrom(dstOffset int, src []float32) error { if t.dtype != tensor.Float32 { return errors.New("CopyPartialFrom only supports Float32") } if dstOffset+len(src) > t.shape.NumElements() { return errors.New("partial copy would exceed tensor bounds") } if len(src) == 0 { return nil } if ret := C.cuda_set_device(C.int(t.gpu)); ret != 0 { return errors.New("failed to set cuda device") } // Calculate destination pointer with offset dstPtr := unsafe.Pointer(uintptr(t.ptr) + uintptr(dstOffset*4)) size := len(src) * 4 start := time.Now() ret := C.cuda_memcpy_h2d(dstPtr, unsafe.Pointer(&src[0]), C.size_t(size)) if ret != 0 { runtime.KeepAlive(src) runtime.KeepAlive(t) return errors.New("cuda memcpy partial failed") } profile.RecordTransfer("CopyPartialFrom/H2D", profile.EventH2D, int64(size), time.Since(start), t.gpu) runtime.KeepAlive(src) runtime.KeepAlive(t) return nil } // CopyPartialFromDevice copies a portion from another CUDA tensor into this tensor. // Offsets and length are in float32 elements. func (t *Tensor) CopyPartialFromDevice(dstOffset int, src *Tensor, srcOffset int, length int) error { if t.dtype != src.dtype { return errors.New("dtype mismatch") } if dstOffset+length > t.shape.NumElements() { return errors.New("dst offset/length exceed tensor bounds") } if srcOffset+length > src.shape.NumElements() { return errors.New("src offset/length exceed tensor bounds") } if length == 0 { return nil } if ret := C.cuda_set_device(C.int(t.gpu)); ret != 0 { return errors.New("failed to set cuda device") } start := time.Now() eltSize := t.dtype.Size() dstPtr := unsafe.Pointer(uintptr(t.ptr) + uintptr(dstOffset*eltSize)) srcPtr := unsafe.Pointer(uintptr(src.ptr) + uintptr(srcOffset*eltSize)) size := C.size_t(length * eltSize) if ret := C.cuda_memcpy_d2d(dstPtr, srcPtr, size); ret != 0 { runtime.KeepAlive(src) runtime.KeepAlive(t) return errors.New("cuda memcpy d2d failed") } profile.RecordTransfer("CopyPartialFromDevice/D2D", profile.EventD2D, int64(length*eltSize), time.Since(start), t.gpu) runtime.KeepAlive(src) runtime.KeepAlive(t) return nil } func CastF32ToF16(srcF32, dstF16 unsafe.Pointer, n int, gpu int) error { if n <= 0 { return nil } if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } if ret := C.cuda_cast_f32_to_f16((*C.float)(srcF32), (*C.ushort)(dstF16), C.int(n)); ret != 0 { return errors.New("cuda cast f32->f16 failed") } return nil } func PagedAttentionF32F16KV(Q, kBlocksDev, vBlocksDev, out unsafe.Pointer, seqLen, kvLen, numHeads, numKVHeads, headDim, blockSize int, scale float32, startPos int, gpu int) error { if err := syncIfProfiling(gpu); err != nil { return err } if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_paged_attention_f32_f16kv( (*C.float)(Q), (**C.ushort)(kBlocksDev), (**C.ushort)(vBlocksDev), (*C.float)(out), C.int(seqLen), C.int(kvLen), C.int(numHeads), C.int(numKVHeads), C.int(headDim), C.int(blockSize), C.float(scale), C.int(startPos), ) if ret != 0 { return errors.New("cuda paged attention f16kv failed") } if err := syncIfProfiling(gpu); err != nil { return err } return nil } func PagedAttentionBatchF32F16KV(Q, kBlocksFlatDev, vBlocksFlatDev, blockOffsetsDev, kvLensDev, queryPosDev, out unsafe.Pointer, numTokens, numHeads, numKVHeads, headDim, blockSize int, scale float32, maxKvLen int, gpu int) error { if err := syncIfProfiling(gpu); err != nil { return err } if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_paged_attention_batch_f32_f16kv( (*C.float)(Q), (**C.ushort)(kBlocksFlatDev), (**C.ushort)(vBlocksFlatDev), (*C.int)(blockOffsetsDev), (*C.int)(kvLensDev), (*C.int)(queryPosDev), (*C.float)(out), C.int(numTokens), C.int(numHeads), C.int(numKVHeads), C.int(headDim), C.int(blockSize), C.float(scale), C.int(maxKvLen), ) if ret != 0 { return errors.New("cuda paged attention batch f16kv failed") } if err := syncIfProfiling(gpu); err != nil { return err } return nil } // PagedAttentionRoPEF32F16KV runs paged attention with fused RoPE inside the kernel. // Expects un-rotated Q and un-rotated K blocks; V blocks are unchanged. func PagedAttentionRoPEF32F16KV(Q, kBlocksDev, vBlocksDev, out unsafe.Pointer, seqLen, kvLen, numHeads, numKVHeads, headDim, blockSize int, scale float32, startPos int, theta float32, gpu int) error { if err := syncIfProfiling(gpu); err != nil { return err } if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_paged_attention_rope_f32_f16kv( (*C.float)(Q), (**C.ushort)(kBlocksDev), (**C.ushort)(vBlocksDev), (*C.float)(out), C.int(seqLen), C.int(kvLen), C.int(numHeads), C.int(numKVHeads), C.int(headDim), C.int(blockSize), C.float(scale), C.int(startPos), C.float(theta), ) if ret != 0 { return errors.New("cuda paged attention rope f16kv failed") } if err := syncIfProfiling(gpu); err != nil { return err } return nil } // PagedAttentionBatchRoPEF32F16KV runs batched paged attention with fused RoPE inside the kernel. // Expects un-rotated Q and un-rotated K blocks; V blocks are unchanged. func PagedAttentionBatchRoPEF32F16KV(Q, kBlocksFlatDev, vBlocksFlatDev, blockOffsetsDev, kvLensDev, queryPosDev, out unsafe.Pointer, numTokens, numHeads, numKVHeads, headDim, blockSize int, scale float32, maxKvLen int, theta float32, gpu int) error { if err := syncIfProfiling(gpu); err != nil { return err } if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_paged_attention_rope_batch_f32_f16kv( (*C.float)(Q), (**C.ushort)(kBlocksFlatDev), (**C.ushort)(vBlocksFlatDev), (*C.int)(blockOffsetsDev), (*C.int)(kvLensDev), (*C.int)(queryPosDev), (*C.float)(out), C.int(numTokens), C.int(numHeads), C.int(numKVHeads), C.int(headDim), C.int(blockSize), C.float(scale), C.int(maxKvLen), C.float(theta), ) if ret != 0 { return errors.New("cuda paged attention batch rope f16kv failed") } if err := syncIfProfiling(gpu); err != nil { return err } return nil } // Available returns whether CUDA is available func Available() bool { return true } // MemoryInfo returns (total, free) bytes for the current CUDA device. func MemoryInfo() (total uint64, free uint64, err error) { var cFree, cTotal C.size_t ret := C.cuda_mem_info(&cFree, &cTotal) if ret != 0 { return 0, 0, errors.New("cuda_mem_info failed") } return uint64(cTotal), uint64(cFree), nil } // MemoryInfoDevice returns (total, free) bytes for the given CUDA device. func MemoryInfoDevice(gpu int) (total uint64, free uint64, err error) { if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return 0, 0, errors.New("failed to set cuda device") } var cFree, cTotal C.size_t ret := C.cuda_mem_info(&cFree, &cTotal) if ret != 0 { return 0, 0, errors.New("cuda_mem_info failed") } return uint64(cTotal), uint64(cFree), nil } // DeviceCount returns the number of visible CUDA devices. func DeviceCount() (int, error) { var cCount C.int ret := C.cuda_device_count(&cCount) if ret != 0 { return 0, errors.New("cuda_device_count failed") } if cCount < 0 { return 0, errors.New("cuda_device_count returned negative") } return int(cCount), nil } // Synchronize waits for all queued work on the given GPU. // Use when explicit host/device coordination is required. func Synchronize(gpu int) error { if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } if ret := C.cuda_synchronize(); ret != 0 { return errors.New("cuda synchronize failed") } return nil } // ============================================================ // Neural Network Operations // ============================================================ // RMSNorm applies RMS normalization in-place on GPU // x: [seqLen, dim] device pointer, w: [dim] device pointer func RMSNorm(x, w unsafe.Pointer, seqLen, dim int, eps float32, gpu int) error { if err := syncIfProfiling(gpu); err != nil { return err } if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_rmsnorm_f32((*C.float)(x), (*C.float)(w), C.int(seqLen), C.int(dim), C.float(eps)) if ret != 0 { return errors.New("cuda rmsnorm failed") } if err := syncIfProfiling(gpu); err != nil { return err } return nil } // RoPE applies rotary positional embeddings in-place // x: [seqLen, numHeads * headDim] device pointer // positions: [seqLen] device pointer (int32) func RoPE(x, positions unsafe.Pointer, seqLen, numHeads, headDim int, theta float32, gpu int) error { if err := syncIfProfiling(gpu); err != nil { return err } if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_rope_f32((*C.float)(x), (*C.int)(positions), C.int(seqLen), C.int(numHeads), C.int(headDim), C.float(theta)) if ret != 0 { return errors.New("cuda rope failed") } if err := syncIfProfiling(gpu); err != nil { return err } return nil } // RoPESingle runs RoPE for a single token at a specific position. func RoPESingle(x unsafe.Pointer, pos, numHeads, headDim int, theta float32, gpu int) error { if err := syncIfProfiling(gpu); err != nil { return err } if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_rope_f32_single((*C.float)(x), C.int(pos), C.int(numHeads), C.int(headDim), C.float(theta)) if ret != 0 { return errors.New("cuda rope single failed") } if err := syncIfProfiling(gpu); err != nil { return err } return nil } // Softmax applies softmax along last dimension in-place func Softmax(x unsafe.Pointer, rows, cols int, gpu int) error { if err := syncIfProfiling(gpu); err != nil { return err } if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_softmax_f32((*C.float)(x), C.int(rows), C.int(cols)) if ret != 0 { return errors.New("cuda softmax failed") } if err := syncIfProfiling(gpu); err != nil { return err } return nil } // SiLU applies SiLU activation in-place: x = x * sigmoid(x) func SiLU(x unsafe.Pointer, n int, gpu int) error { if err := syncIfProfiling(gpu); err != nil { return err } if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_silu_f32((*C.float)(x), C.size_t(n)) if ret != 0 { return errors.New("cuda silu failed") } if err := syncIfProfiling(gpu); err != nil { return err } return nil } // MulInplace performs element-wise a = a * b func MulInplace(a, b unsafe.Pointer, n int, gpu int) error { if err := syncIfProfiling(gpu); err != nil { return err } if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_mul_inplace_f32((*C.float)(a), (*C.float)(b), C.size_t(n)) if ret != 0 { return errors.New("cuda mul inplace failed") } if err := syncIfProfiling(gpu); err != nil { return err } return nil } // Copy copies GPU memory: dst = src func Copy(dst, src unsafe.Pointer, n int, gpu int) error { if err := syncIfProfiling(gpu); err != nil { return err } if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_copy_f32((*C.float)(dst), (*C.float)(src), C.size_t(n)) if ret != 0 { return errors.New("cuda copy failed") } if err := syncIfProfiling(gpu); err != nil { return err } return nil } func KDACausalShortConv1D(x, state, w unsafe.Pointer, tokens, projSize, kernel int, gpu int) error { if err := syncIfProfiling(gpu); err != nil { return err } if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_kda_causal_short_conv1d_f32( (*C.float)(x), (*C.float)(state), (*C.float)(w), C.int(tokens), C.int(projSize), C.int(kernel), ) if ret != 0 { return errors.New("cuda kda causal short conv1d failed") } if err := syncIfProfiling(gpu); err != nil { return err } return nil } func L2NormHeads(q, k unsafe.Pointer, tokens, numHeads, headDim int, eps float32, gpu int) error { if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_l2norm_heads_f32((*C.float)(q), (*C.float)(k), C.int(tokens), C.int(numHeads), C.int(headDim), C.float(eps)) if ret != 0 { return errors.New("cuda l2norm heads failed") } return nil } func KDAGate(g, aLog, dtBias, out unsafe.Pointer, tokens, numHeads, headDim int, gpu int) error { if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_kda_gate_f32((*C.float)(g), (*C.float)(aLog), (*C.float)(dtBias), (*C.float)(out), C.int(tokens), C.int(numHeads), C.int(headDim)) if ret != 0 { return errors.New("cuda kda gate failed") } return nil } func KDARecurrent(q, k, v, g, beta, state unsafe.Pointer, tokens, numHeads, headDim int, gpu int) error { if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_kda_recurrent_f32((*C.float)(q), (*C.float)(k), (*C.float)(v), (*C.float)(g), (*C.float)(beta), (*C.float)(state), C.int(tokens), C.int(numHeads), C.int(headDim)) if ret != 0 { return errors.New("cuda kda recurrent failed") } return nil } func RMSNormGated(out, g, weight unsafe.Pointer, n, headDim int, eps float32, gpu int) error { if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_rmsnorm_gated_f32((*C.float)(out), (*C.float)(g), (*C.float)(weight), C.int(n), C.int(headDim), C.float(eps)) if ret != 0 { return errors.New("cuda rmsnorm gated failed") } return nil } func Sigmoid(x unsafe.Pointer, n int, gpu int) error { if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_sigmoid_f32((*C.float)(x), C.int(n)) if ret != 0 { return errors.New("cuda sigmoid failed") } return nil } func SoftmaxRows(x unsafe.Pointer, rows, cols int, gpu int) error { if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_softmax_rows_f32((*C.float)(x), C.int(rows), C.int(cols)) if ret != 0 { return errors.New("cuda softmax rows failed") } return nil } func TopKPerRow(scores unsafe.Pointer, indices unsafe.Pointer, values unsafe.Pointer, rows, cols, k int, gpu int) error { if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_topk_per_row_f32((*C.float)(scores), (*C.int)(indices), (*C.float)(values), C.int(rows), C.int(cols), C.int(k)) if ret != 0 { return errors.New("cuda topk per row failed") } return nil } // Attention computes full causal attention on GPU // Q: [seqLen, numHeads * headDim] // K, V: [kvLen, numKVHeads * headDim] // out: [seqLen, numHeads * headDim] func Attention(Q, K, V, out unsafe.Pointer, seqLen, kvLen, numHeads, numKVHeads, headDim int, scale float32, startPos int, gpu int) error { if err := syncIfProfiling(gpu); err != nil { return err } if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_attention_f32( (*C.float)(Q), (*C.float)(K), (*C.float)(V), (*C.float)(out), C.int(seqLen), C.int(kvLen), C.int(numHeads), C.int(numKVHeads), C.int(headDim), C.float(scale), C.int(startPos), ) if ret != 0 { return errors.New("cuda attention failed") } if err := syncIfProfiling(gpu); err != nil { return err } return nil } func PagedAttention(Q, kBlocksDev, vBlocksDev, out unsafe.Pointer, seqLen, kvLen, numHeads, numKVHeads, headDim, blockSize int, scale float32, startPos int, gpu int) error { if err := syncIfProfiling(gpu); err != nil { return err } if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_paged_attention_f32( (*C.float)(Q), (**C.float)(kBlocksDev), (**C.float)(vBlocksDev), (*C.float)(out), C.int(seqLen), C.int(kvLen), C.int(numHeads), C.int(numKVHeads), C.int(headDim), C.int(blockSize), C.float(scale), C.int(startPos), ) if ret != 0 { return errors.New("cuda paged attention failed") } if err := syncIfProfiling(gpu); err != nil { return err } return nil } // AttentionTimed runs attention and returns kernel time in milliseconds. // Intended for profiling/debugging only (it synchronizes internally). func AttentionTimed(Q, K, V, out unsafe.Pointer, seqLen, kvLen, numHeads, numKVHeads, headDim int, scale float32, startPos int, gpu int) (float32, error) { if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return 0, errors.New("failed to set cuda device") } var ms C.float ret := C.cuda_attention_f32_timed( (*C.float)(Q), (*C.float)(K), (*C.float)(V), (*C.float)(out), C.int(seqLen), C.int(kvLen), C.int(numHeads), C.int(numKVHeads), C.int(headDim), C.float(scale), C.int(startPos), &ms, ) if ret != 0 { return 0, errors.New("cuda attention timed failed") } return float32(ms), nil } // AddInplace performs element-wise a = a + b func AddInplace(a, b unsafe.Pointer, n int, gpu int) error { if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_add_f32((*C.float)(a), (*C.float)(b), C.size_t(n)) if ret != 0 { return errors.New("cuda add failed") } return nil } // ============================================================ // Dequantization Operations // ============================================================ // DequantQ8K dequantizes Q8_K blocks on GPU // blocks: device pointer to Q8_K data // out: device pointer to output float32 (numBlocks * 256 elements) func DequantQ8K(blocks unsafe.Pointer, out unsafe.Pointer, numBlocks int, gpu int) error { if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_dequant_q8k(blocks, (*C.float)(out), C.int(numBlocks)) if ret != 0 { return errors.New("cuda dequant q8k failed") } return nil } func DequantQ4K(blocks unsafe.Pointer, out unsafe.Pointer, numBlocks int, gpu int) error { if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_dequant_q4k(blocks, (*C.float)(out), C.int(numBlocks)) if ret != 0 { return errors.New("cuda dequant q4k failed") } return nil } func DequantQ5K(blocks unsafe.Pointer, out unsafe.Pointer, numBlocks int, gpu int) error { if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_dequant_q5k(blocks, (*C.float)(out), C.int(numBlocks)) if ret != 0 { return errors.New("cuda dequant q5k failed") } return nil } // DequantQ6K dequantizes Q6_K blocks on GPU func DequantQ6K(blocks unsafe.Pointer, out unsafe.Pointer, numBlocks int, gpu int) error { if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_dequant_q6k(blocks, (*C.float)(out), C.int(numBlocks)) if ret != 0 { return errors.New("cuda dequant q6k failed") } return nil } func DequantQ3K(blocks unsafe.Pointer, out unsafe.Pointer, numBlocks int, gpu int) error { if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_dequant_q3k(blocks, (*C.float)(out), C.int(numBlocks)) if ret != 0 { return errors.New("cuda dequant q3k failed") } return nil } func DequantQ2K(blocks unsafe.Pointer, out unsafe.Pointer, numBlocks int, gpu int) error { if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_dequant_q2k(blocks, (*C.float)(out), C.int(numBlocks)) if ret != 0 { return errors.New("cuda dequant q2k failed") } return nil } // MatMulQ8K performs C = A @ dequant(B) where B is Q8_K quantized func MatMulQ8K(A unsafe.Pointer, B unsafe.Pointer, C unsafe.Pointer, M, K, N, gpu int) error { if err := syncIfProfiling(gpu); err != nil { return err } if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_matmul_f32_q8k((*C.float)(A), B, (*C.float)(C), C.int(M), C.int(K), C.int(N)) if ret != 0 { return errors.New("cuda matmul q8k failed") } if err := syncIfProfiling(gpu); err != nil { return err } return nil } func MatMulQ5K(A unsafe.Pointer, B unsafe.Pointer, C unsafe.Pointer, M, K, N, gpu int) error { if err := syncIfProfiling(gpu); err != nil { return err } if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_matmul_f32_q5k((*C.float)(A), B, (*C.float)(C), C.int(M), C.int(K), C.int(N)) if ret != 0 { return errors.New("cuda matmul q5k failed") } if err := syncIfProfiling(gpu); err != nil { return err } return nil } func MatMulQ4K(A unsafe.Pointer, B unsafe.Pointer, Cptr unsafe.Pointer, M, K, N, gpu int) error { if err := syncIfProfiling(gpu); err != nil { return err } if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_matmul_f32_q4k((*C.float)(A), B, (*C.float)(Cptr), C.int(M), C.int(K), C.int(N)) if ret != 0 { return errors.New("cuda matmul q4k failed") } if err := syncIfProfiling(gpu); err != nil { return err } return nil } func MatMulQ2K(aPtr, bPtr, cPtr unsafe.Pointer, m, k, n int, gpu int) error { if err := syncIfProfiling(gpu); err != nil { return err } if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } if k%256 != 0 { return fmt.Errorf("MatMulQ2K: K must be multiple of 256, got %d", k) } ret := C.cuda_matmul_f32_q2k((*C.float)(aPtr), bPtr, (*C.float)(cPtr), C.int(m), C.int(k), C.int(n)) if ret != 0 { return errors.New("cuda matmul q2k failed") } if err := syncIfProfiling(gpu); err != nil { return err } return nil } func MatMulQ3K(aPtr, bPtr, cPtr unsafe.Pointer, m, k, n int, gpu int) error { if err := syncIfProfiling(gpu); err != nil { return err } if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } if k%256 != 0 { return fmt.Errorf("MatMulQ3K: K must be multiple of 256, got %d", k) } ret := C.cuda_matmul_f32_q3k((*C.float)(aPtr), bPtr, (*C.float)(cPtr), C.int(m), C.int(k), C.int(n)) if ret != 0 { return errors.New("cuda matmul q3k failed") } if err := syncIfProfiling(gpu); err != nil { return err } return nil } func MatMulQ6K(aPtr, bPtr, cPtr unsafe.Pointer, m, k, n int, gpu int) error { if err := syncIfProfiling(gpu); err != nil { return err } if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } if k%256 != 0 { return fmt.Errorf("MatMulQ6K: K must be multiple of 256, got %d", k) } ret := C.cuda_matmul_f32_q6k((*C.float)(aPtr), bPtr, (*C.float)(cPtr), C.int(m), C.int(k), C.int(n)) if ret != 0 { return errors.New("cuda matmul q6k failed") } if err := syncIfProfiling(gpu); err != nil { return err } return nil } func MatMulF32(A, B, Cptr unsafe.Pointer, M, K, N, gpu int) error { if err := syncIfProfiling(gpu); err != nil { return err } if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_matmul_f32_nt( (*C.float)(A), (*C.float)(B), (*C.float)(Cptr), C.int(M), C.int(K), C.int(N), ) if ret != 0 { return errors.New("cuda matmul f32 failed") } if err := syncIfProfiling(gpu); err != nil { return err } return nil } // MatMulF16 performs C = A @ B^T where A and B are float16 (stored as uint16), // and C is float32 output. func MatMulF16(A, B, Cptr unsafe.Pointer, M, K, N, gpu int) error { if err := syncIfProfiling(gpu); err != nil { return err } if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_matmul_f16_nt( (*C.ushort)(A), (*C.ushort)(B), (*C.float)(Cptr), C.int(M), C.int(K), C.int(N), ) if ret != 0 { return errors.New("cuda matmul f16 failed") } if err := syncIfProfiling(gpu); err != nil { return err } return nil } // FP16 Input MatMul variants - 2x memory bandwidth for activations // A is FP16, B is quantized, C is FP32 output func MatMulF16Q8K(A, B, Cptr unsafe.Pointer, M, K, N, gpu int) error { if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_matmul_f16_q8k(A, B, (*C.float)(Cptr), C.int(M), C.int(K), C.int(N)) if ret != 0 { return errors.New("cuda matmul f16 q8k failed") } return nil } func MatMulF16Q4K(A, B, Cptr unsafe.Pointer, M, K, N, gpu int) error { if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_matmul_f16_q4k(A, B, (*C.float)(Cptr), C.int(M), C.int(K), C.int(N)) if ret != 0 { return errors.New("cuda matmul f16 q4k failed") } return nil } func MatMulF16Q5K(A, B, Cptr unsafe.Pointer, M, K, N, gpu int) error { if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_matmul_f16_q5k(A, B, (*C.float)(Cptr), C.int(M), C.int(K), C.int(N)) if ret != 0 { return errors.New("cuda matmul f16 q5k failed") } return nil } func MatMulF16Q2K(A, B, Cptr unsafe.Pointer, M, K, N, gpu int) error { if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_matmul_f16_q2k(A, B, (*C.float)(Cptr), C.int(M), C.int(K), C.int(N)) if ret != 0 { return errors.New("cuda matmul f16 q2k failed") } return nil } func MatMulF16Q3K(A, B, Cptr unsafe.Pointer, M, K, N, gpu int) error { if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_matmul_f16_q3k(A, B, (*C.float)(Cptr), C.int(M), C.int(K), C.int(N)) if ret != 0 { return errors.New("cuda matmul f16 q3k failed") } return nil } func MatMulF16Q6K(A, B, Cptr unsafe.Pointer, M, K, N, gpu int) error { if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_matmul_f16_q6k(A, B, (*C.float)(Cptr), C.int(M), C.int(K), C.int(N)) if ret != 0 { return errors.New("cuda matmul f16 q6k failed") } return nil } // UploadQ8K uploads Q8_K blocks from host to GPU func UploadQ8K(hostData []byte, numBlocks int, gpu int) (unsafe.Pointer, error) { if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return nil, errors.New("failed to set cuda device") } size := len(hostData) ptr := C.cuda_malloc(C.size_t(size)) if ptr == nil { return nil, errors.New("cuda malloc failed for Q8K") } ret := C.cuda_memcpy_h2d(ptr, unsafe.Pointer(&hostData[0]), C.size_t(size)) if ret != 0 { C.cuda_free(ptr) return nil, errors.New("cuda memcpy h2d failed for Q8K") } return ptr, nil } func AllocAndCopyPtrTable(ptrs []uintptr, gpu int) (unsafe.Pointer, error) { if len(ptrs) == 0 { return nil, errors.New("empty ptr table") } if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return nil, errors.New("failed to set cuda device") } size := len(ptrs) * int(unsafe.Sizeof(uintptr(0))) ptr := C.cuda_malloc(C.size_t(size)) if ptr == nil { return nil, errors.New("cuda malloc failed for ptr table") } ret := C.cuda_memcpy_h2d(ptr, unsafe.Pointer(&ptrs[0]), C.size_t(size)) if ret != 0 { C.cuda_free(ptr) return nil, errors.New("cuda memcpy h2d failed for ptr table") } return ptr, nil } func UploadQ5K(hostData []byte, numBlocks int, gpu int) (unsafe.Pointer, error) { if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return nil, errors.New("failed to set cuda device") } size := len(hostData) ptr := C.cuda_malloc(C.size_t(size)) if ptr == nil { return nil, errors.New("cuda malloc failed for Q5K") } ret := C.cuda_memcpy_h2d(ptr, unsafe.Pointer(&hostData[0]), C.size_t(size)) if ret != 0 { C.cuda_free(ptr) return nil, errors.New("cuda memcpy h2d failed for Q5K") } return ptr, nil } // UploadQ4K uploads Q4_K blocks from host to GPU func UploadQ4K(hostData []byte, numBlocks int, gpu int) (unsafe.Pointer, error) { if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return nil, errors.New("failed to set cuda device") } size := len(hostData) ptr := C.cuda_malloc(C.size_t(size)) if ptr == nil { return nil, errors.New("cuda malloc failed for Q4K") } ret := C.cuda_memcpy_h2d(ptr, unsafe.Pointer(&hostData[0]), C.size_t(size)) if ret != 0 { C.cuda_free(ptr) return nil, errors.New("cuda memcpy h2d failed for Q4K") } return ptr, nil } // UploadQ2K uploads Q2_K blocks from host to GPU func UploadQ2K(hostData []byte, numBlocks int, gpu int) (unsafe.Pointer, error) { if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return nil, errors.New("failed to set cuda device") } size := len(hostData) ptr := C.cuda_malloc(C.size_t(size)) if ptr == nil { return nil, errors.New("cuda malloc failed for Q2K") } ret := C.cuda_memcpy_h2d(ptr, unsafe.Pointer(&hostData[0]), C.size_t(size)) if ret != 0 { C.cuda_free(ptr) return nil, errors.New("cuda memcpy h2d failed for Q2K") } return ptr, nil } // UploadQ3K uploads Q3_K blocks from host to GPU func UploadQ3K(hostData []byte, numBlocks int, gpu int) (unsafe.Pointer, error) { if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return nil, errors.New("failed to set cuda device") } size := len(hostData) ptr := C.cuda_malloc(C.size_t(size)) if ptr == nil { return nil, errors.New("cuda malloc failed for Q3K") } ret := C.cuda_memcpy_h2d(ptr, unsafe.Pointer(&hostData[0]), C.size_t(size)) if ret != 0 { C.cuda_free(ptr) return nil, errors.New("cuda memcpy h2d failed for Q3K") } return ptr, nil } // UploadQ6K uploads Q6_K blocks from host to GPU func UploadQ6K(hostData []byte, numBlocks int, gpu int) (unsafe.Pointer, error) { if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return nil, errors.New("failed to set cuda device") } size := len(hostData) ptr := C.cuda_malloc(C.size_t(size)) if ptr == nil { return nil, errors.New("cuda malloc failed for Q6K") } ret := C.cuda_memcpy_h2d(ptr, unsafe.Pointer(&hostData[0]), C.size_t(size)) if ret != 0 { C.cuda_free(ptr) return nil, errors.New("cuda memcpy h2d failed for Q6K") } return ptr, nil } // MemcpyH2D copies data from host to device pointer. // dst: device pointer // src: host data (unsafe.Pointer to first element) // size: number of bytes // gpu: device id (must be active or will be set) func MemcpyH2D(dst, src unsafe.Pointer, size uintptr, gpu int) error { if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_memcpy_h2d(dst, src, C.size_t(size)) if ret != 0 { return errors.New("cuda memcpy h2d failed") } return nil } // MemcpyD2H copies data from device pointer to host pointer. func MemcpyD2H(dst, src unsafe.Pointer, size uintptr, gpu int) error { if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_memcpy_d2h(dst, src, C.size_t(size)) if ret != 0 { return errors.New("cuda memcpy d2h failed") } return nil } func MemcpyD2D(dst, src unsafe.Pointer, size uintptr, gpu int) error { if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return errors.New("failed to set cuda device") } ret := C.cuda_memcpy_d2d(dst, src, C.size_t(size)) if ret != 0 { return errors.New("cuda memcpy d2d failed") } return nil } // TopKLogitsF32 computes per-block top-k on GPU (with repetition penalty applied) // and returns the concatenated candidate list on host (caller does final global top-k). func TopKLogitsF32(logits unsafe.Pointer, vocab int, repIDs []int32, repPenalty float32, k int, gpu int) ([]int32, []float32, int, error) { if k <= 0 { return nil, nil, 0, nil } if k > 64 { return nil, nil, 0, fmt.Errorf("TopKLogitsF32: k too large: %d", k) } blocks := (vocab + 2048 - 1) / 2048 if blocks <= 0 { blocks = 1 } count := blocks * k var repPtr unsafe.Pointer if len(repIDs) > 0 { p, err := AllocAndCopyInt32(repIDs, gpu) if err != nil { return nil, nil, 0, err } repPtr = p defer FreeDevicePtr(repPtr) } // Device outputs outIDsPtr := C.cuda_malloc(C.size_t(count * 4)) if outIDsPtr == nil { return nil, nil, 0, errors.New("TopKLogitsF32: cuda malloc failed for outIDs") } defer C.cuda_free(outIDsPtr) outScoresPtr := C.cuda_malloc(C.size_t(count * 4)) if outScoresPtr == nil { return nil, nil, 0, errors.New("TopKLogitsF32: cuda malloc failed for outScores") } defer C.cuda_free(outScoresPtr) if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return nil, nil, 0, errors.New("failed to set cuda device") } ret := C.cuda_topk_logits_f32( (*C.float)(logits), C.int(vocab), (*C.int)(repPtr), C.int(len(repIDs)), C.float(repPenalty), C.int(k), (*C.int)(outIDsPtr), (*C.float)(outScoresPtr), ) if ret != 0 { return nil, nil, 0, errors.New("cuda topk logits failed") } ids := make([]int32, count) scores := make([]float32, count) if err := MemcpyD2H(unsafe.Pointer(&ids[0]), unsafe.Pointer(outIDsPtr), uintptr(count*4), gpu); err != nil { return nil, nil, 0, err } if err := MemcpyD2H(unsafe.Pointer(&scores[0]), unsafe.Pointer(outScoresPtr), uintptr(count*4), gpu); err != nil { return nil, nil, 0, err } return ids, scores, blocks, nil } // FreeDevicePtr frees a device pointer func FreeDevicePtr(ptr unsafe.Pointer) { if ptr != nil { C.cuda_free(ptr) } } // Free is an alias for FreeDevicePtr for convenience func Free(ptr unsafe.Pointer) { FreeDevicePtr(ptr) } // AllocAndCopyInt32 allocates GPU memory and copies int32 data to it // Returns raw device pointer (caller must Free it) func AllocAndCopyInt32(data []int32, gpu int) (unsafe.Pointer, error) { if len(data) == 0 { return nil, errors.New("empty data") } if ret := C.cuda_set_device(C.int(gpu)); ret != 0 { return nil, errors.New("failed to set cuda device") } size := len(data) * 4 // 4 bytes per int32 ptr := C.cuda_malloc(C.size_t(size)) if ptr == nil { return nil, errors.New("cuda malloc failed for int32 data") } ret := C.cuda_memcpy_h2d(ptr, unsafe.Pointer(&data[0]), C.size_t(size)) if ret != 0 { C.cuda_free(ptr) return nil, errors.New("cuda memcpy h2d failed for int32 data") } return ptr, nil }