Source code for furiosa_llm.llm_engine

import argparse
from argparse import ArgumentParser
import asyncio
from dataclasses import dataclass, fields
import os
from pathlib import Path
from typing import AsyncGenerator, Dict, List, Literal, Optional, Tuple, Union

from transformers import BatchEncoding, PreTrainedTokenizerBase
from typing_extensions import TypedDict

from furiosa.native_runtime.llm import NativeLLMEngine, NativeRequestOutput  # type: ignore
from furiosa_llm.api import (
    CACHE_DIR,
    LLM,
    CompletionOutput,
    LLMBackend,
    RequestOutput,
    SamplingParams,
    SchedulerConfig,
    TokenizerModeType,
)


# adopted from https://github.com/vllm-project/vllm/blob/main/vllm/inputs/data.py
class TextPrompt(TypedDict):
    """Schema for a text prompt."""

    prompt: str
    """The input text to be tokenized before passing to the model."""


class TokensPrompt(TypedDict):
    """Schema for a tokenized prompt."""

    prompt_token_ids: List[int]
    """A list of token IDs to pass to the model."""


SingletonPrompt = Union[str, TextPrompt, TokensPrompt]

# TODO: support ExplicitEncoderDecoderPrompt later
PromptType = Union[SingletonPrompt]


@dataclass
class EngineArgs:
    # Currently only artifact path is supported
    model: str
    num_speculative_tokens: Optional[int] = None
    use_speculative_decoding_if_possible: bool = True
    data_parallel_size: Optional[int] = None
    tokenizer: Optional[str] = None
    tokenizer_mode: TokenizerModeType = "auto"
    seed: Optional[int] = None
    devices: Optional[str] = None
    cache_dir: os.PathLike = CACHE_DIR
    backend: Optional[LLMBackend] = None
    paged_attention_num_blocks: Optional[int] = None

    # scheduler_config
    npu_queue_limit: Optional[int] = None
    max_processing_samples: Optional[int] = None
    spare_blocks_ratio: Optional[float] = None
    is_offline: Optional[bool] = None

    speculative_model_paged_attention_num_blocks: Optional[int] = None
    packing_type: Literal["IDENTITY"] = "IDENTITY"

    @staticmethod
    def add_cli_args(parser: ArgumentParser) -> ArgumentParser:
        """Shared CLI arguments for vLLM engine."""

        # Model arguments
        parser.add_argument(
            '--model',
            type=str,
            required=True,
            help='Path to the LLM engine artifact (Pretrained id will be supported in the future releases).',
        )
        parser.add_argument(
            '--num-speculative-tokens',
            type=int,
            default=EngineArgs.num_speculative_tokens,
            help='The number of tokens that specualtive model will generate speculatively during each iteration of the decoding process.',
        )
        parser.add_argument(
            '--tokenizer',
            type=str,
            default=EngineArgs.tokenizer,
            help='The name or path of a HuggingFace Transformers tokenizer.',
        )
        parser.add_argument(
            '--tokenizer-mode',
            type=str,
            default=EngineArgs.tokenizer_mode,
            help='The tokenizer mode. "auto" will use the fast tokenizer '
            'if available, and "slow" will always use the slow tokenizer.',
        )
        parser.add_argument(
            '--seed',
            type=int,
            default=EngineArgs.seed,
            help='The seed to initialize the random number generator for sampling.',
        )

        # Furiosa LLM specific arguments
        parser.add_argument(
            '--devices',
            type=str,
            default=EngineArgs.devices,
            help='The devices to run the model. It can be a single device or a list of devices. '
            'Each device can be either "npu:X" or "npu:X:*" where X is a specific device index. '
            'If not given, available devices will be used.',
        )
        parser.add_argument(
            '--data-parallel-size',
            type=int,
            default=EngineArgs.data_parallel_size,
            help='The size of the data parallelism group. '
            'If not given, it will be inferred from total avaialble PEs and other parallelism degrees.',
        )
        parser.add_argument(
            '--cache-dir',
            type=Path,
            default=EngineArgs.cache_dir,
            help='The devices to run the model. It can be a single device or a list of devices. '
            'Each device can be either "npu:X" or "npu:X:*" where X is a specific device index. '
            'If not given, available devices will be used.',
        )
        parser.add_argument(
            '--backend',
            type=str,
            default=EngineArgs.backend,
            help='The backend implementation to run forward() of a model for the LLM. '
            'If not specified, the backend will be chosen based on the device kind.',
        )
        parser.add_argument(
            '--paged-attention-num-blocks',
            type=int,
            default=EngineArgs.paged_attention_num_blocks,
            help='The maximum number of blocks that each k/v storage per layer can store. '
            'This argument must be given if model uses paged attention.',
        )
        parser.add_argument(
            '--npu-queue-limit',
            type=int,
            default=EngineArgs.npu_queue_limit,
            help='The NPU queue limit of the scheduler config.',
        )
        parser.add_argument(
            '--max-processing-samples',
            type=int,
            default=EngineArgs.max_processing_samples,
            help='The maximum processing samples. Used as an hint for the scheduler.',
        )
        parser.add_argument(
            '--spare-blocks-ratio',
            type=float,
            default=EngineArgs.spare_blocks_ratio,
            help='The spare blocks ratio. Used as an hint for the scheduler.',
        )
        parser.add_argument(
            '--is-offline',
            type=bool,
            default=EngineArgs.is_offline,
            help='If True, the scheduler will assume the workload will be offline scenario.',
        )
        parser.add_argument(
            '--use-speculative-decoding-if-possible',
            type=bool,
            default=EngineArgs.use_speculative_decoding_if_possible,
            help="If True, speculative decoding will be used if possible "
            "(`speculative_model` is given or there's artifacts for specualtive model in the artifacts.). "
            "Otherwise, speculative decoding will not be used.",
        )
        parser.add_argument(
            '--speculative-model-paged-attention-num-blocks',
            type=int,
            default=EngineArgs.speculative_model_paged_attention_num_blocks,
            help='The maximum number of blocks that each k/v storage per layer can store for the specualtive model. '
            'This argument must be given if the specualtive model uses paged attention.',
        )
        parser.add_argument(
            '--packing-type',
            type=str,
            default=EngineArgs.packing_type,
            help='Packing algorithm. Possible values are "IDENTITY" only for now',
        )

        return parser

    @classmethod
    def from_cli_args(cls, args: argparse.Namespace):
        attrs = [attr.name for attr in fields(cls)]
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args


