SGTY 2 月之前
父節點
當前提交
789506d36e

+ 2 - 1
model/trunks_model.py

@@ -7,10 +7,11 @@ class Trunks(Base):
     __tablename__ = 'trunks'
 
     id = Column(Integer, primary_key=True, index=True)
-    embedding = Column(Vector(1024))  # 假设使用OpenAI的1536维向量
+    embedding = Column(Vector(1024))
     content = Column(Text)
     file_path = Column(String(255))
     content_tsvector = Column(TSVECTOR)
+    type = Column(String(255))
 
     def __repr__(self):
         return f"<Trunks(id={self.id}, file_path={self.file_path})>"

+ 1 - 1
router/knowledge_dify.py

@@ -93,7 +93,7 @@ async def dify_retrieval(
         logger.info(f"Starting retrieval for knowledge base {payload.knowledge_id} with query: {payload.query}")
         
         trunks_service = TrunksService()
-        search_results = trunks_service.search_by_vector(payload.query, payload.retrieval_setting.top_k, payload.metadata_condition if payload.metadata_condition else None)
+        search_results = trunks_service.search_by_vector(payload.query, payload.retrieval_setting.top_k)
         
         if not search_results:
             logger.warning(f"No results found for query: {request.query}")

+ 36 - 1
router/knowledge_saas.py

@@ -32,7 +32,12 @@ class NodeUpdateRequest(BaseModel):
     status: Optional[int] = None
     embedding: Optional[List[float]] = None
 
-@router.post("/paginated_search", response_model=StandardResponse)
+class VectorSearchRequest(BaseModel):
+    text: str
+    limit: int = 10
+    type: Optional[str] = None
+
+@router.post("/nodes/paginated_search", response_model=StandardResponse)
 async def paginated_search(
     payload: PaginatedSearchRequest,
     db: Session = Depends(get_db)
@@ -118,4 +123,34 @@ async def delete_node(
         logger.error(f"删除节点失败: {str(e)}")
         raise HTTPException(500, detail=StandardResponse.error(str(e)))
 
+@router.post('/trunks/vector_search', response_model=StandardResponse)
+async def vector_search(
+    payload: VectorSearchRequest,
+    db: Session = Depends(get_db)
+):
+    try:
+        service = TrunksService()
+        result = service.search_by_vector(
+            payload.text,
+            payload.limit,
+            {'type': payload.type} if payload.type else None
+        )
+        return StandardResponse(success=True, data=result)
+    except Exception as e:
+        logger.error(f"向量搜索失败: {str(e)}")
+        raise HTTPException(500, detail=StandardResponse.error(str(e)))
+
+@router.get('/trunks/{trunk_id}', response_model=StandardResponse)
+async def get_trunk(
+    trunk_id: int,
+    db: Session = Depends(get_db)
+):
+    try:
+        service = TrunksService()
+        result = service.get_trunk_by_id(trunk_id)
+        return StandardResponse(success=True, data=result)
+    except Exception as e:
+        logger.error(f"获取trunk详情失败: {str(e)}")
+        raise HTTPException(500, detail=StandardResponse.error(str(e)))
+
 saas_kb_router = router

+ 2 - 2
service/kg_node_service.py

@@ -32,7 +32,7 @@ class KGNodeService:
         offset = (page_no - 1) * limit
 
         try:
-            total_count = self.db.query(func.count(KGNode.id)).scalar()
+            total_count = self.db.query(func.count(KGNode.id)).filter(KGNode.version.in_(search_params['knowledge_ids'])).scalar() if search_params.get('knowledge_ids') else self.db.query(func.count(KGNode.id)).scalar()
 
             query = self.db.query(
                 KGNode.id,
@@ -50,7 +50,7 @@ class KGNodeService:
                     'name': r.name,
                     'category': r.category,
                     'props': prop_service.get_props_by_ref_id(r.id) if load_props else [],
-                    'edges':edge_service.get_edges_by_nodes(r.id, r.id,False) if load_props else [],
+                    #'edges':edge_service.get_edges_by_nodes(r.id, r.id,False) if load_props else [],
                     'distance': r.distance
                 } for r in results],
                 'pagination': {

+ 20 - 5
service/trunks_service.py

@@ -20,6 +20,8 @@ class TrunksService:
         if 'embedding' in trunk_data and len(trunk_data['embedding']) != 1024:
             raise ValueError("向量维度必须为1024")
         trunk_data['embedding'] = Vectorizer.get_embedding(content)
+        if 'type' not in trunk_data:
+            trunk_data['type'] = 'default'
         print("embedding length:", len(trunk_data['embedding']))
         logger.debug(f"生成的embedding长度: {len(trunk_data['embedding'])}, 内容摘要: {content[:20]}")
         # trunk_data['content_tsvector'] = func.to_tsvector('chinese', content)
@@ -39,14 +41,23 @@ class TrunksService:
         finally:
             db.close()
 
-    def get_trunk_by_id(self, trunk_id: int) -> Optional[Trunks]:
+    def get_trunk_by_id(self, trunk_id: int) -> Optional[dict]:
         db = SessionLocal()
         try:
-            return db.query(Trunks).filter(Trunks.id == trunk_id).first()
+            trunk = db.query(Trunks).filter(Trunks.id == trunk_id).first()
+            if trunk:
+                return {
+                    'id': trunk.id,
+                    'file_path': trunk.file_path,
+                    'content': trunk.content,
+                    'embedding': trunk.embedding.tolist(),
+                    'type': trunk.type
+                }
+            return None
         finally:
             db.close()
 
-    def search_by_vector(self, text: str, limit: int = 10, metadata_condition: Optional[dict] = None) -> List[dict]:
+    def search_by_vector(self, text: str, limit: int = 10, metadata_condition: Optional[dict] = None, type: Optional[str] = None) -> List[dict]:
         embedding = Vectorizer.get_embedding(text)
         db = SessionLocal()
         try:
@@ -58,6 +69,8 @@ class TrunksService:
             )
             if metadata_condition:
                 query = query.filter_by(**metadata_condition)
+            if type:
+                query = query.filter(Trunks.type == type)
             results = query.order_by('distance').limit(limit).all()
             return [{
                 'id': r.id,
@@ -81,6 +94,8 @@ class TrunksService:
         if 'content' in update_data:
             content = update_data['content']
             update_data['embedding'] = Vectorizer.get_embedding(content)
+            if 'type' not in update_data:
+                update_data['type'] = 'default'
             logger.debug(f"更新生成的embedding长度: {len(update_data['embedding'])}, 内容摘要: {content[:20]}")
             # update_data['content_tsvector'] = func.to_tsvector('chinese', content)
         
@@ -137,7 +152,7 @@ class TrunksService:
         db = SessionLocal()
         try:
             # 获取总条数
-            total_count = db.query(func.count(Trunks.id)).scalar()
+            total_count = db.query(func.count(Trunks.id)).filter(Trunks.type == search_params.get('type')).scalar()
             
             # 执行向量搜索
             results = db.query(
@@ -145,7 +160,7 @@ class TrunksService:
                 Trunks.file_path,
                 Trunks.content,
                 Trunks.embedding.l2_distance(embedding).label('distance')
-            ).order_by('distance').offset(offset).limit(limit).all()
+            ).filter(Trunks.type == search_params.get('type')).order_by('distance').offset(offset).limit(limit).all()
             
             return {
                 'data': [{

文件差異過大導致無法顯示
+ 4 - 2
tests/service/test_trunks_service.py