trunks_service.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. from sqlalchemy import func
  2. from sqlalchemy.orm import Session
  3. from typing import List, Optional
  4. from model.trunks_model import Trunks
  5. from db.session import SessionLocal
  6. import logging
  7. from utils.vectorizer import Vectorizer
  8. logger = logging.getLogger(__name__)
  9. class TrunksService:
  10. def __init__(self):
  11. self.db = next(get_db())
  12. def create_trunk(self, trunk_data: dict) -> Trunks:
  13. # 自动生成向量和全文检索字段
  14. content = trunk_data.get('content')
  15. if 'embedding' in trunk_data and len(trunk_data['embedding']) != 1024:
  16. raise ValueError("向量维度必须为1024")
  17. trunk_data['embedding'] = Vectorizer.get_embedding(content)
  18. print("embedding length:", len(trunk_data['embedding']))
  19. logger.debug(f"生成的embedding长度: {len(trunk_data['embedding'])}, 内容摘要: {content[:20]}")
  20. # trunk_data['content_tsvector'] = func.to_tsvector('chinese', content)
  21. db = SessionLocal()
  22. try:
  23. trunk = Trunks(**trunk_data)
  24. db.add(trunk)
  25. db.commit()
  26. db.refresh(trunk)
  27. return trunk
  28. except Exception as e:
  29. db.rollback()
  30. logger.error(f"创建trunk失败: {str(e)}")
  31. raise
  32. finally:
  33. db.close()
  34. def get_trunk_by_id(self, trunk_id: int) -> Optional[Trunks]:
  35. db = SessionLocal()
  36. try:
  37. return db.query(Trunks).filter(Trunks.id == trunk_id).first()
  38. finally:
  39. db.close()
  40. def search_by_vector(self, text: str, limit: int = 10, metadata_condition: Optional[dict] = None) -> List[dict]:
  41. embedding = Vectorizer.get_embedding(text)
  42. db = SessionLocal()
  43. try:
  44. query = db.query(
  45. Trunks.id,
  46. Trunks.file_path,
  47. Trunks.content,
  48. Trunks.embedding.l2_distance(embedding).label('distance')
  49. )
  50. if metadata_condition:
  51. query = query.filter_by(**metadata_condition)
  52. results = query.order_by('distance').limit(limit).all()
  53. return [{
  54. 'id': r.id,
  55. 'file_path': r.file_path,
  56. 'content': r.content,
  57. 'distance': r.distance
  58. } for r in results]
  59. finally:
  60. db.close()
  61. def fulltext_search(self, query: str) -> List[Trunks]:
  62. db = SessionLocal()
  63. try:
  64. return db.query(Trunks).filter(
  65. Trunks.content_tsvector.match(query)
  66. ).all()
  67. finally:
  68. db.close()
  69. def update_trunk(self, trunk_id: int, update_data: dict) -> Optional[Trunks]:
  70. if 'content' in update_data:
  71. content = update_data['content']
  72. update_data['embedding'] = Vectorizer.get_embedding(content)
  73. logger.debug(f"更新生成的embedding长度: {len(update_data['embedding'])}, 内容摘要: {content[:20]}")
  74. # update_data['content_tsvector'] = func.to_tsvector('chinese', content)
  75. db = SessionLocal()
  76. try:
  77. trunk = db.query(Trunks).filter(Trunks.id == trunk_id).first()
  78. if trunk:
  79. for key, value in update_data.items():
  80. setattr(trunk, key, value)
  81. db.commit()
  82. db.refresh(trunk)
  83. return trunk
  84. except Exception as e:
  85. db.rollback()
  86. logger.error(f"更新trunk失败: {str(e)}")
  87. raise
  88. finally:
  89. db.close()
  90. def delete_trunk(self, trunk_id: int) -> bool:
  91. db = SessionLocal()
  92. try:
  93. trunk = db.query(Trunks).filter(Trunks.id == trunk_id).first()
  94. if trunk:
  95. db.delete(trunk)
  96. db.commit()
  97. return True
  98. return False
  99. except Exception as e:
  100. db.rollback()
  101. logger.error(f"删除trunk失败: {str(e)}")
  102. raise
  103. finally:
  104. db.close()
  105. def paginated_search(self, search_params: dict) -> dict:
  106. """
  107. 分页查询方法
  108. :param search_params: 包含keyword, pageNo, limit的字典
  109. :return: 包含结果列表和分页信息的字典
  110. """
  111. keyword = search_params.get('keyword', '')
  112. page_no = search_params.get('pageNo', 1)
  113. limit = search_params.get('limit', 10)
  114. if page_no < 1:
  115. page_no = 1
  116. if limit < 1:
  117. limit = 10
  118. embedding = Vectorizer.get_embedding(keyword)
  119. offset = (page_no - 1) * limit
  120. db = SessionLocal()
  121. try:
  122. # 获取总条数
  123. total_count = db.query(func.count(Trunks.id)).scalar()
  124. # 执行向量搜索
  125. results = db.query(
  126. Trunks.id,
  127. Trunks.file_path,
  128. Trunks.content,
  129. Trunks.embedding.l2_distance(embedding).label('distance')
  130. ).order_by('distance').offset(offset).limit(limit).all()
  131. return {
  132. 'data': [{
  133. 'id': r.id,
  134. 'file_path': r.file_path,
  135. 'content': r.content,
  136. 'distance': r.distance
  137. } for r in results],
  138. 'pagination': {
  139. 'total': total_count,
  140. 'pageNo': page_no,
  141. 'limit': limit,
  142. 'totalPages': (total_count + limit - 1) // limit
  143. }
  144. }
  145. finally:
  146. db.close()