Skip to content

vllm.model_executor.models.deepseek_ocr

Inference-only Deepseek-OCR model compatible with HuggingFace weights.

_IMAGE_TOKEN module-attribute

_IMAGE_TOKEN = '<image>'

DeepseekOCRDummyInputsBuilder

Bases: BaseDummyInputsBuilder[DeepseekOCRProcessingInfo]

Source code in vllm/model_executor/models/deepseek_ocr.py
class DeepseekOCRDummyInputsBuilder(BaseDummyInputsBuilder[DeepseekOCRProcessingInfo]):
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token = processor.image_token

        return image_token * num_images

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)

        max_image_size = self.info.get_image_size_with_most_features()

        return {
            "image": self._get_dummy_images(
                width=max_image_size.width,
                height=max_image_size.height,
                num_images=num_images,
            )
        }

get_dummy_mm_data

get_dummy_mm_data(
    seq_len: int,
    mm_counts: Mapping[str, int],
    mm_options: Mapping[str, BaseDummyOptions]
    | None = None,
) -> MultiModalDataDict
Source code in vllm/model_executor/models/deepseek_ocr.py
def get_dummy_mm_data(
    self,
    seq_len: int,
    mm_counts: Mapping[str, int],
    mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
    num_images = mm_counts.get("image", 0)

    max_image_size = self.info.get_image_size_with_most_features()

    return {
        "image": self._get_dummy_images(
            width=max_image_size.width,
            height=max_image_size.height,
            num_images=num_images,
        )
    }

get_dummy_text

get_dummy_text(mm_counts: Mapping[str, int]) -> str
Source code in vllm/model_executor/models/deepseek_ocr.py
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
    num_images = mm_counts.get("image", 0)

    processor = self.info.get_hf_processor()
    image_token = processor.image_token

    return image_token * num_images

DeepseekOCRForCausalLM

Bases: Module, SupportsMultiModal, SupportsPP

Source code in vllm/model_executor/models/deepseek_ocr.py
@MULTIMODAL_REGISTRY.register_processor(
    DeepseekOCRMultiModalProcessor,
    info=DeepseekOCRProcessingInfo,
    dummy_inputs=DeepseekOCRDummyInputsBuilder,
)
class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # map prefix for language backbone
            "model.embed_tokens.": "language_model.model.embed_tokens.",
            "model.layers.": "language_model.model.layers.",
            "model.norm.": "language_model.model.norm.",
            "lm_head.": "language_model.lm_head.",
            # remove "model." prefix for other components
            "model.": "",
        }
    )

    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
        if modality.startswith("image"):
            return "<image>"

        raise ValueError("Only image modality is supported")

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config: DeepseekVLV2Config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

        self.config = config
        self.multimodal_config = multimodal_config

        self.vision_config = config.vision_config
        self.projector_config = config.projector_config
        self.text_config = config.text_config

        model_config = vllm_config.model_config
        tokenizer = cached_tokenizer_from_config(model_config)
        self.image_token_id = tokenizer.vocab[_IMAGE_TOKEN]

        self.sam_model = build_sam_vit_b()
        clip_vision_config = CLIPVisionConfig(
            hidden_size=1024,
            intermediate_size=4096,
            num_attention_heads=16,
            num_hidden_layers=24,
            image_size=224,
            patch_size=14,
            projection_dim=512,
            layer_norm_eps=1e-5,
        )
        self.vision_model = DeepCLIPVisionTransformer(
            config=clip_vision_config,
            quant_config=quant_config,
        )

        self.projector = MlpProjector(self.projector_config)
        self.tile_tag = config.tile_tag
        self.global_view_pos = config.global_view_pos

        # special token for image token sequence format
        n_embed = self.projector_config.n_embed
        embed_std = 1 / torch.sqrt(torch.tensor(n_embed, dtype=torch.float32))
        if self.tile_tag == "2D":
            # <|view_separator|>, <|\n|>
            self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std)
            # This is a typo in original implementation
            self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std)
        else:
            raise ValueError(
                f"Only 2D tile_tag is supported currently, got: {self.tile_tag}"
            )

        if self.text_config.topk_method == "noaux_tc":
            architectures = ["DeepseekV3ForCausalLM"]
        elif not self.text_config.use_mla:
            architectures = ["DeepseekForCausalLM"]
        else:
            architectures = ["DeepseekV2ForCausalLM"]

        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            hf_config=self.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
            architectures=architectures,
        )

        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors
        )

    def _parse_and_validate_image_input(self, **kwargs: object):
        pixel_values = kwargs.pop("pixel_values", None)
        images_spatial_crop = kwargs.pop("images_spatial_crop", None)
        images_crop = kwargs.pop("images_crop", None)

        if pixel_values is None or torch.sum(pixel_values).item() == 0:
            return None

        if pixel_values is not None:
            if not isinstance(pixel_values, (torch.Tensor, list)):
                raise ValueError(
                    f"Incorrect type of pixel values. Got type: {type(pixel_values)}"
                )

            if not isinstance(images_spatial_crop, (torch.Tensor, list)):
                raise ValueError(
                    "Incorrect type of image sizes. "
                    f"Got type: {type(images_spatial_crop)}"
                )

            if not isinstance(images_crop, (torch.Tensor, list)):
                raise ValueError(
                    f"Incorrect type of image crop. Got type: {type(images_crop)}"
                )

            return [pixel_values, images_crop, images_spatial_crop]

        raise AssertionError("This line should be unreachable.")

    def _encode_global_features(self, image_tensor: torch.Tensor) -> torch.Tensor:
        global_features_1 = self.sam_model(image_tensor)
        global_features_2 = self.vision_model(image_tensor, global_features_1)
        features = torch.cat(
            (
                global_features_2[:, 1:],
                global_features_1.flatten(2).permute(0, 2, 1),
            ),
            dim=-1,
        )
        features = self.projector(features)

        _, hw, dim = features.shape
        side = int(hw**0.5)

        features = features.view(side, side, dim)
        newline = self.image_newline[None, None, :].expand(side, 1, dim)
        features = torch.cat([features, newline], dim=1)
        return features.view(-1, dim)

    def _encode_local_features(
        self, patches: torch.Tensor, crop_shape: torch.Tensor
    ) -> torch.Tensor | None:
        if torch.sum(patches).item() == 0:
            return None

        local_features_1 = self.sam_model(patches)
        local_features_2 = self.vision_model(patches, local_features_1)
        features = torch.cat(
            (
                local_features_2[:, 1:],
                local_features_1.flatten(2).permute(0, 2, 1),
            ),
            dim=-1,
        )
        features = self.projector(features)

        _, hw, dim = features.shape
        patch_side = int(hw**0.5)

        width_tiles = int(crop_shape[0].item())
        height_tiles = int(crop_shape[1].item())

        features = (
            features.view(height_tiles, width_tiles, patch_side, patch_side, dim)
            .permute(0, 2, 1, 3, 4)
            .reshape(height_tiles * patch_side, width_tiles * patch_side, dim)
        )
        newline = self.image_newline[None, None, :].expand(
            height_tiles * patch_side, 1, dim
        )
        features = torch.cat([features, newline], dim=1)

        return features.view(-1, dim)

    def _pixel_values_to_embedding(
        self,
        pixel_values: torch.Tensor,
        images_crop: torch.Tensor,
        images_spatial_crop: torch.Tensor,
    ) -> NestedTensors:
        images_in_this_batch = []

        for jdx in range(images_spatial_crop.size(0)):
            patches = images_crop[jdx][0].to(torch.bfloat16)
            image_ori = pixel_values[jdx]
            crop_shape = images_spatial_crop[jdx][0]

            global_features = self._encode_global_features(image_ori)
            local_features = self._encode_local_features(patches, crop_shape)

            if local_features is not None:
                combined = torch.cat(
                    [local_features, global_features, self.view_seperator[None, :]],
                    dim=0,
                )
            else:
                combined = torch.cat(
                    [global_features, self.view_seperator[None, :]], dim=0
                )

            images_in_this_batch.append(combined)

        return images_in_this_batch

    def _process_image_input(self, image_input) -> torch.Tensor:
        pixel_values = image_input[0].to(torch.bfloat16)
        images_crop = image_input[1]
        images_spatial_crop = image_input[2].to(dtype=torch.long)

        vision_features = self._pixel_values_to_embedding(
            pixel_values=pixel_values,
            images_crop=images_crop,
            images_spatial_crop=images_spatial_crop,
        )

        return vision_features

    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

    def get_multimodal_embeddings(
        self, **kwargs: object
    ) -> MultiModalEmbeddings | None:
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        **kwargs: object,
    ):
        if intermediate_tensors is not None:
            inputs_embeds = None

        hidden_states = self.language_model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor | None:
        return self.language_model.compute_logits(hidden_states)

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(self)
        autoloaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
        return autoloaded_weights

