from sqlalchemy import func from sqlalchemy.orm import Session from db.session import get_db from typing import List, Optional from model.trunks_model import Trunks from db.session import SessionLocal import logging from utils.vectorizer import Vectorizer logger = logging.getLogger(__name__) class TrunksService: def __init__(self): self.db = next(get_db()) def create_trunk(self, trunk_data: dict) -> Trunks: # 自动生成向量和全文检索字段 content = trunk_data.get('content') if 'embedding' in trunk_data and len(trunk_data['embedding']) != 1024: raise ValueError("向量维度必须为1024") trunk_data['embedding'] = Vectorizer.get_embedding(content) print("embedding length:", len(trunk_data['embedding'])) logger.debug(f"生成的embedding长度: {len(trunk_data['embedding'])}, 内容摘要: {content[:20]}") # trunk_data['content_tsvector'] = func.to_tsvector('chinese', content) db = SessionLocal() try: trunk = Trunks(**trunk_data) db.add(trunk) db.commit() db.refresh(trunk) return trunk except Exception as e: db.rollback() logger.error(f"创建trunk失败: {str(e)}") raise finally: db.close() def get_trunk_by_id(self, trunk_id: int) -> Optional[Trunks]: db = SessionLocal() try: return db.query(Trunks).filter(Trunks.id == trunk_id).first() finally: db.close() def search_by_vector(self, text: str, limit: int = 10, metadata_condition: Optional[dict] = None) -> List[dict]: embedding = Vectorizer.get_embedding(text) db = SessionLocal() try: query = db.query( Trunks.id, Trunks.file_path, Trunks.content, Trunks.embedding.l2_distance(embedding).label('distance') ) if metadata_condition: query = query.filter_by(**metadata_condition) results = query.order_by('distance').limit(limit).all() return [{ 'id': r.id, 'file_path': r.file_path, 'content': r.content, 'distance': r.distance } for r in results] finally: db.close() def fulltext_search(self, query: str) -> List[Trunks]: db = SessionLocal() try: return db.query(Trunks).filter( Trunks.content_tsvector.match(query) ).all() finally: db.close() def update_trunk(self, trunk_id: int, update_data: dict) -> Optional[Trunks]: if 'content' in update_data: content = update_data['content'] update_data['embedding'] = Vectorizer.get_embedding(content) logger.debug(f"更新生成的embedding长度: {len(update_data['embedding'])}, 内容摘要: {content[:20]}") # update_data['content_tsvector'] = func.to_tsvector('chinese', content) db = SessionLocal() try: trunk = db.query(Trunks).filter(Trunks.id == trunk_id).first() if trunk: for key, value in update_data.items(): setattr(trunk, key, value) db.commit() db.refresh(trunk) return trunk except Exception as e: db.rollback() logger.error(f"更新trunk失败: {str(e)}") raise finally: db.close() def delete_trunk(self, trunk_id: int) -> bool: db = SessionLocal() try: trunk = db.query(Trunks).filter(Trunks.id == trunk_id).first() if trunk: db.delete(trunk) db.commit() return True return False except Exception as e: db.rollback() logger.error(f"删除trunk失败: {str(e)}") raise finally: db.close() def paginated_search(self, search_params: dict) -> dict: """ 分页查询方法 :param search_params: 包含keyword, pageNo, limit的字典 :return: 包含结果列表和分页信息的字典 """ keyword = search_params.get('keyword', '') page_no = search_params.get('pageNo', 1) limit = search_params.get('limit', 10) if page_no < 1: page_no = 1 if limit < 1: limit = 10 embedding = Vectorizer.get_embedding(keyword) offset = (page_no - 1) * limit db = SessionLocal() try: # 获取总条数 total_count = db.query(func.count(Trunks.id)).scalar() # 执行向量搜索 results = db.query( Trunks.id, Trunks.file_path, Trunks.content, Trunks.embedding.l2_distance(embedding).label('distance') ).order_by('distance').offset(offset).limit(limit).all() return { 'data': [{ 'id': r.id, 'file_path': r.file_path, 'content': r.content, 'distance': r.distance } for r in results], 'pagination': { 'total': total_count, 'pageNo': page_no, 'limit': limit, 'totalPages': (total_count + limit - 1) // limit } } finally: db.close()