Browse Source

代码提交

SGTY 2 tháng trước cách đây
mục cha
commit
932010a704

+ 8 - 4
community/dump_graph_data.py

@@ -26,7 +26,8 @@ def get_props(ref_id):
     return props
 
 def get_entities():
-    COUNT_SQL = "select count(*) from kg_nodes where version=:version"
+    #COUNT_SQL = "select count(*) from kg_nodes where version=:version"
+    COUNT_SQL = "select count(*) from kg_nodes"
     result = db.execute(text(COUNT_SQL), {'version': version})
     count = result.scalar()
 
@@ -35,7 +36,8 @@ def get_entities():
     batch = 100
     start = 1
     while start < count:    
-        sql = """select id,name,category from kg_nodes where version=:version order by id limit :batch OFFSET :start"""
+        #sql = """select id,name,category from kg_nodes where version=:version order by id limit :batch OFFSET :start"""
+        sql = """select id,name,category from kg_nodes order by id limit :batch OFFSET :start"""
         result = db.execute(text(sql), {'start':start, 'batch':batch, 'version': version})
         #["发热",{"type":"症状","description":"发热是诊断的主要目的,用于明确发热病因。"}]
         row_count = 0
@@ -62,7 +64,8 @@ def get_names(src_id, dest_id):
     return (src_name, dest_name)
 
 def get_relationships():
-    COUNT_SQL = "select count(*) from kg_edges where version=:version"
+    #COUNT_SQL = "select count(*) from kg_edges where version=:version"
+    COUNT_SQL = "select count(*) from kg_edges"
     result = db.execute(text(COUNT_SQL), {'version': version})
     count = result.scalar()
 
@@ -72,7 +75,8 @@ def get_relationships():
     start = 1
     file_index = 1
     while start < count:    
-        sql = """select id,name,category,src_id,dest_id from kg_edges where version=:version order by id limit :batch OFFSET :start"""
+        #sql = """select id,name,category,src_id,dest_id from kg_edges where version=:version order by id limit :batch OFFSET :start"""
+        sql = """select id,name,category,src_id,dest_id from kg_edges order by id limit :batch OFFSET :start"""
         result = db.execute(text(sql), {'start':start, 'batch':batch, 'version': version})
         #["发热",{"type":"症状","description":"发热是诊断的主要目的,用于明确发热病因。"}]
         row_count = 0

+ 2 - 0
model/trunks_model.py

@@ -13,6 +13,8 @@ class Trunks(Base):
     content_tsvector = Column(TSVECTOR)
     type = Column(String(255))
     title = Column(String(255))
+    referrence = Column(String(255))
+    meta_header = Column(String(255))
 
     def __repr__(self):
         return f"<Trunks(id={self.id}, file_path={self.file_path})>"

+ 67 - 21
router/text_search.py

@@ -6,6 +6,7 @@ from utils.text_splitter import TextSplitter
 from utils.vector_distance import VectorDistance
 from model.response import StandardResponse
 from utils.vectorizer import Vectorizer
+DISTANCE_THRESHOLD = 0.65
 import logging
 
 logger = logging.getLogger(__name__)
@@ -16,6 +17,10 @@ class TextSearchRequest(BaseModel):
     conversation_id: Optional[str] = None
     need_convert: Optional[bool] = False
 
+class TextCompareRequest(BaseModel):
+    sentence: str
+    text: str
+
 @router.post("/search", response_model=StandardResponse)
 async def search_text(request: TextSearchRequest):
     try:
@@ -34,6 +39,8 @@ async def search_text(request: TextSearchRequest):
         cached_results = trunks_service.get_cached_result(request.conversation_id) if request.conversation_id else []
         
         for sentence in sentences:
+            # if request.need_convert:
+            sentence = sentence.replace("\n", "<br>")
             if len(sentence) < 10:
                 result_sentences.append(sentence)
                 continue
@@ -51,7 +58,7 @@ async def search_text(request: TextSearchRequest):
                         best_result = {**cached_result, 'distance': distance}
                         
                 
-                if best_result and best_result['distance'] < 0.75:
+                if best_result and best_result['distance'] < DISTANCE_THRESHOLD:
                     search_results = [best_result]
                 else:
                     search_results = []
@@ -64,32 +71,43 @@ async def search_text(request: TextSearchRequest):
             
             # 处理搜索结果
             for result in search_results:
-                # 获取distance值,如果大于等于1则跳过
-                distance = result.get("distance", 1.0)
-                if distance >= 1:
+                distance = result.get("distance", DISTANCE_THRESHOLD)
+                if distance >= DISTANCE_THRESHOLD:
+                    result_sentences.append(sentence)
                     continue
                 