config instance-attribute

config = config

global_view_pos instance-attribute

global_view_pos = global_view_pos

hf_to_vllm_mapper class-attribute instance-attribute

hf_to_vllm_mapper = WeightsMapper(
    orig_to_new_prefix={
        "model.embed_tokens.": "language_model.model.embed_tokens.",
        "model.layers.": "language_model.model.layers.",
        "model.norm.": "language_model.model.norm.",
        "lm_head.": "language_model.lm_head.",
        "model.": "",
    }
)

image_newline instance-attribute

image_newline = Parameter(randn(n_embed) * embed_std)

image_token_id instance-attribute

image_token_id = vocab[_IMAGE_TOKEN]

language_model instance-attribute

language_model = init_vllm_registered_model(
    vllm_config=vllm_config,
    hf_config=text_config,
    prefix=maybe_prefix(prefix, "language_model"),
    architectures=architectures,
)

make_empty_intermediate_tensors instance-attribute

make_empty_intermediate_tensors = (
    make_empty_intermediate_tensors
)

multimodal_config instance-attribute

multimodal_config = multimodal_config

projector instance-attribute

projector = MlpProjector(projector_config)

projector_config instance-attribute

projector_config = projector_config

sam_model instance-attribute

sam_model = build_sam_vit_b()

text_config instance-attribute

