Skip to content

vllm.model_executor.models.transformers.utils

Transformers backend utilities.

Style module-attribute

Style = Literal[
    "colwise",
    "colwise_rep",
    "rowwise",
    "rowwise_rep",
    "replicate",
]

logger module-attribute

logger = init_logger(__name__)

can_enable_torch_compile

can_enable_torch_compile(vllm_config: VllmConfig) -> bool

Callable to be passed to @support_torch_compile's enable_if argument.

Defaults to True but is disabled in the following situations:

  • The model uses dynamic rope scaling.
Source code in vllm/model_executor/models/transformers/utils.py
def can_enable_torch_compile(vllm_config: "VllmConfig") -> bool:
    """
    Callable to be passed to `@support_torch_compile`'s `enable_if` argument.

    Defaults to `True` but is disabled in the following situations:

    - The model uses dynamic rope scaling.
    """
    text_config = vllm_config.model_config.hf_config.get_text_config()
    # Dynamic rope scaling is not compatible with torch.compile
    rope_scaling: dict = getattr(text_config, "rope_scaling", None) or {}
    return rope_scaling.get("rope_type") != "dynamic"

get_feature_request_tip

get_feature_request_tip(
    model: str, trust_remote_code: bool
) -> str
Source code in vllm/model_executor/models/transformers/utils.py
def get_feature_request_tip(
    model: str,
    trust_remote_code: bool,
) -> str:
    hf_url = f"a discussion at https://huggingface.co/{model}/discussions/new"
    gh_url = "an issue at https://github.com/huggingface/transformers/issues/new/choose"
    url = hf_url if trust_remote_code else gh_url
    prefix = f"Please open {url} to request support for this feature. "
    if Path(model).exists():
        prefix = ""
    doc_url = "https://docs.vllm.ai/en/latest/models/supported_models.html#writing-custom-models"
    tip = f"See {doc_url} for instructions on how to add support yourself."
    return f"{prefix}{tip}"

init_on_device_without_buffers

init_on_device_without_buffers(device: device)

A context manager under which models are initialized with all parameters on the specified device. However buffers are not initialized on specified device.

Parameters:

Name Type Description Default
device `torch.device`

Device to initialize all parameters on.

required
Source code in vllm/model_executor/models/transformers/utils.py
@contextmanager
def init_on_device_without_buffers(device: torch.device):
    """
    A context manager under which models are initialized with all
    parameters on the specified device. However buffers are not
    initialized on specified device.

    Args:
        device (`torch.device`):
            Device to initialize all parameters on.
    """

    old_register_parameter = nn.Module.register_parameter

    def register_empty_parameter(module, name, param):
        old_register_parameter(module, name, param)
        if param is not None:
            param_cls = type(module._parameters[name])
            kwargs = module._parameters[name].__dict__
            kwargs["requires_grad"] = param.requires_grad
            module._parameters[name] = param_cls(
                module._parameters[name].to(device), **kwargs
            )

    tensor_constructors_to_patch = {}

    def patch_tensor_constructor(fn):
        def wrapper(*args, **kwargs):
            kwargs["device"] = device
            return fn(*args, **kwargs)

        return wrapper

    try:
        nn.Module.register_parameter = register_empty_parameter
        for torch_function_name in tensor_constructors_to_patch:
            setattr(
                torch,
                torch_function_name,
                patch_tensor_constructor(getattr(torch, torch_function_name)),
            )
        yield
    finally:
        nn.Module.register_parameter = old_register_parameter
        for (
            torch_function_name,
            old_torch_function,
        ) in tensor_constructors_to_patch.items():
            setattr(torch, torch_function_name, old_torch_function)

log_replacement

log_replacement(
    name: str, old_module: Module, new_module: Module
)
Source code in vllm/model_executor/models/transformers/utils.py
def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
    logger.debug("%s: %s -> %s", name, old_module, new_module)

replace_linear_class

replace_linear_class(
    linear: Linear,
    style: Style = "replicate",
    quant_config: QuantizationConfig | None = None,
    *,
    prefix: str = "",
) -> (
    ColumnParallelLinear
    | RowParallelLinear
    | ReplicatedLinear
)

