Browse Source

代码提交

SGTY 2 months ago
parent
commit
d262f94510
6 changed files with 59 additions and 13 deletions
  1. 7 1
      community/community_report.py
  2. 1 1
      main.py
  3. 1 0
      model/trunks_model.py
  4. 20 1
      router/knowledge_dify.py
  5. 26 7
      service/trunks_service.py
  6. 4 3
      utils/file_reader.py

+ 7 - 1
community/community_report.py

@@ -121,6 +121,7 @@ def generate_report(G, partition):
             'name': node[0],
             **node[1]
         })
+       
 
 
     print("generate_report community structure finished")
@@ -138,7 +139,12 @@ def generate_report(G, partition):
 
 
         com_report.append("\n**成员节点**:")
+        member_names = ''
+        member_count = 0
         for member in members:
+            if member_count < 8:
+                member_names += member['name'] + '_'
+                member_count += 1
             com_report.append(f"- {member['name']} ({member['type']})")
             if REPORT_INCLUDE_DETAILS == False:
                 continue
@@ -166,7 +172,7 @@ def generate_report(G, partition):
         if density < DENSITY:
             com_report.append("**社区内部连接相对稀疏**\n")
         else:
-            with open(f"{REPORT_PATH}\community_{comm_id}.md", "w", encoding="utf-8") as f:
+            with open(f"{REPORT_PATH}\社区_{member_names}{comm_id}.md", "w", encoding="utf-8") as f:
                 f.write("\n".join(com_report))
         print(f"社区 {comm_id+1} 报告文件大小:{len(''.join(com_report).encode('utf-8'))} 字节")  # 添加文件生成验证
     

+ 1 - 1
main.py

@@ -30,5 +30,5 @@ app.include_router(text_search_router)
 if __name__ == "__main__":
     logger.info('Starting uvicorn server...2222')
     #uvicorn main:app --host 0.0.0.0 --port 8000 --reload
-    uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
+    uvicorn.run("main:app", host="0.0.0.0", port=8001, reload=True)
 

+ 1 - 0
model/trunks_model.py

@@ -12,6 +12,7 @@ class Trunks(Base):
     file_path = Column(String(255))
     content_tsvector = Column(TSVECTOR)
     type = Column(String(255))
+    title = Column(String(255))
 
     def __repr__(self):
         return f"<Trunks(id={self.id}, file_path={self.file_path})>"

+ 20 - 1
router/knowledge_dify.py

@@ -84,7 +84,8 @@ async def dify_retrieval(
     payload: DifyRetrievalRequest,
     request: Request,
     authorization: str = Depends(verify_api_key),
-    db: Session = Depends(get_db)
+    db: Session = Depends(get_db),
+    conversation_id: Optional[str] = None
 ):
     logger.info(f"All headers: {dict(request.headers)}")
     logger.info(f"Request body: {payload.model_dump()}")
@@ -133,5 +134,23 @@ async def dify_retrieval(
             )
         )
 
+@router.post("/chatflow_retrieval", response_model=StandardResponse)
+async def dify_chatflow_retrieval(
+    knowledge_id: str,
+    query: str,
+    top_k: int,
+    score_threshold: float,
+    conversation_id: str,
+    request: Request,
+    authorization: str = Depends(verify_api_key),
+    db: Session = Depends(get_db)
+):
+    payload = DifyRetrievalRequest(
+        knowledge_id=knowledge_id,
+        query=query,
+        retrieval_setting=RetrievalSetting(top_k=top_k, score_threshold=score_threshold)
+    )
+    return await dify_retrieval(payload, request, authorization, db)
+
 dify_kb_router = router
 

+ 26 - 7
service/trunks_service.py

@@ -22,6 +22,9 @@ class TrunksService:
         trunk_data['embedding'] = Vectorizer.get_embedding(content)
         if 'type' not in trunk_data:
             trunk_data['type'] = 'default'
