| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051 |
- package convert
- import (
- "encoding/binary"
- "encoding/json"
- "testing"
- )
- func TestReadSafetensorsRaw_FP8(t *testing.T) {
- // Build a minimal safetensors buffer with dtype F8_E4M3.
- // Format: [u64 header_len][header_json][tensor_data]
- header := map[string]map[string]any{
- "t": {
- "dtype": "F8_E4M3",
- "shape": []uint64{4},
- "data_offsets": []uint64{0, 4},
- },
- }
- b, err := json.Marshal(header)
- if err != nil {
- t.Fatal(err)
- }
- buf := make([]byte, 8+len(b)+4)
- binary.LittleEndian.PutUint64(buf[0:8], uint64(len(b)))
- copy(buf[8:], b)
- // Four bytes of FP8 data (values are arbitrary; just ensure decode doesn't error)
- copy(buf[8+len(b):], []byte{0x00, 0x80, 0x7f, 0xff})
- tensors, err := readSafetensorsRaw(buf)
- if err != nil {
- t.Fatalf("readSafetensorsRaw failed: %v", err)
- }
- tt, ok := tensors["t"]
- if !ok {
- t.Fatalf("tensor not found")
- }
- if tt.DType != "F8_E4M3" {
- t.Fatalf("dtype mismatch: %q", tt.DType)
- }
- if len(tt.Data) != 4 {
- t.Fatalf("data len mismatch: %d", len(tt.Data))
- }
- floats, err := toFloat32(tt.Data, tt.DType)
- if err != nil {
- t.Fatalf("toFloat32 failed: %v", err)
- }
- if len(floats) != 4 {
- t.Fatalf("float len mismatch: %d", len(floats))
- }
- }
|