+                # 检查是否已存在相同引用
+                existing_ref = next((ref for ref in all_references if ref["id"] == result["id"]), None)
+                current_index = reference_index
+                if existing_ref:
+                    current_index = int(existing_ref["index"])
+                else:
+                    # 添加到引用列表
+                    reference = {
+                        "index": str(reference_index),
+                        "id": result["id"],
+                        "content": result["content"],
+                        "file_path": result.get("file_path", ""),
+                        "title": result.get("title", ""),
+                        "distance": distance,
+                        "referrence": result.get("referrence", "")
+                    }
+                    all_references.append(reference)
+                    reference_index += 1
+                
                 # 添加引用标记
-                result_sentence = sentence + f"^[{reference_index}]^"
-                result_sentences.append(result_sentence)
+                if sentence.endswith('<br>'):
+                    # 如果有多个<br>,在所有<br>前添加^[current_index]^
+                    result_sentence = sentence.replace('<br>', f'^[{current_index}]^<br>')
+                else:
+                    # 直接在句子末尾添加^[current_index]^
+                    result_sentence = f'{sentence}^[{current_index}]^'
                 
-                # 添加到引用列表
-                reference = {
-                    "index": str(reference_index),
-                    "content": result["content"],
-                    "file_path": result.get("file_path", ""),
-                    "title": result.get("title", ""),
-                    "distance": distance
-                }
-                all_references.append(reference)
-                reference_index += 1
-
-        answer = "\n".join(result_sentences)
-        if request.need_convert:
-            answer = answer.replace("\n", "</br>")
+                result_sentences.append(result_sentence)
+     
         # 组装返回数据
         response_data = {
-            "answer": answer,
+            "answer": result_sentences,
             "references": all_references
         }
         
@@ -99,4 +117,32 @@ async def search_text(request: TextSearchRequest):
         logger.error(f"Text search failed: {str(e)}")
         raise HTTPException(status_code=500, detail=str(e))
 
+@router.post("/match", response_model=StandardResponse)
+async def match_text(request: TextCompareRequest):
+    try:
+        sentences = TextSplitter.split_text(request.text)
+        sentence_vector = Vectorizer.get_embedding(request.sentence)
+        min_distance = float('inf')
+        best_sentence = ""
+        result_sentences = []
+        for temp in sentences:
+            result_sentences.append(temp)
+            if len(temp) < 10:
+                continue
+            temp_vector = Vectorizer.get_embedding(temp)
+            distance = VectorDistance.calculate_distance(sentence_vector, temp_vector)
+            if distance < min_distance and distance < DISTANCE_THRESHOLD:
+                min_distance = distance
+                best_sentence = temp
+
+        for i in range(len(result_sentences)):
+            result_sentences[i] = {"sentence": result_sentences[i], "matched": False}
+            if result_sentences[i]["sentence"] == best_sentence:
+                result_sentences[i]["matched"] = True    
+                
+        return StandardResponse(success=True, records=result_sentences)
+    except Exception as e:
+        logger.error(f"Text comparison failed: {str(e)}")
+        raise HTTPException(status_code=500, detail=str(e))
+
 text_search_router = router

+ 9 - 5
service/kg_node_service.py

@@ -10,7 +10,7 @@ from service.kg_prop_service import KGPropService
 from service.kg_edge_service import KGEdgeService
 
 logger = logging.getLogger(__name__)
-
+DISTANCE_THRESHOLD = 0.65
 class KGNodeService:
     def __init__(self, db: Session):
         self.db = db
@@ -32,7 +32,7 @@ class KGNodeService:
         offset = (page_no - 1) * limit
 
         try:
-            total_count = self.db.query(func.count(KGNode.id)).filter(KGNode.version.in_(search_params['knowledge_ids'])).scalar() if search_params.get('knowledge_ids') else self.db.query(func.count(KGNode.id)).scalar()
+            total_count = self.db.query(func.count(KGNode.id)).filter(KGNode.version.in_(search_params['knowledge_ids']), KGNode.embedding.l2_distance(embedding) < DISTANCE_THRESHOLD).scalar() if search_params.get('knowledge_ids') else self.db.query(func.count(KGNode.id)).filter(KGNode.embedding.l2_distance(embedding) < DISTANCE_THRESHOLD).scalar()
 
             query = self.db.query(
                 KGNode.id,
@@ -42,6 +42,7 @@ class KGNodeService:
             )
             if search_params.get('knowledge_ids'):
                 query = query.filter(KGNode.version.in_(search_params['knowledge_ids']))
