|
@@ -1,4 +1,4 @@
|
|
-from fastapi import APIRouter, HTTPException
|
|
|
|
|
|
+from fastapi import APIRouter, HTTPException, Depends
|
|
from pydantic import BaseModel, Field, validator
|
|
from pydantic import BaseModel, Field, validator
|
|
from typing import List, Optional
|
|
from typing import List, Optional
|
|
from service.trunks_service import TrunksService
|
|
from service.trunks_service import TrunksService
|
|
@@ -6,9 +6,15 @@ from utils.text_splitter import TextSplitter
|
|
from utils.vector_distance import VectorDistance
|
|
from utils.vector_distance import VectorDistance
|
|
from model.response import StandardResponse
|
|
from model.response import StandardResponse
|
|
from utils.vectorizer import Vectorizer
|
|
from utils.vectorizer import Vectorizer
|
|
|
|
+# from utils.find_text_in_pdf import find_text_in_pdf
|
|
|
|
+import os
|
|
DISTANCE_THRESHOLD = 0.8
|
|
DISTANCE_THRESHOLD = 0.8
|
|
import logging
|
|
import logging
|
|
import time
|
|
import time
|
|
|
|
+from db.session import get_db
|
|
|
|
+from sqlalchemy.orm import Session
|
|
|
|
+from service.kg_node_service import KGNodeService
|
|
|
|
+from service.kg_prop_service import KGPropService
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger = logging.getLogger(__name__)
|
|
router = APIRouter(prefix="/text", tags=["Text Search"])
|
|
router = APIRouter(prefix="/text", tags=["Text Search"])
|
|
@@ -51,6 +57,10 @@ class TextCompareMultiRequest(BaseModel):
|
|
origin: str
|
|
origin: str
|
|
similar: str
|
|
similar: str
|
|
|
|
|
|
|
|
+class NodePropsSearchRequest(BaseModel):
|
|
|
|
+ node_id: int
|
|
|
|
+ props_ids: List[int]
|
|
|
|
+
|
|
@router.post("/search", response_model=StandardResponse)
|
|
@router.post("/search", response_model=StandardResponse)
|
|
async def search_text(request: TextSearchRequest):
|
|
async def search_text(request: TextSearchRequest):
|
|
try:
|
|
try:
|
|
@@ -107,14 +117,14 @@ async def search_text(request: TextSearchRequest):
|
|
)
|
|
)
|
|
|
|
|
|
# 处理搜索结果
|
|
# 处理搜索结果
|
|
- for result in search_results:
|
|
|
|
- distance = result.get("distance", DISTANCE_THRESHOLD)
|
|
|
|
|
|
+ for search_result in search_results:
|
|
|
|
+ distance = search_result.get("distance", DISTANCE_THRESHOLD)
|
|
if distance >= DISTANCE_THRESHOLD:
|
|
if distance >= DISTANCE_THRESHOLD:
|
|
result_sentences.append(sentence)
|
|
result_sentences.append(sentence)
|
|
continue
|
|
continue
|
|
|
|
|
|
# 检查是否已存在相同引用
|
|
# 检查是否已存在相同引用
|
|
- existing_ref = next((ref for ref in all_references if ref["id"] == result["id"]), None)
|
|
|
|
|
|
+ existing_ref = next((ref for ref in all_references if ref["id"] == search_result["id"]), None)
|
|
current_index = reference_index
|
|
current_index = reference_index
|
|
if existing_ref:
|
|
if existing_ref:
|
|
current_index = int(existing_ref["index"])
|
|
current_index = int(existing_ref["index"])
|
|
@@ -122,13 +132,15 @@ async def search_text(request: TextSearchRequest):
|
|
# 添加到引用列表
|
|
# 添加到引用列表
|
|
reference = {
|
|
reference = {
|
|
"index": str(reference_index),
|
|
"index": str(reference_index),
|
|
- "id": result["id"],
|
|
|
|
- "content": result["content"],
|
|
|
|
- "file_path": result.get("file_path", ""),
|
|
|
|
- "title": result.get("title", ""),
|
|
|
|
|
|
+ "id": search_result["id"],
|
|
|
|
+ "content": search_result["content"],
|
|
|
|
+ "file_path": search_result.get("file_path", ""),
|
|
|
|
+ "title": search_result.get("title", ""),
|
|
"distance": distance,
|
|
"distance": distance,
|
|
- "referrence": result.get("referrence", "")
|
|
|
|
|
|
+ "referrence": search_result.get("referrence", "")
|
|
}
|
|
}
|
|
|
|
+
|
|
|
|
+
|
|
all_references.append(reference)
|
|
all_references.append(reference)
|
|
reference_index += 1
|
|
reference_index += 1
|
|
|
|
|
|
@@ -302,4 +314,191 @@ async def compare_text(request: TextCompareMultiRequest):
|
|
logger.info(f"mr_match接口耗时: {(end_time - start_time) * 1000:.2f}ms")
|
|
logger.info(f"mr_match接口耗时: {(end_time - start_time) * 1000:.2f}ms")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
|
|
+@router.post("/eb_search", response_model=StandardResponse)
|
|
|
|
+async def node_props_search(request: NodePropsSearchRequest, db: Session = Depends(get_db)):
|
|
|
|
+ try:
|
|
|
|
+ start_time = time.time()
|
|
|
|
+ # 初始化服务
|
|
|
|
+ trunks_service = TrunksService()
|
|
|
|
+ node_service = KGNodeService(db)
|
|
|
|
+ prop_service = KGPropService(db)
|
|
|
|
+
|
|
|
|
+ # 根据node_id查询节点信息
|
|
|
|
+ node = node_service.get_node(request.node_id)
|
|
|
|
+ if not node:
|
|
|
|
+ raise ValueError(f"节点不存在: {request.node_id}")
|
|
|
|
+
|
|
|
|
+ node_name = node.get('name', '')
|
|
|
|
+
|
|
|
|
+ # 初始化结果
|
|
|
|
+ result = {
|
|
|
|
+ "id": request.node_id,
|
|
|
|
+ "name": node_name,
|
|
|
|
+ "category": node.get('category', ''),
|
|
|
|
+ "props": [],
|
|
|
|
+ "files": [],
|
|
|
|
+ "distance": 0
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ # 遍历props_ids查询属性信息
|
|
|
|
+ for prop_id in request.props_ids:
|
|
|
|
+ prop = prop_service.get_props_by_id(prop_id)
|
|
|
|
+
|
|
|
|
+ if not prop:
|
|
|
|
+ logger.warning(f"属性不存在: {prop_id}")
|
|
|
|
+ continue
|
|
|
|
+
|
|
|
|
+ prop_title = prop.get('prop_title', '')
|
|
|
|
+ prop_value = prop.get('prop_value', '')
|
|
|
|
+
|
|
|
|
+ # 拆分属性值为句子
|
|
|
|
+ sentences = TextSplitter.split_text(prop_value)
|
|
|
|
+ prop_result = {
|
|
|
|
+ "id": prop_id,
|
|
|
|
+ "category": prop.get('category', 0),
|
|
|
|
+ "prop_name": prop.get('prop_name', ''),
|
|
|
|
+ "prop_value": prop_value,
|
|
|
|
+ "prop_title": prop_title,
|
|
|
|
+ "type": prop.get('type', 1)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ # 添加到结果中
|
|
|
|
+ result["props"].append(prop_result)
|
|
|
|
+
|
|
|
|
+ # 处理属性值中的句子
|
|
|
|
+ result_sentences = []
|
|
|
|
+ all_references = []
|
|
|
|
+ reference_index = 1
|
|
|
|
+
|
|
|
|
+ # 对每个句子进行向量搜索
|
|
|
|
+ i = 0
|
|
|
|
+ while i < len(sentences):
|
|
|
|
+ original_sentence = sentences[i]
|
|
|
|
+ sentence = original_sentence
|
|
|
|
+
|
|
|
|
+ # 如果当前句子长度小于10且不是最后一句,则与下一句合并
|
|
|
|
+ if len(sentence) < 10 and i + 1 < len(sentences):
|
|
|
|
+ next_sentence = sentences[i + 1]
|
|
|
|
+ combined_sentence = sentence + " " + next_sentence
|
|
|
|
+ # 添加原短句到结果,flag为空
|
|
|
|
+ result_sentences.append({
|
|
|
|
+ "sentence": sentence,
|
|
|
|
+ "flag": ""
|
|
|
|
+ })
|
|
|
|
+ # 使用合并后的句子进行搜索
|
|
|
|
+ search_text = f"{node_name}:{prop_title}:{combined_sentence}"
|
|
|
|
+ i += 1 # 跳过下一句,因为已经合并使用
|
|
|
|
+ elif len(sentence) < 10:
|
|
|
|
+ # 如果是最后一句且长度小于10,直接添加到结果,flag为空
|
|
|
|
+ result_sentences.append({
|
|
|
|
+ "sentence": sentence,
|
|
|
|
+ "flag": ""
|
|
|
|
+ })
|
|
|
|
+ i += 1
|
|
|
|
+ continue
|
|
|
|
+ else:
|
|
|
|
+ # 句子长度足够,直接使用
|
|
|
|
+ search_text = f"{node_name}:{prop_title}:{sentence}"
|
|
|
|
+
|
|
|
|
+ i += 1
|
|
|
|
+
|
|
|
|
+ # 进行向量搜索
|
|
|
|
+ search_results = trunks_service.search_by_vector(
|
|
|
|
+ text=search_text,
|
|
|
|
+ limit=1,
|
|
|
|
+ type='trunk'
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ # 处理搜索结果
|
|
|
|
+ if not search_results:
|
|
|
|
+ # 没有搜索结果,添加原句子,flag为空
|
|
|
|
+ result_sentences.append({
|
|
|
|
+ "sentence": sentence,
|
|
|
|
+ "flag": ""
|
|
|
|
+ })
|
|
|
|
+ continue
|
|
|
|
+
|
|
|
|
+ for search_result in search_results:
|
|
|
|
+ distance = search_result.get("distance", DISTANCE_THRESHOLD)
|
|
|
|
+ if distance >= DISTANCE_THRESHOLD:
|
|
|
|
+ # 距离过大,添加原句子,flag为空
|
|
|
|
+ result_sentences.append({
|
|
|
|
+ "sentence": sentence,
|
|
|
|
+ "flag": ""
|
|
|
|
+ })
|
|
|
|
+ continue
|
|
|
|
+
|
|
|
|
+ # 检查是否已存在相同引用
|
|
|
|
+ existing_ref = next((ref for ref in all_references if ref["id"] == search_result["id"]), None)
|
|
|
|
+ current_index = reference_index
|
|
|
|
+ if existing_ref:
|
|
|
|
+ current_index = int(existing_ref["index"])
|
|
|
|
+ else:
|
|
|
|
+ # 添加到引用列表
|
|
|
|
+ reference = {
|
|
|
|
+ "index": str(reference_index),
|
|
|
|
+ "id": search_result["id"],
|
|
|
|
+ "content": search_result["content"],
|
|
|
|
+ "file_path": search_result.get("file_path", ""),
|
|
|
|
+ "title": search_result.get("title", ""),
|
|
|
|
+ "distance": distance,
|
|
|
|
+ "page_no": search_result.get("page_no", ""),
|
|
|
|
+ "referrence": search_result.get("referrence", "")
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ all_references.append(reference)
|
|
|
|
+ reference_index += 1
|
|
|
|
+
|
|
|
|
+ # 添加句子和引用标记(作为单独的flag字段)
|
|
|
|
+ result_sentences.append({
|
|
|
|
+ "sentence": sentence,
|
|
|
|
+ "flag": str(current_index)
|
|
|
|
+ })
|
|
|
|
+
|
|
|
|
+ # 更新属性值,添加引用信息
|
|
|
|
+ if all_references:
|
|
|
|
+ prop_result["references"] = all_references
|
|
|
|
+
|
|
|
|
+ # 将处理后的句子添加到结果中
|
|
|
|
+ if result_sentences:
|
|
|
|
+ prop_result["answer"] = result_sentences
|
|
|
|
+
|
|
|
|
+ # 处理所有引用中的文件信息
|
|
|
|
+ all_files = set()
|
|
|
|
+ for prop_result in result["props"]:
|
|
|
|
+ if "references" in prop_result:
|
|
|
|
+ for ref in prop_result["references"]:
|
|
|
|
+ referrence = ref.get("referrence", "")
|
|
|
|
+ if referrence and "/books/" in referrence:
|
|
|
|
+ # 提取/books/后面的文件名
|
|
|
|
+ file_name = referrence.split("/books/")[-1]
|
|
|
|
+ if file_name:
|
|
|
|
+ # 根据文件名后缀确定文件类型
|
|
|
|
+ file_type = ""
|
|
|
|
+ if file_name.lower().endswith(".pdf"):
|
|
|
|
+ file_type = "pdf"
|
|
|
|
+ elif file_name.lower().endswith(".doc") or file_name.lower().endswith(".docx"):
|
|
|
|
+ file_type = "doc"
|
|
|
|
+ elif file_name.lower().endswith(".xls") or file_name.lower().endswith(".xlsx"):
|
|
|
|
+ file_type = "excel"
|
|
|
|
+ elif file_name.lower().endswith(".ppt") or file_name.lower().endswith(".pptx"):
|
|
|
|
+ file_type = "ppt"
|
|
|
|
+ else:
|
|
|
|
+ file_type = "other"
|
|
|
|
+
|
|
|
|
+ all_files.add((file_name, file_type))
|
|
|
|
+
|
|
|
|
+ # 将文件信息添加到结果中
|
|
|
|
+ result["files"] = [{
|
|
|
|
+ "file_name": file_name,
|
|
|
|
+ "file_type": file_type
|
|
|
|
+ } for file_name, file_type in all_files]
|
|
|
|
+
|
|
|
|
+ end_time = time.time()
|
|
|
|
+ logger.info(f"node_props_search接口耗时: {(end_time - start_time) * 1000:.2f}ms")
|
|
|
|
+ return StandardResponse(success=True, data=result)
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"Node props search failed: {str(e)}")
|
|
|
|
+ raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
+
|
|
text_search_router = router
|
|
text_search_router = router
|