import argparse
from argparse import ArgumentParser
import asyncio
from collections.abc import AsyncGenerator, Iterable
from dataclasses import dataclass, fields
import os
from pathlib import Path
import queue
import threading
from typing import Any, cast
import weakref
import torch
from furiosa_llm.api import (
CACHE_DIR,
DEFAULT_JIT_MAX_WORKERS,
DEFAULT_JIT_THRESHOLD,
DEFAULT_JIT_UNIT_SIZE,
LLM,
NativeEngineLike,
PoolingOutput,
PoolingParams,
PoolingRequestOutput,
RequestOutput,
SamplingParams,
SchedulerConfig,
TokenizerModeType,
)
from furiosa_llm.errors import validate_context_length
from furiosa_llm.metadata.tasks import GENERATION_TASKS, POOLING_TASKS, GenerationTask, PoolingTask
from furiosa_llm.multimodal import (
MultimodalConstraints,
call_native_mm,
compute_expanded_prompt_len,
extract_mm_data,
extract_mm_processor_kwargs,
extract_mm_uuids,
load_processor,
process_mm,
resolve_constraints,
)
from furiosa_llm.outputs import NativeOutputConverter
from furiosa_llm.utils import coalesce
from furiosa_llm.vllm_compat import (
AnyTokenizer,
BatchEncoding,
PromptType,
fit_prompt_to_context,
preprocess_prompt,
)
@dataclass
class EngineArgs:
# Currently only artifact path is supported
model: str
revision: str | None = None
pipeline_parallel_size: int | None = None
data_parallel_size: int | None = None
tokenizer: str | None = None
tokenizer_mode: TokenizerModeType = "auto"
seed: int | None = None
devices: str | None = None
cache_dir: os.PathLike = CACHE_DIR
# scheduler_config
npu_queue_limit: int | None = None
max_processing_samples: int | None = None
spare_blocks_ratio: float | None = None
# wiring related arguments
enable_jit_compilation: bool = False
jit_threshold: int = DEFAULT_JIT_THRESHOLD
jit_max_workers: int = DEFAULT_JIT_MAX_WORKERS
jit_unit_size: int = DEFAULT_JIT_UNIT_SIZE
def __post_init__(self):
"""
Validate JIT-related configuration after initialization.
This method ensures that:
- jit_unit_size is at least 2, as smaller values break batching semantics.
- jit_threshold is at least 1, as 0 would trigger JIT on every request and
is incompatible with the ring buffer logic in JIT triggering.
"""
if self.jit_unit_size < 2:
raise ValueError(f"jit_unit_size must be >= 2, got {self.jit_unit_size}")
if self.jit_threshold < 1:
raise ValueError(f"jit_threshold must be >= 1, got {self.jit_threshold}")
if self.jit_max_workers < 1:
raise ValueError(f"jit_max_workers must be >= 1, got {self.jit_max_workers}")
@staticmethod
def add_cli_args(parser: ArgumentParser) -> ArgumentParser:
"""Shared CLI arguments for vLLM engine."""
# Model arguments
parser.add_argument(
'--model',
type=str,
required=True,
help='The Hugging Face model id, or path to Furiosa model artifact. Currently only one model is supported per server.',
)
parser.add_argument(
"--revision",
type=str,
default=EngineArgs.revision,
help="The specific model revision on Hugging Face Hub if the model is given as a Hugging Face model id. It can be a branch name, a tag name, or a commit id."
" Its default value is main. However, if a given model belongs to the furiosa-ai organization, the model will use the release model tag by default.",
)
parser.add_argument(
'--tokenizer',
type=str,
default=EngineArgs.tokenizer,
help='The name or path of a HuggingFace Transformers tokenizer.',
)
parser.add_argument(
'--tokenizer-mode',
type=str,
default=EngineArgs.tokenizer_mode,
help='The tokenizer mode. "auto" will use the fast tokenizer '
'if available, and "slow" will always use the slow tokenizer.',
)
parser.add_argument(
'--seed',
type=int,
default=EngineArgs.seed,
help='The seed to initialize the random number generator for sampling.',
)
# Furiosa LLM specific arguments
parser.add_argument(
'--devices',
type=str,
default=EngineArgs.devices,
help='The devices to run the model. It can be a single device or a comma-separated list of devices. '
'Each device can be either "npu:X" or "npu:X:Y", where X is a device index and Y is a NPU core range notation '
'(e.g. "npu:0" for whole npu 0, "npu:0:0" for core 0 of NPU 0, and "npu:0:0-3" for fused core 0-3 of npu 0). '
'If not given, all available unoccupied devices will be used.',
)
parser.add_argument(
'--pipeline-parallel-size',
type=int,
default=EngineArgs.pipeline_parallel_size,
help='The size of the pipeline parallelism group. '
'If not given, it will use the value from artifact.',
)
parser.add_argument(
'--data-parallel-size',
type=int,
default=EngineArgs.data_parallel_size,
help='The size of the data parallelism group. '
'If not given, it will be inferred from total available PEs and other parallelism degrees.',
)
parser.add_argument(
'--cache-dir',
type=Path,
default=EngineArgs.cache_dir,
help='The cache directory for temporarily generated files for this LLM instance. '
'When its value is ``None``, caching is disabled. The default is "$HOME/.cache/furiosa/llm".',
)
parser.add_argument(
'--npu-queue-limit',
type=int,
default=EngineArgs.npu_queue_limit,
help='The NPU queue limit of the scheduler config.',
)
parser.add_argument(
'--max-processing-samples',
type=int,
default=EngineArgs.max_processing_samples,
help='The maximum processing samples. Used as an hint for the scheduler.',
)
parser.add_argument(
'--spare-blocks-ratio',
type=float,
default=EngineArgs.spare_blocks_ratio,
help='The spare blocks ratio. Used as an hint for the scheduler.',
)
parser.add_argument(
'--enable-jit-compilation',
default=EngineArgs.enable_jit_compilation,
help='[EXPERIMENTAL] Enable JIT compilation.',
)
parser.add_argument(
'--jit-threshold',
type=int,
default=EngineArgs.jit_threshold,
help='[EXPERIMENTAL] Number of requests before triggering JIT compilation.',
)
parser.add_argument(
'--jit-max-workers',
type=int,
default=EngineArgs.jit_max_workers,
help='[EXPERIMENTAL] Maximum concurrent background JIT compilations.',
)
parser.add_argument(
'--jit-unit-size',
type=int,
default=EngineArgs.jit_unit_size,
help='[EXPERIMENTAL] Number of stages to compile together. Must be >= 2.',
)
return parser
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
attrs = [attr.name for attr in fields(cls)]
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
return engine_args
@dataclass
class AsyncEngineArgs(EngineArgs):
# TODO: add async-specific arguments
@staticmethod
def add_cli_args(parser: ArgumentParser) -> ArgumentParser:
# TODO: add async-specific arguments
parser = EngineArgs.add_cli_args(parser)
return parser
# XXX: Since SamplingParams.max_tokens in Rust is not an Option type,
# we must ensure max_tokens is not None when SamplingParams is converted from Python to Rust.
# That's why the validation logic is duplicated here and in `LLM._verify_token_len_and_finalize_max_tokens`.
# Unfortunately there is no way to avoid this duplication while minimizing unnecessary encode/decode operations
# and keeping the Python API compatible with vLLM at the same time.
# The best solution would be to change SamplingParams.max_tokens in Rust to an Option type in the future.
# related PR: https://github.com/furiosa-ai/furiosa-runtime/pull/1260
class LLMEngineBase:
max_model_len: int
model_path: str | None
trust_remote_code: bool
native_engine: "NativeEngineLike"
request_ids: set[str]
# TODO: Also do __verify_sampling_params_with_generator_config
def try_add_request_id(self, request_id: str) -> None:
if request_id in self.request_ids:
raise ValueError(f"Request ID {request_id} already exist in the engine.")
if not isinstance(request_id, str):
raise ValueError(f"Request ID {request_id} must be a string.")
self.request_ids.add(request_id)
def try_remove_request_id(self, request_id: str) -> None:
try:
self.request_ids.remove(request_id)
except KeyError:
pass
def _preprocess_multimodal_inputs(
self,
prompt: PromptType,
prompt_token_ids: list[int],
max_completion_tokens: int | None,
) -> tuple[list, list[str], MultimodalConstraints | None]:
"""Run MM preprocessing and post-expansion length validation.
Returns `(mm_inputs, mm_modalities, constraints)`. When the prompt
carries no `multi_modal_data` the tuple is `([], [], None)` and no
processor is loaded.
`validate_context_length` is re-run on the post-placeholder-expansion
length so an oversized image (which the text-side pre-validate only
counted as 1 token per `<|image_pad|>`) surfaces as
`ContextLengthExceededError` here rather than as a Rust-side
bucket-overflow abort.
"""
mm_data = extract_mm_data(prompt)
if not mm_data:
return [], [], None
if self.model_path is None:
cls = type(self).__name__
raise ValueError(
f"received a multimodal request but {cls} has no model_path; "
f"construct it via {cls}.from_llm or pass model_path explicitly"
)
processor = load_processor(self.model_path, self.trust_remote_code)
constraints = resolve_constraints(
self.native_engine, self.model_path, self.trust_remote_code
)
mm_inputs, mm_modalities = process_mm(
processor,
mm_data,
user_kwargs=extract_mm_processor_kwargs(prompt),
constraints=constraints,
mm_uuids=extract_mm_uuids(prompt),
)
validate_context_length(
max_model_len=self.max_model_len,
prompt_tokens=compute_expanded_prompt_len(
prompt_token_ids=prompt_token_ids,
mm_inputs=mm_inputs,
mm_modalities=mm_modalities,
processor=processor,
constraints=constraints,
),
max_completion_tokens=max_completion_tokens,
)
return mm_inputs, mm_modalities, constraints
def _engine_finalize(aio_loop, loop_thread) -> None:
aio_loop.call_soon_threadsafe(aio_loop.stop)
loop_thread.join(timeout=5.0)
[docs]
class LLMEngine(LLMEngineBase):
"""
LLMEngine receives requests and generates texts.
Implements the API interface compatible with vLLM's `LLMEngine`, but this class is based on furiosa-runtime and FuriosaAI NPU.
The request scheduling approach of this engine is different from that of vLLM's . While vLLM provides
fine-grained control over decoding via the `step` method, this engine immediately begins
text generation in the background as soon as a request is submitted via :meth:`add_request`,
continuing asynchronously until completion. The generated results are placed in a queue that
clients can retrieve by calling :meth:`step`.
The Furiosa native engine handles scheduling and batching internally,
allowing clients to retrieve results via :meth:`step` calls without needing to manage the decoding schedule.
"""
def __init__(
self,
native_engine: NativeEngineLike,
tokenizer: AnyTokenizer,
task_type: GenerationTask | PoolingTask,
max_model_len: int,
llm: LLM | None = None, # keep reference to LLM to prevent native engine shutdown
model_path: str | None = None,
trust_remote_code: bool = False,
):
self.native_engine = native_engine
self.tokenizer = tokenizer
self.task_type = task_type
self.max_model_len = max_model_len
self._llm = llm
self.model_path = model_path or (
getattr(llm, "model_id_or_path", None) if llm is not None else None
)
self.trust_remote_code = trust_remote_code
self.queue: queue.Queue[RequestOutput | PoolingRequestOutput] = queue.Queue()
self.request_ids = set()
self.aio_loop = asyncio.new_event_loop()
loop = self.aio_loop
def run_loop():
asyncio.set_event_loop(loop)
try:
loop.run_forever()
finally:
loop.run_until_complete(loop.shutdown_asyncgens())
loop.close()
self._loop_thread = threading.Thread(target=run_loop, daemon=True)
self._loop_thread.start()
self._finalizer = weakref.finalize(self, _engine_finalize, self.aio_loop, self._loop_thread)
def shutdown(self):
self._finalizer()
def __enter__(self) -> "LLMEngine":
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self.shutdown()
@classmethod
def from_llm(
cls,
llm: LLM,
) -> "LLMEngine":
task_type = llm.model_metadata.task
assert task_type is not None, "Failed to infer task type from model metadata."
return cls(
llm.engine,
llm.tokenizer,
task_type,
llm.max_model_len,
llm=llm,
trust_remote_code=bool(llm.model_metadata.trust_remote_code),
)
[docs]
@classmethod
def from_engine_args(cls, args: EngineArgs) -> "LLMEngine":
"""
Creates an LLMEngine from EngineArgs.
"""
scheduler_config = SchedulerConfig.load_from_args(args)
llm = LLM(
model_id_or_path=args.model,
revision=args.revision,
pipeline_parallel_size=args.pipeline_parallel_size,
data_parallel_size=args.data_parallel_size,
tokenizer=args.tokenizer,
tokenizer_mode=args.tokenizer_mode,
seed=args.seed,
devices=args.devices,
cache_dir=args.cache_dir,
scheduler_config=scheduler_config,
enable_jit_compilation=args.enable_jit_compilation,
jit_threshold=args.jit_threshold,
jit_max_workers=args.jit_max_workers,
jit_unit_size=args.jit_unit_size,
)
return cls.from_llm(llm)
[docs]
def add_request(
self,
request_id: str,
prompt: PromptType,
params: SamplingParams | PoolingParams,
) -> None:
"""
Adds a new request to the engine.
The decoding iteration starts immediately after adding the request.
Args:
request_id: The unique id of the request.
prompt: The prompt to the LLM.
params: The sampling | pooling parameters of the request.
"""
self.try_add_request_id(request_id)
if isinstance(params, SamplingParams):
"""
Note: `add_special_tokens` is set to `False` for text generation requests.
This is because the majority of use cases rely on chat templates, which already include special tokens.
If special tokens need to be added manually, the caller must handle encoding themselves.
While this approach may seem unconventional, it is necessary for compatibility with vLLM,
as there is no straightforward way to pass `add_special_tokens` in this context.
"""
if self.task_type not in GENERATION_TASKS:
raise ValueError("The model does not support text generation tasks.")
batch_encoding, prompt_getter = preprocess_prompt(
prompt, self.tokenizer, {"padding": False, "add_special_tokens": False}
)
prompt_token_ids = batch_encoding.input_ids
validate_context_length(
max_model_len=self.max_model_len,
prompt_tokens=len(prompt_token_ids),
max_completion_tokens=params.max_tokens,
)
n = params.n if params is not None else 1
mm_inputs, mm_modalities, constraints = self._preprocess_multimodal_inputs(
prompt, prompt_token_ids, params.max_tokens
)
# TODO: call prompt_getter after calling `self.native_engine.stream_generate` to reduce latency
prompt_str = prompt_getter()
converter = NativeOutputConverter(
self.tokenizer,
n,
params.output_kind,
params.skip_special_tokens,
request_id,
prompt_str,
prompt_token_ids,
)
asyncio.run_coroutine_threadsafe(
self._process_generation_request(
request_id,
converter,
batch_encoding,
params,
mm_inputs=mm_inputs,
mm_modalities=mm_modalities,
constraints=constraints,
),
self.aio_loop,
)
elif isinstance(params, PoolingParams):
if self.task_type not in POOLING_TASKS:
raise ValueError("The model does not support pooling tasks.")
# If params.task is not set, infer the task from model metadata.
params.task = coalesce(params.task, cast(PoolingTask, self.task_type))
batch_encoding, _ = preprocess_prompt(prompt, self.tokenizer)
fit_prompt_to_context(
batch_encoding,
truncate_prompt_tokens=params.truncate_prompt_tokens,
max_model_len=self.max_model_len,
)
asyncio.run_coroutine_threadsafe(
self._process_encoding_request(request_id, batch_encoding, params),
self.aio_loop,
)
else:
raise ValueError(
f"Unsupported sampling parameters type: {type(params)}. "
"Expected SamplingParams or PoolingParams."
)
[docs]
def abort_request(self, request_id: str | Iterable[str]):
"""
Aborts request(s) with the given ID.
"""
if isinstance(request_id, str):
request_id = [request_id]
for rid in request_id:
self.native_engine.abort_request(rid)
self.try_remove_request_id(rid)
async def _process_generation_request(
self,
request_id: str,
converter: NativeOutputConverter,
batch_encoding: BatchEncoding,
sampling_params: SamplingParams,
mm_inputs: list[Any] | None = None,
mm_modalities: list[str] | None = None,
constraints: Any | None = None,
) -> None:
if mm_inputs:
assert constraints is not None
assert mm_modalities is not None
native_output_generator = call_native_mm(
self.native_engine,
batch_encoding,
sampling_params,
mm_inputs,
mm_modalities,
constraints,
request_id=request_id,
)
else:
native_output_generator = self.native_engine.stream_generate(
batch_encoding, sampling_params, request_id
)
async for request_output in converter.convert_stream(native_output_generator):
self.queue.put(request_output)
async def _process_encoding_request(
self, request_id: str, batch_encoding: BatchEncoding, pooling_params: PoolingParams
) -> None:
assert (
pooling_params.task is not None
), "PoolingParams.task must be set for encoding requests."
native_outputs = await self.native_engine.encode(batch_encoding, pooling_params, request_id)
if not native_outputs:
raise ValueError("No outputs returned from encode; cannot process encoding request.")
native_output = native_outputs[0]
pooling_output = PoolingRequestOutput(
request_id=request_id,
prompt_token_ids=batch_encoding.input_ids,
outputs=PoolingOutput(data=torch.Tensor(native_output.data)),
finished=True,
)
self.queue.put(pooling_output)
[docs]
def has_unfinished_requests(self) -> bool:
"""
Returns True if there are unfinished requests.
"""
return len(self.request_ids) > 0
[docs]
def step(self) -> list[RequestOutput | PoolingRequestOutput]:
"""
Returns newly generated results of one decoding iteration from the queue.
"""
# ensure at least one output is returned
req_output = self.queue.get()
# get as many outputs as possible
req_outputs = [req_output]
while True:
try:
req_outputs.append(self.queue.get_nowait())
except queue.Empty:
break
# ignore aborted request
req_outputs = [r for r in req_outputs if r.request_id in self.request_ids]
# remove finished requests
for req_output in req_outputs:
if req_output.finished:
self.try_remove_request_id(req_output.request_id)
return req_outputs
[docs]
class AsyncLLMEngine(LLMEngineBase):
"""
AsyncLLMEngine receives requests and generates texts asynchronously.
Implements the API interface compatible with vLLM's `AsyncLLMEngine`, but this class is based on furiosa-runtime and FuriosaAI NPU.
"""
def __init__(
self,
native_engine: NativeEngineLike,
tokenizer: AnyTokenizer,
task_type: GenerationTask | PoolingTask,
max_model_len: int,
llm: LLM | None = None, # keep reference to LLM to prevent native engine shutdown
model_path: str | None = None,
trust_remote_code: bool = False,
):
self.native_engine = native_engine
self.tokenizer = tokenizer
self.task_type = task_type
self.max_model_len = max_model_len
self._llm = llm
# Resolved model snapshot dir for AutoProcessor; required when MM
# requests arrive. Falls back to `llm.model_id_or_path` when given.
self.model_path = model_path or (
getattr(llm, "model_id_or_path", None) if llm is not None else None
)
self.trust_remote_code = trust_remote_code
self.request_ids = set()
def __enter__(self) -> "AsyncLLMEngine":
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
# AsyncLLMEngine owns no resources of its own (no aio loop, no native
# engine — both are held by the caller and the LLM, respectively),
# so __exit__ is a no-op. Provided only for API parity with LLMEngine
# and LLM so `with AsyncLLMEngine.from_llm(llm) as engine:` works.
return None
@classmethod
def from_llm(
cls,
llm: LLM,
) -> "AsyncLLMEngine":
task_type = llm.model_metadata.task
assert task_type is not None, "Failed to infer task type from model metadata."
return cls(
llm.engine,
llm.tokenizer,
task_type,
llm.max_model_len,
llm=llm,
trust_remote_code=bool(llm.model_metadata.trust_remote_code),
)
[docs]
@classmethod
def from_engine_args(cls, args: AsyncEngineArgs) -> "AsyncLLMEngine":
"""
Creates an AsyncLLMEngine from AsyncEngineArgs.
"""
scheduler_config = SchedulerConfig.load_from_args(args)
llm = LLM(
model_id_or_path=args.model,
revision=args.revision,
pipeline_parallel_size=args.pipeline_parallel_size,
data_parallel_size=args.data_parallel_size,
tokenizer=args.tokenizer,
tokenizer_mode=args.tokenizer_mode,
seed=args.seed,
devices=args.devices,
cache_dir=args.cache_dir,
scheduler_config=scheduler_config,
enable_jit_compilation=args.enable_jit_compilation,
jit_threshold=args.jit_threshold,
jit_max_workers=args.jit_max_workers,
jit_unit_size=args.jit_unit_size,
)
return cls.from_llm(llm)
[docs]
async def generate(
self,
prompt: PromptType,
sampling_params: SamplingParams,
request_id: str,
) -> AsyncGenerator[RequestOutput, None]:
"""
Generates text completions for a given prompt.
Args:
prompt: The prompt to the LLM. See :class:`~PromptType`
for more details about the format of each input.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
"""
if self.task_type not in GENERATION_TASKS:
raise ValueError("The model does not support text generation tasks.")
self.try_add_request_id(request_id)
# XXX, ditto:`add_special_tokens` is set to `False` for text generation requests.
# See LLMEngine.add_request for more details.
batch_encoding, prompt_getter = preprocess_prompt(
prompt, self.tokenizer, {"padding": False, "add_special_tokens": False}
)
prompt_token_ids = batch_encoding.input_ids
validate_context_length(
max_model_len=self.max_model_len,
prompt_tokens=len(prompt_token_ids),
max_completion_tokens=sampling_params.max_tokens,
)
mm_inputs, mm_modalities, constraints = self._preprocess_multimodal_inputs(
prompt, prompt_token_ids, sampling_params.max_tokens
)
if mm_inputs:
assert constraints is not None
native_output_generator = call_native_mm(
self.native_engine,
batch_encoding,
sampling_params,
mm_inputs,
mm_modalities,
constraints,
request_id=request_id,
)
else:
native_output_generator = self.native_engine.stream_generate(
batch_encoding, sampling_params, request_id
)
prompt_str = prompt_getter()
converter = NativeOutputConverter(
self.tokenizer,
sampling_params.n,
sampling_params.output_kind,
sampling_params.skip_special_tokens,
request_id,
prompt_str,
prompt_token_ids,
)
async for request_output in converter.convert_stream(native_output_generator):
yield request_output
self.try_remove_request_id(request_id)
[docs]
async def encode(
self,
prompt: PromptType,
pooling_params: PoolingParams,
request_id: str,
lora_request=None,
trace_headers=None,
priority=None,
truncate_prompt_tokens: int | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
) -> AsyncGenerator[PoolingRequestOutput, None]:
"""
Apply pooling to the hidden states corresponding to the input
prompts.
`lora_request`, `trace_headers`, `truncate_prompt_tokens`, and `priority` are not supported.
They are just placeholders for compatibility with the vLLM API.
"""
if self.task_type not in POOLING_TASKS:
raise ValueError("The model does not support pooling tasks.")
# If pooling_params.task is not set, infer the task from model metadata.
pooling_params.task = coalesce(pooling_params.task, cast(PoolingTask, self.task_type))
assert (
pooling_params.task is not None
), "PoolingParams.task must be set for encoding requests."
batch_encoding, _ = preprocess_prompt(prompt, self.tokenizer)
fit_prompt_to_context(
batch_encoding,
truncate_prompt_tokens=pooling_params.truncate_prompt_tokens,
max_model_len=self.max_model_len,
)
self.try_add_request_id(request_id)
native_outputs = await self.native_engine.encode(batch_encoding, pooling_params, request_id)
native_output = native_outputs[0] # type: ignore
yield PoolingRequestOutput(
request_id=request_id,
prompt_token_ids=batch_encoding.input_ids,
outputs=PoolingOutput(data=torch.Tensor(native_output.data)),
finished=True,
)
self.try_remove_request_id(request_id)
[docs]
async def abort(self, request_id: str) -> None:
"""
Aborts a request with the given ID.
"""
self.native_engine.abort_request(request_id)
self.try_remove_request_id(request_id)
# TODO
# async def engine_step(self): ...