fp8_test.go 1.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. package convert
  2. import (
  3. "encoding/binary"
  4. "encoding/json"
  5. "testing"
  6. )
  7. func TestReadSafetensorsRaw_FP8(t *testing.T) {
  8. // Build a minimal safetensors buffer with dtype F8_E4M3.
  9. // Format: [u64 header_len][header_json][tensor_data]
  10. header := map[string]map[string]any{
  11. "t": {
  12. "dtype": "F8_E4M3",
  13. "shape": []uint64{4},
  14. "data_offsets": []uint64{0, 4},
  15. },
  16. }
  17. b, err := json.Marshal(header)
  18. if err != nil {
  19. t.Fatal(err)
  20. }
  21. buf := make([]byte, 8+len(b)+4)
  22. binary.LittleEndian.PutUint64(buf[0:8], uint64(len(b)))
  23. copy(buf[8:], b)
  24. // Four bytes of FP8 data (values are arbitrary; just ensure decode doesn't error)
  25. copy(buf[8+len(b):], []byte{0x00, 0x80, 0x7f, 0xff})
  26. tensors, err := readSafetensorsRaw(buf)
  27. if err != nil {
  28. t.Fatalf("readSafetensorsRaw failed: %v", err)
  29. }
  30. tt, ok := tensors["t"]
  31. if !ok {
  32. t.Fatalf("tensor not found")
  33. }
  34. if tt.DType != "F8_E4M3" {
  35. t.Fatalf("dtype mismatch: %q", tt.DType)
  36. }
  37. if len(tt.Data) != 4 {
  38. t.Fatalf("data len mismatch: %d", len(tt.Data))
  39. }
  40. floats, err := toFloat32(tt.Data, tt.DType)
  41. if err != nil {
  42. t.Fatalf("toFloat32 failed: %v", err)
  43. }
  44. if len(floats) != 4 {
  45. t.Fatalf("float len mismatch: %d", len(floats))
  46. }
  47. }