Youtu-RAG
Youtu Embedding

本地部署

以下文档提供将 Youtu Embedding 部署为 Youtu-RAG 后端服务的说明。

下载 Youtu Embedding 模型权重

官方仓库下载预训练模型权重:

git lfs install
git clone https://huggingface.co/tencent/Youtu-Embedding
# 如果您希望使用 BF16 权重以减少内存使用,也可以克隆 BF16 分支:
# git clone -b bfloat16 --single-branch https://huggingface.co/tencent/Youtu-Embedding

安装服务器依赖项

接下来,安装运行 Youtu Embedding 服务器所需的依赖项:

pip install transformers==4.51.3 torch numpy scipy scikit-learn huggingface_hub fastapi uvicorn

运行 Youtu Embedding 服务器

将以下代码保存为 embedding_server.py

import fastapi
from fastapi.responses import JSONResponse
import uvicorn
from argparse import ArgumentParser
from transformers import AutoModel, AutoTokenizer
from typing import List
from pydantic import BaseModel
import torch
import base64
import numpy as np

class LLMEmbeddingModel():

    def __init__(self, 
                model_name_or_path, 
                batch_size=128, 
                max_length=1024, 
                gpu_id=0):
        """Local embedding model with automatic device selection"""
        self.model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side="right", trust_remote_code=True)

        # Device selection: CUDA -> MPS -> CPU
        if torch.cuda.is_available():
            self.device = torch.device(f"cuda:{gpu_id}")
        elif torch.backends.mps.is_available():
            self.device = torch.device("mps")
        else:
            self.device = torch.device("cpu")
        
        self.model.to(self.device).eval()

        self.max_length = max_length
        self.batch_size = batch_size

        query_instruction = "Given a search query, retrieve passages that answer the question"
        if query_instruction:
            self.query_instruction = f"Instruction: {query_instruction} \nQuery:"
        else:
            self.query_instruction = "Query:"

        self.doc_instruction = ""
        print(f"Model loaded: {model_name_or_path}")
        print(f"Device: {self.device}")

    def mean_pooling(self, hidden_state, attention_mask):
        s = torch.sum(hidden_state * attention_mask.unsqueeze(-1).float(), dim=1)
        d = attention_mask.sum(dim=1, keepdim=True).float()
        embedding = s / d
        return embedding

    @torch.no_grad()
    def encode(self, sentences_batch, instruction):
        inputs = self.tokenizer(
            sentences_batch,
            padding=True,
            truncation=True,
            return_tensors="pt",
            max_length=self.max_length,
            add_special_tokens=True,
        )
        # Move inputs to device
        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = self.model(**inputs)
            last_hidden_state = outputs[0]

            instruction_tokens = self.tokenizer(
                instruction,
                padding=False,
                truncation=True,
                max_length=self.max_length,
                add_special_tokens=True,
            )["input_ids"]
            if len(np.shape(np.array(instruction_tokens))) == 1:
                if len(instruction_tokens) > 0:
                     inputs["attention_mask"][:, :len(instruction_tokens)] = 0
            else:
                instruction_length = [len(item) for item in instruction_tokens]
                # assert len(instruction) == len(sentences_batch) # instruction passed is string
                for idx in range(len(instruction_length)):
                    inputs["attention_mask"][idx, :instruction_length[idx]] = 0

            embeddings = self.mean_pooling(last_hidden_state, inputs["attention_mask"])
            embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
        return embeddings

    def encode_queries(self, queries):
        queries = queries if isinstance(queries, list) else [queries]
        queries = [f"{self.query_instruction}{query}" for query in queries]
        return self.encode(queries, self.query_instruction)

    def encode_passages(self, passages):
        passages = passages if isinstance(passages, list) else [passages]
        passages = [f"{self.doc_instruction}{passage}" for passage in passages]
        return self.encode(passages, self.doc_instruction)

# --- Server Logic ---

def parse_args():
    parser = ArgumentParser()
    parser.add_argument(
        "--checkpoint",
        default="./Youtu-Embedding"
    )
    parser.add_argument("--max_length", default=1024, type=int)
    parser.add_argument("--port", default=8081, type=int)
    parser.add_argument("--host", default="0.0.0.0")
    args = parser.parse_args()
    return args

args = parse_args()
print(args)

# Initialize global model
# Note: In a real production setup, one might delay loading or handle it differently, 
# but for this script we load at startup.
try:
    model_wrapper = LLMEmbeddingModel(
        model_name_or_path=args.checkpoint,
        max_length=args.max_length
    )