+        if 'title' not in trunk_data:
+            from pathlib import Path
+            trunk_data['title'] = Path(trunk_data['file_path']).stem
         print("embedding length:", len(trunk_data['embedding']))
         logger.debug(f"生成的embedding长度: {len(trunk_data['embedding'])}, 内容摘要: {content[:20]}")
         # trunk_data['content_tsvector'] = func.to_tsvector('chinese', content)
@@ -51,13 +54,20 @@ class TrunksService:
                     'file_path': trunk.file_path,
                     'content': trunk.content,
                     'embedding': trunk.embedding.tolist(),
-                    'type': trunk.type
+                    'type': trunk.type,
+                    'title':trunk.title
                 }
             return None
         finally:
             db.close()
 
-    def search_by_vector(self, text: str, limit: int = 10, metadata_condition: Optional[dict] = None, type: Optional[str] = None) -> List[dict]:
+    def search_by_vector(self, text: str, limit: int = 10, metadata_condition: Optional[dict] = None, type: Optional[str] = None, conversation_id: Optional[str] = None) -> List[dict]:
+        if conversation_id:
+            from cache import get_cache, set_cache
+            cached_result = get_cache(conversation_id)
+            if cached_result:
+                return cached_result
+
         embedding = Vectorizer.get_embedding(text)
         db = SessionLocal()
         try:
@@ -65,19 +75,26 @@ class TrunksService:
                 Trunks.id,
                 Trunks.file_path,
                 Trunks.content,
-                Trunks.embedding.l2_distance(embedding).label('distance')
+                Trunks.embedding.l2_distance(embedding).label('distance'),
+                Trunks.title
             )
             if metadata_condition:
                 query = query.filter_by(**metadata_condition)
             if type:
                 query = query.filter(Trunks.type == type)
             results = query.order_by('distance').limit(limit).all()
-            return [{
+            result_list = [{
                 'id': r.id,
                 'file_path': r.file_path,
                 'content': r.content,
-                'distance': r.distance
+                'distance': r.distance,
+                'title': r.title
             } for r in results]
+
+            if conversation_id:
+                set_cache(conversation_id, result_list)
+
+            return result_list
         finally:
             db.close()
 
@@ -159,7 +176,8 @@ class TrunksService:
                 Trunks.id,
                 Trunks.file_path,
                 Trunks.content,
-                Trunks.embedding.l2_distance(embedding).label('distance')
+                Trunks.embedding.l2_distance(embedding).label('distance'),
+                Trunks.title
             ).filter(Trunks.type == search_params.get('type')).order_by('distance').offset(offset).limit(limit).all()
             
             return {
@@ -167,7 +185,8 @@ class TrunksService:
                     'id': r.id,
                     'file_path': r.file_path,
                     'content': r.content,
-                    'distance': r.distance
+                    'distance': r.distance,
+                    'title': r.title
                 } for r in results],
                 'pagination': {
                     'total': total_count,

+ 4 - 3
utils/file_reader.py

@@ -7,13 +7,14 @@ class FileReader:
     def find_and_print_split_files(directory):
         for root, dirs, files in os.walk(directory):
             for file in files:
-                if file.endswith('.md'):
+                if '_split_' in file and file.endswith('.txt'):
+                #if file.endswith('.md'):
                     file_path = os.path.join(root, file)
                     relative_path = '\\report\\' + os.path.relpath(file_path, directory)
                     with open(file_path, 'r', encoding='utf-8') as f:
                         content = f.read()
-                    TrunksService().create_trunk({'file_path': relative_path, 'content': content,'type':'community_report'})
+                    TrunksService().create_trunk({'file_path': relative_path, 'content': content,'type':'trunk'})
 
 if __name__ == '__main__':
-    directory = 'e:\\project\\knowledge\\utils\\report'
+    directory = 'e:\\project\\knowledge\\utils\\files'
     FileReader.find_and_print_split_files(directory)