书接上文,今天起来后感谢合作者的点醒,大概发现了如何用 vllm 去 serve embedding model,并且成功利用 serve 了 gte-7b。
vllm 如何处理 embedding/completion 请求?
这里观察两个位于 /vllm/engine/async_llm_engine.py 下的函数(为了方便将部分注释删去):
async def generate(
self,
inputs: PromptInputs,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncGenerator[RequestOutput, None]:
async for output in await self.add_request(
request_id,
inputs,
sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
):
yield LLMEngine.validate_output(output, RequestOutput)
async def encode(
self,
inputs: PromptInputs,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
async for output in await self.add_request(
request_id,
inputs,
pooling_params,
lora_request=lora_request,
trace_headers=trace_headers,
):
yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
实际上就是每次用 OpenAI 的结构调用 embedding or completion 接口时,会分别调用上方的 encode 函数和 generate 函数,得到 embedding 或者 completion。看上去在 vllm 中,任何一个模型都可以接受 embedding 与 completion 请求。
如何魔改 Qwen2ForCausalLM 来支持 embedding 请求?
直接用 vllm serve gte-7b:
CUDA_VISIBLE_DEVICES=0 vllm serve 7embed --dtype auto --api-key \
sk-1dwqsdv4r3wef3rvefg34ef1dwRv --tensor-parallel-size 1 \
--max-model-len 32768 --enforce-eager \
--disable-custom-all-reduce --port 7777 --served-model-name e5_7b
然后发送 embedding 请求,会出错误(pooler not implemented)。
我们进一步观察 vllm 里面 support 的 qwen2 模型(vllm/model_executor/models/qwen2.py):
class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
]
embedding_modules = {}
embedding_padding_modules = []
def __init__(
self,
config: Qwen2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
# TODO (@robertgshaw2): see if this can be moved out
if (cache_config.sliding_window is not None
and hasattr(config, "max_window_layers")):
raise ValueError("Sliding window for some but all layers is not "
"supported. This model uses sliding window "
"but `max_window_layers` = %s is less than "
"`num_hidden_layers` = %s. Please open an issue "
"to discuss this feature." % (
config.max_window_layers,
config.num_hidden_layers,
))
super().__init__()
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = Qwen2Model(config, cache_config, quant_config)
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
然后观察 SGLang 里面的 qwen2 model 文件(python/sglang/srt/models/qwen2.py):
class Qwen2ForCausalLM(nn.Module):
def __init__(
self,
config: Qwen2Config,
quant_config: Optional[QuantizationConfig] = None,
cache_config: Optional[CacheConfig] = None,
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = Qwen2BaseModel(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
这里发现了非常有趣的事情,和 SGLang 不同的是,vllm 里每个 model 的 forward 函数仅仅返回了 hidden_states,而 hidden_states 的 logits_processor 是在 compute_logits 函数里实现的。SGLang 的 forward 函数却将 vllm 里面的 forward 和 logits_processor 合在了一起,直接一步返回了 logits。基于如上的设计,vllm 的 generate 请求实际上调用的是 compute_logits 函数,SGLang 的 generate 请求调用的是 forward 函数。
叙述到此处并不能体现出二者的区别,但是考虑到 embedding 请求时,这个事情就颇有意思了。gte 这个模型的 architecture 是 Qwen2ForCausalLM,vllm 将 gte 模型根据 architecture 映射到 Qwen2ForCausalLM 这个类后,处理 embedding 请求时会试图调用 pooler 函数。因此,我们只需要在 vllm 已经实现的 Qwen2ForCausalLM 类下加上 pooler 即可:
class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
]
embedding_modules = {}
embedding_padding_modules = []
def __init__(
self,
config: Qwen2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
# TODO (@robertgshaw2): see if this can be moved out
if (cache_config.sliding_window is not None
and hasattr(config, "max_window_layers")):
raise ValueError("Sliding window for some but all layers is not "
"supported. This model uses sliding window "
"but `max_window_layers` = %s is less than "
"`num_hidden_layers` = %s. Please open an issue "
"to discuss this feature." % (
config.max_window_layers,
config.num_hidden_layers,
))
super().__init__()
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = Qwen2Model(config, cache_config, quant_config)
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
self.sampler = Sampler()
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
而 pooler 相关的代码可以直接从 vllm/model_executor/models/llama_embedding.py 上 copy。
如此以来,加入 pooler 后,任意一个 architecture 为 Qwen2ForCausalLM 都可以支持 embedding 请求了。道理是这样的:
对于任意一个 architecture 为 Qwen2ForCausalLM 的模型(譬如 Qwen/Qwen2-72B-Instruct 和 Alibaba-NLP/gte-Qwen2-7B-instruct,这两个模型一个是 completion 模型,另一个是 embedding 模型,但是 architecture 都是 Qwen2ForCausalLM),这个模型会被映射到 Qwen2ForCausalLM 这个类上。当用户调用 completion 请求时,engine 会调用 compute_logits 函数,而用户调用 embedding 请求时,engine 会调用 pooler 函数。从而即便是 embedding model 和 completion model 被映射到同一个类上,分别将 embedding 请求和 completion 请求对应到不同的函数上就可以避免冲突。
反过来,考虑 SGLang 的实现。在 SGLang 中,对于一个 server 的 embedding 请求和 completion 请求都会调用 class 的 forward 函数。如同我上篇博文所讲,forward 函数没法在不加参数的情况下区分用户究竟想要得到 embedding 还是 completion。所以 SGLang 暂时没法通过类似的方法进行更改。
完成上述的更改后,回到 vllm 的 vllm/model_executor/models/__init__.py 文件中,将 gte 映射到 Qwen2ForCausalLM 即可。
_EMBEDDING_MODELS = {
"MistralModel": ("llama_embedding", "LlamaEmbeddingModel"),
"Qwen2ForCausalLM": ("Qwen2ForCausalLM"),
}
回顾与问题
回顾下利用 vllm support gte 的过程,我们发现 vllm 将 embedding 请求与 completion 请求分设接口的设计极大的帮助了我们扩展接口。而 SGLang 由于两种请求没有分设接口,因此出现了同一个 architecture 没法映射到两个 class(实现两个 forward 函数)的冲突。
下一步打算利用这种方式修改 SGLang 的接口,然后让 SGLang support gte 模型。
当然,vlllm 的 support 似乎还没有完善。我成功 support 起 gte 模型后,尝试将得到的 embedding 与 sentence_former 的 embedding 进行对拍,发现 vllm 返回的 embedding 存在两个问题:
hidden state 的维度是 sentence_former 的两倍;
hidden state 的偶数维全是 0;
hidden state 的数值远大于 sentence_former 的数值,我怀疑是归一化问题,但是目前没有解决。
总之,麻辣系统真是博大精深呀!