Youtu Embedding
本地部署
以下文档提供将 Youtu Embedding 部署为 Youtu-RAG 后端服务的说明。
下载 Youtu Embedding 模型权重
从官方仓库下载预训练模型权重:
git lfs install
git clone https://huggingface.co/tencent/Youtu-Embedding
# 如果您希望使用 BF16 权重以减少内存使用,也可以克隆 BF16 分支:
# git clone -b bfloat16 --single-branch https://huggingface.co/tencent/Youtu-Embedding安装服务器依赖项
接下来,安装运行 Youtu Embedding 服务器所需的依赖项:
pip install transformers==4.51.3 torch numpy scipy scikit-learn huggingface_hub fastapi uvicorn运行 Youtu Embedding 服务器
将以下代码保存为 embedding_server.py:
import fastapi
from fastapi.responses import JSONResponse
import uvicorn
from argparse import ArgumentParser
from transformers import AutoModel, AutoTokenizer
from typing import List
from pydantic import BaseModel
import torch
import base64
import numpy as np
class LLMEmbeddingModel():
def __init__(self,
model_name_or_path,
batch_size=128,
max_length=1024,
gpu_id=0):
"""Local embedding model with automatic device selection"""
self.model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True)
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side="right", trust_remote_code=True)
# Device selection: CUDA -> MPS -> CPU
if torch.cuda.is_available():
self.device = torch.device(f"cuda:{gpu_id}")
elif torch.backends.mps.is_available():
self.device = torch.device("mps")
else:
self.device = torch.device("cpu")
self.model.to(self.device).eval()
self.max_length = max_length
self.batch_size = batch_size
query_instruction = "Given a search query, retrieve passages that answer the question"
if query_instruction:
self.query_instruction = f"Instruction: {query_instruction} \nQuery:"
else:
self.query_instruction = "Query:"
self.doc_instruction = ""
print(f"Model loaded: {model_name_or_path}")
print(f"Device: {self.device}")
def mean_pooling(self, hidden_state, attention_mask):
s = torch.sum(hidden_state * attention_mask.unsqueeze(-1).float(), dim=1)
d = attention_mask.sum(dim=1, keepdim=True).float()
embedding = s / d
return embedding
@torch.no_grad()
def encode(self, sentences_batch, instruction):
inputs = self.tokenizer(
sentences_batch,
padding=True,
truncation=True,
return_tensors="pt",
max_length=self.max_length,
add_special_tokens=True,
)
# Move inputs to device
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
last_hidden_state = outputs[0]
instruction_tokens = self.tokenizer(
instruction,
padding=False,
truncation=True,
max_length=self.max_length,
add_special_tokens=True,
)["input_ids"]
if len(np.shape(np.array(instruction_tokens))) == 1:
if len(instruction_tokens) > 0:
inputs["attention_mask"][:, :len(instruction_tokens)] = 0
else:
instruction_length = [len(item) for item in instruction_tokens]
# assert len(instruction) == len(sentences_batch) # instruction passed is string
for idx in range(len(instruction_length)):
inputs["attention_mask"][idx, :instruction_length[idx]] = 0
embeddings = self.mean_pooling(last_hidden_state, inputs["attention_mask"])
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
return embeddings
def encode_queries(self, queries):
queries = queries if isinstance(queries, list) else [queries]
queries = [f"{self.query_instruction}{query}" for query in queries]
return self.encode(queries, self.query_instruction)
def encode_passages(self, passages):
passages = passages if isinstance(passages, list) else [passages]
passages = [f"{self.doc_instruction}{passage}" for passage in passages]
return self.encode(passages, self.doc_instruction)
# --- Server Logic ---
def parse_args():
parser = ArgumentParser()
parser.add_argument(
"--checkpoint",
default="./Youtu-Embedding"
)
parser.add_argument("--max_length", default=1024, type=int)
parser.add_argument("--port", default=8081, type=int)
parser.add_argument("--host", default="0.0.0.0")
args = parser.parse_args()
return args
args = parse_args()
print(args)
# Initialize global model
# Note: In a real production setup, one might delay loading or handle it differently,
# but for this script we load at startup.
try:
model_wrapper = LLMEmbeddingModel(
model_name_or_path=args.checkpoint,
max_length=args.max_length
)
except Exception as e:
print(f"Error loading model from {args.checkpoint}: {e}")
print("Please ensure the checkpoint path is correct and contains the model files.")
# We allow the app to start but API calls might fail if model_wrapper is not defined
# But usually better to crash if model is missing.
# Re-raising.
raise e
app = fastapi.FastAPI()
class Query(BaseModel):
query: str
class Doc(BaseModel):
docs: List[str]
class InputText(BaseModel):
texts: List[str]
instruction: str = ""
@app.post("/embed_query")
def embed_query(query: Query):
if query.query == "":
text = " "
else:
text = query.query
# Use model's encode_queries logic
embedding_tensor = model_wrapper.encode_queries([text])
embedding = embedding_tensor.cpu().numpy()
rsp = {
"query": query.query,
"embedding": base64.b64encode(embedding.tobytes()).decode("ascii"),
"shape": embedding.shape,
}
return JSONResponse(rsp)
@app.post("/embed_docs")
def embed_doc(docs: Doc):
if len(docs.docs) > 100:
return fastapi.responses.PlainTextResponse(
"number of docs too large", status_code=501
)
texts = []
for text in docs.docs:
texts.append(" " if text == "" else text)
# Use model's encode_passages logic
embedding_tensor = model_wrapper.encode_passages(texts)
embedding = embedding_tensor.cpu().numpy()
rsp = dict(
docs=docs.docs,
embedding=base64.b64encode(embedding.tobytes()).decode("ascii"),
shape=embedding.shape,
)
return JSONResponse(rsp)
@app.post("/embed")
def embed(docs: Doc):
# This endpoint seems to imply generic embedding similar to docs
if len(docs.docs) > 100:
return fastapi.responses.PlainTextResponse(
"number of texts too large", status_code=501
)
texts = []
for text in docs.docs:
texts.append(" " if text == "" else text)
embedding_tensor = model_wrapper.encode_passages(texts)
embedding = embedding_tensor.cpu().numpy()
# Check for NaNs just in case, as in original code
if np.isnan(embedding).any():
for idx in range(len(texts)):
if np.isnan(embedding[idx]).any():
print(f"nan vec doc: {[texts[idx]]}")
rsp = dict(
docs=docs.docs,
embedding=base64.b64encode(embedding.tobytes()).decode("ascii"),
shape=embedding.shape,
)
return JSONResponse(rsp)
@app.post("/embed_texts")
def embed_texts(inputs: InputText):
if len(inputs.texts) > 100:
return fastapi.responses.PlainTextResponse(
"number of texts too large", status_code=501
)
texts = []
for text in inputs.texts:
texts.append(" " if text == "" else text)
# We use the provided instruction as a prefix
instruction = inputs.instruction
full_texts = [f"{instruction}{text}" for text in texts]
# Run encode directly
embedding_tensor = model_wrapper.encode(full_texts, instruction)
embedding = embedding_tensor.cpu().numpy()
rsp = dict(
texts=inputs.texts,
embedding=base64.b64encode(embedding.tobytes()).decode("ascii"),
shape=embedding.shape,
)
return JSONResponse(rsp)
@app.get("/model_id")
def model_id():
return args.checkpoint
if __name__ == "__main__":
uvicorn.run(app, host=args.host, port=args.port)然后,您可以使用以下命令运行服务器,指定下载的模型检查点的路径:
python embedding_server.py --checkpoint ./Youtu-Embedding --port 8501API 端点
运行后,以下端点可用:
| 端点 | 方法 | 描述 |
|---|---|---|
/embed_query | POST | 嵌入单个查询 |
/embed_docs | POST | 嵌入多个文档 |
/embed | POST | 通用嵌入端点 |
/embed_texts | POST | 使用自定义指令嵌入文本 |
/model_id | GET | 获取模型检查点路径 |
/health | GET | 健康检查端点 |
使用示例
嵌入查询:
curl -X POST http://localhost:8501/embed_query \
-H "Content-Type: application/json" \
-d '{"query": "What is machine learning?"}'嵌入多个文档:
curl -X POST http://localhost:8501/embed_docs \
-H "Content-Type: application/json" \
-d '{"docs": ["Document 1 text", "Document 2 text"]}'