1
0

config_test.go 659 B

1234567891011121314151617181920212223242526
  1. package quant
  2. import (
  3. "testing"
  4. )
  5. func TestMatchPattern(t *testing.T) {
  6. tests := []struct {
  7. name string
  8. pattern string
  9. expected bool
  10. }{
  11. {"model.embed_tokens.weight", "*embed_tokens*", true},
  12. {"model.layers.0.self_attn.v_proj.weight", "*v_proj*", true},
  13. {"model.layers.0.mlp.down_proj.weight", "*down_proj*", true},
  14. {"model.norm.weight", "*norm*", true},
  15. {"model.layers.0.self_attn.q_proj.weight", "*embed_tokens*", false},
  16. }
  17. for _, tt := range tests {
  18. result := matchPattern(tt.name, tt.pattern)
  19. if result != tt.expected {
  20. t.Errorf("matchPattern(%q, %q) = %v, want %v", tt.name, tt.pattern, result, tt.expected)
  21. }
  22. }
  23. }