Skip to content

vllm.compilation.matcher_utils

QUANT_OPS module-attribute

QUANT_OPS: dict[QuantKey, OpOverload] = {
    kFp8StaticTensorSym: default,
    kFp8DynamicTensorSym: default,
    kFp8DynamicTokenSym: default,
}

RMS_ADD_OP module-attribute

RMS_ADD_OP = default

RMS_OP module-attribute

RMS_OP = default

SILU_MUL_OP module-attribute

SILU_MUL_OP = default

MatcherCustomOp

Bases: ABC

Source code in vllm/compilation/matcher_utils.py
class MatcherCustomOp(ABC):
    def __init__(self, enabled: bool):
        config = get_current_vllm_config()
        self.model_dtype = config.model_config.dtype if config.model_config else None
        self.device = config.device_config.device if config.device_config else None

        self.enabled = enabled
        self.forward = self.forward_custom if enabled else self.forward_native

    @abstractmethod
    def forward_custom(self, *args, **kws):
        pass

    @abstractmethod
    def forward_native(self, *args, **kws):
        pass

    def __call__(self, *args, **kws):
        return self.forward(*args, **kws)

    def empty(self, *args, **kws):
        return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kws)

    def empty_f32(self, *args, **kws):
        return torch.empty(*args, dtype=torch.float32, device=self.device, **kws)

    def inputs(self) -> list[torch.Tensor]:
        """Utility for inputs to the pattern"""
        raise NotImplementedError

device instance-attribute

device = device if device_config else None

enabled instance-attribute

enabled = enabled

forward instance-attribute

forward = forward_custom if enabled else forward_native

model_dtype instance-attribute

model_dtype = dtype if model_config else None

__call__

__call__(*args, **kws)
Source code in vllm/compilation/matcher_utils.py
def __call__(self, *args, **kws):
    return self.forward(*args, **kws)

__init__

__init__(enabled: bool)
Source code in vllm/compilation/matcher_utils.py
def __init__(self, enabled: bool):
    config = get_current_vllm_config()
    self.model_dtype = config.model_config.dtype if config.model_config else None
    self.device = config.device_config.device if config.device_config else None

    self.enabled = enabled
    self.forward = self.forward_custom if enabled else self.forward_native

empty

empty(*args, **kws)
Source code in vllm/compilation/matcher_utils.py
def empty(self, *args, **kws):
    return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kws)

empty_f32

empty_f32(*args, **kws)
Source code in vllm/compilation/matcher_utils.py
def empty_f32(self, *args, **kws):
    return torch.empty(*args, dtype=torch.float32, device=self.device, **kws)

forward_custom abstractmethod

forward_custom(*args, **kws)
Source code in vllm/compilation/matcher_utils.py
@abstractmethod
def forward_custom(self, *args, **kws):
    pass

forward_native abstractmethod

forward_native(*args, **kws)
Source code in vllm/compilation/matcher_utils.py
@abstractmethod
def forward_native(self, *args, **kws):
    pass

inputs

inputs() -> list[Tensor]

Utility for inputs to the pattern

Source code in vllm/compilation/matcher_utils.py
def inputs(self) -> list[torch.Tensor]:
    """Utility for inputs to the pattern"""
    raise NotImplementedError

MatcherFusedAddRMSNorm

Bases: MatcherCustomOp

