浏览代码

代码提交

SGTY 1 月之前
父节点
当前提交
2df823cb65
共有 3 个文件被更改,包括 351 次插入48 次删除
  1. 344 0
      app.log
  2. 3 3
      service/kg_node_service.py
  3. 4 45
      utils/vectorizer.py

文件差异内容过多而无法显示
+ 344 - 0
app.log


+ 3 - 3
service/kg_node_service.py

@@ -25,7 +25,7 @@ class KGNodeService:
         if cache_key in self._cache:
             return self._cache[cache_key]
 
-        query_embedding = Vectorizer.get_embedding(title)
+        query_embedding = Vectorizer.get_instance().get_embedding(title)
         db = next(get_db())
         # 执行向量搜索
         results = (
@@ -69,7 +69,7 @@ class KGNodeService:
         if limit < 1:
             limit = 10
 
-        embedding = Vectorizer.get_embedding(keyword)
+        embedding = Vectorizer.get_instance().get_embedding(keyword)
         offset = (page_no - 1) * limit
 
         try:
@@ -209,7 +209,7 @@ class KGNodeService:
                 updated_nodes = []
                 for node in nodes:
                     if not node.embedding:
-                        embedding = Vectorizer.get_embedding(node.name)
+                        embedding = Vectorizer.get_instance().get_embedding(node.name)
                         node.embedding = embedding
                         updated_nodes.append(node)
                 if updated_nodes:

+ 4 - 45
utils/vectorizer.py

@@ -14,18 +14,10 @@ class Vectorizer:
     _instance = None
     
     def __init__(self):
-        self.api_url = "http://172.16.8.98:11434/api/embeddings"
-        self.model_name = "bge-m3"
-        self.session = requests.Session()
-        retries = Retry(total=3, backoff_factor=1)
-        self.session.mount('http://', HTTPAdapter(max_retries=retries))
+        self.embedHelper = EmbedHelper()
 
-    @staticmethod
-    def get_embedding(text: str) -> List[float]:
-        return Vectorizer()._call_ollama_api(text)
-
-    def _initialize_model(self):
-        logger.info("Initialized Ollama API client for bge-m3 model")
+    def get_embedding(self, text: str) -> List[float]:
+       return self.embedHelper.embed_text(text)
 
     @classmethod
     def get_instance(cls):
@@ -38,43 +30,10 @@ class Vectorizer:
         return [self.tokenizer.convert_tokens_to_string(tokens[i:i+chunk_size]) 
                for i in range(0, len(tokens), chunk_size)]
 
-    def encode(self, text: str, batch_size: int = 32) -> List[float]:
-        if not text:
-            return [0.0] * 1024  # 调整为1024维以匹配数据库
-        return self._call_ollama_api(text)
-
-    def batch_encode(self, texts: List[str], batch_size: int = 64) -> List[List[float]]:
-        return [self._call_ollama_api(text) for text in texts]
-
-    def _call_ollama_api(self, text: str) -> List[float]:
-        try:
-            response = self.session.post(
-                self.api_url,
-                json={
-                    "model": self.model_name,
-                    "prompt": text,
-                    "options": {"embedding_only": True}
-                },
-                timeout=30
-            )
-            response.raise_for_status()
-            embedding_array = np.array(response.json()["embedding"])
-            l2_norm = np.linalg.norm(embedding_array)
-            if l2_norm > 0:
-                embedding_array = embedding_array / l2_norm
-            return embedding_array.tolist()
-        except Exception as e:
-            logger.error(f"API请求失败: {str(e)}")
-            raise
 
 if __name__ == '__main__':
 
     text = '你好'
     print(text)
-    
-    embedding1 = EmbedHelper().embed_text(text)
-    print(len(embedding1))
 
-    embedding2 = Vectorizer.get_embedding(text)
-    print(len(embedding2))
-    print(VectorDistance().calculate_distance(embedding1, embedding2))
+    embedding2 = Vectorizer.get_instance().get_embedding(text)