text_config = text_config

tile_tag instance-attribute

tile_tag = tile_tag

view_seperator instance-attribute

view_seperator = Parameter(randn(n_embed) * embed_std)

vision_config instance-attribute

vision_config = vision_config

vision_model instance-attribute

vision_model = DeepCLIPVisionTransformer(
    config=clip_vision_config, quant_config=quant_config
)

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/deepseek_ocr.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
    super().__init__()

    config: DeepseekVLV2Config = vllm_config.model_config.hf_config
    quant_config = vllm_config.quant_config
    multimodal_config = vllm_config.model_config.multimodal_config

    self.config = config
    self.multimodal_config = multimodal_config

    self.vision_config = config.vision_config
    self.projector_config = config.projector_config
    self.text_config = config.text_config

    model_config = vllm_config.model_config
    tokenizer = cached_tokenizer_from_config(model_config)
    self.image_token_id = tokenizer.vocab[_IMAGE_TOKEN]

    self.sam_model = build_sam_vit_b()
    clip_vision_config = CLIPVisionConfig(
        hidden_size=1024,
        intermediate_size=4096,
        num_attention_heads=16,
        num_hidden_layers=24,
        image_size=224,
        patch_size=14,
        projection_dim=512,
        layer_norm_eps=1e-5,
    )
    self.vision_model = DeepCLIPVisionTransformer(
        config=clip_vision_config,
        quant_config=quant_config,
    )

    self.projector = MlpProjector(self.projector_config)
    self.tile_tag = config.tile_tag
    self.global_view_pos = config.global_view_pos

    # special token for image token sequence format
    n_embed = self.projector_config.n_embed
    embed_std = 1 / torch.sqrt(torch.tensor(n_embed, dtype=torch.float32))
    if self.tile_tag == "2D":
        # <|view_separator|>, <|\n|>
        self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std)
        # This is a typo in original implementation
        self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std)
    else:
        raise ValueError(
            f"Only 2D tile_tag is supported currently, got: {self.tile_tag}"
        )

    if self.text_config.topk_method == "noaux_tc":
        architectures = ["DeepseekV3ForCausalLM"]
    elif not self.text_config.use_mla:
        architectures = ["DeepseekForCausalLM"]
    else:
        architectures = ["DeepseekV2ForCausalLM"]

    self.language_model = init_vllm_registered_model(
        vllm_config=vllm_config,
        hf_config=self.text_config,
        prefix=maybe_prefix(prefix, "language_model"),
        architectures=architectures,
    )

    self.make_empty_intermediate_tensors = (
        self.language_model.make_empty_intermediate_tensors
    )

_encode_global_features

