from sqlalchemy import func from sqlalchemy.orm import Session from db.session import get_db from typing import List, Optional from model.trunks_model import Trunks from db.session import SessionLocal import logging from utils.vectorizer import Vectorizer logger = logging.getLogger(__name__) class TrunksService: def __init__(self): self.db = next(get_db()) def create_trunk(self, trunk_data: dict) -> Trunks: # 自动生成向量和全文检索字段 content = trunk_data.get('content') if 'embedding' in trunk_data and len(trunk_data['embedding']) != 1024: raise ValueError("向量维度必须为1024") trunk_data['embedding'] = Vectorizer.get_embedding(content) if 'type' not in trunk_data: trunk_data['type'] = 'default' if 'title' not in trunk_data: from pathlib import Path trunk_data['title'] = Path(trunk_data['file_path']).stem print("embedding length:", len(trunk_data['embedding'])) logger.debug(f"生成的embedding长度: {len(trunk_data['embedding'])}, 内容摘要: {content[:20]}") # trunk_data['content_tsvector'] = func.to_tsvector('chinese', content) db = SessionLocal() try: trunk = Trunks(**trunk_data) db.add(trunk) db.commit() db.refresh(trunk) return trunk except Exception as e: db.rollback() logger.error(f"创建trunk失败: {str(e)}") raise finally: db.close() def get_trunk_by_id(self, trunk_id: int) -> Optional[dict]: db = SessionLocal() try: trunk = db.query(Trunks).filter(Trunks.id == trunk_id).first() if trunk: return { 'id': trunk.id, 'file_path': trunk.file_path, 'content': trunk.content, 'embedding': trunk.embedding.tolist(), 'type': trunk.type, 'title':trunk.title } return None finally: db.close() def search_by_vector(self, text: str, limit: int = 10, metadata_condition: Optional[dict] = None, type: Optional[str] = None, conversation_id: Optional[str] = None) -> List[dict]: embedding = Vectorizer.get_embedding(text) db = SessionLocal() try: query = db.query( Trunks.id, Trunks.file_path, Trunks.content, Trunks.embedding.l2_distance(embedding).label('distance'), Trunks.title ) if metadata_condition: query = query.filter_by(**metadata_condition) if type: query = query.filter(Trunks.type == type) results = query.order_by('distance').limit(limit).all() result_list = [{ 'id': r.id, 'file_path': r.file_path, 'content': r.content, 'distance': r.distance, 'title': r.title } for r in results] if conversation_id: set_cache(conversation_id, result_list) return result_list finally: db.close() def fulltext_search(self, query: str) -> List[Trunks]: db = SessionLocal() try: return db.query(Trunks).filter( Trunks.content_tsvector.match(query) ).all() finally: db.close() def update_trunk(self, trunk_id: int, update_data: dict) -> Optional[Trunks]: if 'content' in update_data: content = update_data['content'] update_data['embedding'] = Vectorizer.get_embedding(content) if 'type' not in update_data: update_data['type'] = 'default' logger.debug(f"更新生成的embedding长度: {len(update_data['embedding'])}, 内容摘要: {content[:20]}") # update_data['content_tsvector'] = func.to_tsvector('chinese', content) db = SessionLocal() try: trunk = db.query(Trunks).filter(Trunks.id == trunk_id).first() if trunk: for key, value in update_data.items(): setattr(trunk, key, value) db.commit() db.refresh(trunk) return trunk except Exception as e: db.rollback() logger.error(f"更新trunk失败: {str(e)}") raise finally: db.close() def delete_trunk(self, trunk_id: int) -> bool: db = SessionLocal() try: trunk = db.query(Trunks).filter(Trunks.id == trunk_id).first() if trunk: db.delete(trunk) db.commit() return True return False except Exception as e: db.rollback() logger.error(f"删除trunk失败: {str(e)}") raise finally: db.close() def get_cached_result(self, conversation_id: str) -> List[dict]: """ 根据conversation_id获取缓存结果 :param conversation_id: 会话ID :return: 结果列表 """ if not conversation_id: return [] from cache import get_cache, set_cache cached_result = get_cache(conversation_id) if cached_result: return cached_result return [] def paginated_search(self, search_params: dict) -> dict: """ 分页查询方法 :param search_params: 包含keyword, pageNo, limit的字典 :return: 包含结果列表和分页信息的字典 """ keyword = search_params.get('keyword', '') page_no = search_params.get('pageNo', 1) limit = search_params.get('limit', 10) if page_no < 1: page_no = 1 if limit < 1: limit = 10 embedding = Vectorizer.get_embedding(keyword) offset = (page_no - 1) * limit db = SessionLocal() try: # 获取总条数 total_count = db.query(func.count(Trunks.id)).filter(Trunks.type == search_params.get('type')).scalar() # 执行向量搜索 results = db.query( Trunks.id, Trunks.file_path, Trunks.content, Trunks.embedding.l2_distance(embedding).label('distance'), Trunks.title ).filter(Trunks.type == search_params.get('type')).order_by('distance').offset(offset).limit(limit).all() return { 'data': [{ 'id': r.id, 'file_path': r.file_path, 'content': r.content, 'distance': r.distance, 'title': r.title } for r in results], 'pagination': { 'total': total_count, 'pageNo': page_no, 'limit': limit, 'totalPages': (total_count + limit - 1) // limit } } finally: db.close()