Skip to content

vllm.compilation.fusion

FP4_DTYPE module-attribute

FP4_DTYPE = uint8

FP8_DTYPE module-attribute

FP8_DTYPE = fp8_dtype()

FUSED_OPS module-attribute

FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
    FusedRMSQuantKey(kFp8StaticTensorSym, False): default,
    FusedRMSQuantKey(kFp8StaticTensorSym, True): default,
    FusedRMSQuantKey(kFp8DynamicTokenSym, False): default,
    FusedRMSQuantKey(kFp8DynamicTokenSym, True): default,
}

QUANT_OPS module-attribute

QUANT_OPS: dict[QuantKey, OpOverload] = {
    kFp8StaticTensorSym: default,
    kFp8DynamicTensorSym: default,
    kFp8DynamicTokenSym: default,
}

RMS_ADD_OP module-attribute

RMS_ADD_OP = default

RMS_OP module-attribute

RMS_OP = default

logger module-attribute

logger = init_logger(__name__)

FusedAddRMSNormDynamicQuantPattern

Bases: RMSNormQuantPattern

Source code in vllm/compilation/fusion.py
class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
    def __init__(
        self,
        epsilon: float,
        quant_dtype: torch.dtype,
        group_shape: GroupShape = GroupShape.PER_TOKEN,
        symmetric=True,
    ):
        scale = ScaleDesc(torch.float32, False, group_shape)
        key = FusedRMSQuantKey(
            fused_add=True,
            quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
        )
        super().__init__(epsilon, key)

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor):
            result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
            result, scale = self.quant_matcher(result_rms)

            return result, residual, scale

        def replacement(
            input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
        ):
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)

            result = torch.empty_like(input, dtype=self.quant_dtype)
            scale = self.quant_matcher.make_scale(input)
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
                scale_ub=None,
                residual=residual,
            )

            # result, residual, scale
            return at[1], at[3], at[2]

        pm.register_replacement(
            pattern,
            replacement,
            self.rmsnorm_matcher.inputs(),
            pm.fwd_only,
            pm_pass,
        )

__init__

__init__(
    epsilon: float,
    quant_dtype: dtype,
    group_shape: GroupShape = PER_TOKEN,
    symmetric=True,
)
Source code in vllm/compilation/fusion.py
def __init__(
    self,
    epsilon: float,
    quant_dtype: torch.dtype,
    group_shape: GroupShape = GroupShape.PER_TOKEN,
    symmetric=True,
):
    scale = ScaleDesc(torch.float32, False, group_shape)
    key = FusedRMSQuantKey(
        fused_add=True,
        quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
    )
    super().__init__(epsilon, key)

register

register(pm_pass: PatternMatcherPass)
Source code in vllm/compilation/fusion.py
def register(self, pm_pass: PatternMatcherPass):
    def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor):
        result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
        result, scale = self.quant_matcher(result_rms)

        return result, residual, scale

    def replacement(
        input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
    ):
        # In case we're matching native rms-norm, conversions might be
        # optimized out. We convert here just to be safe.
        input = input.to(dtype=self.model_dtype)

        result = torch.empty_like(input, dtype=self.quant_dtype)
        scale = self.quant_matcher.make_scale(input)
        at = auto_functionalized(
            self.FUSED_OP,
            result=result,
            input=input,
            weight=weight,
            scale=scale,
            epsilon=self.epsilon,
            scale_ub=None,
            residual=residual,
        )

        # result, residual, scale
        return at[1], at[3], at[2]

    pm.register_replacement(
        pattern,
        replacement,
        self.rmsnorm_matcher.inputs(),
        pm.fwd_only,
        pm_pass,
    )

FusedAddRMSNormStaticQuantPattern

Bases: RMSNormQuantPattern

Source code in vllm/compilation/fusion.py
class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
    def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
        key = FusedRMSQuantKey(
            fused_add=True,
            quant=QuantKey(
                dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
            ),
        )
        super().__init__(epsilon, key)

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(
            input: torch.Tensor,
            weight: torch.Tensor,
            residual: torch.Tensor,
            scale: torch.Tensor,
        ):
            result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
            result, _ = self.quant_matcher(result_rms, scale)

            return result, residual

        def replacement(
            input: torch.Tensor,
            weight: torch.Tensor,
            residual: torch.Tensor,
            scale: torch.Tensor,
        ):
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)

            result = torch.empty_like(input, dtype=self.quant_dtype)
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                residual=residual,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
            )

            # result, residual
            return at[1], at[2]

        inputs = [
            # input, weight, residual
            *self.rmsnorm_matcher.inputs(),
            self.quant_matcher.inputs()[1],  # scale
        ]

        pm.register_replacement(
            pattern,
            replacement,
            inputs,
            pm.fwd_only,
            pm_pass,
        )

