|
@@ -49,6 +49,7 @@ class TensorInfo:
|
|
|
class GGUFValue:
|
|
class GGUFValue:
|
|
|
value: Any
|
|
value: Any
|
|
|
type: GGUFValueType
|
|
type: GGUFValueType
|
|
|
|
|
+ sub_type: GGUFValueType | None = None
|
|
|
|
|
|
|
|
|
|
|
|
|
class WriterState(Enum):
|
|
class WriterState(Enum):
|
|
@@ -238,7 +239,7 @@ class GGUFWriter:
|
|
|
|
|
|
|
|
for key, val in kv_data.items():
|
|
for key, val in kv_data.items():
|
|
|
kv_bytes += self._pack_val(key, GGUFValueType.STRING, add_vtype=False)
|
|
kv_bytes += self._pack_val(key, GGUFValueType.STRING, add_vtype=False)
|
|
|
- kv_bytes += self._pack_val(val.value, val.type, add_vtype=True)
|
|
|
|
|
|
|
+ kv_bytes += self._pack_val(val.value, val.type, add_vtype=True, sub_type=val.sub_type)
|
|
|
|
|
|
|
|
fout.write(kv_bytes)
|
|
fout.write(kv_bytes)
|
|
|
|
|
|
|
@@ -268,11 +269,11 @@ class GGUFWriter:
|
|
|
fout.flush()
|
|
fout.flush()
|
|
|
self.state = WriterState.TI_DATA
|
|
self.state = WriterState.TI_DATA
|
|
|
|
|
|
|
|
- def add_key_value(self, key: str, val: Any, vtype: GGUFValueType) -> None:
|
|
|
|
|
|
|
+ def add_key_value(self, key: str, val: Any, vtype: GGUFValueType, sub_type: GGUFValueType | None = None) -> None:
|
|
|
if any(key in kv_data for kv_data in self.kv_data):
|
|
if any(key in kv_data for kv_data in self.kv_data):
|
|
|
raise ValueError(f'Duplicated key name {key!r}')
|
|
raise ValueError(f'Duplicated key name {key!r}')
|
|
|
|
|
|
|
|
- self.kv_data[0][key] = GGUFValue(value=val, type=vtype)
|
|
|
|
|
|
|
+ self.kv_data[0][key] = GGUFValue(value=val, type=vtype, sub_type=sub_type)
|
|
|
|
|
|
|
|
def add_uint8(self, key: str, val: int) -> None:
|
|
def add_uint8(self, key: str, val: int) -> None:
|
|
|
self.add_key_value(key,val, GGUFValueType.UINT8)
|
|
self.add_key_value(key,val, GGUFValueType.UINT8)
|
|
@@ -1022,7 +1023,7 @@ class GGUFWriter:
|
|
|
pack_prefix = '<' if self.endianess == GGUFEndian.LITTLE else '>'
|
|
pack_prefix = '<' if self.endianess == GGUFEndian.LITTLE else '>'
|
|
|
return struct.pack(f'{pack_prefix}{fmt}', value)
|
|
return struct.pack(f'{pack_prefix}{fmt}', value)
|
|
|
|
|
|
|
|
- def _pack_val(self, val: Any, vtype: GGUFValueType, add_vtype: bool) -> bytes:
|
|
|
|
|
|
|
+ def _pack_val(self, val: Any, vtype: GGUFValueType, add_vtype: bool, sub_type: GGUFValueType | None = None) -> bytes:
|
|
|
kv_data = bytearray()
|
|
kv_data = bytearray()
|
|
|
|
|
|
|
|
if add_vtype:
|
|
if add_vtype:
|
|
@@ -1043,7 +1044,9 @@ class GGUFWriter:
|
|
|
if len(val) == 0:
|
|
if len(val) == 0:
|
|
|
raise ValueError("Invalid GGUF metadata array. Empty array")
|
|
raise ValueError("Invalid GGUF metadata array. Empty array")
|
|
|
|
|
|
|
|
- if isinstance(val, bytes):
|
|
|
|
|
|
|
+ if sub_type is not None:
|
|
|
|
|
+ ltype = sub_type
|
|
|
|
|
+ elif isinstance(val, bytes):
|
|
|
ltype = GGUFValueType.UINT8
|
|
ltype = GGUFValueType.UINT8
|
|
|
else:
|
|
else:
|
|
|
ltype = GGUFValueType.get_type(val[0])
|
|
ltype = GGUFValueType.get_type(val[0])
|