Skip to content

vllm.model_executor.layers.quantization.utils.mxfp8_utils

MXFP8_BLOCK_SIZE module-attribute

MXFP8_BLOCK_SIZE = 32

MXFP8_SCALE_DTYPE module-attribute

MXFP8_SCALE_DTYPE = uint8

MXFP8_VALUE_DTYPE module-attribute

MXFP8_VALUE_DTYPE = float8_e4m3fn

logger module-attribute

logger = init_logger(__name__)

Mxfp8LinearOp

This class executes a MXFP8 linear layer.

Source code in vllm/model_executor/layers/quantization/utils/mxfp8_utils.py
class Mxfp8LinearOp:
    """
    This class executes a MXFP8 linear layer.
    """

    def apply(
        self,
        input: torch.Tensor,
        weight: torch.Tensor,
        weight_scale: torch.Tensor,
        out_dtype: torch.dtype,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        # weight_scale comes in as float8_e8m0fnu
        # after process_weights_after_loading
        # It may be padded to [N_padded, K/32] and flattened
        # Convert back to uint8 for dequantization
        weight_scale_uint8 = weight_scale.view(MXFP8_SCALE_DTYPE)

        out_features, in_features = weight.shape
        # Number of scale blocks along K dimension
        scale_k = in_features // MXFP8_BLOCK_SIZE

        # Compute padded dimensions (same logic as process_weights_after_loading)
        out_features_padded = (out_features + 127) // 128 * 128

        # Reshape to padded 2D, then slice to get original shape
        weight_scale_2d_padded = weight_scale_uint8.view(out_features_padded, scale_k)
        weight_scale_2d = weight_scale_2d_padded[:out_features, :]

        # Dequantize weight to bf16
        weight_bf16 = dequant_mxfp8_to_bf16(weight, weight_scale_2d)

        # Standard linear operation
        output = torch.nn.functional.linear(input, weight_bf16, bias)
        return output.to(out_dtype)

apply

apply(
    input: Tensor,
    weight: Tensor,
    weight_scale: Tensor,
    out_dtype: dtype,
    bias: Tensor | None = None,
) -> Tensor
Source code in vllm/model_executor/layers/quantization/utils/mxfp8_utils.py
def apply(
    self,
    input: torch.Tensor,
    weight: torch.Tensor,
    weight_scale: torch.Tensor,
    out_dtype: torch.dtype,
    bias: torch.Tensor | None = None,
) -> torch.Tensor:
    # weight_scale comes in as float8_e8m0fnu
    # after process_weights_after_loading
    # It may be padded to [N_padded, K/32] and flattened
    # Convert back to uint8 for dequantization
    weight_scale_uint8 = weight_scale.view(MXFP8_SCALE_DTYPE)

    out_features, in_features = weight.shape
    # Number of scale blocks along K dimension
    scale_k = in_features // MXFP8_BLOCK_SIZE

    # Compute padded dimensions (same logic as process_weights_after_loading)
    out_features_padded = (out_features + 127) // 128 * 128

    # Reshape to padded 2D, then slice to get original shape
    weight_scale_2d_padded = weight_scale_uint8.view(out_features_padded, scale_k)
    weight_scale_2d = weight_scale_2d_padded[:out_features, :]

    # Dequantize weight to bf16
    weight_bf16 = dequant_mxfp8_to_bf16(weight, weight_scale_2d)

    # Standard linear operation
    output = torch.nn.functional.linear(input, weight_bf16, bias)
    return output.to(out_dtype)

_cast_mxfp8_scales_to_bf16

_cast_mxfp8_scales_to_bf16(scales: Tensor) -> Tensor

Cast MXFP8 scales from uint8 to BF16. The scales are stored in uint8 format and need to be converted to BF16 by left-shifting by 7 bits (to form the exponent) and reinterpreting as bfloat16. Args: scales: uint8 tensor containing MXFP8 scales Returns: BF16 tensor with the converted scales

Source code in vllm/model_executor/layers/quantization/utils/mxfp8_utils.py
def _cast_mxfp8_scales_to_bf16(scales: torch.Tensor) -> torch.Tensor:
    """
    Cast MXFP8 scales from uint8 to BF16.
    The scales are stored in uint8 format and need to be converted to BF16
    by left-shifting by 7 bits (to form the exponent) and reinterpreting
    as bfloat16.
    Args:
        scales: uint8 tensor containing MXFP8 scales
    Returns:
        BF16 tensor with the converted scales
    """
    return (scales.to(torch.int16) << 7).view(torch.bfloat16)

dequant_mxfp8_to_bf16

dequant_mxfp8_to_bf16(x: Tensor, scales: Tensor) -> Tensor

Dequantize MXFP8 tensor to BF16. Args: x: FP8 E4M3 tensor to dequantize scales: uint8 tensor containing MXFP8 scales Returns: BF16 dequantized tensor

Source code in vllm/model_executor/layers/quantization/utils/mxfp8_utils.py
def dequant_mxfp8_to_bf16(x: torch.Tensor, scales: torch.Tensor) -> torch.Tensor:
    """
    Dequantize MXFP8 tensor to BF16.
    Args:
        x: FP8 E4M3 tensor to dequantize
        scales: uint8 tensor containing MXFP8 scales
    Returns:
        BF16 dequantized tensor
    """
    scales_bf16 = _cast_mxfp8_scales_to_bf16(scales)
    # Repeat scales along the last dimension to match the block size
    scales_expanded = scales_bf16.reshape(*x.shape[:-1], -1).repeat_interleave(
        MXFP8_BLOCK_SIZE, dim=-1
    )
    return x.to(torch.bfloat16) * scales_expanded

mxfp8_e4m3_quantize

mxfp8_e4m3_quantize(
    x: Tensor, is_sf_swizzled_layout: bool = False
) -> tuple[Tensor, Tensor]
Source code in vllm/model_executor/layers/quantization/utils/mxfp8_utils.py
def mxfp8_e4m3_quantize(
    x: torch.Tensor, is_sf_swizzled_layout: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
    try:
        from flashinfer import mxfp8_quantize as mxfp8_e4m3_quantize
    except ImportError as err:
        raise ImportError(
            "The package `flashinfer` is required to do "
            "MX-FP8 quantization. Please install it with"
            "`pip install flashinfer`"
        ) from err

    x_q, x_scales = mxfp8_e4m3_quantize(x, is_sf_swizzled_layout=is_sf_swizzled_layout)
    if x_scales.ndim == 1:
        if is_sf_swizzled_layout:
            # TODO: check this, maybe not required?
            # When swizzled, scales are padded: M to multiple of 128, K to multiple of 4
            # We must use the padded dimensions, not the original input dimensions
            def _round_up(val: int, mult: int) -> int:
                return (val + mult - 1) // mult * mult

            M = x.size(0)
            K = x.size(-1) // MXFP8_BLOCK_SIZE
            M_padded = _round_up(M, 128)
            K_padded = _round_up(K, 4)
            x_scales = x_scales.view(M_padded, K_padded)
        else:
            x_scales = x_scales.view(x.size(0), -1)
    return x_q, x_scales

mxfp8_e4m3_quantize_fake

mxfp8_e4m3_quantize_fake(
    x: Tensor, is_sf_swizzled_layout: bool = False
) -> tuple[Tensor, Tensor]

Fake implementation for torch.compile tracing. Returns empty tensors with the correct shapes and dtypes.

Source code in vllm/model_executor/layers/quantization/utils/mxfp8_utils.py
def mxfp8_e4m3_quantize_fake(
    x: torch.Tensor, is_sf_swizzled_layout: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Fake implementation for torch.compile tracing.
    Returns empty tensors with the correct shapes and dtypes.
    """
    # FP8 quantized data has same shape as input
    fp_data = torch.empty_like(x, dtype=MXFP8_VALUE_DTYPE)

    # Compute scale shape: one scale per block of 32 elements along last dim
    block_size = MXFP8_BLOCK_SIZE

    if x.ndim == 2:
        M, N = x.shape
        K = (N + block_size - 1) // block_size
        if is_sf_swizzled_layout:
            # When swizzled, scales are padded: M to multiple of 128, K to multiple of 4
            M_padded = ((M + 127) // 128) * 128
            K_padded = ((K + 3) // 4) * 4
            scales = torch.empty(
                (M_padded, K_padded), dtype=MXFP8_SCALE_DTYPE, device=x.device
            )
        else:
            scales = torch.empty((M, K), dtype=MXFP8_SCALE_DTYPE, device=x.device)
    elif x.ndim == 3:
        B, M, N = x.shape
        K = (N + block_size - 1) // block_size
        if is_sf_swizzled_layout:
            M_padded = ((M + 127) // 128) * 128
            K_padded = ((K + 3) // 4) * 4
            scales = torch.empty(
                (B, M_padded, K_padded), dtype=MXFP8_SCALE_DTYPE, device=x.device
            )
        else:
            scales = torch.empty((B, M, K), dtype=MXFP8_SCALE_DTYPE, device=x.device)
    else:
        # Fallback for other dimensions
        scale_shape = list(x.shape)
        scale_shape[-1] = (x.shape[-1] + block_size - 1) // block_size
        scales = torch.empty(scale_shape, dtype=MXFP8_SCALE_DTYPE, device=x.device)

    return fp_data, scales