Replace nn.Linear with one of vLLM's tensor parallel linear classes.

Parameters:

Name Type Description Default
linear Linear

nn.Linear to be replaced.

required
style Style

Tensor parallel style of the new linear, e.g. "colwise".

'replicate'
quant_config QuantizationConfig | None

Quantization config for the new linear.

None

Returns: The new linear.

Source code in vllm/model_executor/models/transformers/utils.py
def replace_linear_class(
    linear: nn.Linear,
    style: Style = "replicate",
    quant_config: "QuantizationConfig | None" = None,
    *,
    prefix: str = "",
) -> ColumnParallelLinear | RowParallelLinear | ReplicatedLinear:
    """
    Replace nn.Linear with one of vLLM's tensor parallel linear classes.

    Args:
        linear: `nn.Linear` to be replaced.
        style: Tensor parallel style of the new linear, e.g. "colwise".
        quant_config: Quantization config for the new linear.
    Returns:
        The new linear.
    """

    if not isinstance(style, str):
        raise ValueError(f"Unsupported parallel style type {type(style)}, expected str")

    vllm_linear_cls, vllm_linear_kwargs = {
        "colwise": (ColumnParallelLinear, {}),
        "colwise_rep": (ColumnParallelLinear, {"gather_output": True}),
        "rowwise": (RowParallelLinear, {}),
        "rowwise_rep": (RowParallelLinear, {"input_is_parallel": False}),
        "replicate": (ReplicatedLinear, {}),
    }.get(style, (ReplicatedLinear, {}))

    return vllm_linear_cls(
        input_size=linear.in_features,
        output_size=linear.out_features,
        bias=linear.bias is not None,
        quant_config=quant_config,
        prefix=prefix,
        return_bias=False,
        **vllm_linear_kwargs,
    )

replace_rms_norm_class

replace_rms_norm_class(
    rms_norm: Module, hidden_size: int
) -> RMSNorm

Replace a Transformers RMSNorm with vLLM's RMSNorm.

This method assumes: - Weight is stored as weight. - Epsilon is stored as eps or variance_epsilon. - with_scale indicates whether the layer has a weight (Gemma3n only). - var_hidden_size is only ever used for Intern vision encoder in vLLM and Transformers doesn't appear to have the same concept.

Source code in vllm/model_executor/models/transformers/utils.py
def replace_rms_norm_class(rms_norm: nn.Module, hidden_size: int) -> RMSNorm:
    """Replace a Transformers RMSNorm with vLLM's RMSNorm.

    This method assumes:
    - Weight is stored as `weight`.
    - Epsilon is stored as `eps` or `variance_epsilon`.
    - `with_scale` indicates whether the layer has a weight (Gemma3n only).
    - `var_hidden_size` is only ever used for Intern vision encoder in vLLM
    and Transformers doesn't appear to have the same concept.
    """
    eps = getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6)
    kwargs = {"hidden_size": hidden_size, "eps": eps}
    # Update hidden size if weight is available
    weight_meta = getattr(rms_norm, "weight", None)
    if weight_meta is not None:
        kwargs["hidden_size"] = weight_meta.size(0)
    # Check if weight is all zeros, which indicates GemmaRMSNorm
    # We must create a new instance because rms_norm is on meta
    try:
        with torch.device("cpu"):
            weight_test = getattr(rms_norm.__class__(1), "weight", None)
    except Exception:
        logger.warning(
            "Failed to determine if RMSNorm weight is centered on zero or one. "
            "Defaulting to one."
        )
        weight_test = None
    if weight_test is not None and torch.all(weight_test == 0):
        return GemmaRMSNorm(**kwargs)
    # Otherwise assume it's a regular RMSNorm
    kwargs["has_weight"] = getattr(rms_norm, "with_scale", True)
    if weight_meta is not None:
        kwargs["dtype"] = weight_meta.dtype
    else:
        # No weight, fall back to weightless RMSNorm
        kwargs["has_weight"] = False
    return RMSNorm(**kwargs)