|
@@ -22,6 +22,9 @@ class TrunksService:
|
|
|
trunk_data['embedding'] = Vectorizer.get_embedding(content)
|
|
|
if 'type' not in trunk_data:
|
|
|
trunk_data['type'] = 'default'
|
|
|
+ if 'title' not in trunk_data:
|
|
|
+ from pathlib import Path
|
|
|
+ trunk_data['title'] = Path(trunk_data['file_path']).stem
|
|
|
print("embedding length:", len(trunk_data['embedding']))
|
|
|
logger.debug(f"生成的embedding长度: {len(trunk_data['embedding'])}, 内容摘要: {content[:20]}")
|
|
|
# trunk_data['content_tsvector'] = func.to_tsvector('chinese', content)
|
|
@@ -51,13 +54,20 @@ class TrunksService:
|
|
|
'file_path': trunk.file_path,
|
|
|
'content': trunk.content,
|
|
|
'embedding': trunk.embedding.tolist(),
|
|
|
- 'type': trunk.type
|
|
|
+ 'type': trunk.type,
|
|
|
+ 'title':trunk.title
|
|
|
}
|
|
|
return None
|
|
|
finally:
|
|
|
db.close()
|
|
|
|
|
|
- def search_by_vector(self, text: str, limit: int = 10, metadata_condition: Optional[dict] = None, type: Optional[str] = None) -> List[dict]:
|
|
|
+ def search_by_vector(self, text: str, limit: int = 10, metadata_condition: Optional[dict] = None, type: Optional[str] = None, conversation_id: Optional[str] = None) -> List[dict]:
|
|
|
+ if conversation_id:
|
|
|
+ from cache import get_cache, set_cache
|
|
|
+ cached_result = get_cache(conversation_id)
|
|
|
+ if cached_result:
|
|
|
+ return cached_result
|
|
|
+
|
|
|
embedding = Vectorizer.get_embedding(text)
|
|
|
db = SessionLocal()
|
|
|
try:
|
|
@@ -65,19 +75,26 @@ class TrunksService:
|
|
|
Trunks.id,
|
|
|
Trunks.file_path,
|
|
|
Trunks.content,
|
|
|
- Trunks.embedding.l2_distance(embedding).label('distance')
|
|
|
+ Trunks.embedding.l2_distance(embedding).label('distance'),
|
|
|
+ Trunks.title
|
|
|
)
|
|
|
if metadata_condition:
|
|
|
query = query.filter_by(**metadata_condition)
|
|
|
if type:
|
|
|
query = query.filter(Trunks.type == type)
|
|
|
results = query.order_by('distance').limit(limit).all()
|
|
|
- return [{
|
|
|
+ result_list = [{
|
|
|
'id': r.id,
|
|
|
'file_path': r.file_path,
|
|
|
'content': r.content,
|
|
|
- 'distance': r.distance
|
|
|
+ 'distance': r.distance,
|
|
|
+ 'title': r.title
|
|
|
} for r in results]
|
|
|
+
|
|
|
+ if conversation_id:
|
|
|
+ set_cache(conversation_id, result_list)
|
|
|
+
|
|
|
+ return result_list
|
|
|
finally:
|
|
|
db.close()
|
|
|
|
|
@@ -159,7 +176,8 @@ class TrunksService:
|
|
|
Trunks.id,
|
|
|
Trunks.file_path,
|
|
|
Trunks.content,
|
|
|
- Trunks.embedding.l2_distance(embedding).label('distance')
|
|
|
+ Trunks.embedding.l2_distance(embedding).label('distance'),
|
|
|
+ Trunks.title
|
|
|
).filter(Trunks.type == search_params.get('type')).order_by('distance').offset(offset).limit(limit).all()
|
|
|
|
|
|
return {
|
|
@@ -167,7 +185,8 @@ class TrunksService:
|
|
|
'id': r.id,
|
|
|
'file_path': r.file_path,
|
|
|
'content': r.content,
|
|
|
- 'distance': r.distance
|
|
|
+ 'distance': r.distance,
|
|
|
+ 'title': r.title
|
|
|
} for r in results],
|
|
|
'pagination': {
|
|
|
'total': total_count,
|