Skip to content

vllm.compilation.passes.fusion.rope_kvcache_fusion

logger module-attribute

logger = init_logger(__name__)

RopeKVCacheFusionPass

Bases: VllmPatternMatcherPass

This pass fuses the rotary embedding and KV cache update operations into a single fused kernel if available.

It uses the pattern matcher and matches each layer manually, as strings cannot be wildcarded. This also lets us check support on attention layers upon registration instead of during pattern matching.

This fusion eliminates the need for separate kernel launches and intermediate memory operations between the RoPE and cache update steps.

Source code in vllm/compilation/passes/fusion/rope_kvcache_fusion.py
class RopeKVCacheFusionPass(VllmPatternMatcherPass):
    """
    This pass fuses the rotary embedding and KV cache update operations
    into a single fused kernel if available.

    It uses the pattern matcher and matches each layer manually, as strings
    cannot be wildcarded. This also lets us check support on attention layers
    upon registration instead of during pattern matching.

    This fusion eliminates the need for separate kernel launches and
    intermediate memory operations between the RoPE and cache update steps.
    """

    @enable_fake_mode
    def __init__(self, config: VllmConfig) -> None:
        super().__init__(config)

        self.patterns: PatternMatcherPass = PatternMatcherPass(
            pass_name="rope_kv_cache_fusion_pass"
        )

        cc = config.compilation_config
        self.max_token_num = cc.pass_config.rope_kvcache_fusion_max_token_num

        attn_layers = get_layers_from_vllm_config(config, Attention)
        for _, layer in attn_layers.items():
            if layer.impl.fused_rope_kvcache_supported():
                for is_neox in [True, False]:
                    RopeReshapeKVCachePattern(
                        layer=layer,
                        is_neox=is_neox,
                    ).register(self.patterns)

        self.dump_patterns(config, self.patterns)

    @VllmInductorPass.time_and_log
    def __call__(self, graph: fx.Graph) -> None:
        self.matched_count = self.patterns.apply(graph)
        logger.debug("Replaced %s patterns", self.matched_count)

    def is_applicable_for_range(self, compile_range: Range) -> bool:
        # This pass works best for the small-batch decode setting.
        # For large-batch e.g. prefill, it is better to use two separate kernels
        # since they are compute bound and the fused kernels require further tuning.
        return compile_range.end <= self.max_token_num

    def uuid(self) -> str:
        return VllmInductorPass.hash_source(self, RopeReshapeKVCachePattern)

max_token_num instance-attribute

max_token_num = rope_kvcache_fusion_max_token_num

patterns instance-attribute

patterns: PatternMatcherPass = PatternMatcherPass(
    pass_name="rope_kv_cache_fusion_pass"
)

__call__

__call__(graph: Graph) -> None
Source code in vllm/compilation/passes/fusion/rope_kvcache_fusion.py
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
    self.matched_count = self.patterns.apply(graph)
    logger.debug("Replaced %s patterns", self.matched_count)

__init__

__init__(config: VllmConfig) -> None
Source code in vllm/compilation/passes/fusion/rope_kvcache_fusion.py
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
    super().__init__(config)

    self.patterns: PatternMatcherPass = PatternMatcherPass(
        pass_name="rope_kv_cache_fusion_pass"
    )

    cc = config.compilation_config
    self.max_token_num = cc.pass_config.rope_kvcache_fusion_max_token_num

    attn_layers = get_layers_from_vllm_config(config, Attention)
    for _, layer in attn_layers.items():
        if layer.impl.fused_rope_kvcache_supported():
            for is_neox in [True, False]:
                RopeReshapeKVCachePattern(
                    layer=layer,
                    is_neox=is_neox,
                ).register(self.patterns)

    self.dump_patterns(config, self.patterns)

is_applicable_for_range

is_applicable_for_range(compile_range: Range) -> bool
Source code in vllm/compilation/passes/fusion/rope_kvcache_fusion.py
def is_applicable_for_range(self, compile_range: Range) -> bool:
    # This pass works best for the small-batch decode setting.
    # For large-batch e.g. prefill, it is better to use two separate kernels
    # since they are compute bound and the fused kernels require further tuning.
    return compile_range.end <= self.max_token_num

uuid

uuid() -> str
Source code in vllm/compilation/passes/fusion/rope_kvcache_fusion.py
def uuid(self) -> str:
    return VllmInductorPass.hash_source(self, RopeReshapeKVCachePattern)

RopeReshapeKVCachePattern

This pattern matches the following unfused inplace ops

q, k = rotary_embedding(positions, q, k, head_size, cos_sin_cache, is_neox) kv_cache_dummy = unified_kv_cache_update(k, v, layer_name)

and replaces it with the fused inplace op

kv_cache_dummy = fused_rope_and_unified_kv_cache_update( q, k, v, positions, cos_sin_cache, is_neox, layer_name )

