Dequantization embeddings

Dequantization

This notebook demonstrates how to dequantize various quantized representations (int8, uint8, binary, ubinary) of embeddings back to float values using PyTorch and NumPy.

Installation

Install the required packages:

!pip install torch numpy

define the dequantization functions

import torch
import numpy as np

def _unpack_sign_bits(x_packed: torch.Tensor, signed: bool, C: int) -> torch.Tensor:
    """
    Unpacks a bit-packed tensor into a ±1 float32 tensor.

    Parameters:
    ----------
    x_packed : torch.Tensor
        A tensor containing packed bit values (usually uint8) representing the sign bits.
    signed : bool
        Indicates if the original data was signed ('binary') or unsigned ('ubinary').
    C : int
        The number of original channels or dimensions to unpack (used to trim padding bits).

    Returns:
    -------
    torch.Tensor
        A float32 tensor on the same device as x_packed, with values ±1 and shape (..., C).
    """
    arr = (x_packed.cpu().numpy() + (128 if signed else 0)).astype(np.uint8)
    bits = np.unpackbits(arr, axis=-1)[..., :C]  # remove pad bits
    return torch.from_numpy(bits * 2 - 1).float().to(x_packed.device)

def dequantize(
    q: torch.Tensor | list, quant: str, orig_dim: int | None = None
) -> torch.Tensor:
    """
    Dequantizes a quantized tensor or list back to float32 values in [-1, 1].

    Parameters:
    ----------
    q : torch.Tensor or list
        The quantized data, either as a tensor or list.
    quant : str
        The quantization type. Supported values:
        - 'fp8': already float, just converted.
        - 'int8': scaled by 127.
        - 'uint8': scaled and shifted to [-1,1].
        - 'binary' or 'ubinary': unpacked sign bits (requires orig_dim).
    orig_dim : int or None, optional
        The original number of dimensions/channels before packing,
        required when quant is 'binary' or 'ubinary' to correctly unpack bits.

    Returns:
    -------
    torch.Tensor
        A float32 tensor of shape (B, T, C) with values in [-1, 1].

    Raises:
    ------
    ValueError
        If an unsupported quantization type is provided or if `orig_dim` is missing
        for 'binary'/'ubinary' unpacking.
    """
    if isinstance(q, list):
        q = torch.tensor(q)
    if quant == "fp8":
        return q.float()
    if quant == "int8":
        return q.float() / 127.0
    if quant == "uint8":
        return q.float() / 127.5 - 1.0
    if quant in {"binary", "ubinary"}:
        if orig_dim is None:
            raise ValueError("orig_dim needed for (u)binary unpack")
        return _unpack_sign_bits(q, quant == "binary", orig_dim)
    raise ValueError(f"Invalid quantization {quant}")

Examples

embed_float = [-0.11944580078125,-0.2734375,0.040771484375,0.3056640625,-0.1470947265625,-0.11749267578125,0.0799560546875,0.08282470703125,-0.04205322265625,0.220947265625,0.0015048980712890625,-0.00397491455078125,-0.01099395751953125,-0.052642822265625,0.0504150390625,0.01605224609375,0.029693603515625,-0.024078369140625]
embed_int = [-15,-35,5,39,-19,-15,10,11,-5,28,0,0,-2,-7,6,2,4,-3]
embed_uint = [112,93,133,166,109,112,138,138,122,156,128,127,126,121,134,130,131,124]
embed_bin = [-77,-29,0]
embed_ubin = [51,99,128]
dequantize(embed_bin, quant="binary", orig_dim=18)
dequantize(embed_ubin, quant="ubinary", orig_dim=18)
dequantize(embed_int, quant="int8")
dequantize(embed_uint, quant="uint8")