| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778 |
- package compute
- import (
- "context"
- "makarna/pkg/backend/device"
- )
- type dispatcherContextKey struct{}
- func WithDispatcher(ctx context.Context, dispatcher *device.DeviceDispatcher) context.Context {
- if ctx == nil {
- ctx = context.Background()
- }
- return context.WithValue(ctx, dispatcherContextKey{}, dispatcher)
- }
- func DispatcherFromContext(ctx context.Context) *device.DeviceDispatcher {
- if ctx == nil {
- return nil
- }
- d, _ := ctx.Value(dispatcherContextKey{}).(*device.DeviceDispatcher)
- return d
- }
- type scratchContextKey struct{}
- func WithScratch(ctx context.Context, scratch *ScratchSpace) context.Context {
- if ctx == nil {
- ctx = context.Background()
- }
- return context.WithValue(ctx, scratchContextKey{}, scratch)
- }
- func ScratchFromContext(ctx context.Context) *ScratchSpace {
- if ctx == nil {
- return nil
- }
- s, _ := ctx.Value(scratchContextKey{}).(*ScratchSpace)
- return s
- }
- type scratchSetContextKey struct{}
- func WithScratchSet(ctx context.Context, scratch *ScratchSet) context.Context {
- if ctx == nil {
- ctx = context.Background()
- }
- return context.WithValue(ctx, scratchSetContextKey{}, scratch)
- }
- func ScratchSetFromContext(ctx context.Context) *ScratchSet {
- if ctx == nil {
- return nil
- }
- s, _ := ctx.Value(scratchSetContextKey{}).(*ScratchSet)
- return s
- }
- type cpuMoEContextKey struct{}
- // WithCPUMoE adds CPUMoE flag to context.
- // When true, MoE expert weights stay on CPU to save GPU memory.
- func WithCPUMoE(ctx context.Context, cpuMoE bool) context.Context {
- if ctx == nil {
- ctx = context.Background()
- }
- return context.WithValue(ctx, cpuMoEContextKey{}, cpuMoE)
- }
- // CPUMoEFromContext returns whether CPUMoE is enabled.
- func CPUMoEFromContext(ctx context.Context) bool {
- if ctx == nil {
- return false
- }
- v, _ := ctx.Value(cpuMoEContextKey{}).(bool)
- return v
- }
|