|
@@ -20,6 +20,8 @@ class TrunksService:
|
|
|
if 'embedding' in trunk_data and len(trunk_data['embedding']) != 1024:
|
|
|
raise ValueError("向量维度必须为1024")
|
|
|
trunk_data['embedding'] = Vectorizer.get_embedding(content)
|
|
|
+ if 'type' not in trunk_data:
|
|
|
+ trunk_data['type'] = 'default'
|
|
|
print("embedding length:", len(trunk_data['embedding']))
|
|
|
logger.debug(f"生成的embedding长度: {len(trunk_data['embedding'])}, 内容摘要: {content[:20]}")
|
|
|
# trunk_data['content_tsvector'] = func.to_tsvector('chinese', content)
|
|
@@ -39,14 +41,23 @@ class TrunksService:
|
|
|
finally:
|
|
|
db.close()
|
|
|
|
|
|
- def get_trunk_by_id(self, trunk_id: int) -> Optional[Trunks]:
|
|
|
+ def get_trunk_by_id(self, trunk_id: int) -> Optional[dict]:
|
|
|
db = SessionLocal()
|
|
|
try:
|
|
|
- return db.query(Trunks).filter(Trunks.id == trunk_id).first()
|
|
|
+ trunk = db.query(Trunks).filter(Trunks.id == trunk_id).first()
|
|
|
+ if trunk:
|
|
|
+ return {
|
|
|
+ 'id': trunk.id,
|
|
|
+ 'file_path': trunk.file_path,
|
|
|
+ 'content': trunk.content,
|
|
|
+ 'embedding': trunk.embedding.tolist(),
|
|
|
+ 'type': trunk.type
|
|
|
+ }
|
|
|
+ return None
|
|
|
finally:
|
|
|
db.close()
|
|
|
|
|
|
- def search_by_vector(self, text: str, limit: int = 10, metadata_condition: Optional[dict] = None) -> List[dict]:
|
|
|
+ def search_by_vector(self, text: str, limit: int = 10, metadata_condition: Optional[dict] = None, type: Optional[str] = None) -> List[dict]:
|
|
|
embedding = Vectorizer.get_embedding(text)
|
|
|
db = SessionLocal()
|
|
|
try:
|
|
@@ -58,6 +69,8 @@ class TrunksService:
|
|
|
)
|
|
|
if metadata_condition:
|
|
|
query = query.filter_by(**metadata_condition)
|
|
|
+ if type:
|
|
|
+ query = query.filter(Trunks.type == type)
|
|
|
results = query.order_by('distance').limit(limit).all()
|
|
|
return [{
|
|
|
'id': r.id,
|
|
@@ -81,6 +94,8 @@ class TrunksService:
|
|
|
if 'content' in update_data:
|
|
|
content = update_data['content']
|
|
|
update_data['embedding'] = Vectorizer.get_embedding(content)
|
|
|
+ if 'type' not in update_data:
|
|
|
+ update_data['type'] = 'default'
|
|
|
logger.debug(f"更新生成的embedding长度: {len(update_data['embedding'])}, 内容摘要: {content[:20]}")
|
|
|
# update_data['content_tsvector'] = func.to_tsvector('chinese', content)
|
|
|
|
|
@@ -137,7 +152,7 @@ class TrunksService:
|
|
|
db = SessionLocal()
|
|
|
try:
|
|
|
# 获取总条数
|
|
|
- total_count = db.query(func.count(Trunks.id)).scalar()
|
|
|
+ total_count = db.query(func.count(Trunks.id)).filter(Trunks.type == search_params.get('type')).scalar()
|
|
|
|
|
|
# 执行向量搜索
|
|
|
results = db.query(
|
|
@@ -145,7 +160,7 @@ class TrunksService:
|
|
|
Trunks.file_path,
|
|
|
Trunks.content,
|
|
|
Trunks.embedding.l2_distance(embedding).label('distance')
|
|
|
- ).order_by('distance').offset(offset).limit(limit).all()
|
|
|
+ ).filter(Trunks.type == search_params.get('type')).order_by('distance').offset(offset).limit(limit).all()
|
|
|
|
|
|
return {
|
|
|
'data': [{
|