_encode_global_features(image_tensor: Tensor) -> Tensor
Source code in vllm/model_executor/models/deepseek_ocr.py
def _encode_global_features(self, image_tensor: torch.Tensor) -> torch.Tensor:
    global_features_1 = self.sam_model(image_tensor)
    global_features_2 = self.vision_model(image_tensor, global_features_1)
    features = torch.cat(
        (
            global_features_2[:, 1:],
            global_features_1.flatten(2).permute(0, 2, 1),
        ),
        dim=-1,
    )
    features = self.projector(features)

    _, hw, dim = features.shape
    side = int(hw**0.5)

    features = features.view(side, side, dim)
    newline = self.image_newline[None, None, :].expand(side, 1, dim)
    features = torch.cat([features, newline], dim=1)
    return features.view(-1, dim)

_encode_local_features

_encode_local_features(
    patches: Tensor, crop_shape: Tensor
) -> Tensor | None
Source code in vllm/model_executor/models/deepseek_ocr.py
def _encode_local_features(
    self, patches: torch.Tensor, crop_shape: torch.Tensor
) -> torch.Tensor | None:
    if torch.sum(patches).item() == 0:
        return None

    local_features_1 = self.sam_model(patches)
    local_features_2 = self.vision_model(patches, local_features_1)
    features = torch.cat(
        (
            local_features_2[:, 1:],
            local_features_1.flatten(2).permute(0, 2, 1),
        ),
        dim=-1,
    )
    features = self.projector(features)

    _, hw, dim = features.shape
    patch_side = int(hw**0.5)

    width_tiles = int(crop_shape[0].item())
    height_tiles = int(crop_shape[1].item())

    features = (
        features.view(height_tiles, width_tiles, patch_side, patch_side, dim)
        .permute(0, 2, 1, 3, 4)
        .reshape(height_tiles * patch_side, width_tiles * patch_side, dim)
    )
    newline = self.image_newline[None, None, :].expand(
        height_tiles * patch_side, 1, dim
    )
    features = torch.cat([features, newline], dim=1)

    return features.view(-1, dim)

_parse_and_validate_image_input

_parse_and_validate_image_input(**kwargs: object)
Source code in vllm/model_executor/models/deepseek_ocr.py
def _parse_and_validate_image_input(self, **kwargs: object):
    pixel_values = kwargs.pop("pixel_values", None)
    images_spatial_crop = kwargs.pop("images_spatial_crop", None)
    images_crop = kwargs.pop("images_crop", None)

    if pixel_values is None or torch.sum(pixel_values).item() == 0:
        return None

    if pixel_values is not None:
        if not isinstance(pixel_values, (torch.Tensor, list)):
            raise ValueError(
                f"Incorrect type of pixel values. Got type: {type(pixel_values)}"
            )

        if not isinstance(images_spatial_crop, (torch.Tensor, list)):
            raise ValueError(
                "Incorrect type of image sizes. "
                f"Got type: {type(images_spatial_crop)}"
            )

        if not isinstance(images_crop, (torch.Tensor, list)):
            raise ValueError(
                f"Incorrect type of image crop. Got type: {type(images_crop)}"
            )

        return [pixel_values, images_crop, images_spatial_crop]

    raise AssertionError("This line should be unreachable.")

_pixel_values_to_embedding

_pixel_values_to_embedding(
    pixel_values: Tensor,
    images_crop: Tensor,
    images_spatial_crop: Tensor,
) -> NestedTensors
Source code in vllm/model_executor/models/deepseek_ocr.py
def _pixel_values_to_embedding(
    self,
    pixel_values: torch.Tensor,
    images_crop: torch.Tensor,
    images_spatial_crop: torch.Tensor,
) -> NestedTensors:
    images_in_this_batch = []

    for jdx in range(images_spatial_crop.size(0)):
        patches = images_crop[jdx][0].to(torch.bfloat16)
        image_ori = pixel_values[jdx]
        crop_shape = images_spatial_crop[jdx][0]

        global_features = self._encode_global_features(image_ori)
        local_features = self._encode_local_features(patches, crop_shape)

        if local_features is not None:
            combined = torch.cat(
                [local_features, global_features, self.view_seperator[None, :]],
                dim=0,
            )
        else:
            combined = torch.cat(
                [global_features, self.view_seperator[None, :]], dim=0
            )

        images_in_this_batch.append(combined)

    return images_in_this_batch

