class Executor(ABC):
"""Abstract base class for vLLM executors."
An executor is responsible for executing the model on one device,
or it can be a distributed executor that can execute the model on multiple devices.
"""
uses_ray: bool = False # whether the executor uses Ray for orchestration.
supports_pp: bool = False # whether the executor supports PP
@staticmethod
def get_class(vllm_config: VllmConfig) -> type["Executor"]:
executor_class: type[Executor]
parallel_config = vllm_config.parallel_config
distributed_executor_backend = parallel_config.distributed_executor_backend
# distributed_executor_backend must be set in VllmConfig.__post_init__
if isinstance(distributed_executor_backend, type):
if not issubclass(distributed_executor_backend, Executor):
raise TypeError(
"distributed_executor_backend must be a subclass of "
f"Executor. Got {distributed_executor_backend}."
)
executor_class = distributed_executor_backend
elif distributed_executor_backend == "ray":
from vllm.v1.executor.ray_executor import RayDistributedExecutor
executor_class = RayDistributedExecutor
elif distributed_executor_backend == "mp":
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
executor_class = MultiprocExecutor
elif distributed_executor_backend == "uni":
from vllm.v1.executor.uniproc_executor import UniProcExecutor
executor_class = UniProcExecutor
elif distributed_executor_backend == "external_launcher":
# TODO: make v1 scheduling deterministic
# to support external launcher
executor_class = ExecutorWithExternalLauncher
elif isinstance(distributed_executor_backend, str):
executor_class = resolve_obj_by_qualname(distributed_executor_backend)
if not issubclass(executor_class, Executor):
raise TypeError(
"distributed_executor_backend must be a subclass of "
f"Executor. Got {executor_class}."
)
else:
raise ValueError(
f"Unknown distributed executor backend: {distributed_executor_backend}"
)
return executor_class
def __init__(
self,
vllm_config: VllmConfig,
) -> None:
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.observability_config = vllm_config.observability_config
self._init_executor()
self.is_sleeping = False
self.sleeping_tags: set[str] = set()
self.kv_output_aggregator: KVOutputAggregator | None = None
@abstractmethod
def _init_executor(self) -> None:
raise NotImplementedError
def initialize_from_config(self, kv_cache_configs: list[KVCacheConfig]) -> None:
"""
Initialize the KV caches and begin the model execution loop of the
underlying workers.
"""
self.collective_rpc("initialize_from_config", args=(kv_cache_configs,))
self.collective_rpc("compile_or_warm_up_model")
def register_failure_callback(self, callback: FailureCallback): # noqa: B027
"""
Register a function to be called if the executor enters a permanent
failed state.
"""
pass
def determine_available_memory(self) -> list[int]: # in bytes
return self.collective_rpc("determine_available_memory")
def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]:
return self.collective_rpc("get_kv_cache_spec")
@overload
def collective_rpc(
self,
method: str | Callable[[WorkerBase], _R],
timeout: float | None = None,
args: tuple = (),
kwargs: dict | None = None,
non_block: Literal[False] = False,
) -> list[_R]:
"""
Execute an RPC call on all workers.
Args:
method: Name of the worker method to execute, or a callable that
is serialized and sent to all workers to execute.
If the method is a callable, it should accept an additional
`self` argument, in addition to the arguments passed in `args`
and `kwargs`. The `self` argument will be the worker object.
timeout: Maximum time in seconds to wait for execution. Raises a
[`TimeoutError`][] on timeout. `None` means wait indefinitely.
args: Positional arguments to pass to the worker method.
kwargs: Keyword arguments to pass to the worker method.
non_block: If `True`, returns a list of Futures instead of waiting
for the results.
Returns:
A list containing the results from each worker.
Note:
It is recommended to use this API to only pass control messages,
and set up data-plane communication to pass data.
"""
pass
@overload
def collective_rpc(
self,
method: str | Callable[[WorkerBase], _R],
timeout: float | None = None,
args: tuple = (),
kwargs: dict | None = None,
non_block: Literal[True] = True,
) -> list[Future[_R]]:
pass
@abstractmethod
def collective_rpc(
self, method, timeout=None, args=(), kwargs=None, non_block: bool = False
):
raise NotImplementedError
@overload
def execute_model(
self,
scheduler_output: SchedulerOutput,
non_block: Literal[False] = False,
) -> ModelRunnerOutput:
pass
@overload
def execute_model(
self,
scheduler_output: SchedulerOutput,
non_block: Literal[True] = True,
) -> Future[ModelRunnerOutput]:
pass
def execute_model(
self, scheduler_output: SchedulerOutput, non_block: bool = False
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
output = self.collective_rpc( # type: ignore[call-overload]
"execute_model", args=(scheduler_output,), non_block=non_block
)
return output[0]
def execute_dummy_batch(self) -> None:
self.collective_rpc("execute_dummy_batch")
def take_draft_token_ids(self) -> DraftTokenIds | None:
output: list[DraftTokenIds] = self.collective_rpc("take_draft_token_ids")
return output[0]
@property
def max_concurrent_batches(self) -> int:
return 1
def profile(self, is_start: bool = True):
self.collective_rpc("profile", args=(is_start,))
def save_sharded_state(
self,
path: str,
pattern: str | None = None,
max_size: int | None = None,
) -> None:
self.collective_rpc(
"save_sharded_state",
kwargs=dict(path=path, pattern=pattern, max_size=max_size),
)
@abstractmethod
def check_health(self) -> None:
"""Checks if the executor is healthy. If not, it should raise an
exception."""
raise NotImplementedError
def shutdown(self) -> None:
"""Shutdown the executor."""
self.collective_rpc("shutdown")
def init_kv_output_aggregator(self, finished_count: int | None) -> None:
"""Init KVOutputAggregator"""
self.kv_output_aggregator = KVOutputAggregator(
finished_count or self.parallel_config.world_size
)
@cached_property # Avoid unnecessary RPC calls
def supported_tasks(self) -> tuple[SupportedTask, ...]:
output: list[tuple[SupportedTask, ...]]
output = self.collective_rpc("get_supported_tasks")
return output[0]
def add_lora(self, lora_request: LoRARequest) -> bool:
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
return all(self.collective_rpc("add_lora", args=(lora_request,)))
def remove_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return all(self.collective_rpc("remove_lora", args=(lora_id,)))
def pin_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return all(self.collective_rpc("pin_lora", args=(lora_id,)))
def list_loras(self) -> set[int]:
sets: list[set[int]] = self.collective_rpc("list_loras")
for s in sets:
assert s == sets[0], "All workers should have the same LORAs."
return sets[0]
def reset_mm_cache(self) -> None:
"""Reset the multi-modal cache in each worker."""
self.collective_rpc("reset_mm_cache")
def start_profile(self) -> None:
self.collective_rpc("start_profile")
def stop_profile(self) -> None:
self.collective_rpc("stop_profile")
def sleep(self, level: int = 1):
if self.is_sleeping:
logger.warning("Executor is already sleeping.")
return
time_before_sleep = time.perf_counter()
self.collective_rpc("sleep", kwargs=dict(level=level))
time_after_sleep = time.perf_counter()
self.sleeping_tags = {"weights", "kv_cache"}
self.is_sleeping = True
logger.info(
"It took %.6f seconds to fall asleep.", time_after_sleep - time_before_sleep
)
def wake_up(self, tags: list[str] | None = None):
if not self.is_sleeping:
logger.warning("Executor is not sleeping.")
return
if tags:
for tag in tags:
if tag not in self.sleeping_tags:
logger.warning(
"Tag %s is not in sleeping tags %s", tag, self.sleeping_tags
)
return
time_before_wakeup = time.perf_counter()
self.collective_rpc("wake_up", kwargs=dict(tags=tags))
time_after_wakeup = time.perf_counter()
logger.info(
"It took %.6f seconds to wake up tags %s.",
time_after_wakeup - time_before_wakeup,
tags if tags is not None else self.sleeping_tags,
)
if tags:
for tag in tags:
self.sleeping_tags.remove(tag)
else:
self.sleeping_tags.clear()
if not self.sleeping_tags:
self.is_sleeping = False
def reinitialize_distributed(
self, reconfig_request: ReconfigureDistributedRequest
) -> None:
raise NotImplementedError