1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071 |
- import logging
- from typing import List
- import numpy as np
- import requests
- from requests.adapters import HTTPAdapter
- from urllib3.util.retry import Retry
- logger = logging.getLogger(__name__)
- class Vectorizer:
- _instance = None
-
- def __init__(self):
- self.api_url = "http://172.16.8.98:11434/api/embeddings"
- self.model_name = "bge-m3"
- self.session = requests.Session()
- retries = Retry(total=3, backoff_factor=1)
- self.session.mount('http://', HTTPAdapter(max_retries=retries))
- @staticmethod
- def get_embedding(text: str) -> List[float]:
- return Vectorizer()._call_ollama_api(text)
- def _initialize_model(self):
- logger.info("Initialized Ollama API client for bge-m3 model")
- @classmethod
- def get_instance(cls):
- if cls._instance is None:
- cls._instance = cls()
- return cls._instance
- def chunk_text(self, text: str, chunk_size: int = 500) -> List[str]:
- tokens = self.tokenizer.tokenize(text)
- return [self.tokenizer.convert_tokens_to_string(tokens[i:i+chunk_size])
- for i in range(0, len(tokens), chunk_size)]
- def encode(self, text: str, batch_size: int = 32) -> List[float]:
- if not text:
- return [0.0] * 1024 # 调整为1024维以匹配数据库
- return self._call_ollama_api(text)
- def batch_encode(self, texts: List[str], batch_size: int = 64) -> List[List[float]]:
- return [self._call_ollama_api(text) for text in texts]
- def _call_ollama_api(self, text: str) -> List[float]:
- try:
- response = self.session.post(
- self.api_url,
- json={
- "model": self.model_name,
- "prompt": text,
- "options": {"embedding_only": True}
- },
- timeout=30
- )
- response.raise_for_status()
- embedding_array = np.array(response.json()["embedding"])
- l2_norm = np.linalg.norm(embedding_array)
- if l2_norm > 0:
- embedding_array = embedding_array / l2_norm
- return embedding_array.tolist()
- except Exception as e:
- logger.error(f"API请求失败: {str(e)}")
- raise
- if __name__ == "__main__":
- text ='腹痛'
- embedding = Vectorizer.get_embedding(text)
- print(embedding)
- print(len(embedding))
|