profiler.go 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602
  1. // Package profile provides performance profiling for the Makarna inference engine.
  2. // Tracks operation timings, CPU-GPU transfers, memory allocations, and more.
  3. //
  4. // Usage:
  5. // profile.Enable()
  6. // profile.Start("LinearQ8K")
  7. // // ... operation ...
  8. // profile.End("LinearQ8K")
  9. // profile.Report()
  10. package profile
  11. import (
  12. "fmt"
  13. "io"
  14. "os"
  15. "sort"
  16. "strings"
  17. "sync"
  18. "sync/atomic"
  19. "time"
  20. )
  21. // EventType categorizes profile events
  22. type EventType int
  23. const (
  24. EventUnknown EventType = iota
  25. EventOp // Neural network operation (Linear, RMSNorm, etc.)
  26. EventH2D // Host to Device transfer
  27. EventD2H // Device to Host transfer
  28. EventD2D // Device to Device transfer
  29. EventAlloc // Memory allocation
  30. EventFree // Memory free
  31. EventSync // Device synchronization
  32. EventKernel // CUDA kernel launch
  33. )
  34. func (e EventType) String() string {
  35. switch e {
  36. case EventOp:
  37. return "OP"
  38. case EventH2D:
  39. return "H2D"
  40. case EventD2H:
  41. return "D2H"
  42. case EventD2D:
  43. return "D2D"
  44. case EventAlloc:
  45. return "ALLOC"
  46. case EventFree:
  47. return "FREE"
  48. case EventSync:
  49. return "SYNC"
  50. case EventKernel:
  51. return "KERNEL"
  52. default:
  53. return "UNKNOWN"
  54. }
  55. }
  56. // Event represents a single profiled event
  57. type Event struct {
  58. Name string
  59. Type EventType
  60. StartTime time.Time
  61. Duration time.Duration
  62. Bytes int64 // For transfers/allocations
  63. GPU int
  64. Extra string // Additional context
  65. }
  66. // Stats holds aggregated statistics for an operation
  67. type Stats struct {
  68. Name string
  69. Type EventType
  70. Count int64
  71. TotalTime time.Duration
  72. MinTime time.Duration
  73. MaxTime time.Duration
  74. TotalBytes int64
  75. }
  76. func (s *Stats) AvgTime() time.Duration {
  77. if s.Count == 0 {
  78. return 0
  79. }
  80. return s.TotalTime / time.Duration(s.Count)
  81. }
  82. // Profiler is the main profiling controller
  83. type Profiler struct {
  84. mu sync.Mutex
  85. enabled atomic.Bool
  86. output io.Writer
  87. logFile *os.File
  88. realtime bool // Print each event as it happens
  89. // Active spans for Start/End pattern
  90. activeSpans sync.Map // map[string]time.Time
  91. // Collected events
  92. events []Event
  93. // Aggregated stats by operation name
  94. stats map[string]*Stats
  95. // Per-token stats
  96. tokenEvents [][]Event
  97. tokenStart time.Time
  98. inToken bool
  99. curToken []Event
  100. tokenDurs []time.Duration
  101. // Counters
  102. totalH2DBytes atomic.Int64
  103. totalD2HBytes atomic.Int64
  104. totalH2DCount atomic.Int64
  105. totalD2HCount atomic.Int64
  106. }
  107. var globalProfiler = &Profiler{
  108. output: os.Stderr,
  109. stats: make(map[string]*Stats),
  110. }
  111. // Enable turns on profiling
  112. func Enable() {
  113. globalProfiler.enabled.Store(true)
  114. }
  115. // Disable turns off profiling
  116. func Disable() {
  117. globalProfiler.enabled.Store(false)
  118. }
  119. // Enabled returns whether profiling is active
  120. func Enabled() bool {
  121. return globalProfiler.enabled.Load()
  122. }
  123. // SetOutput sets the output writer for profile logs
  124. func SetOutput(w io.Writer) {
  125. globalProfiler.mu.Lock()
  126. globalProfiler.output = w
  127. globalProfiler.mu.Unlock()
  128. }
  129. // SetLogFile sets a file path for profile output. Pass "" to close.
  130. func SetLogFile(path string) error {
  131. globalProfiler.mu.Lock()
  132. defer globalProfiler.mu.Unlock()
  133. // Close existing
  134. if globalProfiler.logFile != nil {
  135. globalProfiler.logFile.Close()
  136. globalProfiler.logFile = nil
  137. }
  138. if path == "" {
  139. globalProfiler.output = os.Stderr
  140. return nil
  141. }
  142. f, err := os.Create(path)
  143. if err != nil {
  144. return err
  145. }
  146. globalProfiler.logFile = f
  147. globalProfiler.output = f
  148. return nil
  149. }
  150. // SetRealtime enables/disables printing each event as it happens
  151. func SetRealtime(enabled bool) {
  152. globalProfiler.mu.Lock()
  153. globalProfiler.realtime = enabled
  154. globalProfiler.mu.Unlock()
  155. }
  156. // Start begins timing an operation
  157. func Start(name string) {
  158. if !globalProfiler.enabled.Load() {
  159. return
  160. }
  161. globalProfiler.activeSpans.Store(name, time.Now())
  162. }
  163. // End finishes timing and records the event
  164. func End(name string) time.Duration {
  165. if !globalProfiler.enabled.Load() {
  166. return 0
  167. }
  168. now := time.Now()
  169. v, ok := globalProfiler.activeSpans.LoadAndDelete(name)
  170. if !ok {
  171. return 0
  172. }
  173. start := v.(time.Time)
  174. dur := now.Sub(start)
  175. e := Event{
  176. Name: name,
  177. Type: EventOp,
  178. StartTime: start,
  179. Duration: dur,
  180. }
  181. recordEvent(e)
  182. return dur
  183. }
  184. // StartOp starts a named operation span with type
  185. func StartOp(name string, etype EventType) {
  186. if !globalProfiler.enabled.Load() {
  187. return
  188. }
  189. key := fmt.Sprintf("%s:%d", name, etype)
  190. globalProfiler.activeSpans.Store(key, time.Now())
  191. }
  192. // EndOp ends a typed operation span
  193. func EndOp(name string, etype EventType) time.Duration {
  194. if !globalProfiler.enabled.Load() {
  195. return 0
  196. }
  197. now := time.Now()
  198. key := fmt.Sprintf("%s:%d", name, etype)
  199. v, ok := globalProfiler.activeSpans.LoadAndDelete(key)
  200. if !ok {
  201. return 0
  202. }
  203. start := v.(time.Time)
  204. dur := now.Sub(start)
  205. e := Event{
  206. Name: name,
  207. Type: etype,
  208. StartTime: start,
  209. Duration: dur,
  210. }
  211. recordEvent(e)
  212. return dur
  213. }
  214. // RecordTransfer records a memory transfer event
  215. func RecordTransfer(name string, etype EventType, bytes int64, duration time.Duration, gpu int) {
  216. if !globalProfiler.enabled.Load() {
  217. return
  218. }
  219. e := Event{
  220. Name: name,
  221. Type: etype,
  222. StartTime: time.Now().Add(-duration),
  223. Duration: duration,
  224. Bytes: bytes,
  225. GPU: gpu,
  226. }
  227. recordEvent(e)
  228. // Update counters
  229. switch etype {
  230. case EventH2D:
  231. globalProfiler.totalH2DBytes.Add(bytes)
  232. globalProfiler.totalH2DCount.Add(1)
  233. case EventD2H:
  234. globalProfiler.totalD2HBytes.Add(bytes)
  235. globalProfiler.totalD2HCount.Add(1)
  236. }
  237. }
  238. // RecordAlloc records a memory allocation
  239. func RecordAlloc(name string, bytes int64, duration time.Duration, gpu int) {
  240. if !globalProfiler.enabled.Load() {
  241. return
  242. }
  243. e := Event{
  244. Name: name,
  245. Type: EventAlloc,
  246. StartTime: time.Now().Add(-duration),
  247. Duration: duration,
  248. Bytes: bytes,
  249. GPU: gpu,
  250. }
  251. recordEvent(e)
  252. }
  253. // Record records a complete event
  254. func Record(e Event) {
  255. if !globalProfiler.enabled.Load() {
  256. return
  257. }
  258. recordEvent(e)
  259. }
  260. // Instant records a point-in-time event (for logging)
  261. func Instant(name string, etype EventType, extra string) {
  262. if !globalProfiler.enabled.Load() {
  263. return
  264. }
  265. e := Event{
  266. Name: name,
  267. Type: etype,
  268. StartTime: time.Now(),
  269. Extra: extra,
  270. }
  271. recordEvent(e)
  272. }
  273. // TokenStart marks the beginning of a new token generation
  274. func TokenStart() {
  275. if !globalProfiler.enabled.Load() {
  276. return
  277. }
  278. globalProfiler.mu.Lock()
  279. globalProfiler.tokenStart = time.Now()
  280. globalProfiler.inToken = true
  281. globalProfiler.curToken = globalProfiler.curToken[:0]
  282. globalProfiler.mu.Unlock()
  283. }
  284. // TokenEnd marks the end of a token generation
  285. func TokenEnd() {
  286. if !globalProfiler.enabled.Load() {
  287. return
  288. }
  289. globalProfiler.mu.Lock()
  290. if globalProfiler.inToken {
  291. globalProfiler.inToken = false
  292. dur := time.Since(globalProfiler.tokenStart)
  293. globalProfiler.tokenDurs = append(globalProfiler.tokenDurs, dur)
  294. globalProfiler.curToken = append(globalProfiler.curToken, Event{
  295. Name: "Token/Total",
  296. Type: EventOp,
  297. StartTime: globalProfiler.tokenStart,
  298. Duration: dur,
  299. })
  300. cp := make([]Event, len(globalProfiler.curToken))
  301. copy(cp, globalProfiler.curToken)
  302. globalProfiler.tokenEvents = append(globalProfiler.tokenEvents, cp)
  303. globalProfiler.curToken = globalProfiler.curToken[:0]
  304. }
  305. globalProfiler.mu.Unlock()
  306. }
  307. func recordEvent(e Event) {
  308. globalProfiler.mu.Lock()
  309. defer globalProfiler.mu.Unlock()
  310. // Store event
  311. globalProfiler.events = append(globalProfiler.events, e)
  312. if globalProfiler.inToken {
  313. globalProfiler.curToken = append(globalProfiler.curToken, e)
  314. }
  315. // Update stats
  316. key := fmt.Sprintf("%s:%d", e.Name, e.Type)
  317. s, ok := globalProfiler.stats[key]
  318. if !ok {
  319. s = &Stats{
  320. Name: e.Name,
  321. Type: e.Type,
  322. MinTime: e.Duration,
  323. MaxTime: e.Duration,
  324. }
  325. globalProfiler.stats[key] = s
  326. }
  327. s.Count++
  328. s.TotalTime += e.Duration
  329. s.TotalBytes += e.Bytes
  330. if e.Duration < s.MinTime {
  331. s.MinTime = e.Duration
  332. }
  333. if e.Duration > s.MaxTime {
  334. s.MaxTime = e.Duration
  335. }
  336. // Realtime output
  337. if globalProfiler.realtime && globalProfiler.output != nil {
  338. printEvent(globalProfiler.output, e)
  339. }
  340. }
  341. func printEvent(w io.Writer, e Event) {
  342. var bytesStr string
  343. if e.Bytes > 0 {
  344. bytesStr = fmt.Sprintf(" bytes=%s", formatBytes(e.Bytes))
  345. }
  346. var extraStr string
  347. if e.Extra != "" {
  348. extraStr = fmt.Sprintf(" (%s)", e.Extra)
  349. }
  350. fmt.Fprintf(w, "[PROFILE] %-6s %-30s %12s%s%s\n",
  351. e.Type.String(), e.Name, e.Duration, bytesStr, extraStr)
  352. }
  353. func formatBytes(b int64) string {
  354. const unit = 1024
  355. if b < unit {
  356. return fmt.Sprintf("%d B", b)
  357. }
  358. div, exp := int64(unit), 0
  359. for n := b / unit; n >= unit; n /= unit {
  360. div *= unit
  361. exp++
  362. }
  363. return fmt.Sprintf("%.1f %cB", float64(b)/float64(div), "KMGTPE"[exp])
  364. }
  365. // Reset clears all collected data
  366. func Reset() {
  367. globalProfiler.mu.Lock()
  368. globalProfiler.events = nil
  369. globalProfiler.stats = make(map[string]*Stats)
  370. globalProfiler.tokenEvents = nil
  371. globalProfiler.curToken = nil
  372. globalProfiler.tokenDurs = nil
  373. globalProfiler.totalH2DBytes.Store(0)
  374. globalProfiler.totalD2HBytes.Store(0)
  375. globalProfiler.totalH2DCount.Store(0)
  376. globalProfiler.totalD2HCount.Store(0)
  377. globalProfiler.mu.Unlock()
  378. }
  379. // Report prints a summary report
  380. func Report() {
  381. globalProfiler.mu.Lock()
  382. defer globalProfiler.mu.Unlock()
  383. w := globalProfiler.output
  384. if w == nil {
  385. w = os.Stderr
  386. }
  387. fmt.Fprintln(w, "")
  388. fmt.Fprintln(w, "╔══════════════════════════════════════════════════════════════════════════════╗")
  389. fmt.Fprintln(w, "║ PERFORMANCE PROFILE REPORT ║")
  390. fmt.Fprintln(w, "╠══════════════════════════════════════════════════════════════════════════════╣")
  391. // Summary
  392. totalEvents := len(globalProfiler.events)
  393. h2dBytes := globalProfiler.totalH2DBytes.Load()
  394. d2hBytes := globalProfiler.totalD2HBytes.Load()
  395. h2dCount := globalProfiler.totalH2DCount.Load()
  396. d2hCount := globalProfiler.totalD2HCount.Load()
  397. fmt.Fprintf(w, "║ Total Events: %-64d║\n", totalEvents)
  398. fmt.Fprintf(w, "║ H2D Transfers: %d (total %s)%s║\n",
  399. h2dCount, formatBytes(h2dBytes), strings.Repeat(" ", 44-len(fmt.Sprintf("%d (total %s)", h2dCount, formatBytes(h2dBytes)))))
  400. fmt.Fprintf(w, "║ D2H Transfers: %d (total %s)%s║\n",
  401. d2hCount, formatBytes(d2hBytes), strings.Repeat(" ", 44-len(fmt.Sprintf("%d (total %s)", d2hCount, formatBytes(d2hBytes)))))
  402. fmt.Fprintln(w, "╠══════════════════════════════════════════════════════════════════════════════╣")
  403. // Per-operation stats sorted by total time
  404. var statsList []*Stats
  405. for _, s := range globalProfiler.stats {
  406. statsList = append(statsList, s)
  407. }
  408. sort.Slice(statsList, func(i, j int) bool {
  409. return statsList[i].TotalTime > statsList[j].TotalTime
  410. })
  411. fmt.Fprintln(w, "║ OPERATION BREAKDOWN ║")
  412. fmt.Fprintln(w, "╠═══════════════════════════════╦════════╦══════════════╦════════════╦═════════╣")
  413. fmt.Fprintln(w, "║ Operation ║ Count ║ Total Time ║ Avg Time ║ Type ║")
  414. fmt.Fprintln(w, "╠═══════════════════════════════╬════════╬══════════════╬════════════╬═════════╣")
  415. for _, s := range statsList {
  416. name := s.Name
  417. if len(name) > 29 {
  418. name = name[:26] + "..."
  419. }
  420. fmt.Fprintf(w, "║ %-29s ║ %6d ║ %12s ║ %10s ║ %-7s ║\n",
  421. name, s.Count, s.TotalTime, s.AvgTime(), s.Type.String())
  422. }
  423. fmt.Fprintln(w, "╚═══════════════════════════════╩════════╩══════════════╩════════════╩═════════╝")
  424. // Transfer hotspots
  425. if h2dCount > 0 || d2hCount > 0 {
  426. fmt.Fprintln(w, "")
  427. fmt.Fprintln(w, "╔══════════════════════════════════════════════════════════════════════════════╗")
  428. fmt.Fprintln(w, "║ TRANSFER HOTSPOTS ║")
  429. fmt.Fprintln(w, "╠═══════════════════════════════╦════════╦══════════════╦════════════╦═════════╣")
  430. fmt.Fprintln(w, "║ Location ║ Count ║ Total Bytes ║ Total Time ║ Type ║")
  431. fmt.Fprintln(w, "╠═══════════════════════════════╬════════╬══════════════╬════════════╬═════════╣")
  432. for _, s := range statsList {
  433. if s.Type != EventH2D && s.Type != EventD2H && s.Type != EventD2D {
  434. continue
  435. }
  436. name := s.Name
  437. if len(name) > 29 {
  438. name = name[:26] + "..."
  439. }
  440. fmt.Fprintf(w, "║ %-29s ║ %6d ║ %12s ║ %10s ║ %-7s ║\n",
  441. name, s.Count, formatBytes(s.TotalBytes), s.TotalTime, s.Type.String())
  442. }
  443. fmt.Fprintln(w, "╚═══════════════════════════════╩════════╩══════════════╩════════════╩═════════╝")
  444. }
  445. if len(globalProfiler.tokenDurs) > 0 {
  446. minTok := globalProfiler.tokenDurs[0]
  447. maxTok := globalProfiler.tokenDurs[0]
  448. var sumTok time.Duration
  449. for _, d := range globalProfiler.tokenDurs {
  450. sumTok += d
  451. if d < minTok {
  452. minTok = d
  453. }
  454. if d > maxTok {
  455. maxTok = d
  456. }
  457. }
  458. avgTok := sumTok / time.Duration(len(globalProfiler.tokenDurs))
  459. perTok := make(map[string]*Stats)
  460. for _, tokEvents := range globalProfiler.tokenEvents {
  461. for _, e := range tokEvents {
  462. if e.Name == "Token/Total" {
  463. continue
  464. }
  465. key := fmt.Sprintf("%s:%d", e.Name, e.Type)
  466. s, ok := perTok[key]
  467. if !ok {
  468. s = &Stats{Name: e.Name, Type: e.Type, MinTime: e.Duration, MaxTime: e.Duration}
  469. perTok[key] = s
  470. }
  471. s.Count++
  472. s.TotalTime += e.Duration
  473. s.TotalBytes += e.Bytes
  474. if e.Duration < s.MinTime {
  475. s.MinTime = e.Duration
  476. }
  477. if e.Duration > s.MaxTime {
  478. s.MaxTime = e.Duration
  479. }
  480. }
  481. }
  482. var tokStats []*Stats
  483. for _, s := range perTok {
  484. tokStats = append(tokStats, s)
  485. }
  486. sort.Slice(tokStats, func(i, j int) bool { return tokStats[i].TotalTime > tokStats[j].TotalTime })
  487. fmt.Fprintln(w, "")
  488. fmt.Fprintln(w, "╔══════════════════════════════════════════════════════════════════════════════╗")
  489. fmt.Fprintln(w, "║ TOKEN BREAKDOWN ║")
  490. fmt.Fprintln(w, "╠══════════════════════════════════════════════════════════════════════════════╣")
  491. fmt.Fprintf(w, "║ Tokens: %-69d║\n", len(globalProfiler.tokenDurs))
  492. fmt.Fprintf(w, "║ Token Time: avg=%-14s min=%-14s max=%-14s%24s║\n", avgTok, minTok, maxTok, "")
  493. fmt.Fprintln(w, "╠═══════════════════════════════╦════════╦══════════════╦════════════╦═════════╣")
  494. fmt.Fprintln(w, "║ Op (in-token) ║ Count ║ Total Time ║ Avg Time ║ Type ║")
  495. fmt.Fprintln(w, "╠═══════════════════════════════╬════════╬══════════════╬════════════╬═════════╣")
  496. limit := 30
  497. if len(tokStats) < limit {
  498. limit = len(tokStats)
  499. }
  500. for i := 0; i < limit; i++ {
  501. s := tokStats[i]
  502. name := s.Name
  503. if len(name) > 29 {
  504. name = name[:26] + "..."
  505. }
  506. fmt.Fprintf(w, "║ %-29s ║ %6d ║ %12s ║ %10s ║ %-7s ║\n",
  507. name, s.Count, s.TotalTime, s.AvgTime(), s.Type.String())
  508. }
  509. fmt.Fprintln(w, "╚═══════════════════════════════╩════════╩══════════════╩════════════╩═════════╝")
  510. }
  511. fmt.Fprintln(w, "")
  512. }
  513. // Events returns a copy of all collected events
  514. func Events() []Event {
  515. globalProfiler.mu.Lock()
  516. defer globalProfiler.mu.Unlock()
  517. result := make([]Event, len(globalProfiler.events))
  518. copy(result, globalProfiler.events)
  519. return result
  520. }
  521. // Stats returns a copy of aggregated stats
  522. func GetStats() map[string]*Stats {
  523. globalProfiler.mu.Lock()
  524. defer globalProfiler.mu.Unlock()
  525. result := make(map[string]*Stats)
  526. for k, v := range globalProfiler.stats {
  527. cp := *v
  528. result[k] = &cp
  529. }
  530. return result
  531. }
  532. // Close cleans up resources
  533. func Close() {
  534. globalProfiler.mu.Lock()
  535. defer globalProfiler.mu.Unlock()
  536. if globalProfiler.logFile != nil {
  537. globalProfiler.logFile.Close()
  538. globalProfiler.logFile = nil
  539. }
  540. }