|
@@ -9,6 +9,10 @@ from utils.vectorizer import Vectorizer
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TrunksService:
|
|
|
+ def __init__(self):
|
|
|
+ self.db = next(get_db())
|
|
|
+
|
|
|
+
|
|
|
def create_trunk(self, trunk_data: dict) -> Trunks:
|
|
|
# 自动生成向量和全文检索字段
|
|
|
content = trunk_data.get('content')
|
|
@@ -108,5 +112,53 @@ class TrunksService:
|
|
|
db.rollback()
|
|
|
logger.error(f"删除trunk失败: {str(e)}")
|
|
|
raise
|
|
|
+ finally:
|
|
|
+ db.close()
|
|
|
+
|
|
|
+ def paginated_search(self, search_params: dict) -> dict:
|
|
|
+ """
|
|
|
+ 分页查询方法
|
|
|
+ :param search_params: 包含keyword, pageNo, limit的字典
|
|
|
+ :return: 包含结果列表和分页信息的字典
|
|
|
+ """
|
|
|
+ keyword = search_params.get('keyword', '')
|
|
|
+ page_no = search_params.get('pageNo', 1)
|
|
|
+ limit = search_params.get('limit', 10)
|
|
|
+
|
|
|
+ if page_no < 1:
|
|
|
+ page_no = 1
|
|
|
+ if limit < 1:
|
|
|
+ limit = 10
|
|
|
+
|
|
|
+ embedding = Vectorizer.get_embedding(keyword)
|
|
|
+ offset = (page_no - 1) * limit
|
|
|
+
|
|
|
+ db = SessionLocal()
|
|
|
+ try:
|
|
|
+ # 获取总条数
|
|
|
+ total_count = db.query(func.count(Trunks.id)).scalar()
|
|
|
+
|
|
|
+ # 执行向量搜索
|
|
|
+ results = db.query(
|
|
|
+ Trunks.id,
|
|
|
+ Trunks.file_path,
|
|
|
+ Trunks.content,
|
|
|
+ Trunks.embedding.l2_distance(embedding).label('distance')
|
|
|
+ ).order_by('distance').offset(offset).limit(limit).all()
|
|
|
+
|
|
|
+ return {
|
|
|
+ 'data': [{
|
|
|
+ 'id': r.id,
|
|
|
+ 'file_path': r.file_path,
|
|
|
+ 'content': r.content,
|
|
|
+ 'distance': r.distance
|
|
|
+ } for r in results],
|
|
|
+ 'pagination': {
|
|
|
+ 'total': total_count,
|
|
|
+ 'pageNo': page_no,
|
|
|
+ 'limit': limit,
|
|
|
+ 'totalPages': (total_count + limit - 1) // limit
|
|
|
+ }
|
|
|
+ }
|
|
|
finally:
|
|
|
db.close()
|