@dataclass
class AsyncEngineArgs(EngineArgs):
    # TODO: add async-specific arguments

    @staticmethod
    def add_cli_args(parser: ArgumentParser) -> ArgumentParser:
        # TODO: add async-specific arguments
        parser = EngineArgs.add_cli_args(parser)
        return parser


# TODO: currently LLMEngine wraps LLM, but the relationship must be inverted
[docs] class LLMEngine: def __init__( self, native_engine: NativeLLMEngine, tokenizer, ): self.native_engine = native_engine self.tokenizer = tokenizer # Queue[Tuple[request_id, NativeRequestOutput]] self.queue: asyncio.Queue[Tuple[str, NativeRequestOutput]] = asyncio.Queue() self.unfinished = 0 # Dict[request_id, Tuple[prompt, prompt_token_ids]] self.prompt_cache: Dict[str, Tuple[str, List[int]]] = {} # vllm's accumulates output tokens between each step() calls. # To mimic the behavior, we store outputs and update them in each step() call. self.outputs: Dict[str, List[CompletionOutput]] = {} self.aio_loop = asyncio.new_event_loop() @classmethod def from_llm(cls, llm: LLM) -> "LLMEngine": return cls(llm.engine, llm.tokenizer) @classmethod def from_engine_args(cls, args: EngineArgs) -> "LLMEngine": try: scheduler_config_ = SchedulerConfig.load(f"{args.model}/scheduler_config.json") for scheduler_config_attr in fields(SchedulerConfig): if v := getattr(args, scheduler_config_attr.name, None) is not None: setattr(scheduler_config_, scheduler_config_attr.name, v) except Exception: scheduler_config_ = None llm = LLM.from_artifacts( path=args.model, num_speculative_tokens=args.num_speculative_tokens, use_speculative_decoding_if_possible=args.use_speculative_decoding_if_possible, data_parallel_size=args.data_parallel_size, tokenizer=args.tokenizer, tokenizer_mode=args.tokenizer_mode, seed=args.seed, devices=args.devices, cache_dir=args.cache_dir, backend=args.backend, paged_attention_num_blocks=args.paged_attention_num_blocks, scheduler_config=scheduler_config_, speculative_model_paged_attention_num_blocks=args.speculative_model_paged_attention_num_blocks, packing_type=args.packing_type, ) return cls.from_llm(llm) def add_request( self, request_id: str, prompt: PromptType, sampling_params: Optional[SamplingParams] = None, ) -> None: prompt_str, batch_encoding = preprocess_prompt(prompt, self.tokenizer) prompt_token_ids = batch_encoding["input_ids"] self.unfinished += 1 self.prompt_cache[request_id] = ( prompt_str, prompt_token_ids, ) self.aio_loop.create_task( self._process_request(request_id, batch_encoding, sampling_params) ) async def _process_request( self, request_id: str, batch_encoding: BatchEncoding, sampling_params: Optional[SamplingParams] = None, ) -> None: async for output in self.native_engine.stream_generate(batch_encoding, sampling_params): await self.queue.put((request_id, output)) def has_unfinished_requests(self) -> bool: return self.unfinished > 0 def step(self) -> List[RequestOutput]: output: NativeRequestOutput req_id, output = self.aio_loop.run_until_complete(self.queue.get()) # currently only single output is supported inside `hf_compat_stream_generate` # if the behavior is changed, this part should be updated accordingly assert len(output.outputs) == 1 prompt_str, prompt_token_ids = self.prompt_cache[req_id] if req_id not in self.outputs: completion_outputs = [ CompletionOutput( index=0, text=self.tokenizer.decode( o.token_ids, True, clean_up_tokenization_spaces=True ), token_ids=o.token_ids, finish_reason=o.finish_reason, ) for o in output.outputs ] self.outputs[req_id] = completion_outputs else: for i, o in enumerate(output.outputs): self.outputs[req_id][i].token_ids.extend(o.token_ids) self.outputs[req_id][i].text = self.tokenizer.decode( self.outputs[req_id][i].token_ids, True, clean_up_tokenization_spaces=True ) self.outputs[req_id][i].finish_reason = o.finish_reason completion_outputs = self.outputs[req_id] finished = all(o.finish_reason is not None for o in output.outputs) if finished: self.unfinished -= 1 del self.prompt_cache[req_id] del self.outputs[req_id] return [ RequestOutput( request_id=req_id, prompt=prompt_str, prompt_token_ids=prompt_token_ids, outputs=completion_outputs, finished=finished, ) ]
[docs] class AsyncLLMEngine: def __init__( self, native_engine: NativeLLMEngine, tokenizer, ): self.native_engine = native_engine self.tokenizer = tokenizer @classmethod def from_llm(cls, llm: LLM) -> "AsyncLLMEngine": return cls(llm.engine, llm.tokenizer) @classmethod def from_engine_args(cls, args: AsyncEngineArgs) -> "AsyncLLMEngine": try: scheduler_config_ = SchedulerConfig.load(f"{args.model}/scheduler_config.json") for scheduler_config_attr in fields(SchedulerConfig): if v := getattr(args, scheduler_config_attr.name, None) is not None: setattr(scheduler_config_, scheduler_config_attr.name, v) except Exception: scheduler_config_ = None llm = LLM.from_artifacts( path=args.model, num_speculative_tokens=args.num_speculative_tokens, use_speculative_decoding_if_possible=args.use_speculative_decoding_if_possible, data_parallel_size=args.data_parallel_size, tokenizer=args.tokenizer, tokenizer_mode=args.tokenizer_mode, seed=args.seed, devices=args.devices, cache_dir=args.cache_dir, backend=args.backend, paged_attention_num_blocks=args.paged_attention_num_blocks, scheduler_config=scheduler_config_, speculative_model_paged_attention_num_blocks=args.speculative_model_paged_attention_num_blocks, packing_type=args.packing_type, ) return cls.from_llm(llm) async def generate( self, prompt: PromptType, sampling_params: SamplingParams, request_id: str, ) -> AsyncGenerator[RequestOutput, None]: async for output in self.add_request( request_id, prompt, sampling_params, ): yield output async def add_request( self, request_id: str, prompt: PromptType, sampling_params: Optional[SamplingParams] = None, ) -> AsyncGenerator[RequestOutput, None]: prompt_str, batch_encoding = preprocess_prompt(prompt, self.tokenizer) prompt_token_ids = batch_encoding["input_ids"] completion_outputs = None async for output in self.native_engine.stream_generate(batch_encoding, sampling_params): if completion_outputs is None: completion_outputs = [ CompletionOutput( index=0, text=self.tokenizer.decode( o.token_ids, True, clean_up_tokenization_spaces=True ), token_ids=o.token_ids, finish_reason=o.finish_reason, ) for o in output.outputs ] else: for i, o in enumerate(output.outputs): completion_outputs[i].token_ids.extend(o.token_ids) completion_outputs[i].text = self.tokenizer.decode( completion_outputs[i].token_ids, True, clean_up_tokenization_spaces=True ) completion_outputs[i].finish_reason = o.finish_reason yield RequestOutput( request_id=request_id, prompt=prompt_str, prompt_token_ids=prompt_token_ids, outputs=completion_outputs, finished=all(o.finish_reason is not None for o in output.outputs), )
# TODO # async def engine_step(self): ... # async def abort(self, request_id): ... def preprocess_prompt( prompt: PromptType, tokenizer: PreTrainedTokenizerBase ) -> Tuple[str, BatchEncoding]: if isinstance(prompt, str): prompt_str = prompt elif isinstance(prompt, dict): if "prompt" in prompt: prompt_str = prompt["prompt"] # type: ignore elif "prompt_token_ids" in prompt: # To get BatchEncoding, decode first and then encode again. # This is inefficient but since generator takes BatchEncoding as input this is the safest way. prompt_str = tokenizer.decode(prompt["prompt_token_ids"], skip_special_tokens=True) else: raise ValueError(f"Unsupported prompt type: {type(prompt)}") else: raise ValueError(f"Unsupported prompt type: {type(prompt)}") return prompt_str, tokenizer(prompt_str, padding=False, add_special_tokens=True)