except Exception as e:
    print(f"Error loading model from {args.checkpoint}: {e}")
    print("Please ensure the checkpoint path is correct and contains the model files.")
    # We allow the app to start but API calls might fail if model_wrapper is not defined
    # But usually better to crash if model is missing.
    # Re-raising.
    raise e

app = fastapi.FastAPI()

class Query(BaseModel):
    query: str

class Doc(BaseModel):
    docs: List[str]

class InputText(BaseModel):
    texts: List[str]
    instruction: str = ""

@app.post("/embed_query")
def embed_query(query: Query):
    if query.query == "":
        text = " "
    else:
        text = query.query

    # Use model's encode_queries logic
    embedding_tensor = model_wrapper.encode_queries([text])
    embedding = embedding_tensor.cpu().numpy()
    
    rsp = {
        "query": query.query,
        "embedding": base64.b64encode(embedding.tobytes()).decode("ascii"),
        "shape": embedding.shape,
    }
    return JSONResponse(rsp)

@app.post("/embed_docs")
def embed_doc(docs: Doc):
    if len(docs.docs) > 100:
        return fastapi.responses.PlainTextResponse(
            "number of docs too large", status_code=501
        )

    texts = []
    for text in docs.docs:
        texts.append(" " if text == "" else text)

    # Use model's encode_passages logic
    embedding_tensor = model_wrapper.encode_passages(texts)
    embedding = embedding_tensor.cpu().numpy()

    rsp = dict(
        docs=docs.docs,
        embedding=base64.b64encode(embedding.tobytes()).decode("ascii"),
        shape=embedding.shape,
    )
    return JSONResponse(rsp)

@app.post("/embed")
def embed(docs: Doc):
    # This endpoint seems to imply generic embedding similar to docs
    if len(docs.docs) > 100:
         return fastapi.responses.PlainTextResponse(
            "number of texts too large", status_code=501
        )
    
    texts = []
    for text in docs.docs:
        texts.append(" " if text == "" else text)
        
    embedding_tensor = model_wrapper.encode_passages(texts)
    embedding = embedding_tensor.cpu().numpy()
    
    # Check for NaNs just in case, as in original code
    if np.isnan(embedding).any():
        for idx in range(len(texts)):
            if np.isnan(embedding[idx]).any():
                print(f"nan vec doc: {[texts[idx]]}")

    rsp = dict(
        docs=docs.docs,
        embedding=base64.b64encode(embedding.tobytes()).decode("ascii"),
        shape=embedding.shape,
    )
    return JSONResponse(rsp)

@app.post("/embed_texts")
def embed_texts(inputs: InputText):
    if len(inputs.texts) > 100:
        return fastapi.responses.PlainTextResponse(
            "number of texts too large", status_code=501
        )

    texts = []
    for text in inputs.texts:
        texts.append(" " if text == "" else text)

    # We use the provided instruction as a prefix
    instruction = inputs.instruction
    full_texts = [f"{instruction}{text}" for text in texts]
    
    # Run encode directly
    embedding_tensor = model_wrapper.encode(full_texts, instruction)
    embedding = embedding_tensor.cpu().numpy()

    rsp = dict(
        texts=inputs.texts,
        embedding=base64.b64encode(embedding.tobytes()).decode("ascii"),
        shape=embedding.shape,
    )
    return JSONResponse(rsp)

@app.get("/model_id")
def model_id():
    return args.checkpoint

if __name__ == "__main__":
    uvicorn.run(app, host=args.host, port=args.port)

然后,您可以使用以下命令运行服务器,指定下载的模型检查点的路径:

python embedding_server.py --checkpoint ./Youtu-Embedding --port 8501

API 端点

运行后,以下端点可用:

端点方法描述
/embed_queryPOST嵌入单个查询
/embed_docsPOST嵌入多个文档
/embedPOST通用嵌入端点
/embed_textsPOST使用自定义指令嵌入文本
/model_idGET获取模型检查点路径
/healthGET健康检查端点

使用示例

嵌入查询:

curl -X POST http://localhost:8501/embed_query \
    -H "Content-Type: application/json" \
    -d '{"query": "What is machine learning?"}'

嵌入多个文档:

curl -X POST http://localhost:8501/embed_docs \
    -H "Content-Type: application/json" \
    -d '{"docs": ["Document 1 text", "Document 2 text"]}'

On this page