|
@@ -14,18 +14,10 @@ class Vectorizer:
|
|
|
_instance = None
|
|
|
|
|
|
def __init__(self):
|
|
|
- self.api_url = "http://172.16.8.98:11434/api/embeddings"
|
|
|
- self.model_name = "bge-m3"
|
|
|
- self.session = requests.Session()
|
|
|
- retries = Retry(total=3, backoff_factor=1)
|
|
|
- self.session.mount('http://', HTTPAdapter(max_retries=retries))
|
|
|
+ self.embedHelper = EmbedHelper()
|
|
|
|
|
|
- @staticmethod
|
|
|
- def get_embedding(text: str) -> List[float]:
|
|
|
- return Vectorizer()._call_ollama_api(text)
|
|
|
-
|
|
|
- def _initialize_model(self):
|
|
|
- logger.info("Initialized Ollama API client for bge-m3 model")
|
|
|
+ def get_embedding(self, text: str) -> List[float]:
|
|
|
+ return self.embedHelper.embed_text(text)
|
|
|
|
|
|
@classmethod
|
|
|
def get_instance(cls):
|
|
@@ -38,43 +30,10 @@ class Vectorizer:
|
|
|
return [self.tokenizer.convert_tokens_to_string(tokens[i:i+chunk_size])
|
|
|
for i in range(0, len(tokens), chunk_size)]
|
|
|
|
|
|
- def encode(self, text: str, batch_size: int = 32) -> List[float]:
|
|
|
- if not text:
|
|
|
- return [0.0] * 1024 # 调整为1024维以匹配数据库
|
|
|
- return self._call_ollama_api(text)
|
|
|
-
|
|
|
- def batch_encode(self, texts: List[str], batch_size: int = 64) -> List[List[float]]:
|
|
|
- return [self._call_ollama_api(text) for text in texts]
|
|
|
-
|
|
|
- def _call_ollama_api(self, text: str) -> List[float]:
|
|
|
- try:
|
|
|
- response = self.session.post(
|
|
|
- self.api_url,
|
|
|
- json={
|
|
|
- "model": self.model_name,
|
|
|
- "prompt": text,
|
|
|
- "options": {"embedding_only": True}
|
|
|
- },
|
|
|
- timeout=30
|
|
|
- )
|
|
|
- response.raise_for_status()
|
|
|
- embedding_array = np.array(response.json()["embedding"])
|
|
|
- l2_norm = np.linalg.norm(embedding_array)
|
|
|
- if l2_norm > 0:
|
|
|
- embedding_array = embedding_array / l2_norm
|
|
|
- return embedding_array.tolist()
|
|
|
- except Exception as e:
|
|
|
- logger.error(f"API请求失败: {str(e)}")
|
|
|
- raise
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
|
text = '你好'
|
|
|
print(text)
|
|
|
-
|
|
|
- embedding1 = EmbedHelper().embed_text(text)
|
|
|
- print(len(embedding1))
|
|
|
|
|
|
- embedding2 = Vectorizer.get_embedding(text)
|
|
|
- print(len(embedding2))
|
|
|
- print(VectorDistance().calculate_distance(embedding1, embedding2))
|
|
|
+ embedding2 = Vectorizer.get_instance().get_embedding(text)
|