vectorizer.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. import logging
  2. from typing import List
  3. import numpy as np
  4. import requests
  5. from requests.adapters import HTTPAdapter
  6. from urllib3.util.retry import Retry
  7. logger = logging.getLogger(__name__)
  8. class Vectorizer:
  9. _instance = None
  10. def __init__(self):
  11. self.api_url = "http://172.16.8.98:11434/api/embeddings"
  12. self.model_name = "bge-m3"
  13. self.session = requests.Session()
  14. retries = Retry(total=3, backoff_factor=1)
  15. self.session.mount('http://', HTTPAdapter(max_retries=retries))
  16. @staticmethod
  17. def get_embedding(text: str) -> List[float]:
  18. return Vectorizer()._call_ollama_api(text)
  19. def _initialize_model(self):
  20. logger.info("Initialized Ollama API client for bge-m3 model")
  21. @classmethod
  22. def get_instance(cls):
  23. if cls._instance is None:
  24. cls._instance = cls()
  25. return cls._instance
  26. def chunk_text(self, text: str, chunk_size: int = 500) -> List[str]:
  27. tokens = self.tokenizer.tokenize(text)
  28. return [self.tokenizer.convert_tokens_to_string(tokens[i:i+chunk_size])
  29. for i in range(0, len(tokens), chunk_size)]
  30. def encode(self, text: str, batch_size: int = 32) -> List[float]:
  31. if not text:
  32. return [0.0] * 1024 # 调整为1024维以匹配数据库
  33. return self._call_ollama_api(text)
  34. def batch_encode(self, texts: List[str], batch_size: int = 64) -> List[List[float]]:
  35. return [self._call_ollama_api(text) for text in texts]
  36. def _call_ollama_api(self, text: str) -> List[float]:
  37. try:
  38. response = self.session.post(
  39. self.api_url,
  40. json={
  41. "model": self.model_name,
  42. "prompt": text,
  43. "options": {"embedding_only": True}
  44. },
  45. timeout=30
  46. )
  47. response.raise_for_status()
  48. embedding_array = np.array(response.json()["embedding"])
  49. l2_norm = np.linalg.norm(embedding_array)
  50. if l2_norm > 0:
  51. embedding_array = embedding_array / l2_norm
  52. return embedding_array.tolist()
  53. except Exception as e:
  54. logger.error(f"API请求失败: {str(e)}")
  55. raise
  56. if __name__ == "__main__":
  57. text ='腹痛'
  58. embedding = Vectorizer.get_embedding(text)
  59. print(embedding)
  60. print(len(embedding))