_process_image_input

_process_image_input(image_input) -> Tensor
Source code in vllm/model_executor/models/deepseek_ocr.py
def _process_image_input(self, image_input) -> torch.Tensor:
    pixel_values = image_input[0].to(torch.bfloat16)
    images_crop = image_input[1]
    images_spatial_crop = image_input[2].to(dtype=torch.long)

    vision_features = self._pixel_values_to_embedding(
        pixel_values=pixel_values,
        images_crop=images_crop,
        images_spatial_crop=images_spatial_crop,
    )

    return vision_features

compute_logits

compute_logits(hidden_states: Tensor) -> Tensor | None
Source code in vllm/model_executor/models/deepseek_ocr.py
def compute_logits(
    self,
    hidden_states: torch.Tensor,
) -> torch.Tensor | None:
    return self.language_model.compute_logits(hidden_states)

forward

forward(
    input_ids: Tensor,
    positions: Tensor,
    intermediate_tensors: IntermediateTensors | None = None,
    inputs_embeds: Tensor | None = None,
    **kwargs: object,
)
Source code in vllm/model_executor/models/deepseek_ocr.py
def forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    intermediate_tensors: IntermediateTensors | None = None,
    inputs_embeds: torch.Tensor | None = None,
    **kwargs: object,
):
    if intermediate_tensors is not None:
        inputs_embeds = None

    hidden_states = self.language_model(
        input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
    )

    return hidden_states

get_language_model

get_language_model() -> Module
Source code in vllm/model_executor/models/deepseek_ocr.py
def get_language_model(self) -> torch.nn.Module:
    return self.language_model

get_multimodal_embeddings

get_multimodal_embeddings(
    **kwargs: object,
) -> MultiModalEmbeddings | None
Source code in vllm/model_executor/models/deepseek_ocr.py
def get_multimodal_embeddings(
    self, **kwargs: object
) -> MultiModalEmbeddings | None:
    image_input = self._parse_and_validate_image_input(**kwargs)
    if image_input is None:
        return None
    vision_embeddings = self._process_image_input(image_input)
    return vision_embeddings

get_placeholder_str classmethod

get_placeholder_str(modality: str, i: int) -> str | None
Source code in vllm/model_executor/models/deepseek_ocr.py
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
    if modality.startswith("image"):
        return "<image>"

    raise ValueError("Only image modality is supported")

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/deepseek_ocr.py
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
    loader = AutoWeightsLoader(self)
    autoloaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
    return autoloaded_weights

DeepseekOCRMultiModalProcessor

Bases: BaseMultiModalProcessor[DeepseekOCRProcessingInfo]

Source code in vllm/model_executor/models/deepseek_ocr.py
class DeepseekOCRMultiModalProcessor(
    BaseMultiModalProcessor[DeepseekOCRProcessingInfo]
):
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
        tok_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        if mm_data:
            processed_outputs = self.info.ctx.call_hf_processor(
                self.info.get_hf_processor(**mm_kwargs),
                dict(prompt=prompt, **mm_data),
                mm_kwargs,
            )

        else:
            tokenizer = self.info.get_tokenizer()
            processed_outputs = tokenizer(
                prompt, add_special_tokens=True, return_tensors="pt"
            )

        return processed_outputs

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            images_spatial_crop=MultiModalFieldConfig.batched("image"),
            images_crop=MultiModalFieldConfig.batched("image"),
        )

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
    ) -> Sequence[PromptUpdate]:
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)

        image_token_id = hf_processor.image_token_id
        assert isinstance(image_token_id, int)

        def get_replacement_deepseek_vl2(item_idx: int):
            images = mm_items.get_items(
                "image", (ImageEmbeddingItems, ImageProcessorItems)
            )

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                size = images.get_image_size(item_idx)

                num_image_tokens = self.info.get_num_image_tokens(
                    image_width=size.width,
                    image_height=size.height,
                    cropping=CROP_MODE,
                )
            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_replacement_deepseek_vl2,
            )
        ]

_call_hf_processor