+            query = query.filter(KGNode.embedding.l2_distance(embedding) < DISTANCE_THRESHOLD)
             results = query.order_by('distance').offset(offset).limit(limit).all()
 
             return {
@@ -51,7 +52,7 @@ class KGNodeService:
                     'category': r.category,
                     'props': prop_service.get_props_by_ref_id(r.id) if load_props else [],
                     #'edges':edge_service.get_edges_by_nodes(r.id, r.id,False) if load_props else [],
-                    'distance': r.distance
+                    'distance': round(r.distance, 3)
                 } for r in results],
                 'pagination': {
                     'total': total_count,
@@ -132,18 +133,21 @@ class KGNodeService:
         while True:
             try:
                 nodes = self.db.query(KGNode).filter(
-                    KGNode.version == 'ER',
+                    #KGNode.version == 'ER',
                     KGNode.embedding == None
                 ).offset(offset).limit(batch_size).all()
 
                 if not nodes:
                     break
 
+                updated_nodes = []
                 for node in nodes:
                     if not node.embedding:
                         embedding = Vectorizer.get_embedding(node.name)
                         node.embedding = embedding
-                        self.db.commit()
+                        updated_nodes.append(node)
+                if updated_nodes:
+                    self.db.commit()
 
                 offset += batch_size
             except Exception as e:

+ 8 - 4
service/trunks_service.py

@@ -72,7 +72,8 @@ class TrunksService:
                 Trunks.content,
                 Trunks.embedding.l2_distance(embedding).label('distance'),
                 Trunks.title,
-                Trunks.embedding
+                Trunks.embedding,
+                Trunks.referrence
             )
             if metadata_condition:
                 query = query.filter_by(**metadata_condition)
@@ -83,9 +84,11 @@ class TrunksService:
                 'id': r.id,
                 'file_path': r.file_path,
                 'content': r.content,
-                'distance': r.distance,
+                #保留小数点后三位   
+                'distance': round(r.distance, 3),
                 'title': r.title,
-                'embedding': r.embedding.tolist()
+                'embedding': r.embedding.tolist(),
+                'referrence': r.referrence
             } for r in results]
 
             if conversation_id:
@@ -210,7 +213,8 @@ class TrunksService:
                     'id': r.id,
                     'file_path': r.file_path,
                     'content': r.content,
-                    'distance': r.distance,
+                    #保留小数点后三位   
+                    'distance': round(r.distance, 3),
                     'title': r.title
                 } for r in results],
                 'pagination': {

+ 3 - 2
tests/service/test_kg_node_service.py

@@ -1,11 +1,12 @@
 import pytest
 from service.kg_node_service import KGNodeService
-from model.trunks_model import KGNode
+from model.kg_node import KGNode
 from sqlalchemy.exc import IntegrityError
 
 @pytest.fixture(scope="module")
 def kg_node_service():
-    return KGNodeService()
+    from db.session import get_db
+    return KGNodeService(next(get_db()))
 
 @pytest.fixture
 def test_node_data():

+ 4 - 2
utils/file_reader.py

@@ -12,8 +12,10 @@ class FileReader:
                     file_path = os.path.join(root, file)
                     relative_path = '\\report\\' + os.path.relpath(file_path, directory)
                     with open(file_path, 'r', encoding='utf-8') as f:
-                        content = f.read()
-                    TrunksService().create_trunk({'file_path': relative_path, 'content': content,'type':'trunk'})
+                        lines = f.readlines()
+                    meta_header = lines[0]
+                    content = ''.join(lines[1:])
+                    TrunksService().create_trunk({'file_path': relative_path, 'content': content,'type':'trunk','meta_header':meta_header})
 
 if __name__ == '__main__':
     directory = 'e:\\project\\knowledge\\utils\\files'

+ 6 - 3
utils/text_splitter.py

@@ -88,9 +88,12 @@ class TextSplitter:
                     if current_sentence.strip():
                         # 保留句子末尾的换行符
                         if char == '\n':
-                            current_sentence = current_sentence.rstrip('\n') + '\n'
-                        sentences.append(current_sentence)
-                    current_sentence = ""
+                            current_sentence = current_sentence.rstrip('\n')
+                            sentences.append(current_sentence)
+                            current_sentence = '\n'
+                        else:
+                            sentences.append(current_sentence)
+                            current_sentence = ""
                     
                     # 处理空格 - 保留空格在下一个句子的开头
                     if i + 1 < len(text) and text[i + 1].isspace() and text[i + 1] != '\n':