__init__

__init__(
    epsilon: float, quant_dtype: dtype, symmetric=True
)
Source code in vllm/compilation/fusion.py
def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
    key = FusedRMSQuantKey(
        fused_add=True,
        quant=QuantKey(
            dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
        ),
    )
    super().__init__(epsilon, key)

register

register(pm_pass: PatternMatcherPass)
Source code in vllm/compilation/fusion.py
def register(self, pm_pass: PatternMatcherPass):
    def pattern(
        input: torch.Tensor,
        weight: torch.Tensor,
        residual: torch.Tensor,
        scale: torch.Tensor,
    ):
        result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
        result, _ = self.quant_matcher(result_rms, scale)

        return result, residual

    def replacement(
        input: torch.Tensor,
        weight: torch.Tensor,
        residual: torch.Tensor,
        scale: torch.Tensor,
    ):
        # In case we're matching native rms-norm, conversions might be
        # optimized out. We convert here just to be safe.
        input = input.to(dtype=self.model_dtype)

        result = torch.empty_like(input, dtype=self.quant_dtype)
        at = auto_functionalized(
            self.FUSED_OP,
            result=result,
            input=input,
            residual=residual,
            weight=weight,
            scale=scale,
            epsilon=self.epsilon,
        )

        # result, residual
        return at[1], at[2]

    inputs = [
        # input, weight, residual
        *self.rmsnorm_matcher.inputs(),
        self.quant_matcher.inputs()[1],  # scale
    ]

    pm.register_replacement(
        pattern,
        replacement,
        inputs,
        pm.fwd_only,
        pm_pass,
    )

FusedRMSQuantKey

Bases: NamedTuple

Named tuple for identifying the type of RMSNorm + quant fusion. quant: type of quantization fused_add: does the op also perform the residual add

Source code in vllm/compilation/fusion.py
class FusedRMSQuantKey(NamedTuple):
    """
    Named tuple for identifying the type of RMSNorm + quant fusion.
    quant: type of quantization
    fused_add: does the op also perform the residual add
    """

    quant: QuantKey
    fused_add: bool

    def __str__(self):
        return (
            f"FusedQuantKey({self.quant}, with"
            f"{'' if self.fused_add else 'out'} residual)"
        )

fused_add instance-attribute

fused_add: bool

quant instance-attribute

quant: QuantKey

__str__

__str__()
Source code in vllm/compilation/fusion.py
def __str__(self):
    return (
        f"FusedQuantKey({self.quant}, with"
        f"{'' if self.fused_add else 'out'} residual)"
    )

RMSNormDynamicQuantPattern

Bases: RMSNormQuantPattern

Source code in vllm/compilation/fusion.py
class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
    def __init__(
        self,
        epsilon: float,
        quant_dtype: torch.dtype,
        group_shape: GroupShape = GroupShape.PER_TOKEN,
        symmetric=True,
    ):
        scale = ScaleDesc(torch.float32, False, group_shape)
        key = FusedRMSQuantKey(
            fused_add=False,
            quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
        )
        super().__init__(epsilon, key)

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(input: torch.Tensor, weight: torch.Tensor):
            result_rms = self.rmsnorm_matcher(input, weight)
            # result, scale
            return self.quant_matcher(result_rms)

        def replacement(input: torch.Tensor, weight: torch.Tensor):
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)

            result = torch.empty_like(input, dtype=self.quant_dtype)
            scale = self.quant_matcher.make_scale(input)
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
                scale_ub=None,
                residual=None,
            )

            # result, scale
            return at[1], at[2]

        pm.register_replacement(
            pattern,
            replacement,
            self.rmsnorm_matcher.inputs(),
            pm.fwd_only,
            pm_pass,
        )

__init__

__init__(
    epsilon: float,
    quant_dtype: dtype,
    group_shape: GroupShape = PER_TOKEN,
    symmetric=True,
)
Source code in vllm/compilation/fusion.py
def __init__(
    self,
    epsilon: float,
    quant_dtype: torch.dtype,
    group_shape: GroupShape = GroupShape.PER_TOKEN,
    symmetric=True,
):
    scale = ScaleDesc(torch.float32, False, group_shape)
    key = FusedRMSQuantKey(
        fused_add=False,
        quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
    )
    super().__init__(epsilon, key)