_call_hf_processor(
    prompt: str,
    mm_data: Mapping[str, object],
    mm_kwargs: Mapping[str, object],
    tok_kwargs: Mapping[str, object],
) -> BatchFeature
Source code in vllm/model_executor/models/deepseek_ocr.py
def _call_hf_processor(
    self,
    prompt: str,
    mm_data: Mapping[str, object],
    mm_kwargs: Mapping[str, object],
    tok_kwargs: Mapping[str, object],
) -> BatchFeature:
    if mm_data:
        processed_outputs = self.info.ctx.call_hf_processor(
            self.info.get_hf_processor(**mm_kwargs),
            dict(prompt=prompt, **mm_data),
            mm_kwargs,
        )

    else:
        tokenizer = self.info.get_tokenizer()
        processed_outputs = tokenizer(
            prompt, add_special_tokens=True, return_tensors="pt"
        )

    return processed_outputs

_get_mm_fields_config

_get_mm_fields_config(
    hf_inputs: BatchFeature,
    hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]
Source code in vllm/model_executor/models/deepseek_ocr.py
def _get_mm_fields_config(
    self,
    hf_inputs: BatchFeature,
    hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
    return dict(
        pixel_values=MultiModalFieldConfig.batched("image"),
        images_spatial_crop=MultiModalFieldConfig.batched("image"),
        images_crop=MultiModalFieldConfig.batched("image"),
    )

_get_prompt_updates

_get_prompt_updates(
    mm_items: MultiModalDataItems,
    hf_processor_mm_kwargs: Mapping[str, object],
    out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]
Source code in vllm/model_executor/models/deepseek_ocr.py
def _get_prompt_updates(
    self,
    mm_items: MultiModalDataItems,
    hf_processor_mm_kwargs: Mapping[str, object],
    out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]:
    hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)

    image_token_id = hf_processor.image_token_id
    assert isinstance(image_token_id, int)

    def get_replacement_deepseek_vl2(item_idx: int):
        images = mm_items.get_items(
            "image", (ImageEmbeddingItems, ImageProcessorItems)
        )

        if isinstance(images, ImageEmbeddingItems):
            num_image_tokens = images.get_feature_size(item_idx)
        else:
            size = images.get_image_size(item_idx)

            num_image_tokens = self.info.get_num_image_tokens(
                image_width=size.width,
                image_height=size.height,
                cropping=CROP_MODE,
            )
        return [image_token_id] * num_image_tokens

    return [
        PromptReplacement(
            modality="image",
            target=[image_token_id],
            replacement=get_replacement_deepseek_vl2,
        )
    ]

DeepseekOCRProcessingInfo

Bases: BaseProcessingInfo

