|
@@ -0,0 +1,319 @@
|
|
|
|
+from sqlalchemy import func
|
|
|
|
+from sqlalchemy.orm import Session
|
|
|
|
+from ..db.session import get_db
|
|
|
|
+from typing import List, Optional
|
|
|
|
+from ..model.trunks_model import Trunks
|
|
|
|
+from ..db.session import SessionLocal
|
|
|
|
+import logging
|
|
|
|
+
|
|
|
|
+from ..utils.sentence_util import SentenceUtil
|
|
|
|
+from ..utils.vectorizer import Vectorizer
|
|
|
|
+
|
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
|
+
|
|
|
|
+class TrunksService:
|
|
|
|
+ def __init__(self):
|
|
|
|
+ self.db = next(get_db())
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+ def create_trunk(self, trunk_data: dict) -> Trunks:
|
|
|
|
+ # 自动生成向量和全文检索字段
|
|
|
|
+ content = trunk_data.get('content')
|
|
|
|
+ if 'embedding' in trunk_data and len(trunk_data['embedding']) != 1024:
|
|
|
|
+ raise ValueError("向量维度必须为1024")
|
|
|
|
+ trunk_data['embedding'] = Vectorizer.get_embedding(content)
|
|
|
|
+ if 'type' not in trunk_data:
|
|
|
|
+ trunk_data['type'] = 'default'
|
|
|
|
+ if 'title' not in trunk_data:
|
|
|
|
+ from pathlib import Path
|
|
|
|
+ trunk_data['title'] = Path(trunk_data['file_path']).stem
|
|
|
|
+ print("embedding length:", len(trunk_data['embedding']))
|
|
|
|
+ logger.debug(f"生成的embedding长度: {len(trunk_data['embedding'])}, 内容摘要: {content[:20]}")
|
|
|
|
+ # trunk_data['content_tsvector'] = func.to_tsvector('chinese', content)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+ db = SessionLocal()
|
|
|
|
+ try:
|
|
|
|
+ trunk = Trunks(**trunk_data)
|
|
|
|
+ db.add(trunk)
|
|
|
|
+ db.commit()
|
|
|
|
+ db.refresh(trunk)
|
|
|
|
+ return trunk
|
|
|
|
+ except Exception as e:
|
|
|
|
+ db.rollback()
|
|
|
|
+ logger.error(f"创建trunk失败: {str(e)}")
|
|
|
|
+ raise
|
|
|
|
+ finally:
|
|
|
|
+ db.close()
|
|
|
|
+
|
|
|
|
+ def get_trunk_by_id(self, trunk_id: int) -> Optional[dict]:
|
|
|
|
+ db = SessionLocal()
|
|
|
|
+ try:
|
|
|
|
+ trunk = db.query(Trunks).filter(Trunks.id == trunk_id).first()
|
|
|
|
+ if trunk:
|
|
|
|
+ return {
|
|
|
|
+ 'id': trunk.id,
|
|
|
|
+ 'file_path': trunk.file_path,
|
|
|
|
+ 'content': trunk.content,
|
|
|
|
+ 'embedding': trunk.embedding.tolist(),
|
|
|
|
+ 'type': trunk.type,
|
|
|
|
+ 'title':trunk.title
|
|
|
|
+ }
|
|
|
|
+ return None
|
|
|
|
+ finally:
|
|
|
|
+ db.close()
|
|
|
|
+
|
|
|
|
+ def search_by_vector(self, text: str, limit: int = 10, file_path: Optional[str]=None, distance: Optional[float]=None, type: Optional[str] = None, conversation_id: Optional[str] = None) -> List[dict]:
|
|
|
|
+
|
|
|
|
+ embedding = Vectorizer.get_instance().get_embedding(text)
|
|
|
|
+ db = SessionLocal()
|
|
|
|
+ try:
|
|
|
|
+ query = db.query(
|
|
|
|
+ Trunks.id,
|
|
|
|
+ Trunks.file_path,
|
|
|
|
+ Trunks.content,
|
|
|
|
+ Trunks.embedding.l2_distance(embedding).label('distance'),
|
|
|
|
+ Trunks.title,
|
|
|
|
+ Trunks.embedding,
|
|
|
|
+ Trunks.page_no,
|
|
|
|
+ Trunks.referrence,
|
|
|
|
+ Trunks.meta_header
|
|
|
|
+ )
|
|
|
|
+ if distance:
|
|
|
|
+ query = query.filter(Trunks.embedding.l2_distance(embedding) <= distance)
|
|
|
|
+ if type:
|
|
|
|
+ query = query.filter(Trunks.type == type)
|
|
|
|
+ if file_path:
|
|
|
|
+ query = query.filter(Trunks.file_path.like('%'+file_path+'%'))
|
|
|
|
+ results = query.order_by('distance').limit(limit).all()
|
|
|
|
+ result_list = [{
|
|
|
|
+ 'id': r.id,
|
|
|
|
+ 'file_path': r.file_path,
|
|
|
|
+ 'content': r.content,
|
|
|
|
+ #保留小数点后三位
|
|
|
|
+ 'distance': round(r.distance, 3),
|
|
|
|
+ 'title': r.title,
|
|
|
|
+ 'embedding': r.embedding.tolist(),
|
|
|
|
+ 'page_no': r.page_no,
|
|
|
|
+ 'referrence': r.referrence,
|
|
|
|
+ 'meta_header': r.meta_header
|
|
|
|
+ } for r in results]
|
|
|
|
+
|
|
|
|
+ if conversation_id:
|
|
|
|
+ self.set_cache(conversation_id, result_list)
|
|
|
|
+
|
|
|
|
+ return result_list
|
|
|
|
+ finally:
|
|
|
|
+ db.close()
|
|
|
|
+
|
|
|
|
+ def fulltext_search(self, query: str) -> List[Trunks]:
|
|
|
|
+ db = SessionLocal()
|
|
|
|
+ try:
|
|
|
|
+ return db.query(Trunks).filter(
|
|
|
|
+ Trunks.content_tsvector.match(query)
|
|
|
|
+ ).all()
|
|
|
|
+ finally:
|
|
|
|
+ db.close()
|
|
|
|
+
|
|
|
|
+ def update_trunk(self, trunk_id: int, update_data: dict) -> Optional[Trunks]:
|
|
|
|
+ if 'content' in update_data:
|
|
|
|
+ content = update_data['content']
|
|
|
|
+ update_data['embedding'] = Vectorizer.get_embedding(content)
|
|
|
|
+ if 'type' not in update_data:
|
|
|
|
+ update_data['type'] = 'default'
|
|
|
|
+ logger.debug(f"更新生成的embedding长度: {len(update_data['embedding'])}, 内容摘要: {content[:20]}")
|
|
|
|
+ # update_data['content_tsvector'] = func.to_tsvector('chinese', content)
|
|
|
|
+
|
|
|
|
+ db = SessionLocal()
|
|
|
|
+ try:
|
|
|
|
+ trunk = db.query(Trunks).filter(Trunks.id == trunk_id).first()
|
|
|
|
+ if trunk:
|
|
|
|
+ for key, value in update_data.items():
|
|
|
|
+ setattr(trunk, key, value)
|
|
|
|
+ db.commit()
|
|
|
|
+ db.refresh(trunk)
|
|
|
|
+ return trunk
|
|
|
|
+ except Exception as e:
|
|
|
|
+ db.rollback()
|
|
|
|
+ logger.error(f"更新trunk失败: {str(e)}")
|
|
|
|
+ raise
|
|
|
|
+ finally:
|
|
|
|
+ db.close()
|
|
|
|
+
|
|
|
|
+ def delete_trunk(self, trunk_id: int) -> bool:
|
|
|
|
+ db = SessionLocal()
|
|
|
|
+ try:
|
|
|
|
+ trunk = db.query(Trunks).filter(Trunks.id == trunk_id).first()
|
|
|
|
+ if trunk:
|
|
|
|
+ db.delete(trunk)
|
|
|
|
+ db.commit()
|
|
|
|
+ return True
|
|
|
|
+ return False
|
|
|
|
+ except Exception as e:
|
|
|
|
+ db.rollback()
|
|
|
|
+ logger.error(f"删除trunk失败: {str(e)}")
|
|
|
|
+ raise
|
|
|
|
+ finally:
|
|
|
|
+ db.close()
|
|
|
|
+
|
|
|
|
+ def highlight(self, trunk_id: int, targetSentences: List[str]) -> List[int]:
|
|
|
|
+ trunk = self.get_trunk_by_id(trunk_id)
|
|
|
|
+ if not trunk:
|
|
|
|
+ return []
|
|
|
|
+
|
|
|
|
+ content = trunk['content']
|
|
|
|
+ sentence_util = SentenceUtil()
|
|
|
|
+ cleanedContent = sentence_util.clean_text(content)
|
|
|
|
+
|
|
|
|
+ result = []
|
|
|
|
+ for i, targetSentence in enumerate(targetSentences):
|
|
|
|
+ cleanedTarget = sentence_util.clean_text(targetSentence)
|
|
|
|
+ #cleanedTarget长度小于5的不进行匹配
|
|
|
|
+ if len(cleanedTarget)<5:
|
|
|
|
+ continue
|
|
|
|
+ if cleanedTarget in cleanedContent:
|
|
|
|
+ result.append(i)
|
|
|
|
+
|
|
|
|
+ # 补齐连续下标
|
|
|
|
+ if result:
|
|
|
|
+ result.sort()
|
|
|
|
+ filled_result = []
|
|
|
|
+ prev = result[0]
|
|
|
|
+ filled_result.append(prev)
|
|
|
|
+ for current in result[1:]:
|
|
|
|
+ if current - prev <= 2:
|
|
|
|
+ for i in range(prev + 1, current):
|
|
|
|
+ filled_result.append(i)
|
|
|
|
+ filled_result.append(current)
|
|
|
|
+ prev = current
|
|
|
|
+ return filled_result
|
|
|
|
+ return result
|
|
|
|
+
|
|
|
|
+ _cache = {}
|
|
|
|
+
|
|
|
|
+ def get_cache(self, conversation_id: str) -> List[dict]:
|
|
|
|
+ """
|
|
|
|
+ 根据conversation_id获取缓存结果
|
|
|
|
+ :param conversation_id: 会话ID
|
|
|
|
+ :return: 结果列表
|
|
|
|
+ """
|
|
|
|
+ return self._cache.get(conversation_id, [])
|
|
|
|
+
|
|
|
|
+ def set_cache(self, conversation_id: str, result: List[dict]) -> None:
|
|
|
|
+ """
|
|
|
|
+ 设置缓存结果
|
|
|
|
+ :param conversation_id: 会话ID
|
|
|
|
+ :param result: 要缓存的结果
|
|
|
|
+ """
|
|
|
|
+ self._cache[conversation_id] = result
|
|
|
|
+
|
|
|
|
+ def get_cached_result(self, conversation_id: str) -> List[dict]:
|
|
|
|
+ """
|
|
|
|
+ 根据conversation_id获取缓存结果
|
|
|
|
+ :param conversation_id: 会话ID
|
|
|
|
+ :return: 结果列表
|
|
|
|
+ """
|
|
|
|
+ return self.get_cache(conversation_id)
|
|
|
|
+
|
|
|
|
+ def paginated_search_by_type_and_filepath(self, search_params: dict) -> dict:
|
|
|
|
+ """
|
|
|
|
+ 根据type和file_path进行分页查询
|
|
|
|
+ :param search_params: 包含pageNo, limit的字典
|
|
|
|
+ :return: 包含结果列表和分页信息的字典
|
|
|
|
+ """
|
|
|
|
+ page_no = search_params.get('pageNo', 1)
|
|
|
|
+ limit = search_params.get('limit', 10)
|
|
|
|
+ file_path = search_params.get('file_path', None)
|
|
|
|
+ type = search_params.get('type', None)
|
|
|
|
+
|
|
|
|
+ if page_no < 1:
|
|
|
|
+ page_no = 1
|
|
|
|
+ if limit < 1:
|
|
|
|
+ limit = 10
|
|
|
|
+
|
|
|
|
+ offset = (page_no - 1) * limit
|
|
|
|
+
|
|
|
|
+ db = SessionLocal()
|
|
|
|
+ try:
|
|
|
|
+
|
|
|
|
+ query = db.query(
|
|
|
|
+ Trunks.id,
|
|
|
|
+ Trunks.file_path,
|
|
|
|
+ Trunks.content,
|
|
|
|
+ Trunks.type,
|
|
|
|
+ Trunks.title,
|
|
|
|
+ Trunks.meta_header
|
|
|
|
+ )
|
|
|
|
+ if type:
|
|
|
|
+ query = query.filter(Trunks.type == type)
|
|
|
|
+ if file_path:
|
|
|
|
+ query = query.filter(Trunks.file_path.like('%' + file_path + '%'))
|
|
|
|
+
|
|
|
|
+ query = query.filter(Trunks.page_no == None)
|
|
|
|
+ results = query.offset(offset).limit(limit).all()
|
|
|
|
+
|
|
|
|
+ return {
|
|
|
|
+ 'data': [{
|
|
|
|
+ 'id': r.id,
|
|
|
|
+ 'file_path': r.file_path,
|
|
|
|
+ 'content': r.content,
|
|
|
|
+ 'type': r.type,
|
|
|
|
+ 'title': r.title,
|
|
|
|
+ 'meta_header':r.meta_header
|
|
|
|
+ } for r in results]
|
|
|
|
+ }
|
|
|
|
+ finally:
|
|
|
|
+ db.close()
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+ def paginated_search(self, search_params: dict) -> dict:
|
|
|
|
+ """
|
|
|
|
+ 分页查询方法
|
|
|
|
+ :param search_params: 包含keyword, pageNo, limit的字典
|
|
|
|
+ :return: 包含结果列表和分页信息的字典
|
|
|
|
+ """
|
|
|
|
+ keyword = search_params.get('keyword', '')
|
|
|
|
+ page_no = search_params.get('pageNo', 1)
|
|
|
|
+ limit = search_params.get('limit', 10)
|
|
|
|
+
|
|
|
|
+ if page_no < 1:
|
|
|
|
+ page_no = 1
|
|
|
|
+ if limit < 1:
|
|
|
|
+ limit = 10
|
|
|
|
+
|
|
|
|
+ embedding = Vectorizer.get_embedding(keyword)
|
|
|
|
+ offset = (page_no - 1) * limit
|
|
|
|
+
|
|
|
|
+ db = SessionLocal()
|
|
|
|
+ try:
|
|
|
|
+ # 获取总条数
|
|
|
|
+ total_count = db.query(func.count(Trunks.id)).filter(Trunks.type == search_params.get('type')).scalar()
|
|
|
|
+
|
|
|
|
+ # 执行向量搜索
|
|
|
|
+ results = db.query(
|
|
|
|
+ Trunks.id,
|
|
|
|
+ Trunks.file_path,
|
|
|
|
+ Trunks.content,
|
|
|
|
+ Trunks.embedding.l2_distance(embedding).label('distance'),
|
|
|
|
+ Trunks.title
|
|
|
|
+ ).filter(Trunks.type == search_params.get('type')).order_by('distance').offset(offset).limit(limit).all()
|
|
|
|
+
|
|
|
|
+ return {
|
|
|
|
+ 'data': [{
|
|
|
|
+ 'id': r.id,
|
|
|
|
+ 'file_path': r.file_path,
|
|
|
|
+ 'content': r.content,
|
|
|
|
+ #保留小数点后三位
|
|
|
|
+ 'distance': round(r.distance, 3),
|
|
|
|
+ 'title': r.title
|
|
|
|
+ } for r in results],
|
|
|
|
+ 'pagination': {
|
|
|
|
+ 'total': total_count,
|
|
|
|
+ 'pageNo': page_no,
|
|
|
|
+ 'limit': limit,
|
|
|
|
+ 'totalPages': (total_count + limit - 1) // limit
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ finally:
|
|
|
|
+ db.close()
|