|
@@ -278,15 +278,14 @@ class ModelBase:
|
|
|
# The scale is inverted
|
|
# The scale is inverted
|
|
|
return data / scale.float()
|
|
return data / scale.float()
|
|
|
|
|
|
|
|
- def dequant_simple(weight: Tensor, scale: Tensor) -> Tensor:
|
|
|
|
|
|
|
+ def dequant_simple(weight: Tensor, scale: Tensor, block_size: Sequence[int] | None = None) -> Tensor:
|
|
|
scale = scale.float()
|
|
scale = scale.float()
|
|
|
|
|
|
|
|
- if (weight_block_size := quant_config.get("weight_block_size")):
|
|
|
|
|
- # TODO: make sure it's a list of integers
|
|
|
|
|
- for i, size in enumerate(weight_block_size):
|
|
|
|
|
|
|
+ if block_size is not None:
|
|
|
|
|
+ for i, size in enumerate(block_size):
|
|
|
scale = scale.repeat_interleave(size, i)
|
|
scale = scale.repeat_interleave(size, i)
|
|
|
- # unpad the scale (e.g. when the tensor size isn't a multiple of the block size)
|
|
|
|
|
- scale = scale[tuple(slice(0, size) for size in weight.shape)]
|
|
|
|
|
|
|
+ # unpad the scale (e.g. when the tensor size isn't a multiple of the block size)
|
|
|
|
|
+ scale = scale[tuple(slice(0, size) for size in weight.shape)]
|
|
|
|
|
|
|
|
return weight.float() * scale
|
|
return weight.float() * scale
|
|
|
|
|
|
|
@@ -333,6 +332,40 @@ class ModelBase:
|
|
|
|
|
|
|
|
return (scales[g_idx].float() * (weight - zeros[g_idx]).float()).T
|
|
return (scales[g_idx].float() * (weight - zeros[g_idx]).float()).T
|
|
|
|
|
|
|
|
|
|
+ def dequant_packed(w: Tensor, scale: Tensor, shape_tensor: Tensor, zero_point: Tensor | None, num_bits: int, group_size: int):
|
|
|
|
|
+ assert w.dtype == torch.int32
|
|
|
|
|
+ shape = tuple(shape_tensor.tolist())
|
|
|
|
|
+ assert len(shape) == 2
|
|
|
|
|
+ mask = (1 << num_bits) - 1
|
|
|
|
|
+
|
|
|
|
|
+ shifts = torch.arange(0, 32 - (num_bits - 1), num_bits, dtype=torch.int32)
|
|
|
|
|
+ if self.lazy:
|
|
|
|
|
+ shifts = LazyTorchTensor.from_eager(shifts)
|
|
|
|
|
+
|
|
|
|
|
+ if zero_point is None:
|
|
|
|
|
+ offset = 1 << (num_bits - 1)
|
|
|
|
|
+ else:
|
|
|
|
|
+ assert len(zero_point.shape) == 2
|
|
|
|
|
+ offset = (zero_point.unsqueeze(1) >> shifts.reshape(1, -1, 1)) & mask
|
|
|
|
|
+ offset = offset.reshape(-1, zero_point.shape[1])
|
|
|
|
|
+ # trim padding, and prepare for broadcast
|
|
|
|
|
+ # NOTE: the zero-point is packed along dim 0
|
|
|
|
|
+ offset = offset[:shape[0], :].unsqueeze(-1)
|
|
|
|
|
+
|
|
|
|
|
+ # extract values
|
|
|
|
|
+ # NOTE: the weights are packed along dim 1
|
|
|
|
|
+ unpacked = (w.unsqueeze(-1) >> shifts.reshape(1, 1, -1)) & mask
|
|
|
|
|
+ unpacked = unpacked.reshape(shape[0], -1)
|
|
|
|
|
+
|
|
|
|
|
+ # trim padding
|
|
|
|
|
+ unpacked = unpacked[:, :shape[1]]
|
|
|
|
|
+
|
|
|
|
|
+ # prepare for broadcast of the scale
|
|
|
|
|
+ unpacked = unpacked.reshape(shape[0], (unpacked.shape[-1] + group_size - 1) // group_size, group_size)
|
|
|
|
|
+ unpacked = unpacked - offset
|
|
|
|
|
+
|
|
|
|
|
+ return (unpacked * scale.unsqueeze(-1).float()).reshape(shape)
|
|
|
|
|
+
|
|
|
if quant_method == "bitnet":
|
|
if quant_method == "bitnet":
|
|
|
for name in self.model_tensors.keys():
|
|
for name in self.model_tensors.keys():
|
|
|
if name.endswith(".weight_scale"):
|
|
if name.endswith(".weight_scale"):
|
|
@@ -342,12 +375,13 @@ class ModelBase:
|
|
|
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_bitnet(w(), s())
|
|
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_bitnet(w(), s())
|
|
|
tensors_to_remove.append(name)
|
|
tensors_to_remove.append(name)
|
|
|
elif quant_method == "fp8":
|
|
elif quant_method == "fp8":
|
|
|
|
|
+ block_size = quant_config.get("weight_block_size")
|
|
|
for name in self.model_tensors.keys():
|
|
for name in self.model_tensors.keys():
|
|
|
if name.endswith(".weight_scale_inv"):
|
|
if name.endswith(".weight_scale_inv"):
|
|
|
weight_name = name.removesuffix("_scale_inv")
|
|
weight_name = name.removesuffix("_scale_inv")
|
|
|
w = self.model_tensors[weight_name]
|
|
w = self.model_tensors[weight_name]
|
|
|
s = self.model_tensors[name]
|
|
s = self.model_tensors[name]
|
|
|
- self.model_tensors[weight_name] = lambda w=w, s=s: dequant_simple(w(), s())
|
|
|
|
|
|
|
+ self.model_tensors[weight_name] = lambda w=w, s=s, bs=block_size: dequant_simple(w(), s(), bs)
|
|
|
tensors_to_remove.append(name)
|
|
tensors_to_remove.append(name)
|
|
|
elif quant_method == "gptq":
|
|
elif quant_method == "gptq":
|
|
|
for name in self.model_tensors.keys():
|
|
for name in self.model_tensors.keys():
|
|
@@ -371,6 +405,49 @@ class ModelBase:
|
|
|
".scales",
|
|
".scales",
|
|
|
)
|
|
)
|
|
|
]
|
|
]
|
|
|
|
|
+ elif quant_method == "compressed-tensors":
|
|
|
|
|
+ quant_format = quant_config["format"]
|
|
|
|
|
+ groups = quant_config["config_groups"]
|
|
|
|
|
+ if len(groups) > 1:
|
|
|
|
|
+ raise NotImplementedError("Can't handle multiple config groups for compressed-tensors yet")
|
|
|
|
|
+ weight_config = tuple(groups.values())[0]["weights"]
|
|
|
|
|
+
|
|
|
|
|
+ if quant_format == "float-quantized" or quant_format == "int-quantized" or quant_format == "naive-quantized":
|
|
|
|
|
+ block_size = weight_config.get("block_structure", None)
|
|
|
|
|
+ strategy = weight_config.get("strategy")
|
|
|
|
|
+ assert strategy == "channel" or strategy == "block"
|
|
|
|
|
+ assert weight_config.get("group_size") is None # didn't find a model using this yet
|
|
|
|
|
+ for name in self.model_tensors.keys():
|
|
|
|
|
+ if name.endswith(".weight_scale"):
|
|
|
|
|
+ weight_name = name.removesuffix("_scale")
|
|
|
|
|
+ w = self.model_tensors[weight_name]
|
|
|
|
|
+ s = self.model_tensors[name]
|
|
|
|
|
+ self.model_tensors[weight_name] = lambda w=w, s=s: dequant_simple(w(), s(), block_size)
|
|
|
|
|
+ tensors_to_remove.append(name)
|
|
|
|
|
+ elif quant_format == "pack-quantized":
|
|
|
|
|
+ assert weight_config.get("strategy") == "group"
|
|
|
|
|
+ assert weight_config.get("type", "int") == "int"
|
|
|
|
|
+ num_bits = weight_config.get("num_bits")
|
|
|
|
|
+ group_size = weight_config.get("group_size")
|
|
|
|
|
+ assert isinstance(num_bits, int)
|
|
|
|
|
+ assert isinstance(group_size, int)
|
|
|
|
|
+ for name in self.model_tensors.keys():
|
|
|
|
|
+ if name.endswith(".weight_packed"):
|
|
|
|
|
+ base_name = name.removesuffix("_packed")
|
|
|
|
|
+ w = self.model_tensors[name]
|
|
|
|
|
+ scale = self.model_tensors[base_name + "_scale"]
|
|
|
|
|
+ shape = self.model_tensors[base_name + "_shape"]
|
|
|
|
|
+ zero_point = self.model_tensors.get(base_name + "_zero_point", lambda: None)
|
|
|
|
|
+ new_tensors[base_name] = (
|
|
|
|
|
+ lambda w=w, scale=scale, shape=shape, zero_point=zero_point: dequant_packed(
|
|
|
|
|
+ w(), scale(), shape(), zero_point(), num_bits, group_size,
|
|
|
|
|
+ )
|
|
|
|
|
+ )
|
|
|
|
|
+ tensors_to_remove += [base_name + n for n in ("_packed", "_shape", "_scale")]
|
|
|
|
|
+ if (base_name + "_zero_point") in self.model_tensors:
|
|
|
|
|
+ tensors_to_remove.append(base_name + "_zero_point")
|
|
|
|
|
+ else:
|
|
|
|
|
+ raise NotImplementedError(f"Quant format {quant_format!r} for method {quant_method!r} is not yet supported")
|
|
|
else:
|
|
else:
|
|
|
raise NotImplementedError(f"Quant method is not yet supported: {quant_method!r}")
|
|
raise NotImplementedError(f"Quant method is not yet supported: {quant_method!r}")
|
|
|
|
|
|