Source code in vllm/model_executor/models/deepseek_ocr.py
class DeepseekOCRProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
        return self.ctx.get_hf_config(DeepseekVLV2Config)

    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(DeepseekOCRProcessor, **kwargs)

    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
        return {"image": None}

    def get_num_image_tokens(
        self, *, image_width: int, image_height: int, cropping: bool = True
    ) -> int:
        image_size = IMAGE_SIZE
        base_size = BASE_SIZE
        patch_size = 16
        downsample_ratio = 4

        if CROP_MODE:
            if image_width <= 640 and image_height <= 640:
                crop_ratio = [1, 1]
            else:
                # find the closest aspect ratio to the target
                crop_ratio = count_tiles(
                    image_width, image_height, image_size=IMAGE_SIZE
                )

            num_width_tiles, num_height_tiles = crop_ratio
        else:
            num_width_tiles = num_height_tiles = 1

        h = w = math.ceil((base_size // patch_size) / downsample_ratio)

        h2 = w2 = math.ceil((image_size // patch_size) / downsample_ratio)

        global_views_tokens = h * (w + 1)
        if num_width_tiles > 1 or num_height_tiles > 1:
            local_views_tokens = (num_height_tiles * h2) * (num_width_tiles * w2 + 1)
        else:
            local_views_tokens = 0

        return global_views_tokens + local_views_tokens + 1

    def get_image_size_with_most_features(self) -> ImageSize:
        if IMAGE_SIZE == 1024 and BASE_SIZE == 1280:
            return ImageSize(width=1024 * 2, height=1024 * 2)
        return ImageSize(width=640 * 2, height=640 * 2)

get_hf_config

get_hf_config()
Source code in vllm/model_executor/models/deepseek_ocr.py
def get_hf_config(self):
    return self.ctx.get_hf_config(DeepseekVLV2Config)

get_hf_processor

get_hf_processor(**kwargs: object)
Source code in vllm/model_executor/models/deepseek_ocr.py
def get_hf_processor(self, **kwargs: object):
    return self.ctx.get_hf_processor(DeepseekOCRProcessor, **kwargs)

get_image_size_with_most_features

get_image_size_with_most_features() -> ImageSize
Source code in vllm/model_executor/models/deepseek_ocr.py
def get_image_size_with_most_features(self) -> ImageSize:
    if IMAGE_SIZE == 1024 and BASE_SIZE == 1280:
        return ImageSize(width=1024 * 2, height=1024 * 2)
    return ImageSize(width=640 * 2, height=640 * 2)

get_num_image_tokens

get_num_image_tokens(
    *,
    image_width: int,
    image_height: int,
    cropping: bool = True,
) -> int
Source code in vllm/model_executor/models/deepseek_ocr.py
def get_num_image_tokens(
    self, *, image_width: int, image_height: int, cropping: bool = True
) -> int:
    image_size = IMAGE_SIZE
    base_size = BASE_SIZE
    patch_size = 16
    downsample_ratio = 4

    if CROP_MODE:
        if image_width <= 640 and image_height <= 640:
            crop_ratio = [1, 1]
        else:
            # find the closest aspect ratio to the target
            crop_ratio = count_tiles(
                image_width, image_height, image_size=IMAGE_SIZE
            )

        num_width_tiles, num_height_tiles = crop_ratio
    else:
        num_width_tiles = num_height_tiles = 1

    h = w = math.ceil((base_size // patch_size) / downsample_ratio)

    h2 = w2 = math.ceil((image_size // patch_size) / downsample_ratio)

    global_views_tokens = h * (w + 1)
    if num_width_tiles > 1 or num_height_tiles > 1:
        local_views_tokens = (num_height_tiles * h2) * (num_width_tiles * w2 + 1)
    else:
        local_views_tokens = 0

    return global_views_tokens + local_views_tokens + 1

get_supported_mm_limits

get_supported_mm_limits() -> Mapping[str, int | None]
Source code in vllm/model_executor/models/deepseek_ocr.py
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
    return {"image": None}

NGramPerReqLogitsProcessor

Bases: AdapterLogitsProcessor

Example of overriding the wrapper class __init__() in order to utilize info about the device type

Source code in vllm/model_executor/models/deepseek_ocr.py
class NGramPerReqLogitsProcessor(AdapterLogitsProcessor):
    """Example of overriding the wrapper class `__init__()` in order to utilize
    info about the device type"""

    def __init__(
        self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool
    ):
        super().__init__(vllm_config, device, is_pin_memory)

    def is_argmax_invariant(self) -> bool:
        return True

    def new_req_logits_processor(
        self,
        params: SamplingParams,
    ) -> RequestLogitsProcessor | None:
        ngram_size = params.extra_args and params.extra_args.get("ngram_size")
        window_size = params.extra_args and params.extra_args.get("window_size", 100)
        whitelist_token_ids = params.extra_args and params.extra_args.get(
            "whitelist_token_ids", None
        )
        if ngram_size is None:
            return None
        if not isinstance(ngram_size, int) or ngram_size <= 0:
            raise ValueError(
                f"`ngram_size` has to be a strictly positive integer, got {ngram_size}."
            )
        if not isinstance(window_size, int) or window_size <= 0:
            raise ValueError(
                "`window_size` has to be a strictly positive integer, "
                f"got {window_size}."
            )
        if whitelist_token_ids is not None and not isinstance(
            whitelist_token_ids, Iterable
        ):
            raise ValueError(
                "`whitelist_token_ids` has to be a set of integers, "
                f"got {whitelist_token_ids}."
            )
        else:
            whitelist_token_ids = (
                set(whitelist_token_ids) if whitelist_token_ids else None
            )
        return NoRepeatNGramLogitsProcessor(
            ngram_size=ngram_size,
            window_size=window_size,
            whitelist_token_ids=whitelist_token_ids,
        )

__init__

__init__(
    vllm_config: VllmConfig,
    device: device,
    is_pin_memory: bool,
)
Source code in vllm/model_executor/models/deepseek_ocr.py
def __init__(
    self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool
):
    super().__init__(vllm_config, device, is_pin_memory)

is_argmax_invariant

is_argmax_invariant() -> bool
Source code in vllm/model_executor/models/deepseek_ocr.py
def is_argmax_invariant(self) -> bool:
    return True

new_req_logits_processor

new_req_logits_processor(
    params: SamplingParams,
) -> RequestLogitsProcessor | None
Source code in vllm/model_executor/models/deepseek_ocr.py
def new_req_logits_processor(
    self,
    params: SamplingParams,
) -> RequestLogitsProcessor | None:
    ngram_size = params.extra_args and params.extra_args.get("ngram_size")
    window_size = params.extra_args and params.extra_args.get("window_size", 100)
    whitelist_token_ids = params.extra_args and params.extra_args.get(
        "whitelist_token_ids", None
    )
    if ngram_size is None:
        return None
    if not isinstance(ngram_size, int) or ngram_size <= 0:
        raise ValueError(
            f"`ngram_size` has to be a strictly positive integer, got {ngram_size}."
        )
    if not isinstance(window_size, int) or window_size <= 0:
        raise ValueError(
            "`window_size` has to be a strictly positive integer, "
            f"got {window_size}."
        )
    if whitelist_token_ids is not None and not isinstance(
        whitelist_token_ids, Iterable
    ):
        raise ValueError(
            "`whitelist_token_ids` has to be a set of integers, "
            f"got {whitelist_token_ids}."
        )
    else:
        whitelist_token_ids = (
            set(whitelist_token_ids) if whitelist_token_ids else None
        )
    return NoRepeatNGramLogitsProcessor(
        ngram_size=ngram_size,
        window_size=window_size,
        whitelist_token_ids=whitelist_token_ids,
    )

NoRepeatNGramLogitsProcessor

Source code in vllm/model_executor/models/deepseek_ocr.py
class NoRepeatNGramLogitsProcessor:
    def __init__(
        self,
        ngram_size: int,
        window_size: int,
        whitelist_token_ids: set[int] | None = None,
    ):
        self.ngram_size = ngram_size
        self.window_size = window_size
        self.whitelist_token_ids = whitelist_token_ids or set()

    def __call__(
        self,
        output_ids: list[int],
        logits: torch.Tensor,
    ) -> torch.Tensor:
        if len(output_ids) < self.ngram_size:
            return logits

        current_prefix = tuple(output_ids[-(self.ngram_size - 1) :])

        search_start = max(0, len(output_ids) - self.window_size)
        search_end = len(output_ids) - self.ngram_size + 1

        banned_tokens = set()
        for i in range(search_start, search_end):
            ngram = tuple(output_ids[i : i + self.ngram_size])
            if ngram[:-1] == current_prefix:
                banned_tokens.add(ngram[-1])

        banned_tokens = banned_tokens - self.whitelist_token_ids

        if banned_tokens:
            logits[list(banned_tokens)] = -float("inf")

        return logits

ngram_size instance-attribute

ngram_size = ngram_size

whitelist_token_ids instance-attribute

whitelist_token_ids = whitelist_token_ids or set()

window_size instance-attribute

window_size = window_size

__call__

__call__(output_ids: list[int], logits: Tensor) -> Tensor
Source code in vllm/model_executor/models/deepseek_ocr.py
def __call__(
    self,
    output_ids: list[int],
    logits: torch.Tensor,
) -> torch.Tensor:
    if len(output_ids) < self.ngram_size:
        return logits

    current_prefix = tuple(output_ids[-(self.ngram_size - 1) :])

    search_start = max(0, len(output_ids) - self.window_size)
    search_end = len(output_ids) - self.ngram_size + 1

    banned_tokens = set()
    for i in range(search_start, search_end):
        ngram = tuple(output_ids[i : i + self.ngram_size])
        if ngram[:-1] == current_prefix:
            banned_tokens.add(ngram[-1])

    banned_tokens = banned_tokens - self.whitelist_token_ids

    if banned_tokens:
        logits[list(banned_tokens)] = -float("inf")

    return logits

__init__

__init__(
    ngram_size: int,
    window_size: int,
    whitelist_token_ids: set[int] | None = None,
)
Source code in vllm/model_executor/models/deepseek_ocr.py
def __init__(
    self,
    ngram_size: int,
    window_size: int,
    whitelist_token_ids: set[int] | None = None,
):
    self.ngram_size = ngram_size
    self.window_size = window_size
    self.whitelist_token_ids = whitelist_token_ids or set()