ããã¯ããªã«ãããããŠæžãããã®ïŒ
ããã¹ãåã蟌ã¿ãè¡ãã«ã¯Sentence Transformersã䜿ãã®ãããã®ããªãšæã£ãŠããã®ã§ãããã§ããã°åäœã§åäœãããµãŒããŒãšããŠ
䜿ããããªãšã
ãããããããšãããšLocalAIã䜿ãã®ã1çªè¿ãæ°ãããã®ã§ãããæºåã«ããªãæéãããããŸãã
Embeddings / Huggingface embeddings
ããããããã£ãã®ããšç°¡åãªAPIãµãŒããŒãèªåã§äœã£ããããããªãšããããšã§ãäœãããšã«ããŸããã
Sentence Transformersã®ã€ã³ã¹ããŒã«ã«ã¯æéããããã®ã§ãããããããã§ããŠããŸãã°ããã¹ãåã蟌ã¿ãåããã®ã«ããã»ã©
倧éã®ãªãœãŒã¹ã¯èŠããªãã®ã§ã
FastAPIã§äœã
ãé¡ãšããŠã¯ãSentence Transformersã®æ©èœã䜿ã£ãããã¹ãåã蟌ã¿ãè¡ããREST APIãã§ãã
FastAPIã§äœãã®ãããããªãšã
ç°¡åã«ãã¹ããŸã§è¡ãããšã«ããŸããã
ç°å¢
ä»åã®ç°å¢ã¯ãã¡ãã
$ python3 --version Python 3.10.12 $ pip3 --version pip 22.0.2 from /usr/lib/python3/dist-packages/pip (python 3.10)
FastAPIã§Sentence Transformersã䜿ã£ãããã¹ãåã蟌ã¿APIãäœã
ãŸãã¯ã©ã€ãã©ãªãŒã®ã€ã³ã¹ããŒã«ãASGIãµãŒããŒã¯Uvicornã䜿ãããšã«ããŸãã
$ pip3 install sentence-transformers fastapi uvicorn[standard]
ãã¹ãåãã®ã©ã€ãã©ãªãŒãã€ã³ã¹ããŒã«ã
$ pip3 install pytest httpx
ã€ã³ã¹ããŒã«ããã©ã€ãã©ãªãŒã®äžèŠ§ã¯ãã¡ãã
$ pip3 list Package Version ------------------------ ---------- annotated-types 0.6.0 anyio 4.3.0 certifi 2024.2.2 charset-normalizer 3.3.2 click 8.1.7 exceptiongroup 1.2.1 fastapi 0.110.2 filelock 3.13.4 fsspec 2024.3.1 h11 0.14.0 httpcore 1.0.5 httptools 0.6.1 httpx 0.27.0 huggingface-hub 0.22.2 idna 3.7 iniconfig 2.0.0 Jinja2 3.1.3 joblib 1.4.0 MarkupSafe 2.1.5 mpmath 1.3.0 networkx 3.3 numpy 1.26.4 nvidia-cublas-cu12 12.1.3.1 nvidia-cuda-cupti-cu12 12.1.105 nvidia-cuda-nvrtc-cu12 12.1.105 nvidia-cuda-runtime-cu12 12.1.105 nvidia-cudnn-cu12 8.9.2.26 nvidia-cufft-cu12 11.0.2.54 nvidia-curand-cu12 10.3.2.106 nvidia-cusolver-cu12 11.4.5.107 nvidia-cusparse-cu12 12.1.0.106 nvidia-nccl-cu12 2.20.5 nvidia-nvjitlink-cu12 12.4.127 nvidia-nvtx-cu12 12.1.105 packaging 24.0 pillow 10.3.0 pip 22.0.2 pluggy 1.5.0 pydantic 2.7.1 pydantic_core 2.18.2 pytest 8.2.0 python-dotenv 1.0.1 PyYAML 6.0.1 regex 2024.4.28 requests 2.31.0 safetensors 0.4.3 scikit-learn 1.4.2 scipy 1.13.0 sentence-transformers 2.7.0 setuptools 59.6.0 sniffio 1.3.1 starlette 0.37.2 sympy 1.12 threadpoolctl 3.4.0 tokenizers 0.19.1 tomli 2.0.1 torch 2.3.0 tqdm 4.66.2 transformers 4.40.1 triton 2.3.0 typing_extensions 4.11.0 urllib3 2.2.1 uvicorn 0.29.0 uvloop 0.19.0 watchfiles 0.21.0 websockets 12.0
äœæãããœãŒã¹ã³ãŒãã¯ãã¡ãã
api.py
from fastapi import FastAPI from pydantic import BaseModel import os from sentence_transformers import SentenceTransformer app = FastAPI() class EmbeddingRequest(BaseModel): model: str text: str class EmbeddingResponse(BaseModel): model: str embedding: list[float] dimention: int @app.post("/embeddings/encode") def encode(request: EmbeddingRequest) -> EmbeddingResponse: sentence_transformer_model = SentenceTransformer( request.model, device=os.getenv("EMBEDDING_API_DEVICE", "cpu") ) embeddings = sentence_transformer_model.encode([request.text]) embedding = embeddings[0] # numpy array to float list embedding_as_float = embedding.tolist() return EmbeddingResponse( model=request.model, embedding=embedding_as_float, dimention=sentence_transformer_model.get_sentence_embedding_dimension() )
ãªã¯ãšã¹ãã«ã¯ããã¹ãåã蟌ã¿ã«äœ¿ãã¢ãã«ãšå¯Ÿè±¡ã®ããã¹ãã
class EmbeddingRequest(BaseModel): model: str text: str
ã¬ã¹ãã³ã¹ã«ã¯ãªã¯ãšã¹ãã§æå®ãããã¢ãã«ãããã¹ãåã蟌ã¿ã®çµæããã¯ãã«ã®æ¬¡å æ°ãè¿ãããšã«ããŸããã
class EmbeddingResponse(BaseModel): model: str embedding: list[float] dimention: int
APIã®å®è£ ã¯ãããªæãã§ããã
@app.post("/embeddings/encode") def encode(request: EmbeddingRequest) -> EmbeddingResponse: sentence_transformer_model = SentenceTransformer( request.model, device=os.getenv("EMBEDDING_API_DEVICE", "cpu") ) embeddings = sentence_transformer_model.encode([request.text]) embedding = embeddings[0] # numpy array to float list embedding_as_float = embedding.tolist() return EmbeddingResponse( model=request.model, embedding=embedding_as_float, dimention=sentence_transformer_model.get_sentence_embedding_dimension() )
ã¢ãã«ã¯ãå®è¡æã«èªåçã«Hugging Face HubããããŠã³ããŒãããŠããŸãã
numpyã®é åããªã¹ãã«å€æããå¿ èŠããã£ããšãããå°ã£ããããã§ããâŠã
èµ·åã
$ uvicorn api:app # ãŸã㯠$ uvicorn api:app --reload
確èªã
$ curl -s -XPOST -H 'Content-Type: application/json' localhost:8000/embeddings/encode -d '{"model": "all-MiniLM-L6-v2", "text": "Hello World"}' | jq { "model": "all-MiniLM-L6-v2", "embedding": [ -0.03447727486491203, 0.03102317824959755, 0.006734995171427727, 0.026108944788575172, -0.039361994713544846, ãçç¥ã 0.03323201462626457, 0.02379228174686432, -0.022889817133545876, 0.03893755003809929, 0.0302068330347538 ], "dimention": 384 }
ããã²ãšã€ãã¢ãã«ãå€æŽããŠç¢ºèªããŠã¿ãŸãããã
$ curl -s -XPOST -H 'Content-Type: application/json' localhost:8000/embeddings/encode -d '{"model": "intfloat/multilingual-e5-base", "text": "query: Hello World"}' | jq { "model": "intfloat/multilingual-e5-base", "embedding": [ 0.03324141725897789, 0.04988044500350952, 0.00241446984000504, 0.011555945500731468, 0.03409387916326523, ãçç¥ã -0.018477996811270714, 0.04818818345665932, -0.04364151135087013, -0.04888230562210083, 0.03604992479085922 ], "dimention": 768 }
OKã§ããã
ããšã¯ãã¹ããæžããŠãããŸãã
test_api.py
from fastapi.testclient import TestClient from api import app, EmbeddingRequest, EmbeddingResponse client = TestClient(app) def test_encode_basic(): request = EmbeddingRequest(model="all-MiniLM-L6-v2", text="Hello World") raw_response = client.post("/embeddings/encode", json=request.model_dump()) assert raw_response.status_code == 200 response = EmbeddingResponse.model_validate(raw_response.json()) assert response.model == "all-MiniLM-L6-v2" assert len(response.embedding) == 384 assert response.dimention == 384 def test_encode_e5(): request = EmbeddingRequest(model="intfloat/multilingual-e5-base", text="passave: Hello World") raw_response = client.post("/embeddings/encode", json=request.model_dump()) assert raw_response.status_code == 200 response = EmbeddingResponse.model_validate(raw_response.json()) assert response.model == "intfloat/multilingual-e5-base" assert len(response.embedding) == 768 assert response.dimention == 768
åèã«ããã®ã¯ãã¡ãã®ããŒãžãš
ãã¡ãã
Pydanticã¯ããŸãèŠãŠããªãã£ãã®ã§ãã¡ãã£ãšæéåããŸããâŠã
確èªã
$ pytest ===================================================================================== test session starts ====================================================================================== platform linux -- Python 3.10.12, pytest-8.2.0, pluggy-1.5.0 rootdir: /path/to plugins: anyio-4.3.0 collected 2 items test_api.py .. [100%] ====================================================================================== 2 passed in 9.16s =======================================================================================
OKã§ããã
ãããã«
FastAPIãšSentence Transformersã䜿ã£ãŠãç°¡åãªããã¹ãåã蟌ã¿APIãäœæããŠã¿ãŸããã
ç¹ã«Python以å€ã§ããã¹ãåã蟌ã¿ããããããšæã£ãæã«ãã©ããã£ãŠããã¹ãåã蟌ã¿ãè¡ããã«ã¡ãã£ãšå°ã£ãŠããã®ã§ããããã£ãŠ
èªåã§äœã£ããã®ã䜿ã£ãŠã¿ãŠãããããªãšã
FastAPIã®ã¡ãã£ãšããå匷ã«ããªããŸããã