| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899 |
- package cpu
- import "unsafe"
- // DotFloat32 computes the dot product between two float32 slices.
- // Dispatch order: AVX-512 -> AVX2 -> scalar fallback.
- func DotFloat32(a, b []float32) float32 {
- if len(a) != len(b) {
- panic("DotFloat32: mismatched slice lengths")
- }
- if len(a) == 0 {
- return 0
- }
- if hasAVX512Kernel && SupportsAVX512() && len(a) >= 16 {
- return dotFloat32AVX512(a, b)
- }
- if hasAVX2Kernel && SupportsAVX2() && len(a) >= 8 {
- return dotFloat32AVX2(a, b)
- }
- return dotFloat32Scalar(a, b)
- }
- func DotFloat32Ptr(a, b *float32, n int) float32 {
- if n <= 0 {
- return 0
- }
- if a == nil || b == nil {
- panic("DotFloat32Ptr: nil pointer")
- }
- if hasAVX512Kernel && SupportsAVX512() && n >= 16 {
- return dotAVX512(a, b, n)
- }
- if hasAVX2Kernel && SupportsAVX2() && n >= 8 {
- return dotAVX2(a, b, n)
- }
- aa := unsafe.Slice(a, n)
- bb := unsafe.Slice(b, n)
- return dotFloat32Scalar(aa, bb)
- }
- func dotFloat32Scalar(a, b []float32) float32 {
- var sum float32
- for i := 0; i < len(a); i++ {
- sum += a[i] * b[i]
- }
- return sum
- }
- // Axpy performs y += alpha * x for float32 slices of equal length.
- // Intended for small vector adds in attention/value accumulation.
- func Axpy(alpha float32, x, y []float32) {
- if len(x) != len(y) {
- panic("Axpy: mismatched slice lengths")
- }
- if len(x) == 0 {
- return
- }
- if hasAVX512Kernel && SupportsAVX512() && len(x) >= 16 {
- axpyFloat32AVX512(alpha, x, y)
- return
- }
- if hasAVX2Kernel && SupportsAVX2() && len(x) >= 8 {
- axpyFloat32AVX2(alpha, x, y)
- return
- }
- for i := 0; i < len(x); i++ {
- y[i] += alpha * x[i]
- }
- }
- func AxpyPtr(alpha float32, x, y *float32, n int) {
- if n <= 0 {
- return
- }
- if x == nil || y == nil {
- panic("AxpyPtr: nil pointer")
- }
- if hasAVX512Kernel && SupportsAVX512() && n >= 16 {
- axpyAVX512(alpha, x, y, n)
- return
- }
- if hasAVX2Kernel && SupportsAVX2() && n >= 8 {
- axpyAVX2(alpha, x, y, n)
- return
- }
- xs := unsafe.Slice(x, n)
- ys := unsafe.Slice(y, n)
- for i := 0; i < n; i++ {
- ys[i] += alpha * xs[i]
- }
- }
|