// Package profile provides performance profiling for the Makarna inference engine. // Tracks operation timings, CPU-GPU transfers, memory allocations, and more. // // Usage: // profile.Enable() // profile.Start("LinearQ8K") // // ... operation ... // profile.End("LinearQ8K") // profile.Report() package profile import ( "fmt" "io" "os" "sort" "strings" "sync" "sync/atomic" "time" ) // EventType categorizes profile events type EventType int const ( EventUnknown EventType = iota EventOp // Neural network operation (Linear, RMSNorm, etc.) EventH2D // Host to Device transfer EventD2H // Device to Host transfer EventD2D // Device to Device transfer EventAlloc // Memory allocation EventFree // Memory free EventSync // Device synchronization EventKernel // CUDA kernel launch ) func (e EventType) String() string { switch e { case EventOp: return "OP" case EventH2D: return "H2D" case EventD2H: return "D2H" case EventD2D: return "D2D" case EventAlloc: return "ALLOC" case EventFree: return "FREE" case EventSync: return "SYNC" case EventKernel: return "KERNEL" default: return "UNKNOWN" } } // Event represents a single profiled event type Event struct { Name string Type EventType StartTime time.Time Duration time.Duration Bytes int64 // For transfers/allocations GPU int Extra string // Additional context } // Stats holds aggregated statistics for an operation type Stats struct { Name string Type EventType Count int64 TotalTime time.Duration MinTime time.Duration MaxTime time.Duration TotalBytes int64 } func (s *Stats) AvgTime() time.Duration { if s.Count == 0 { return 0 } return s.TotalTime / time.Duration(s.Count) } // Profiler is the main profiling controller type Profiler struct { mu sync.Mutex enabled atomic.Bool output io.Writer logFile *os.File realtime bool // Print each event as it happens // Active spans for Start/End pattern activeSpans sync.Map // map[string]time.Time // Collected events events []Event // Aggregated stats by operation name stats map[string]*Stats // Per-token stats tokenEvents [][]Event tokenStart time.Time inToken bool curToken []Event tokenDurs []time.Duration // Counters totalH2DBytes atomic.Int64 totalD2HBytes atomic.Int64 totalH2DCount atomic.Int64 totalD2HCount atomic.Int64 } var globalProfiler = &Profiler{ output: os.Stderr, stats: make(map[string]*Stats), } // Enable turns on profiling func Enable() { globalProfiler.enabled.Store(true) } // Disable turns off profiling func Disable() { globalProfiler.enabled.Store(false) } // Enabled returns whether profiling is active func Enabled() bool { return globalProfiler.enabled.Load() } // SetOutput sets the output writer for profile logs func SetOutput(w io.Writer) { globalProfiler.mu.Lock() globalProfiler.output = w globalProfiler.mu.Unlock() } // SetLogFile sets a file path for profile output. Pass "" to close. func SetLogFile(path string) error { globalProfiler.mu.Lock() defer globalProfiler.mu.Unlock() // Close existing if globalProfiler.logFile != nil { globalProfiler.logFile.Close() globalProfiler.logFile = nil } if path == "" { globalProfiler.output = os.Stderr return nil } f, err := os.Create(path) if err != nil { return err } globalProfiler.logFile = f globalProfiler.output = f return nil } // SetRealtime enables/disables printing each event as it happens func SetRealtime(enabled bool) { globalProfiler.mu.Lock() globalProfiler.realtime = enabled globalProfiler.mu.Unlock() } // Start begins timing an operation func Start(name string) { if !globalProfiler.enabled.Load() { return } globalProfiler.activeSpans.Store(name, time.Now()) } // End finishes timing and records the event func End(name string) time.Duration { if !globalProfiler.enabled.Load() { return 0 } now := time.Now() v, ok := globalProfiler.activeSpans.LoadAndDelete(name) if !ok { return 0 } start := v.(time.Time) dur := now.Sub(start) e := Event{ Name: name, Type: EventOp, StartTime: start, Duration: dur, } recordEvent(e) return dur } // StartOp starts a named operation span with type func StartOp(name string, etype EventType) { if !globalProfiler.enabled.Load() { return } key := fmt.Sprintf("%s:%d", name, etype) globalProfiler.activeSpans.Store(key, time.Now()) } // EndOp ends a typed operation span func EndOp(name string, etype EventType) time.Duration { if !globalProfiler.enabled.Load() { return 0 } now := time.Now() key := fmt.Sprintf("%s:%d", name, etype) v, ok := globalProfiler.activeSpans.LoadAndDelete(key) if !ok { return 0 } start := v.(time.Time) dur := now.Sub(start) e := Event{ Name: name, Type: etype, StartTime: start, Duration: dur, } recordEvent(e) return dur } // RecordTransfer records a memory transfer event func RecordTransfer(name string, etype EventType, bytes int64, duration time.Duration, gpu int) { if !globalProfiler.enabled.Load() { return } e := Event{ Name: name, Type: etype, StartTime: time.Now().Add(-duration), Duration: duration, Bytes: bytes, GPU: gpu, } recordEvent(e) // Update counters switch etype { case EventH2D: globalProfiler.totalH2DBytes.Add(bytes) globalProfiler.totalH2DCount.Add(1) case EventD2H: globalProfiler.totalD2HBytes.Add(bytes) globalProfiler.totalD2HCount.Add(1) } } // RecordAlloc records a memory allocation func RecordAlloc(name string, bytes int64, duration time.Duration, gpu int) { if !globalProfiler.enabled.Load() { return } e := Event{ Name: name, Type: EventAlloc, StartTime: time.Now().Add(-duration), Duration: duration, Bytes: bytes, GPU: gpu, } recordEvent(e) } // Record records a complete event func Record(e Event) { if !globalProfiler.enabled.Load() { return } recordEvent(e) } // Instant records a point-in-time event (for logging) func Instant(name string, etype EventType, extra string) { if !globalProfiler.enabled.Load() { return } e := Event{ Name: name, Type: etype, StartTime: time.Now(), Extra: extra, } recordEvent(e) } // TokenStart marks the beginning of a new token generation func TokenStart() { if !globalProfiler.enabled.Load() { return } globalProfiler.mu.Lock() globalProfiler.tokenStart = time.Now() globalProfiler.inToken = true globalProfiler.curToken = globalProfiler.curToken[:0] globalProfiler.mu.Unlock() } // TokenEnd marks the end of a token generation func TokenEnd() { if !globalProfiler.enabled.Load() { return } globalProfiler.mu.Lock() if globalProfiler.inToken { globalProfiler.inToken = false dur := time.Since(globalProfiler.tokenStart) globalProfiler.tokenDurs = append(globalProfiler.tokenDurs, dur) globalProfiler.curToken = append(globalProfiler.curToken, Event{ Name: "Token/Total", Type: EventOp, StartTime: globalProfiler.tokenStart, Duration: dur, }) cp := make([]Event, len(globalProfiler.curToken)) copy(cp, globalProfiler.curToken) globalProfiler.tokenEvents = append(globalProfiler.tokenEvents, cp) globalProfiler.curToken = globalProfiler.curToken[:0] } globalProfiler.mu.Unlock() } func recordEvent(e Event) { globalProfiler.mu.Lock() defer globalProfiler.mu.Unlock() // Store event globalProfiler.events = append(globalProfiler.events, e) if globalProfiler.inToken { globalProfiler.curToken = append(globalProfiler.curToken, e) } // Update stats key := fmt.Sprintf("%s:%d", e.Name, e.Type) s, ok := globalProfiler.stats[key] if !ok { s = &Stats{ Name: e.Name, Type: e.Type, MinTime: e.Duration, MaxTime: e.Duration, } globalProfiler.stats[key] = s } s.Count++ s.TotalTime += e.Duration s.TotalBytes += e.Bytes if e.Duration < s.MinTime { s.MinTime = e.Duration } if e.Duration > s.MaxTime { s.MaxTime = e.Duration } // Realtime output if globalProfiler.realtime && globalProfiler.output != nil { printEvent(globalProfiler.output, e) } } func printEvent(w io.Writer, e Event) { var bytesStr string if e.Bytes > 0 { bytesStr = fmt.Sprintf(" bytes=%s", formatBytes(e.Bytes)) } var extraStr string if e.Extra != "" { extraStr = fmt.Sprintf(" (%s)", e.Extra) } fmt.Fprintf(w, "[PROFILE] %-6s %-30s %12s%s%s\n", e.Type.String(), e.Name, e.Duration, bytesStr, extraStr) } func formatBytes(b int64) string { const unit = 1024 if b < unit { return fmt.Sprintf("%d B", b) } div, exp := int64(unit), 0 for n := b / unit; n >= unit; n /= unit { div *= unit exp++ } return fmt.Sprintf("%.1f %cB", float64(b)/float64(div), "KMGTPE"[exp]) } // Reset clears all collected data func Reset() { globalProfiler.mu.Lock() globalProfiler.events = nil globalProfiler.stats = make(map[string]*Stats) globalProfiler.tokenEvents = nil globalProfiler.curToken = nil globalProfiler.tokenDurs = nil globalProfiler.totalH2DBytes.Store(0) globalProfiler.totalD2HBytes.Store(0) globalProfiler.totalH2DCount.Store(0) globalProfiler.totalD2HCount.Store(0) globalProfiler.mu.Unlock() } // Report prints a summary report func Report() { globalProfiler.mu.Lock() defer globalProfiler.mu.Unlock() w := globalProfiler.output if w == nil { w = os.Stderr } fmt.Fprintln(w, "") fmt.Fprintln(w, "╔══════════════════════════════════════════════════════════════════════════════╗") fmt.Fprintln(w, "║ PERFORMANCE PROFILE REPORT ║") fmt.Fprintln(w, "╠══════════════════════════════════════════════════════════════════════════════╣") // Summary totalEvents := len(globalProfiler.events) h2dBytes := globalProfiler.totalH2DBytes.Load() d2hBytes := globalProfiler.totalD2HBytes.Load() h2dCount := globalProfiler.totalH2DCount.Load() d2hCount := globalProfiler.totalD2HCount.Load() fmt.Fprintf(w, "║ Total Events: %-64d║\n", totalEvents) fmt.Fprintf(w, "║ H2D Transfers: %d (total %s)%s║\n", h2dCount, formatBytes(h2dBytes), strings.Repeat(" ", 44-len(fmt.Sprintf("%d (total %s)", h2dCount, formatBytes(h2dBytes))))) fmt.Fprintf(w, "║ D2H Transfers: %d (total %s)%s║\n", d2hCount, formatBytes(d2hBytes), strings.Repeat(" ", 44-len(fmt.Sprintf("%d (total %s)", d2hCount, formatBytes(d2hBytes))))) fmt.Fprintln(w, "╠══════════════════════════════════════════════════════════════════════════════╣") // Per-operation stats sorted by total time var statsList []*Stats for _, s := range globalProfiler.stats { statsList = append(statsList, s) } sort.Slice(statsList, func(i, j int) bool { return statsList[i].TotalTime > statsList[j].TotalTime }) fmt.Fprintln(w, "║ OPERATION BREAKDOWN ║") fmt.Fprintln(w, "╠═══════════════════════════════╦════════╦══════════════╦════════════╦═════════╣") fmt.Fprintln(w, "║ Operation ║ Count ║ Total Time ║ Avg Time ║ Type ║") fmt.Fprintln(w, "╠═══════════════════════════════╬════════╬══════════════╬════════════╬═════════╣") for _, s := range statsList { name := s.Name if len(name) > 29 { name = name[:26] + "..." } fmt.Fprintf(w, "║ %-29s ║ %6d ║ %12s ║ %10s ║ %-7s ║\n", name, s.Count, s.TotalTime, s.AvgTime(), s.Type.String()) } fmt.Fprintln(w, "╚═══════════════════════════════╩════════╩══════════════╩════════════╩═════════╝") // Transfer hotspots if h2dCount > 0 || d2hCount > 0 { fmt.Fprintln(w, "") fmt.Fprintln(w, "╔══════════════════════════════════════════════════════════════════════════════╗") fmt.Fprintln(w, "║ TRANSFER HOTSPOTS ║") fmt.Fprintln(w, "╠═══════════════════════════════╦════════╦══════════════╦════════════╦═════════╣") fmt.Fprintln(w, "║ Location ║ Count ║ Total Bytes ║ Total Time ║ Type ║") fmt.Fprintln(w, "╠═══════════════════════════════╬════════╬══════════════╬════════════╬═════════╣") for _, s := range statsList { if s.Type != EventH2D && s.Type != EventD2H && s.Type != EventD2D { continue } name := s.Name if len(name) > 29 { name = name[:26] + "..." } fmt.Fprintf(w, "║ %-29s ║ %6d ║ %12s ║ %10s ║ %-7s ║\n", name, s.Count, formatBytes(s.TotalBytes), s.TotalTime, s.Type.String()) } fmt.Fprintln(w, "╚═══════════════════════════════╩════════╩══════════════╩════════════╩═════════╝") } if len(globalProfiler.tokenDurs) > 0 { minTok := globalProfiler.tokenDurs[0] maxTok := globalProfiler.tokenDurs[0] var sumTok time.Duration for _, d := range globalProfiler.tokenDurs { sumTok += d if d < minTok { minTok = d } if d > maxTok { maxTok = d } } avgTok := sumTok / time.Duration(len(globalProfiler.tokenDurs)) perTok := make(map[string]*Stats) for _, tokEvents := range globalProfiler.tokenEvents { for _, e := range tokEvents { if e.Name == "Token/Total" { continue } key := fmt.Sprintf("%s:%d", e.Name, e.Type) s, ok := perTok[key] if !ok { s = &Stats{Name: e.Name, Type: e.Type, MinTime: e.Duration, MaxTime: e.Duration} perTok[key] = s } s.Count++ s.TotalTime += e.Duration s.TotalBytes += e.Bytes if e.Duration < s.MinTime { s.MinTime = e.Duration } if e.Duration > s.MaxTime { s.MaxTime = e.Duration } } } var tokStats []*Stats for _, s := range perTok { tokStats = append(tokStats, s) } sort.Slice(tokStats, func(i, j int) bool { return tokStats[i].TotalTime > tokStats[j].TotalTime }) fmt.Fprintln(w, "") fmt.Fprintln(w, "╔══════════════════════════════════════════════════════════════════════════════╗") fmt.Fprintln(w, "║ TOKEN BREAKDOWN ║") fmt.Fprintln(w, "╠══════════════════════════════════════════════════════════════════════════════╣") fmt.Fprintf(w, "║ Tokens: %-69d║\n", len(globalProfiler.tokenDurs)) fmt.Fprintf(w, "║ Token Time: avg=%-14s min=%-14s max=%-14s%24s║\n", avgTok, minTok, maxTok, "") fmt.Fprintln(w, "╠═══════════════════════════════╦════════╦══════════════╦════════════╦═════════╣") fmt.Fprintln(w, "║ Op (in-token) ║ Count ║ Total Time ║ Avg Time ║ Type ║") fmt.Fprintln(w, "╠═══════════════════════════════╬════════╬══════════════╬════════════╬═════════╣") limit := 30 if len(tokStats) < limit { limit = len(tokStats) } for i := 0; i < limit; i++ { s := tokStats[i] name := s.Name if len(name) > 29 { name = name[:26] + "..." } fmt.Fprintf(w, "║ %-29s ║ %6d ║ %12s ║ %10s ║ %-7s ║\n", name, s.Count, s.TotalTime, s.AvgTime(), s.Type.String()) } fmt.Fprintln(w, "╚═══════════════════════════════╩════════╩══════════════╩════════════╩═════════╝") } fmt.Fprintln(w, "") } // Events returns a copy of all collected events func Events() []Event { globalProfiler.mu.Lock() defer globalProfiler.mu.Unlock() result := make([]Event, len(globalProfiler.events)) copy(result, globalProfiler.events) return result } // Stats returns a copy of aggregated stats func GetStats() map[string]*Stats { globalProfiler.mu.Lock() defer globalProfiler.mu.Unlock() result := make(map[string]*Stats) for k, v := range globalProfiler.stats { cp := *v result[k] = &cp } return result } // Close cleans up resources func Close() { globalProfiler.mu.Lock() defer globalProfiler.mu.Unlock() if globalProfiler.logFile != nil { globalProfiler.logFile.Close() globalProfiler.logFile = nil } }