[2' read]
Dequantization embeddings
Dequantization
This notebook demonstrates how to dequantize various quantized representations (int8, uint8, binary, ubinary) of embeddings back to float values using PyTorch and NumPy.
Installation
Install the required packages:
!pip install torch numpydefine the dequantization functions
import torch
import numpy as np
def _unpack_sign_bits(x_packed: torch.Tensor, signed: bool, C: int) -> torch.Tensor:
"""
Unpacks a bit-packed tensor into a ±1 float32 tensor.
Parameters:
----------
x_packed : torch.Tensor
A tensor containing packed bit values (usually uint8) representing the sign bits.
signed : bool
Indicates if the original data was signed ('binary') or unsigned ('ubinary').
C : int
The number of original channels or dimensions to unpack (used to trim padding bits).
Returns:
-------
torch.Tensor
A float32 tensor on the same device as x_packed, with values ±1 and shape (..., C).
"""
arr = (x_packed.cpu().numpy() + (128 if signed else 0)).astype(np.uint8)
bits = np.unpackbits(arr, axis=-1)[..., :C] # remove pad bits
return torch.from_numpy(bits * 2 - 1).float().to(x_packed.device)
def dequantize(
q: torch.Tensor | list, quant: str, orig_dim: int | None = None
) -> torch.Tensor:
"""
Dequantizes a quantized tensor or list back to float32 values in [-1, 1].
Parameters:
----------
q : torch.Tensor or list
The quantized data, either as a tensor or list.
quant : str
The quantization type. Supported values:
- 'fp8': already float, just converted.
- 'int8': scaled by 127.
- 'uint8': scaled and shifted to [-1,1].
- 'binary' or 'ubinary': unpacked sign bits (requires orig_dim).
orig_dim : int or None, optional
The original number of dimensions/channels before packing,
required when quant is 'binary' or 'ubinary' to correctly unpack bits.
Returns:
-------
torch.Tensor
A float32 tensor of shape (B, T, C) with values in [-1, 1].
Raises:
------
ValueError
If an unsupported quantization type is provided or if `orig_dim` is missing
for 'binary'/'ubinary' unpacking.
"""
if isinstance(q, list):
q = torch.tensor(q)
if quant == "fp8":
return q.float()
if quant == "int8":
return q.float() / 127.0
if quant == "uint8":
return q.float() / 127.5 - 1.0
if quant in {"binary", "ubinary"}:
if orig_dim is None:
raise ValueError("orig_dim needed for (u)binary unpack")
return _unpack_sign_bits(q, quant == "binary", orig_dim)
raise ValueError(f"Invalid quantization {quant}")Examples
embed_float = [-0.11944580078125,-0.2734375,0.040771484375,0.3056640625,-0.1470947265625,-0.11749267578125,0.0799560546875,0.08282470703125,-0.04205322265625,0.220947265625,0.0015048980712890625,-0.00397491455078125,-0.01099395751953125,-0.052642822265625,0.0504150390625,0.01605224609375,0.029693603515625,-0.024078369140625]
embed_int = [-15,-35,5,39,-19,-15,10,11,-5,28,0,0,-2,-7,6,2,4,-3]
embed_uint = [112,93,133,166,109,112,138,138,122,156,128,127,126,121,134,130,131,124]
embed_bin = [-77,-29,0]
embed_ubin = [51,99,128]dequantize(embed_bin, quant="binary", orig_dim=18)dequantize(embed_ubin, quant="ubinary", orig_dim=18)dequantize(embed_int, quant="int8")dequantize(embed_uint, quant="uint8")