simd.go 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. package cpu
  2. import "unsafe"
  3. // DotFloat32 computes the dot product between two float32 slices.
  4. // Dispatch order: AVX-512 -> AVX2 -> scalar fallback.
  5. func DotFloat32(a, b []float32) float32 {
  6. if len(a) != len(b) {
  7. panic("DotFloat32: mismatched slice lengths")
  8. }
  9. if len(a) == 0 {
  10. return 0
  11. }
  12. if hasAVX512Kernel && SupportsAVX512() && len(a) >= 16 {
  13. return dotFloat32AVX512(a, b)
  14. }
  15. if hasAVX2Kernel && SupportsAVX2() && len(a) >= 8 {
  16. return dotFloat32AVX2(a, b)
  17. }
  18. return dotFloat32Scalar(a, b)
  19. }
  20. func DotFloat32Ptr(a, b *float32, n int) float32 {
  21. if n <= 0 {
  22. return 0
  23. }
  24. if a == nil || b == nil {
  25. panic("DotFloat32Ptr: nil pointer")
  26. }
  27. if hasAVX512Kernel && SupportsAVX512() && n >= 16 {
  28. return dotAVX512(a, b, n)
  29. }
  30. if hasAVX2Kernel && SupportsAVX2() && n >= 8 {
  31. return dotAVX2(a, b, n)
  32. }
  33. aa := unsafe.Slice(a, n)
  34. bb := unsafe.Slice(b, n)
  35. return dotFloat32Scalar(aa, bb)
  36. }
  37. func dotFloat32Scalar(a, b []float32) float32 {
  38. var sum float32
  39. for i := 0; i < len(a); i++ {
  40. sum += a[i] * b[i]
  41. }
  42. return sum
  43. }
  44. // Axpy performs y += alpha * x for float32 slices of equal length.
  45. // Intended for small vector adds in attention/value accumulation.
  46. func Axpy(alpha float32, x, y []float32) {
  47. if len(x) != len(y) {
  48. panic("Axpy: mismatched slice lengths")
  49. }
  50. if len(x) == 0 {
  51. return
  52. }
  53. if hasAVX512Kernel && SupportsAVX512() && len(x) >= 16 {
  54. axpyFloat32AVX512(alpha, x, y)
  55. return
  56. }
  57. if hasAVX2Kernel && SupportsAVX2() && len(x) >= 8 {
  58. axpyFloat32AVX2(alpha, x, y)
  59. return
  60. }
  61. for i := 0; i < len(x); i++ {
  62. y[i] += alpha * x[i]
  63. }
  64. }
  65. func AxpyPtr(alpha float32, x, y *float32, n int) {
  66. if n <= 0 {
  67. return
  68. }
  69. if x == nil || y == nil {
  70. panic("AxpyPtr: nil pointer")
  71. }
  72. if hasAVX512Kernel && SupportsAVX512() && n >= 16 {
  73. axpyAVX512(alpha, x, y, n)
  74. return
  75. }
  76. if hasAVX2Kernel && SupportsAVX2() && n >= 8 {
  77. axpyAVX2(alpha, x, y, n)
  78. return
  79. }
  80. xs := unsafe.Slice(x, n)
  81. ys := unsafe.Slice(y, n)
  82. for i := 0; i < n; i++ {
  83. ys[i] += alpha * xs[i]
  84. }
  85. }