vectorizer.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import logging
  2. from typing import List
  3. import numpy as np
  4. import requests
  5. from requests.adapters import HTTPAdapter
  6. from urllib3.util.retry import Retry
  7. from utils.embed_helper import EmbedHelper
  8. from utils.vector_distance import VectorDistance
  9. logger = logging.getLogger(__name__)
  10. class Vectorizer:
  11. _instance = None
  12. def __init__(self):
  13. self.api_url = "http://172.16.8.98:11434/api/embeddings"
  14. self.model_name = "bge-m3"
  15. self.session = requests.Session()
  16. retries = Retry(total=3, backoff_factor=1)
  17. self.session.mount('http://', HTTPAdapter(max_retries=retries))
  18. @staticmethod
  19. def get_embedding(text: str) -> List[float]:
  20. return Vectorizer()._call_ollama_api(text)
  21. def _initialize_model(self):
  22. logger.info("Initialized Ollama API client for bge-m3 model")
  23. @classmethod
  24. def get_instance(cls):
  25. if cls._instance is None:
  26. cls._instance = cls()
  27. return cls._instance
  28. def chunk_text(self, text: str, chunk_size: int = 500) -> List[str]:
  29. tokens = self.tokenizer.tokenize(text)
  30. return [self.tokenizer.convert_tokens_to_string(tokens[i:i+chunk_size])
  31. for i in range(0, len(tokens), chunk_size)]
  32. def encode(self, text: str, batch_size: int = 32) -> List[float]:
  33. if not text:
  34. return [0.0] * 1024 # 调整为1024维以匹配数据库
  35. return self._call_ollama_api(text)
  36. def batch_encode(self, texts: List[str], batch_size: int = 64) -> List[List[float]]:
  37. return [self._call_ollama_api(text) for text in texts]
  38. def _call_ollama_api(self, text: str) -> List[float]:
  39. try:
  40. response = self.session.post(
  41. self.api_url,
  42. json={
  43. "model": self.model_name,
  44. "prompt": text,
  45. "options": {"embedding_only": True}
  46. },
  47. timeout=30
  48. )
  49. response.raise_for_status()
  50. embedding_array = np.array(response.json()["embedding"])
  51. l2_norm = np.linalg.norm(embedding_array)
  52. if l2_norm > 0:
  53. embedding_array = embedding_array / l2_norm
  54. return embedding_array.tolist()
  55. except Exception as e:
  56. logger.error(f"API请求失败: {str(e)}")
  57. raise
  58. if __name__ == '__main__':
  59. text = '你好'
  60. print(text)
  61. embedding1 = EmbedHelper().embed_text(text)
  62. print(len(embedding1))
  63. embedding2 = Vectorizer.get_embedding(text)
  64. print(len(embedding2))
  65. print(VectorDistance().calculate_distance(embedding1, embedding2))