dispatcher_context.go 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. package compute
  2. import (
  3. "context"
  4. "makarna/pkg/backend/device"
  5. )
  6. type dispatcherContextKey struct{}
  7. func WithDispatcher(ctx context.Context, dispatcher *device.DeviceDispatcher) context.Context {
  8. if ctx == nil {
  9. ctx = context.Background()
  10. }
  11. return context.WithValue(ctx, dispatcherContextKey{}, dispatcher)
  12. }
  13. func DispatcherFromContext(ctx context.Context) *device.DeviceDispatcher {
  14. if ctx == nil {
  15. return nil
  16. }
  17. d, _ := ctx.Value(dispatcherContextKey{}).(*device.DeviceDispatcher)
  18. return d
  19. }
  20. type scratchContextKey struct{}
  21. func WithScratch(ctx context.Context, scratch *ScratchSpace) context.Context {
  22. if ctx == nil {
  23. ctx = context.Background()
  24. }
  25. return context.WithValue(ctx, scratchContextKey{}, scratch)
  26. }
  27. func ScratchFromContext(ctx context.Context) *ScratchSpace {
  28. if ctx == nil {
  29. return nil
  30. }
  31. s, _ := ctx.Value(scratchContextKey{}).(*ScratchSpace)
  32. return s
  33. }
  34. type scratchSetContextKey struct{}
  35. func WithScratchSet(ctx context.Context, scratch *ScratchSet) context.Context {
  36. if ctx == nil {
  37. ctx = context.Background()
  38. }
  39. return context.WithValue(ctx, scratchSetContextKey{}, scratch)
  40. }
  41. func ScratchSetFromContext(ctx context.Context) *ScratchSet {
  42. if ctx == nil {
  43. return nil
  44. }
  45. s, _ := ctx.Value(scratchSetContextKey{}).(*ScratchSet)
  46. return s
  47. }
  48. type cpuMoEContextKey struct{}
  49. // WithCPUMoE adds CPUMoE flag to context.
  50. // When true, MoE expert weights stay on CPU to save GPU memory.
  51. func WithCPUMoE(ctx context.Context, cpuMoE bool) context.Context {
  52. if ctx == nil {
  53. ctx = context.Background()
  54. }
  55. return context.WithValue(ctx, cpuMoEContextKey{}, cpuMoE)
  56. }
  57. // CPUMoEFromContext returns whether CPUMoE is enabled.
  58. func CPUMoEFromContext(ctx context.Context) bool {
  59. if ctx == nil {
  60. return false
  61. }
  62. v, _ := ctx.Value(cpuMoEContextKey{}).(bool)
  63. return v
  64. }