register

register(pm_pass: PatternMatcherPass)
Source code in vllm/compilation/fusion.py
def register(self, pm_pass: PatternMatcherPass):
    def pattern(input: torch.Tensor, weight: torch.Tensor):
        result_rms = self.rmsnorm_matcher(input, weight)
        # result, scale
        return self.quant_matcher(result_rms)

    def replacement(input: torch.Tensor, weight: torch.Tensor):
        # In case we're matching native rms-norm, conversions might be
        # optimized out. We convert here just to be safe.
        input = input.to(dtype=self.model_dtype)

        result = torch.empty_like(input, dtype=self.quant_dtype)
        scale = self.quant_matcher.make_scale(input)
        at = auto_functionalized(
            self.FUSED_OP,
            result=result,
            input=input,
            weight=weight,
            scale=scale,
            epsilon=self.epsilon,
            scale_ub=None,
            residual=None,
        )

        # result, scale
        return at[1], at[2]

    pm.register_replacement(
        pattern,
        replacement,
        self.rmsnorm_matcher.inputs(),
        pm.fwd_only,
        pm_pass,
    )

RMSNormQuantFusionPass

Bases: VllmPatternMatcherPass

This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op. It also supports fused_add_rms_norm.

Source code in vllm/compilation/fusion.py
class RMSNormQuantFusionPass(VllmPatternMatcherPass):
    """
    This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
    It also supports fused_add_rms_norm.
    """

    @enable_fake_mode
    def __init__(self, config: VllmConfig):
        super().__init__(config)

        self.patterns: PatternMatcherPass = PatternMatcherPass(
            pass_name="rmsnorm_quant_fusion_pass"
        )

        # Make sure fused add patterns are before simple rms norm,
        # as the latter is a subset of the former in torch ops
        for epsilon in [1e-5, 1e-6]:
            # Fuse fused_add_rms_norm + static fp8 quant
            FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
                self.patterns
            )

            # Fuse rms_norm + static fp8 quant
            RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)

            # Fuse fused_add_rms_norm + dynamic per-token fp8 quant
            FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
                self.patterns
            )

            # Fuse rms_norm + dynamic per-token fp8 quant
            RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)

        self.dump_patterns(config, self.patterns)

    @VllmInductorPass.time_and_log
    def __call__(self, graph: fx.Graph):
        self.matched_count = self.patterns.apply(graph)
        logger.debug("Replaced %s patterns", self.matched_count)

    def uuid(self) -> Any:
        return self.hash_source(
            self,
            RMSNormQuantPattern,
            RMSNormStaticQuantPattern,
            RMSNormDynamicQuantPattern,
            FusedAddRMSNormStaticQuantPattern,
            FusedAddRMSNormDynamicQuantPattern,
        )

patterns instance-attribute

patterns: PatternMatcherPass = PatternMatcherPass(
    pass_name="rmsnorm_quant_fusion_pass"
)

__call__

__call__(graph: Graph)
Source code in vllm/compilation/fusion.py
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph):
    self.matched_count = self.patterns.apply(graph)
    logger.debug("Replaced %s patterns", self.matched_count)

__init__

__init__(config: VllmConfig)
Source code in vllm/compilation/fusion.py
@enable_fake_mode
def __init__(self, config: VllmConfig):
    super().__init__(config)

    self.patterns: PatternMatcherPass = PatternMatcherPass(
        pass_name="rmsnorm_quant_fusion_pass"
    )

    # Make sure fused add patterns are before simple rms norm,
    # as the latter is a subset of the former in torch ops
    for epsilon in [1e-5, 1e-6]:
        # Fuse fused_add_rms_norm + static fp8 quant
        FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
            self.patterns
        )

        # Fuse rms_norm + static fp8 quant
        RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)

        # Fuse fused_add_rms_norm + dynamic per-token fp8 quant
        FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
            self.patterns
        )

        # Fuse rms_norm + dynamic per-token fp8 quant
        RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)

    self.dump_patterns(config, self.patterns)

uuid

uuid() -> Any
Source code in vllm/compilation/fusion.py
def uuid(self) -> Any:
    return self.hash_source(
        self,
        RMSNormQuantPattern,
        RMSNormStaticQuantPattern,
        RMSNormDynamicQuantPattern,
        FusedAddRMSNormStaticQuantPattern,
        FusedAddRMSNormDynamicQuantPattern,
    )

RMSNormQuantPattern

