Browse Source

接口修改

yuchengwei 1 month ago
parent
commit
f7f59c655b
4 changed files with 27 additions and 33 deletions
  1. 12 4
      main.py
  2. 1 0
      model/trunks_model.py
  3. 12 29
      router/text_search.py
  4. 2 0
      service/trunks_service.py

+ 12 - 4
main.py

@@ -7,13 +7,14 @@ from typing import Optional, Set
 import os
 import os
 import uvicorn
 import uvicorn
 from fastapi.staticfiles import StaticFiles
 from fastapi.staticfiles import StaticFiles
+from fastapi.middleware.cors import CORSMiddleware
 
 
-from agent.cdss.capbility import CDSSCapability
+# from agent.cdss.capbility import CDSSCapability
 from router.knowledge_dify import dify_kb_router
 from router.knowledge_dify import dify_kb_router
 from router.knowledge_saas import saas_kb_router
 from router.knowledge_saas import saas_kb_router
 from router.text_search import text_search_router
 from router.text_search import text_search_router
 from router.graph_router import graph_router
 from router.graph_router import graph_router
-from router.knowledge_nodes_api import knowledge_nodes_api_router
+# from router.knowledge_nodes_api import knowledge_nodes_api_router
 
 
 # 配置日志
 # 配置日志
 logging.basicConfig(
 logging.basicConfig(
@@ -33,11 +34,18 @@ app.include_router(dify_kb_router)
 app.include_router(saas_kb_router)
 app.include_router(saas_kb_router)
 app.include_router(text_search_router)
 app.include_router(text_search_router)
 app.include_router(graph_router)
 app.include_router(graph_router)
-app.include_router(knowledge_nodes_api_router)
+# app.include_router(knowledge_nodes_api_router)
 
 
 # 挂载静态文件目录,将/books路径映射到本地books文件夹
 # 挂载静态文件目录,将/books路径映射到本地books文件夹
 app.mount("/books", StaticFiles(directory="books"), name="books")
 app.mount("/books", StaticFiles(directory="books"), name="books")
-
+# 允许所有来源(仅用于测试,生产环境应限制)
+app.add_middleware(
+    CORSMiddleware,
+    allow_origins=["*"],  # 允许所有来源(或指定 ["http://localhost:3000"])
+    allow_credentials=True,  # 允许携带 Cookie
+    allow_methods=["*"],  # 允许所有方法(或指定 ["GET", "POST"])
+    allow_headers=["*"],  # 允许所有请求头
+)
 
 
 # 需要拦截的 URL 列表(支持通配符)
 # 需要拦截的 URL 列表(支持通配符)
 INTERCEPT_URLS = {
 INTERCEPT_URLS = {

+ 1 - 0
model/trunks_model.py

@@ -15,6 +15,7 @@ class Trunks(Base):
     title = Column(String(255))
     title = Column(String(255))
     referrence = Column(String(255))
     referrence = Column(String(255))
     meta_header = Column(String(255))
     meta_header = Column(String(255))
+    page_no = Column(Integer) 
 
 
     def __repr__(self):
     def __repr__(self):
         return f"<Trunks(id={self.id}, file_path={self.file_path})>"
         return f"<Trunks(id={self.id}, file_path={self.file_path})>"

+ 12 - 29
router/text_search.py

@@ -6,6 +6,8 @@ from utils.text_splitter import TextSplitter
 from utils.vector_distance import VectorDistance
 from utils.vector_distance import VectorDistance
 from model.response import StandardResponse
 from model.response import StandardResponse
 from utils.vectorizer import Vectorizer
 from utils.vectorizer import Vectorizer
+# from utils.find_text_in_pdf import find_text_in_pdf
+import os
 DISTANCE_THRESHOLD = 0.8
 DISTANCE_THRESHOLD = 0.8
 import logging
 import logging
 import time
 import time
@@ -58,7 +60,6 @@ class TextCompareMultiRequest(BaseModel):
 class NodePropsSearchRequest(BaseModel):
 class NodePropsSearchRequest(BaseModel):
     node_id: int
     node_id: int
     props_ids: List[int]
     props_ids: List[int]
-    conversation_id: Optional[str] = None
 
 
 @router.post("/search", response_model=StandardResponse)
 @router.post("/search", response_model=StandardResponse)
 async def search_text(request: TextSearchRequest):
 async def search_text(request: TextSearchRequest):
@@ -138,6 +139,8 @@ async def search_text(request: TextSearchRequest):
                         "distance": distance,
                         "distance": distance,
                         "referrence": search_result.get("referrence", "")
                         "referrence": search_result.get("referrence", "")
                     }
                     }
+                    
+                    
                     all_references.append(reference)
                     all_references.append(reference)
                     reference_index += 1
                     reference_index += 1
                 
                 
@@ -337,9 +340,6 @@ async def node_props_search(request: NodePropsSearchRequest, db: Session = Depen
             "distance": 0
             "distance": 0
         }
         }
 
 
-        # 缓存结果
-        cached_results = trunks_service.get_cached_result(request.conversation_id) if request.conversation_id else []
-
         # 遍历props_ids查询属性信息
         # 遍历props_ids查询属性信息
         for prop_id in request.props_ids:
         for prop_id in request.props_ids:
             prop = prop_service.get_props_by_id(prop_id)
             prop = prop_service.get_props_by_id(prop_id)
@@ -374,11 +374,11 @@ async def node_props_search(request: NodePropsSearchRequest, db: Session = Depen
             i = 0
             i = 0
             while i < len(sentences):
             while i < len(sentences):
                 original_sentence = sentences[i]
                 original_sentence = sentences[i]
-                sentence = original_sentence.replace("\n", "<br>")
+                sentence = original_sentence
                 
                 
                 # 如果当前句子长度小于10且不是最后一句,则与下一句合并
                 # 如果当前句子长度小于10且不是最后一句,则与下一句合并
                 if len(sentence) < 10 and i + 1 < len(sentences):
                 if len(sentence) < 10 and i + 1 < len(sentences):
-                    next_sentence = sentences[i + 1].replace("\n", "<br>")
+                    next_sentence = sentences[i + 1]
                     combined_sentence = sentence + " " + next_sentence
                     combined_sentence = sentence + " " + next_sentence
                     # 添加原短句到结果,flag为空
                     # 添加原短句到结果,flag为空
                     result_sentences.append({
                     result_sentences.append({
@@ -402,31 +402,12 @@ async def node_props_search(request: NodePropsSearchRequest, db: Session = Depen
                 
                 
                 i += 1
                 i += 1
 
 
-                # 检查缓存
-                if cached_results:
-                    min_distance = float('inf')
-                    best_result = None
-                    sentence_vector = Vectorizer.get_embedding(search_text)
-
-                    for cached_result in cached_results:
-                        content_vector = cached_result['embedding']
-                        distance = VectorDistance.calculate_distance(sentence_vector, content_vector)
-                        if distance < min_distance:
-                            min_distance = distance
-                            best_result = {**cached_result, 'distance': distance}
-
-                    if best_result and best_result['distance'] < DISTANCE_THRESHOLD:
-                        search_results = [best_result]
-                    else:
-                        search_results = []
-                else:
-                    # 进行向量搜索
-                    search_results = trunks_service.search_by_vector(
+                # 进行向量搜索
+                search_results = trunks_service.search_by_vector(
                         text=search_text,
                         text=search_text,
                         limit=1,
                         limit=1,
-                        type='trunk',
-                        conversation_id=request.conversation_id
-                    )
+                        type='trunk'
+                )
 
 
                 # 处理搜索结果
                 # 处理搜索结果
                 if not search_results:
                 if not search_results:
@@ -461,8 +442,10 @@ async def node_props_search(request: NodePropsSearchRequest, db: Session = Depen
                             "file_path": search_result.get("file_path", ""),
                             "file_path": search_result.get("file_path", ""),
                             "title": search_result.get("title", ""),
                             "title": search_result.get("title", ""),
                             "distance": distance,
                             "distance": distance,
+                            "page_no": search_result.get("page_no", ""),
                             "referrence": search_result.get("referrence", "")
                             "referrence": search_result.get("referrence", "")
                         }
                         }
+                        
                         all_references.append(reference)
                         all_references.append(reference)
                         reference_index += 1
                         reference_index += 1
 
 

+ 2 - 0
service/trunks_service.py

@@ -73,6 +73,7 @@ class TrunksService:
                 Trunks.embedding.l2_distance(embedding).label('distance'),
                 Trunks.embedding.l2_distance(embedding).label('distance'),
                 Trunks.title,
                 Trunks.title,
                 Trunks.embedding,
                 Trunks.embedding,
+                Trunks.page_no,
                 Trunks.referrence
                 Trunks.referrence
             )
             )
             if metadata_condition:
             if metadata_condition:
@@ -88,6 +89,7 @@ class TrunksService:
                 'distance': round(r.distance, 3),
                 'distance': round(r.distance, 3),
                 'title': r.title,
                 'title': r.title,
                 'embedding': r.embedding.tolist(),
                 'embedding': r.embedding.tolist(),
+                'page_no': r.page_no,
                 'referrence': r.referrence
                 'referrence': r.referrence
             } for r in results]
             } for r in results]