yuchengwei 2 месяцев назад
Родитель
Сommit
aff8d92e77
2 измененных файлов с 129 добавлено и 2 удалено
  1. 102 1
      router/text_search.py
  2. 27 1
      service/kg_node_service.py

+ 102 - 1
router/text_search.py

@@ -16,6 +16,7 @@ from db.session import get_db
 from sqlalchemy.orm import Session
 from service.kg_node_service import KGNodeService
 from service.kg_prop_service import KGPropService
+from service.kg_edge_service import KGEdgeService
 
 from cachetools import TTLCache
 
@@ -69,6 +70,7 @@ class TextCompareMultiRequest(BaseModel):
 class NodePropsSearchRequest(BaseModel):
     node_id: int
     props_ids: List[int]
+    symptoms: Optional[List[str]] = None
 
 @router.post("/kgrt_api/text/clear_cache", response_model=StandardResponse)
 async def clear_cache():
@@ -508,11 +510,35 @@ async def node_props_search(request: NodePropsSearchRequest, db: Session = Depen
         trunks_service = TrunksService()
         node_service = KGNodeService(db)
         prop_service = KGPropService(db)
+        edge_service = KGEdgeService(db)
 
         # 获取节点信息
         result = _get_node_info(node_service, request.node_id)
         node_name = result["name"]
 
+        # 处理症状列表
+        symptom_list = []
+        if request.symptoms:
+            for symptom in request.symptoms:
+                try:
+                    # 添加原始症状
+                    symptom_list.append(symptom)
+                    # 获取症状节点
+                    symptom_node = node_service.get_node_by_name_category(symptom, '症状')
+                    # 获取症状相关同义词
+                    edges = edge_service.get_edges_by_nodes(src_id=symptom_node['id'], category='症状同义词')
+                    if edges:
+                        # 添加同义词
+                        for edge in edges:
+                            if edge['dest_node'] and edge['dest_node'].get('name'):
+                                symptom_list.append(edge['dest_node']['name'])
+                except ValueError:
+                    # 如果找不到节点,只添加原始症状
+                    symptom_list.append(symptom)
+            
+            # 按照字符长度进行倒序排序
+            symptom_list.sort(key=len, reverse=True)
+
         # 遍历props_ids查询属性信息
         for prop_id in request.props_ids:
             prop = prop_service.get_props_by_id(prop_id)
@@ -593,8 +619,44 @@ async def node_props_search(request: NodePropsSearchRequest, db: Session = Depen
                     
                     reference, _ = _process_search_result(search_result, 1)
                     prop_result["references"] = [reference]
+                    # 处理症状标记
+                    marked_sentence = prop_value
+                    if prop_title == '临床表现' and symptom_list:
+                        # 创建一个标记位置的列表,记录每个位置是否已被标记
+                        marked_positions = [False] * len(marked_sentence)
+                        
+                        # 创建一个列表来存储已处理的症状
+                        processed_symptoms = []
+                        
+                        for symptom in symptom_list:
+                            # 检查是否已处理过该症状或其子集
+                            if any(symptom in processed_sym or processed_sym in symptom for processed_sym in processed_symptoms):
+                                continue
+                                
+                            # 查找所有匹配位置
+                            start_pos = 0
+                            while True:
+                                pos = marked_sentence.find(symptom, start_pos)
+                                if pos == -1:
+                                    break
+                                    
+                                # 检查这个位置是否已被标记
+                                if not any(marked_positions[pos:pos + len(symptom)]):
+                                    # 标记这个范围的所有位置
+                                    for i in range(pos, pos + len(symptom)):
+                                        marked_positions[i] = True
+                                    # 替换文本
+                                    marked_sentence = marked_sentence[:pos] + f'<i style="color:red;">{symptom}</i>' + marked_sentence[pos + len(symptom):]
+                                    # 将成功标记的症状添加到已处理列表中
+                                    if symptom not in processed_symptoms:
+                                        processed_symptoms.append(symptom)
+                                    # 更新标记位置数组以适应新插入的标签
+                                    new_positions = [False] * (len('<i style="color:red;">') + len('</i>'))
+                                    marked_positions = marked_positions[:pos] + new_positions + marked_positions[pos:]
+                                
+                                start_pos = pos + len('<i style="color:red;">') + len(symptom) + len('</i>')
                     prop_result["answer"] = [{
-                        "sentence": prop_value,
+                        "sentence": marked_sentence,
                         "flag": "1"
                     }]
                 else:
@@ -606,6 +668,45 @@ async def node_props_search(request: NodePropsSearchRequest, db: Session = Depen
                 result_sentences, references = _process_sentence_search(
                     node_name, prop_title, sentences, trunks_service
                 )
+                # 处理症状标记
+                if prop_title == '临床表现' and symptom_list and result_sentences:
+                    for sentence in result_sentences:
+                        marked_sentence = sentence["sentence"]
+                        # 创建一个标记位置的列表,记录每个位置是否已被标记
+                        marked_positions = [False] * len(marked_sentence)
+
+                        # 创建一个列表来存储已处理的症状
+                        processed_symptoms = []
+                        
+                        for symptom in symptom_list:
+                            # 检查是否已处理过该症状或其子集
+                            if any(symptom in processed_sym or processed_sym in symptom for processed_sym in processed_symptoms):
+                                continue
+                                
+                            # 查找所有匹配位置
+                            start_pos = 0
+                            while True:
+                                pos = marked_sentence.find(symptom, start_pos)
+                                if pos == -1:
+                                    break
+
+                                # 检查这个位置是否已被标记
+                                if not any(marked_positions[pos:pos + len(symptom)]):
+                                    # 标记这个范围的所有位置
+                                    for i in range(pos, pos + len(symptom)):
+                                        marked_positions[i] = True
+                                    # 替换文本
+                                    marked_sentence = marked_sentence[:pos] + f'<i style="color:red;">{symptom}</i>' + marked_sentence[pos + len(symptom):]
+                                    # 将成功标记的症状添加到已处理列表中
+                                    if symptom not in processed_symptoms:
+                                        processed_symptoms.append(symptom)
+                                    # 更新标记位置数组以适应新插入的标签
+                                    new_positions = [False] * (len('<i style="color:red;">') + len('</i>'))
+                                    marked_positions = marked_positions[:pos] + new_positions + marked_positions[pos:]
+
+                                start_pos = pos + len('<i style="color:red;">') + len(symptom) + len('</i>')
+
+                        sentence["sentence"] = marked_sentence
                 if references:
                     prop_result["references"] = references
                 if result_sentences:

+ 27 - 1
service/kg_node_service.py

@@ -229,4 +229,30 @@ class KGNodeService:
             except Exception as e:
                 self.db.rollback()
                 print(f"批量处理ER节点失败: {str(e)}")
-                raise ValueError("Batch process failed")
+                raise ValueError("Batch process failed")
+
+    def get_node_by_name_category(self, name: str, category: str):
+        if not name or not category:
+            raise ValueError("Name and category are required")
+        
+        cache_key = f"get_node_by_name_category_{name}:{category}"
+        if cache_key in self._cache:
+            return self._cache[cache_key]
+        
+        node = self.db.query(KGNode).filter(
+            KGNode.name == name,
+            KGNode.category == category,
+            KGNode.status == 0
+        ).first()
+        
+        if not node:
+            raise ValueError("Node not found")
+        
+        node_data = {
+            'id': node.id,
+            'name': node.name,
+            'category': node.category,
+            'version': node.version
+        }
+        self._cache[cache_key] = node_data
+        return node_data