class QuarkOCP_MX(QuarkScheme):
def __init__(
self,
weight_quant_spec: dict[str, Any],
input_quant_spec: dict[str, Any],
dynamic_mxfp4_quant: bool = False,
):
self.out_dtype = torch.get_default_dtype()
self.qscheme = "per_group"
self.weight_quant_spec = weight_quant_spec
self.input_quant_spec = input_quant_spec
self.dynamic_mxfp4_quant = dynamic_mxfp4_quant
self.weight_dtype = weight_quant_spec["dtype"].replace("fp", "mxfp")
self.input_dtype = input_quant_spec["dtype"].replace("fp", "mxfp")
self.ocp_mx_scheme = OCP_MX_Scheme.from_quant_dtype(
self.input_dtype, self.weight_dtype
)
if self.weight_dtype == "mxfp4":
self.packed_factor: int | Fraction = 2
self.dequant_func = dequant_mxfp4
else:
self.packed_factor = Fraction(numerator=8, denominator=6)
self.dequant_func = partial(
dequant_mxfp6, quant_dtype=self.weight_dtype.replace("mx", "")
)
if self.input_dtype == "mxfp4":
self.quant_dequant_func = quant_dequant_mxfp4
else:
self.quant_dequant_func = partial(
quant_dequant_mxfp6, quant_dtype=self.input_dtype.replace("mx", "")
)
self.static_input_scales = not input_quant_spec.get("is_dynamic")
if self.static_input_scales:
raise NotImplementedError(
"QuarkOCP_MX with static input scales is currently not "
"implemented. Please open an issue."
)
# TODO: integrate (or test) mixed-precision kernel.
self.emulate = not current_platform.supports_mx() or (
self.input_dtype != "mxfp4" or self.weight_dtype != "mxfp4"
)
self.rocm_use_aiter_fp4_asm_gemm = is_rocm_aiter_fp4_asm_gemm_enabled()
if not self.emulate and (dynamic_mxfp4_quant is None or gemm_afp4wfp4 is None):
# Currently need these kernels if not emulating
raise NotImplementedError(
f"{self.__class__.__name__} requires AITER to be installed "
"for non-emulation mode! Please refer to "
"https://github.com/ROCm/aiter for installation details."
)
if not current_platform.supports_mx():
logger.warning_once(
"The current platform does not support native MXFP4/MXFP6 "
"computation. Simulated weight dequantization and activation "
"QDQ (quantize and dequantize) will be used, with the linear "
"layers computed in high precision."
)
if current_platform.supports_mx() and (
self.input_dtype != "mxfp4" or self.weight_dtype != "mxfp4"
):
logger.warning_once(
"The current platform supports native MXFP4/MXFP6 "
f"computation, but kernels for input_dtype={self.input_dtype} "
f"and weight_dtype={self.weight_dtype} are not yet integrated "
"in vLLM. Simulated weight dequantization and activation "
"QDQ (quantize and dequantize) will be used, with the linear "
"layers computed in high precision."
)
def get_packed_dim(self, dim: int, quant_dtype: str):
if quant_dtype == "mxfp4":
assert dim % 2 == 0
return dim // 2
elif quant_dtype in {"mxfp6_e3m2", "mxfp6_e2m3"}:
# FP6 packs 4 * 6 = 24 bits on 3 bytes.
assert (dim * 3) % 4 == 0
return (dim * 3) // 4
else:
raise NotImplementedError(
"Unsupported quant_dtype in QuarkOCP_MX.get_packed_dim, "
f"got quant_dtype={quant_dtype}. Something is wrong, please "
"open an issue."
)
@classmethod
def get_min_capability(cls) -> int:
return 70
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
if self.emulate:
layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.data, requires_grad=False
)
else:
if self.dynamic_mxfp4_quant:
w_q, w_s = dynamic_mxfp4_quant(layer.weight)
layer.weight_scale = torch.nn.Parameter(
w_s.T.contiguous(), requires_grad=False
)
layer.weight = torch.nn.Parameter(w_q, requires_grad=False)
elif self.rocm_use_aiter_fp4_asm_gemm:
# shuffle weight scale
weight_scale_shuffle = layer.weight_scale.data
sm, sn = weight_scale_shuffle.shape
weight_scale_shuffle = weight_scale_shuffle.view(
sm // 32, 2, 16, sn // 8, 2, 4, 1
)
weight_scale_shuffle = weight_scale_shuffle.permute(
0, 3, 5, 2, 4, 1, 6
).contiguous()
weight_scale_shuffle = weight_scale_shuffle.view(sm, sn)
layer.weight_scale = torch.nn.Parameter(
weight_scale_shuffle, requires_grad=False
)
# shuffle weight
weight_shuffle = layer.weight.data
weight_shuffle = shuffle_weight(weight_shuffle, layout=(16, 16))
layer.weight = torch.nn.Parameter(weight_shuffle, requires_grad=False)
else:
layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.data.T.contiguous(), requires_grad=False
)
def create_weights(
self,
layer: torch.nn.Module,
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype,
weight_loader: Callable,
**kwargs,
):
if self.dynamic_mxfp4_quant:
weight = ModelWeightParameter(
data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
set_weight_attrs(weight, kwargs)
else:
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
# WEIGHT
weight = PackedvLLMParameter(
data=torch.empty(
output_size_per_partition,
self.get_packed_dim(input_size_per_partition, self.weight_dtype),
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
packed_dim=1,
packed_factor=self.packed_factor,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
# WEIGHT SCALE
weight_scale = GroupQuantScaleParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // OCP_MX_BLOCK_SIZE,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
if self.emulate:
dq_w = self.dequant_func(layer.weight, layer.weight_scale, x.dtype)
qdq_x = self.quant_dequant_func(x)
return F.linear(qdq_x, dq_w, bias)
else:
return torch.ops.vllm.gemm_with_dynamic_quant(
x,
layer.weight,
layer.weight_scale,
self.rocm_use_aiter_fp4_asm_gemm,
self.out_dtype,
)