Source code in vllm/compilation/passes/fusion/rope_kvcache_fusion.py
class RopeReshapeKVCachePattern:
    """
    This pattern matches the following unfused inplace ops:
      q, k = rotary_embedding(positions, q, k, head_size, cos_sin_cache, is_neox)
      kv_cache_dummy = unified_kv_cache_update(k, v, layer_name)

    and replaces it with the fused inplace op:
      kv_cache_dummy = fused_rope_and_unified_kv_cache_update(
        q, k, v, positions, cos_sin_cache, is_neox, layer_name
      )
    """

    FUSED_OP = torch.ops.vllm.fused_rope_and_unified_kv_cache_update.default

    def __init__(
        self,
        layer: Attention,
        is_neox: bool,
    ) -> None:
        self.layer_name = layer.layer_name
        self.num_heads = layer.num_heads
        self.num_kv_heads = layer.num_kv_heads
        self.head_size = layer.head_size
        self.head_size_v = layer.head_size_v
        self.is_neox = is_neox

        self.q_size = self.num_heads * self.head_size
        self.k_size = self.num_kv_heads * self.head_size
        self.v_size = self.num_kv_heads * self.head_size_v

        self.rope_matcher = MatcherRotaryEmbedding(
            is_neox=self.is_neox,
            head_size=self.head_size,
            num_heads=self.num_heads,
            num_kv_heads=self.num_kv_heads,
        )

    def get_inputs(self) -> list[torch.Tensor]:
        # Sample inputs to help pattern tracing
        T = 5
        L = 4096
        qkv = empty_bf16(T, self.q_size + self.k_size + self.v_size)
        positions = empty_i64(T)
        cos_sin_cache = empty_bf16(L, self.head_size)
        return [
            qkv,
            positions,
            cos_sin_cache,
        ]

    def register(self, pm_pass: PatternMatcherPass) -> None:
        def pattern(
            qkv: torch.Tensor,
            positions: torch.Tensor,
            cos_sin_cache: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
            q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
            q, k = self.rope_matcher(positions, q, k, cos_sin_cache)
            q = q.view(-1, self.num_heads, self.head_size)
            k = k.view(-1, self.num_kv_heads, self.head_size)
            v = v.view(-1, self.num_kv_heads, self.head_size_v)
            dummy = torch.ops.vllm.unified_kv_cache_update(k, v, self.layer_name)
            return dummy, q, k, v

        def replacement(
            qkv: torch.Tensor,
            positions: torch.Tensor,
            cos_sin_cache: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
            q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
            q = q.view(-1, self.num_heads, self.head_size)
            k = k.view(-1, self.num_kv_heads, self.head_size)
            v = v.view(-1, self.num_kv_heads, self.head_size_v)
            results = auto_functionalized(
                self.FUSED_OP,
                query=q,
                key=k,
                value=v,
                positions=positions,
                cos_sin_cache=cos_sin_cache,
                is_neox=self.is_neox,
                layer_name=self.layer_name,
            )
            return results[0], results[1], results[2], v

        # NOTE: use view_to_reshape to unify view/reshape to simplify
        # pattern and increase matching opportunities
        def fwd_and_view_to_reshape(*args, **kwargs) -> fx.GraphModule:
            gm = pm.fwd_only(*args, **kwargs)
            view_to_reshape(gm)
            return gm

        pm.register_replacement(
            pattern, replacement, self.get_inputs(), fwd_and_view_to_reshape, pm_pass
        )

FUSED_OP class-attribute instance-attribute

FUSED_OP = default

head_size instance-attribute

head_size = head_size

head_size_v instance-attribute

head_size_v = head_size_v

is_neox instance-attribute

is_neox = is_neox

k_size instance-attribute

k_size = num_kv_heads * head_size

layer_name instance-attribute

layer_name = layer_name

num_heads instance-attribute

num_heads = num_heads

num_kv_heads instance-attribute

num_kv_heads = num_kv_heads

q_size instance-attribute

q_size = num_heads * head_size

rope_matcher instance-attribute

rope_matcher = MatcherRotaryEmbedding(
    is_neox=is_neox,
    head_size=head_size,
    num_heads=num_heads,
    num_kv_heads=num_kv_heads,
)

v_size instance-attribute

v_size = num_kv_heads * head_size_v

__init__

__init__(layer: Attention, is_neox: bool) -> None
Source code in vllm/compilation/passes/fusion/rope_kvcache_fusion.py
def __init__(
    self,
    layer: Attention,
    is_neox: bool,
) -> None:
    self.layer_name = layer.layer_name
    self.num_heads = layer.num_heads
    self.num_kv_heads = layer.num_kv_heads
    self.head_size = layer.head_size
    self.head_size_v = layer.head_size_v
    self.is_neox = is_neox

    self.q_size = self.num_heads * self.head_size
    self.k_size = self.num_kv_heads * self.head_size
    self.v_size = self.num_kv_heads * self.head_size_v

    self.rope_matcher = MatcherRotaryEmbedding(
        is_neox=self.is_neox,
        head_size=self.head_size,
        num_heads=self.num_heads,
        num_kv_heads=self.num_kv_heads,
    )

get_inputs

get_inputs() -> list[Tensor]
Source code in vllm/compilation/passes/fusion/rope_kvcache_fusion.py
def get_inputs(self) -> list[torch.Tensor]:
    # Sample inputs to help pattern tracing
    T = 5
    L = 4096
    qkv = empty_bf16(T, self.q_size + self.k_size + self.v_size)
    positions = empty_i64(T)
    cos_sin_cache = empty_bf16(L, self.head_size)
    return [
        qkv,
        positions,
        cos_sin_cache,
    ]

register

register(pm_pass: PatternMatcherPass) -> None
Source code in vllm/compilation/passes/fusion/rope_kvcache_fusion.py
def register(self, pm_pass: PatternMatcherPass) -> None:
    def pattern(
        qkv: torch.Tensor,
        positions: torch.Tensor,
        cos_sin_cache: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
        q, k = self.rope_matcher(positions, q, k, cos_sin_cache)
        q = q.view(-1, self.num_heads, self.head_size)
        k = k.view(-1, self.num_kv_heads, self.head_size)
        v = v.view(-1, self.num_kv_heads, self.head_size_v)
        dummy = torch.ops.vllm.unified_kv_cache_update(k, v, self.layer_name)
        return dummy, q, k, v

    def replacement(
        qkv: torch.Tensor,
        positions: torch.Tensor,
        cos_sin_cache: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
        q = q.view(-1, self.num_heads, self.head_size)
        k = k.view(-1, self.num_kv_heads, self.head_size)
        v = v.view(-1, self.num_kv_heads, self.head_size_v)
        results = auto_functionalized(
            self.FUSED_OP,
            query=q,
            key=k,
            value=v,
            positions=positions,
            cos_sin_cache=cos_sin_cache,
            is_neox=self.is_neox,
            layer_name=self.layer_name,
        )
        return results[0], results[1], results[2], v

    # NOTE: use view_to_reshape to unify view/reshape to simplify
    # pattern and increase matching opportunities
    def fwd_and_view_to_reshape(*args, **kwargs) -> fx.GraphModule:
        gm = pm.fwd_only(*args, **kwargs)
        view_to_reshape(gm)
        return gm

    pm.register_replacement(
        pattern, replacement, self.get_inputs(), fwd_and_view_to_reshape, pm_pass
    )

fused_rope_and_unified_kv_cache_update_fake

fused_rope_and_unified_kv_cache_update_fake(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    positions: Tensor,
    cos_sin_cache: Tensor,
    is_neox: bool,
    layer_name: str = "",
) -> Tensor
Source code in vllm/compilation/passes/fusion/rope_kvcache_fusion.py
def fused_rope_and_unified_kv_cache_update_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    positions: torch.Tensor,
    cos_sin_cache: torch.Tensor,
    is_neox: bool,
    layer_name: str = "",
) -> torch.Tensor:
    return torch.empty(0, device=query.device, dtype=query.dtype)

fused_rope_and_unified_kv_cache_update_impl

fused_rope_and_unified_kv_cache_update_impl(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    positions: Tensor,
    cos_sin_cache: Tensor,
    is_neox: bool,
    layer_name: str = "",
) -> Tensor

This impl fetches the KV cache and slot mapping from the forward context, then calls the layer impl's AttentionImpl.do_rope_and_kv_cache_update method. It also returns a dummy tensor, similar to Attention.unified_kv_cache_update, that is passed to unified_attention to signal a side effect and the data dependency between them to ensure torch.compile preserves ordering.

Source code in vllm/compilation/passes/fusion/rope_kvcache_fusion.py
def fused_rope_and_unified_kv_cache_update_impl(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    positions: torch.Tensor,
    cos_sin_cache: torch.Tensor,
    is_neox: bool,
    layer_name: str = "",
) -> torch.Tensor:
    """
    This impl fetches the KV cache and slot mapping from the forward context,
    then calls the layer impl's `AttentionImpl.do_rope_and_kv_cache_update` method.
    It also returns a dummy tensor, similar to `Attention.unified_kv_cache_update`,
    that is passed to unified_attention to signal a side effect and
    the data dependency between them to ensure torch.compile preserves ordering.
    """
    _, attn_layer, kv_cache, layer_slot_mapping = get_attention_context(layer_name)
    if layer_slot_mapping is not None:
        attn_layer.impl.do_rope_and_kv_cache_update(
            attn_layer,
            query,
            key,
            value,
            positions,
            cos_sin_cache,
            is_neox,
            kv_cache,
            layer_slot_mapping,
        )

    return torch.empty(0, device=kv_cache.device, dtype=kv_cache.dtype)