Source code in vllm/compilation/matcher_utils.py
class MatcherFusedAddRMSNorm(MatcherCustomOp):
    def __init__(self, epsilon: float, enabled: bool | None = None):
        if enabled is None:
            enabled = RMSNorm.enabled()

        super().__init__(enabled)
        self.epsilon = epsilon

    def inputs(self):
        input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
        weight = self.empty(16)
        residual = self.empty(5, 16)
        return [input, weight, residual]

    def forward_custom(
        self,
        input: torch.Tensor,
        weight: torch.Tensor,
        residual: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        _, result, residual = auto_functionalized(
            RMS_ADD_OP,
            input=input,
            residual=residual,
            weight=weight,
            epsilon=self.epsilon,
        )

        return result, residual

    def forward_native(
        self,
        input: torch.Tensor,
        weight: torch.Tensor,
        residual: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        return RMSNorm.forward_static(
            input, self.epsilon, input.size(-1), self.model_dtype, weight, residual
        )

epsilon instance-attribute

epsilon = epsilon

__init__

__init__(epsilon: float, enabled: bool | None = None)
Source code in vllm/compilation/matcher_utils.py
def __init__(self, epsilon: float, enabled: bool | None = None):
    if enabled is None:
        enabled = RMSNorm.enabled()

    super().__init__(enabled)
    self.epsilon = epsilon

forward_custom

forward_custom(
    input: Tensor, weight: Tensor, residual: Tensor
) -> tuple[Tensor, Tensor]
Source code in vllm/compilation/matcher_utils.py
def forward_custom(
    self,
    input: torch.Tensor,
    weight: torch.Tensor,
    residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    _, result, residual = auto_functionalized(
        RMS_ADD_OP,
        input=input,
        residual=residual,
        weight=weight,
        epsilon=self.epsilon,
    )

    return result, residual

forward_native

forward_native(
    input: Tensor, weight: Tensor, residual: Tensor
) -> tuple[Tensor, Tensor]
Source code in vllm/compilation/matcher_utils.py
def forward_native(
    self,
    input: torch.Tensor,
    weight: torch.Tensor,
    residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    return RMSNorm.forward_static(
        input, self.epsilon, input.size(-1), self.model_dtype, weight, residual
    )

inputs

inputs()
Source code in vllm/compilation/matcher_utils.py
def inputs(self):
    input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
    weight = self.empty(16)
    residual = self.empty(5, 16)
    return [input, weight, residual]

MatcherQuantFP8

Bases: MatcherCustomOp

Source code in vllm/compilation/matcher_utils.py
class MatcherQuantFP8(MatcherCustomOp):
    def __init__(self, quant_key: QuantKey, enabled: bool | None = None):
        if enabled is None:
            enabled = QuantFP8.enabled()

        super().__init__(enabled)
        self.quant_key = quant_key
        assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}"
        self.QUANT_OP = QUANT_OPS[quant_key]

        assert quant_key.dtype == current_platform.fp8_dtype(), (
            "Only QuantFP8 supported by"
        )
        assert quant_key.scale2 is None
        self.quant_fp8 = QuantFP8(quant_key.scale.static, quant_key.scale.group_shape)

    def forward_custom(
        self,
        input: torch.Tensor,
        scale: torch.Tensor | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        result = torch.empty(
            input.shape, device=input.device, dtype=self.quant_key.dtype
        )

        if self.quant_key.scale.static:
            assert scale is not None
            _, result = auto_functionalized(
                self.QUANT_OP, result=result, input=input, scale=scale
            )
            return result, scale
        else:
            assert scale is None
            scale = self.make_scale(input)
            _, result, scale = auto_functionalized(
                self.QUANT_OP, result=result, input=input, scale=scale, scale_ub=None
            )
            return result, scale

    def forward_native(
        self,
        input: torch.Tensor,
        scale: torch.Tensor | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        return self.quant_fp8(input, scale)

    def make_scale(self, input: torch.Tensor):
        normalized_group_shape = _normalize_quant_group_shape(
            input, self.quant_key.scale.group_shape
        )
        scale_shape = (
            input.shape[0] // normalized_group_shape[0],
            input.shape[1] // normalized_group_shape[1],
        )

        return torch.empty(scale_shape, device=input.device, dtype=torch.float32)

    def inputs(self) -> list[torch.Tensor]:
        input = self.empty(5, 16)
        if self.quant_key.scale.static:
            return [input, self.empty_f32(1, 1)]

        return [input]

QUANT_OP instance-attribute

QUANT_OP = QUANT_OPS[quant_key]

quant_fp8 instance-attribute

quant_fp8 = QuantFP8(static, group_shape)

quant_key instance-attribute

quant_key = quant_key

__init__

__init__(quant_key: QuantKey, enabled: bool | None = None)
Source code in vllm/compilation/matcher_utils.py
def __init__(self, quant_key: QuantKey, enabled: bool | None = None):
    if enabled is None:
        enabled = QuantFP8.enabled()

    super().__init__(enabled)
    self.quant_key = quant_key
    assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}"
    self.QUANT_OP = QUANT_OPS[quant_key]

    assert quant_key.dtype == current_platform.fp8_dtype(), (
        "Only QuantFP8 supported by"
    )
    assert quant_key.scale2 is None
    self.quant_fp8 = QuantFP8(quant_key.scale.static, quant_key.scale.group_shape)

forward_custom

forward_custom(
    input: Tensor, scale: Tensor | None = None
) -> tuple[Tensor, Tensor]
Source code in vllm/compilation/matcher_utils.py
def forward_custom(
    self,
    input: torch.Tensor,
    scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    result = torch.empty(
        input.shape, device=input.device, dtype=self.quant_key.dtype
    )

    if self.quant_key.scale.static:
        assert scale is not None
        _, result = auto_functionalized(
            self.QUANT_OP, result=result, input=input, scale=scale
        )
        return result, scale
    else:
        assert scale is None
        scale = self.make_scale(input)
        _, result, scale = auto_functionalized(
            self.QUANT_OP, result=result, input=input, scale=scale, scale_ub=None
        )
        return result, scale

forward_native

forward_native(
    input: Tensor, scale: Tensor | None = None
) -> tuple[Tensor, Tensor]
Source code in vllm/compilation/matcher_utils.py
def forward_native(
    self,
    input: torch.Tensor,
    scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    return self.quant_fp8(input, scale)

inputs

inputs() -> list[Tensor]
Source code in vllm/compilation/matcher_utils.py
def inputs(self) -> list[torch.Tensor]:
    input = self.empty(5, 16)
    if self.quant_key.scale.static:
        return [input, self.empty_f32(1, 1)]

    return [input]

make_scale

make_scale(input: Tensor)
Source code in vllm/compilation/matcher_utils.py
def make_scale(self, input: torch.Tensor):
    normalized_group_shape = _normalize_quant_group_shape(
        input, self.quant_key.scale.group_shape
    )
    scale_shape = (
        input.shape[0] // normalized_group_shape[0],
        input.shape[1] // normalized_group_shape[1],
    )

    return torch.empty(scale_shape, device=input.device, dtype=torch.float32)

MatcherRMSNorm

Bases: MatcherCustomOp

Source code in vllm/compilation/matcher_utils.py
class MatcherRMSNorm(MatcherCustomOp):
    def __init__(self, epsilon: float, enabled: bool | None = None):
        if enabled is None:
            enabled = RMSNorm.enabled()

        super().__init__(enabled)
        self.epsilon = epsilon

    def inputs(self):
        input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
        weight = self.empty(16)
        return [input, weight]

    def forward_custom(
        self,
        input: torch.Tensor,
        weight: torch.Tensor,
    ) -> torch.Tensor:
        result = torch.empty_like(input)
        _, result = auto_functionalized(
            RMS_OP,
            result=result,
            input=input,
            weight=weight,
            epsilon=self.epsilon,
        )

        return result

    def forward_native(
        self,
        input: torch.Tensor,
        weight: torch.Tensor,
    ) -> torch.Tensor:
        return RMSNorm.forward_static(
            input, self.epsilon, input.size(-1), self.model_dtype, weight
        )

epsilon instance-attribute

epsilon = epsilon

__init__

__init__(epsilon: float, enabled: bool | None = None)
Source code in vllm/compilation/matcher_utils.py
def __init__(self, epsilon: float, enabled: bool | None = None):
    if enabled is None:
        enabled = RMSNorm.enabled()

    super().__init__(enabled)
    self.epsilon = epsilon

forward_custom

forward_custom(input: Tensor, weight: Tensor) -> Tensor
Source code in vllm/compilation/matcher_utils.py
def forward_custom(
    self,
    input: torch.Tensor,
    weight: torch.Tensor,
) -> torch.Tensor:
    result = torch.empty_like(input)
    _, result = auto_functionalized(
        RMS_OP,
        result=result,
        input=input,
        weight=weight,
        epsilon=self.epsilon,
    )

    return result

forward_native

forward_native(input: Tensor, weight: Tensor) -> Tensor
Source code in vllm/compilation/matcher_utils.py
def forward_native(
    self,
    input: torch.Tensor,
    weight: torch.Tensor,
) -> torch.Tensor:
    return RMSNorm.forward_static(
        input, self.epsilon, input.size(-1), self.model_dtype, weight
    )

inputs

inputs()
Source code in vllm/compilation/matcher_utils.py
def inputs(self):
    input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
    weight = self.empty(16)
    return [input, weight]

MatcherSiluAndMul

Bases: MatcherCustomOp

Source code in vllm/compilation/matcher_utils.py
class MatcherSiluAndMul(MatcherCustomOp):
    def __init__(self, enabled: bool | None = None):
        if enabled is None:
            enabled = SiluAndMul.enabled()
        super().__init__(enabled)

    def inputs(self) -> list[torch.Tensor]:
        input = self.empty(5, 4)
        return [input]

    def forward_custom(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = x.shape[:-1] + (d,)
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        result = auto_functionalized(SILU_MUL_OP, result=out, input=x)
        return result[1]

    def forward_native(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        return SiluAndMul.forward_native(x)

__init__

__init__(enabled: bool | None = None)
Source code in vllm/compilation/matcher_utils.py
def __init__(self, enabled: bool | None = None):
    if enabled is None:
        enabled = SiluAndMul.enabled()
    super().__init__(enabled)

forward_custom

forward_custom(x: Tensor) -> Tensor
Source code in vllm/compilation/matcher_utils.py
def forward_custom(
    self,
    x: torch.Tensor,
) -> torch.Tensor:
    d = x.shape[-1] // 2
    output_shape = x.shape[:-1] + (d,)
    out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
    result = auto_functionalized(SILU_MUL_OP, result=out, input=x)
    return result[1]

forward_native

forward_native(x: Tensor) -> Tensor
Source code in vllm/compilation/matcher_utils.py
def forward_native(
    self,
    x: torch.Tensor,
) -> torch.Tensor:
    return SiluAndMul.forward_native(x)

inputs

inputs() -> list[Tensor]
Source code in vllm/compilation/matcher_utils.py
def inputs(self) -> list[torch.Tensor]:
    input = self.empty(5, 4)
    return [input]