Source code in vllm/compilation/fusion.py
class RMSNormQuantPattern:
    def __init__(self, epsilon: float, key: FusedRMSQuantKey):
        self.epsilon = epsilon
        self.quant_dtype = key.quant.dtype
        config = get_current_vllm_config()
        self.model_dtype = config.model_config.dtype if config.model_config else None

        assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
        self.FUSED_OP = FUSED_OPS[key]

        self.rmsnorm_matcher = (
            MatcherRMSNorm(epsilon)
            if not key.fused_add
            else MatcherFusedAddRMSNorm(epsilon)
        )
        self.quant_matcher = MatcherQuantFP8(key.quant)

FUSED_OP instance-attribute

FUSED_OP = FUSED_OPS[key]

epsilon instance-attribute

epsilon = epsilon

model_dtype instance-attribute

model_dtype = dtype if model_config else None

quant_dtype instance-attribute

quant_dtype = dtype

quant_matcher instance-attribute

quant_matcher = MatcherQuantFP8(quant)

rmsnorm_matcher instance-attribute

rmsnorm_matcher = (
    MatcherRMSNorm(epsilon)
    if not fused_add
    else MatcherFusedAddRMSNorm(epsilon)
)

__init__

__init__(epsilon: float, key: FusedRMSQuantKey)
Source code in vllm/compilation/fusion.py
def __init__(self, epsilon: float, key: FusedRMSQuantKey):
    self.epsilon = epsilon
    self.quant_dtype = key.quant.dtype
    config = get_current_vllm_config()
    self.model_dtype = config.model_config.dtype if config.model_config else None

    assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
    self.FUSED_OP = FUSED_OPS[key]

    self.rmsnorm_matcher = (
        MatcherRMSNorm(epsilon)
        if not key.fused_add
        else MatcherFusedAddRMSNorm(epsilon)
    )
    self.quant_matcher = MatcherQuantFP8(key.quant)

RMSNormStaticQuantPattern

Bases: RMSNormQuantPattern

Source code in vllm/compilation/fusion.py
class RMSNormStaticQuantPattern(RMSNormQuantPattern):
    def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
        fused_key = FusedRMSQuantKey(
            fused_add=False,
            quant=QuantKey(
                dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
            ),
        )
        super().__init__(epsilon, fused_key)

    def register(self, pm_pass: PatternMatcherPass):
        # Cannot use methods, as the self argument affects tracing
        def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
            result_rms = self.rmsnorm_matcher(input, weight)
            return self.quant_matcher(result_rms, scale)[0]

        def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)

            result = torch.empty(
                input.shape, device=input.device, dtype=self.quant_dtype
            )
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
            )

            # result
            return at[1]

        inputs = [
            # input, weight
            *self.rmsnorm_matcher.inputs(),
            self.quant_matcher.inputs()[1],  # scale
        ]
        pattern(*inputs)

        pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)

__init__

__init__(
    epsilon: float, quant_dtype: dtype, symmetric=True
)
Source code in vllm/compilation/fusion.py
def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
    fused_key = FusedRMSQuantKey(
        fused_add=False,
        quant=QuantKey(
            dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
        ),
    )
    super().__init__(epsilon, fused_key)

register

register(pm_pass: PatternMatcherPass)
Source code in vllm/compilation/fusion.py
def register(self, pm_pass: PatternMatcherPass):
    # Cannot use methods, as the self argument affects tracing
    def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
        result_rms = self.rmsnorm_matcher(input, weight)
        return self.quant_matcher(result_rms, scale)[0]

    def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
        # In case we're matching native rms-norm, conversions might be
        # optimized out. We convert here just to be safe.
        input = input.to(dtype=self.model_dtype)

        result = torch.empty(
            input.shape, device=input.device, dtype=self.quant_dtype
        )
        at = auto_functionalized(
            self.FUSED_OP,
            result=result,
            input=input,
            weight=weight,
            scale=scale,
            epsilon=self.epsilon,
        )

        # result
        return at[1]

    inputs = [
        # input, weight
        *self.rmsnorm_matcher.inputs(),
        self.quant_matcher.inputs()[1],  # scale
    ]
    pattern(*inputs)

    pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)

empty_bf16

empty_bf16(*args, **kwargs)
Source code in vllm/compilation/fusion.py
def empty_bf16(*args, **kwargs):
    return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")

empty_fp32

empty_fp32(*args, **kwargs)
Source code in vllm/compilation/fusion.py
def empty_fp32(*args, **kwargs):
    return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")

empty_i32

empty_i32(*args, **kwargs)
Source code in vllm/compilation/fusion.py
def empty_i32(*args, **kwargs):
    return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda")