|
@@ -1,4 +1,4 @@
|
|
|
-from fastapi import APIRouter, HTTPException
|
|
|
+from fastapi import APIRouter, HTTPException, Depends
|
|
|
from pydantic import BaseModel, Field, validator
|
|
|
from typing import List, Optional
|
|
|
from service.trunks_service import TrunksService
|
|
@@ -9,6 +9,10 @@ from utils.vectorizer import Vectorizer
|
|
|
DISTANCE_THRESHOLD = 0.8
|
|
|
import logging
|
|
|
import time
|
|
|
+from db.session import get_db
|
|
|
+from sqlalchemy.orm import Session
|
|
|
+from service.kg_node_service import KGNodeService
|
|
|
+from service.kg_prop_service import KGPropService
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
router = APIRouter(prefix="/text", tags=["Text Search"])
|
|
@@ -51,6 +55,11 @@ class TextCompareMultiRequest(BaseModel):
|
|
|
origin: str
|
|
|
similar: str
|
|
|
|
|
|
+class NodePropsSearchRequest(BaseModel):
|
|
|
+ node_id: int
|
|
|
+ props_ids: List[int]
|
|
|
+ conversation_id: Optional[str] = None
|
|
|
+
|
|
|
@router.post("/search", response_model=StandardResponse)
|
|
|
async def search_text(request: TextSearchRequest):
|
|
|
try:
|
|
@@ -107,14 +116,14 @@ async def search_text(request: TextSearchRequest):
|
|
|
)
|
|
|
|
|
|
# 处理搜索结果
|
|
|
- for result in search_results:
|
|
|
- distance = result.get("distance", DISTANCE_THRESHOLD)
|
|
|
+ for search_result in search_results:
|
|
|
+ distance = search_result.get("distance", DISTANCE_THRESHOLD)
|
|
|
if distance >= DISTANCE_THRESHOLD:
|
|
|
result_sentences.append(sentence)
|
|
|
continue
|
|
|
|
|
|
# 检查是否已存在相同引用
|
|
|
- existing_ref = next((ref for ref in all_references if ref["id"] == result["id"]), None)
|
|
|
+ existing_ref = next((ref for ref in all_references if ref["id"] == search_result["id"]), None)
|
|
|
current_index = reference_index
|
|
|
if existing_ref:
|
|
|
current_index = int(existing_ref["index"])
|
|
@@ -122,12 +131,12 @@ async def search_text(request: TextSearchRequest):
|
|
|
# 添加到引用列表
|
|
|
reference = {
|
|
|
"index": str(reference_index),
|
|
|
- "id": result["id"],
|
|
|
- "content": result["content"],
|
|
|
- "file_path": result.get("file_path", ""),
|
|
|
- "title": result.get("title", ""),
|
|
|
+ "id": search_result["id"],
|
|
|
+ "content": search_result["content"],
|
|
|
+ "file_path": search_result.get("file_path", ""),
|
|
|
+ "title": search_result.get("title", ""),
|
|
|
"distance": distance,
|
|
|
- "referrence": result.get("referrence", "")
|
|
|
+ "referrence": search_result.get("referrence", "")
|
|
|
}
|
|
|
all_references.append(reference)
|
|
|
reference_index += 1
|
|
@@ -302,4 +311,150 @@ async def compare_text(request: TextCompareMultiRequest):
|
|
|
logger.info(f"mr_match接口耗时: {(end_time - start_time) * 1000:.2f}ms")
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
+@router.post("/eb_search", response_model=StandardResponse)
|
|
|
+async def node_props_search(request: NodePropsSearchRequest, db: Session = Depends(get_db)):
|
|
|
+ try:
|
|
|
+ start_time = time.time()
|
|
|
+ # 初始化服务
|
|
|
+ trunks_service = TrunksService()
|
|
|
+ node_service = KGNodeService(db)
|
|
|
+ prop_service = KGPropService(db)
|
|
|
+
|
|
|
+ # 根据node_id查询节点信息
|
|
|
+ node = node_service.get_node(request.node_id)
|
|
|
+ if not node:
|
|
|
+ raise ValueError(f"节点不存在: {request.node_id}")
|
|
|
+
|
|
|
+ node_name = node.get('name', '')
|
|
|
+
|
|
|
+ # 初始化结果
|
|
|
+ result = {
|
|
|
+ "id": request.node_id,
|
|
|
+ "name": node_name,
|
|
|
+ "category": node.get('category', ''),
|
|
|
+ "props": [],
|
|
|
+ "distance": 0
|
|
|
+ }
|
|
|
+
|
|
|
+ # 缓存结果
|
|
|
+ cached_results = trunks_service.get_cached_result(request.conversation_id) if request.conversation_id else []
|
|
|
+
|
|
|
+ # 遍历props_ids查询属性信息
|
|
|
+ for prop_id in request.props_ids:
|
|
|
+ prop = prop_service.get_props_by_id(prop_id)
|
|
|
+
|
|
|
+ if not prop:
|
|
|
+ logger.warning(f"属性不存在: {prop_id}")
|
|
|
+ continue
|
|
|
+
|
|
|
+ prop_title = prop.get('prop_title', '')
|
|
|
+ prop_value = prop.get('prop_value', '')
|
|
|
+
|
|
|
+ # 拆分属性值为句子
|
|
|
+ sentences = TextSplitter.split_text(prop_value)
|
|
|
+ prop_result = {
|
|
|
+ "id": prop_id,
|
|
|
+ "category": prop.get('category', 0),
|
|
|
+ "prop_name": prop.get('prop_name', ''),
|
|
|
+ "prop_value": prop_value,
|
|
|
+ "prop_title": prop_title,
|
|
|
+ "type": prop.get('type', 1)
|
|
|
+ }
|
|
|
+
|
|
|
+ # 添加到结果中
|
|
|
+ result["props"].append(prop_result)
|
|
|
+
|
|
|
+ # 处理属性值中的句子
|
|
|
+ result_sentences = []
|
|
|
+ all_references = []
|
|
|
+ reference_index = 1
|
|
|
+
|
|
|
+ # 对每个句子进行向量搜索
|
|
|
+ for sentence in sentences:
|
|
|
+ original_sentence = sentence
|
|
|
+ sentence = sentence.replace("\n", "<br>")
|
|
|
+ if len(sentence) < 10:
|
|
|
+ result_sentences.append(sentence)
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 构建搜索文本
|
|
|
+ search_text = f"{node_name}:{prop_title}:{sentence}"
|
|
|
+
|
|
|
+ # 检查缓存
|
|
|
+ if cached_results:
|
|
|
+ min_distance = float('inf')
|
|
|
+ best_result = None
|
|
|
+ sentence_vector = Vectorizer.get_embedding(search_text)
|
|
|
+
|
|
|
+ for cached_result in cached_results:
|
|
|
+ content_vector = cached_result['embedding']
|
|
|
+ distance = VectorDistance.calculate_distance(sentence_vector, content_vector)
|
|
|
+ if distance < min_distance:
|
|
|
+ min_distance = distance
|
|
|
+ best_result = {**cached_result, 'distance': distance}
|
|
|
+
|
|
|
+ if best_result and best_result['distance'] < DISTANCE_THRESHOLD:
|
|
|
+ search_results = [best_result]
|
|
|
+ else:
|
|
|
+ search_results = []
|
|
|
+ else:
|
|
|
+ # 进行向量搜索
|
|
|
+ search_results = trunks_service.search_by_vector(
|
|
|
+ text=search_text,
|
|
|
+ limit=1,
|
|
|
+ type='trunk',
|
|
|
+ conversation_id=request.conversation_id
|
|
|
+ )
|
|
|
+
|
|
|
+ # 处理搜索结果
|
|
|
+ for search_result in search_results:
|
|
|
+ distance = search_result.get("distance", DISTANCE_THRESHOLD)
|
|
|
+ if distance >= DISTANCE_THRESHOLD:
|
|
|
+ result_sentences.append(sentence)
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 检查是否已存在相同引用
|
|
|
+ existing_ref = next((ref for ref in all_references if ref["id"] == search_result["id"]), None)
|
|
|
+ current_index = reference_index
|
|
|
+ if existing_ref:
|
|
|
+ current_index = int(existing_ref["index"])
|
|
|
+ else:
|
|
|
+ # 添加到引用列表
|
|
|
+ reference = {
|
|
|
+ "index": str(reference_index),
|
|
|
+ "id": search_result["id"],
|
|
|
+ "content": search_result["content"],
|
|
|
+ "file_path": search_result.get("file_path", ""),
|
|
|
+ "title": search_result.get("title", ""),
|
|
|
+ "distance": distance,
|
|
|
+ "referrence": search_result.get("referrence", "")
|
|
|
+ }
|
|
|
+ all_references.append(reference)
|
|
|
+ reference_index += 1
|
|
|
+
|
|
|
+ # 添加引用标记
|
|
|
+ if sentence.endswith('<br>'):
|
|
|
+ # 如果有多个<br>,在所有<br>前添加^[current_index]^
|
|
|
+ result_sentence = sentence.replace('<br>', f'^[{current_index}]^<br>')
|
|
|
+ else:
|
|
|
+ # 直接在句子末尾添加^[current_index]^
|
|
|
+ result_sentence = f'{sentence}^[{current_index}]^'
|
|
|
+
|
|
|
+ result_sentences.append(result_sentence)
|
|
|
+
|
|
|
+ # 更新属性值,添加引用信息
|
|
|
+ if all_references:
|
|
|
+ prop_result["references"] = all_references
|
|
|
+
|
|
|
+ # 将处理后的句子添加到结果中
|
|
|
+ if result_sentences:
|
|
|
+ prop_result["answer"] = result_sentences
|
|
|
+
|
|
|
+ end_time = time.time()
|
|
|
+ logger.info(f"node_props_search接口耗时: {(end_time - start_time) * 1000:.2f}ms")
|
|
|
+ return StandardResponse(success=True, data=result)
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Node props search failed: {str(e)}")
|
|
|
+ raise HTTPException(status_code=500, detail=str(e))
|
|
|
+
|
|
|
text_search_router = router
|