Browse Source

代码提交

SGTY 2 months ago
parent
commit
9e4d9d082c

+ 20 - 5
agent/cdss/libs/cdss_helper2.py

@@ -286,7 +286,8 @@ class CDSSHelper(GraphHelper):
         symptom_same_edge = ['症状同义词', '症状同义词2.0']
         department_edge = ['belongs_to','所属科室']
         allowed_links = symptom_edge+department_edge+symptom_same_edge
-        #allowed_links = symptom_edge + department_edge
+        # allowed_links = symptom_edge + department_edge
+
         # 将输入的症状名称转换为节点ID
         # 由于可能存在同名节点,转换后的节点ID数量可能大于输入的症状数量
         node_ids = []
@@ -313,11 +314,25 @@ class CDSSHelper(GraphHelper):
                 logger.debug(f"node {node} not found")
         node_ids = node_ids_filtered
 
+        # out_edges = self.graph.out_edges(disease, data=True)
+        # for edge in out_edges:
+        #     src, dest, edge_data = edge
+        #     if edge_data["type"] not in department_edge:
+        #         continue
+        #     dest_data = self.entity_data[self.entity_data.index == dest]
+        #     if dest_data.empty:
+        #         continue
+        #     department_name = self.entity_data[self.entity_data.index == dest]['name'].tolist()[0]
+        #     department_data.extend([department_name] * results[disease]["count"])
+
+
         results = self.step1(node_ids,node_id_names, input, allowed_types, allowed_links,max_hops,DIESEASE)
 
         #self.validDisease(results, start_nodes)
         results = self.validDisease(results, start_nodes)
 
+        sorted_score_diags = sorted(results.items(), key=lambda x: x[1]["count"], reverse=True)[:10]
+
         # 调用step2方法处理科室、检查和药品信息
         results = self.step2(results,department_edge)
 
@@ -479,7 +494,7 @@ class CDSSHelper(GraphHelper):
                 # todo 目前是取入边,出边是不是也有用?
                 for edge in self.graph.in_edges(temp_node, data=True):
                     src, dest, edge_data = edge
-                    if src not in visited and depth + 1 < max_hops:
+                    if src not in visited and depth + 1 < max_hops and edge_data['type'] in allowed_links:
                         # print(f"put into queue travel from {src} to {dest}")
                         weight = edge_data['weight']
                         try :
@@ -495,6 +510,7 @@ class CDSSHelper(GraphHelper):
                                 weight = 10
                         except Exception as e:
                             print(f'Error processing file {weight}: {str(e)}')
+
                         queue.append((src, depth + 1, path, int(weight), data))
                     # else:
                     # print(f"skip travel from {src} to {dest}")
@@ -596,9 +612,8 @@ class CDSSHelper(GraphHelper):
             else:
                 edges = KGEdgeService(next(get_db())).get_edges_by_nodes(src_id=disease, category='所属科室')
                 #edges有可能为空,这里需要做一下处理
-                if len(edges) == 0:
-                    continue
-                departments = [edge['dest_node']['name'] for edge in edges]
+                if len(edges) > 0:
+                    departments = [edge['dest_node']['name'] for edge in edges]
             # 处理查询结果
             for department in departments:
                 total += 1

+ 16 - 0
router/knowledge_saas.py

@@ -18,6 +18,7 @@ logger = logging.getLogger(__name__)
 
 class PaginatedSearchRequest(BaseModel):
     keyword: Optional[str] = None
+    category: Optional[str] = None
     pageNo: int = 1
     limit: int = 10
     knowledge_ids: Optional[List[str]] = None
@@ -58,6 +59,7 @@ async def paginated_search(
         service = KGNodeService(db)
         search_params = {
             'keyword': payload.keyword,
+            'category': payload.category,
             'pageNo': payload.pageNo,
             'limit': payload.limit,
             'knowledge_ids': payload.knowledge_ids,
@@ -187,4 +189,18 @@ async def get_trunk(
         logger.error(f"获取trunk详情失败: {str(e)}")
         raise HTTPException(500, detail=StandardResponse.error(str(e)))
 
+@router.post('/trunks/{trunk_id}/highlight', response_model=StandardResponse)
+async def highlight(
+    trunk_id: int,
+    targetSentences: List[str],
+    db: Session = Depends(get_db)
+):
+    try:
+        service = TrunksService()
+        result = service.highlight(trunk_id, targetSentences)
+        return StandardResponse(success=True, data=result)
+    except Exception as e:
+        logger.error(f"获取trunk高亮信息失败: {str(e)}")
+        raise HTTPException(500, detail=StandardResponse.error(str(e)))
+
 saas_kb_router = router

+ 7 - 6
router/text_search.py

@@ -2,7 +2,8 @@ from fastapi import APIRouter, HTTPException, Depends
 from pydantic import BaseModel, Field, validator
 from typing import List, Optional
 from service.trunks_service import TrunksService
-from utils.text_splitter import TextSplitter
+from utils.sentence_util import SentenceUtil
+
 from utils.vector_distance import VectorDistance
 from model.response import StandardResponse
 from utils.vectorizer import Vectorizer
@@ -86,7 +87,7 @@ async def search_text(request: TextSearchRequest):
             request.text = converter.convert(request.text)
 
         # 使用TextSplitter拆分文本
-        sentences = TextSplitter.split_text(request.text)
+        sentences = SentenceUtil.split_text(request.text)
         if not sentences:
             return StandardResponse(success=True, data={"answer": "", "references": []})
         
@@ -193,7 +194,7 @@ async def search_text(request: TextSearchRequest):
 @router.post("/match", response_model=StandardResponse)
 async def match_text(request: TextCompareRequest):
     try:
-        sentences = TextSplitter.split_text(request.text)
+        sentences = SentenceUtil.split_text(request.text)
         sentence_vector = Vectorizer.get_embedding(request.sentence)
         min_distance = float('inf')
         best_sentence = ""
@@ -263,8 +264,8 @@ async def compare_text(request: TextCompareMultiRequest):
     start_time = time.time()
     try:
         # 拆分两段文本
-        origin_sentences = TextSplitter.split_text(request.origin)
-        similar_sentences = TextSplitter.split_text(request.similar)
+        origin_sentences = SentenceUtil.split_text(request.origin)
+        similar_sentences = SentenceUtil.split_text(request.similar)
         end_time = time.time()
         logger.info(f"mr_match接口处理文本耗时: {(end_time - start_time) * 1000:.2f}ms")
         
@@ -501,7 +502,7 @@ async def node_props_search(request: NodePropsSearchRequest, db: Session = Depen
                 }]
             else:
                 # 如果整体搜索没有找到匹配结果,则进行句子拆分搜索
-                sentences = TextSplitter.split_text(prop_value)
+                sentences = SentenceUtil.split_text(prop_value)
                 result_sentences, references = _process_sentence_search(
                     node_name, prop_title, sentences, trunks_service
                 )

+ 1 - 1
service/kg_node_service.py

@@ -59,7 +59,7 @@ class KGNodeService:
         prop_service = KGPropService(self.db)
         edge_service = KGEdgeService(self.db)
         keyword = search_params.get('keyword', '')
-        category = search_params.get('category', '')
+        category = search_params.get('category', None)
         page_no = search_params.get('pageNo', 1)
         distance = search_params.get('distance',DISTANCE_THRESHOLD)
         limit = search_params.get('limit', 10)

+ 38 - 10
service/trunks_service.py

@@ -5,6 +5,8 @@ from typing import List, Optional
 from model.trunks_model import Trunks
 from db.session import SessionLocal
 import logging
+
+from utils.sentence_util import SentenceUtil
 from utils.vectorizer import Vectorizer
 
 logger = logging.getLogger(__name__)
@@ -61,7 +63,7 @@ class TrunksService:
         finally:
             db.close()
 
-    def search_by_vector(self, text: str, limit: int = 10, metadata_condition: Optional[dict] = None, type: Optional[str] = None, conversation_id: Optional[str] = None) -> List[dict]:
+    def search_by_vector(self, text: str, limit: int = 10, file_path: Optional[str]=None,metadata_condition: Optional[dict] = None, type: Optional[str] = None, conversation_id: Optional[str] = None) -> List[dict]:
        
         embedding = Vectorizer.get_embedding(text)
         db = SessionLocal()
@@ -80,6 +82,8 @@ class TrunksService:
                 query = query.filter_by(**metadata_condition)
             if type:
                 query = query.filter(Trunks.type == type)
+            if type:
+                query = query.filter(Trunks.file_path.like('%'+file_path+'%'))
             results = query.order_by('distance').limit(limit).all()
             result_list = [{
                 'id': r.id,
@@ -150,6 +154,28 @@ class TrunksService:
         finally:
             db.close()
 
+    def highlight(self, trunk_id: int, targetSentences: List[str]) -> List[int]:
+        trunk = self.get_trunk_by_id(trunk_id)
+        if not trunk:
+            return []
+
+        content = trunk['content']
+        sentence_util = SentenceUtil()
+        contentSentences = sentence_util.split_text(content)
+
+        result = []
+        for targetSentence in targetSentences:
+            cleanedTarget = sentence_util.clean_text(targetSentence)
+            matched = False
+            for contentSentence in contentSentences:
+                if sentence_util.clean_text(contentSentence) == cleanedTarget:
+                    matched = True
+                    break
+            if matched:
+                result.append(index)
+
+        return result
+
     _cache = {}
 
     def get_cache(self, conversation_id: str) -> List[dict]:
@@ -184,6 +210,8 @@ class TrunksService:
         """
         page_no = search_params.get('pageNo', 1)
         limit = search_params.get('limit', 10)
+        file_path = search_params.get('file_path', None)
+        type = search_params.get('type', None)
         
         if page_no < 1:
             page_no = 1
@@ -195,20 +223,20 @@ class TrunksService:
         db = SessionLocal()
         try:
 
-            # 执行查询
-            results = db.query(
+            query = db.query(
                 Trunks.id,
                 Trunks.file_path,
                 Trunks.content,
                 Trunks.type,
                 Trunks.title
-            )\
-            .filter(Trunks.type == 'trunk')\
-            .filter(Trunks.file_path.like('%内科学 第10版%'))\
-            .filter(Trunks.page_no == None)\
-            .offset(offset)\
-            .limit(limit)\
-            .all()
+            )
+            if type:
+                query = query.filter(Trunks.type == type)
+            if file_path:
+                query = query.filter(Trunks.file_path.like('%' + file_path + '%'))
+
+            query = query.filter(Trunks.page_no == None)
+            results = query.offset(offset).limit(limit).all()
             
             return {
                 'data': [{

+ 1 - 1
tests/pdf_to_txt_mupdf.py

@@ -179,7 +179,7 @@ def process_pdf_files(input_path, log_dir='logs'):
 
 
 if __name__ == "__main__":
-    directory = 'E:\急诊科资料\pdf'
+    directory = 'C:\\Users\\17664\\Desktop\\test'
 
     # 设置日志记录器
     setup_logging(directory)

+ 32 - 8
tests/service/test_trunks_service.py

@@ -1,3 +1,4 @@
+import regex
 from pathlib import Path
 
 import pytest
@@ -34,44 +35,67 @@ class TestTrunksServiceCRUD:
         assert trunks_service.get_trunk_by_id(trunk.id) is None
 
 class TestSearchOperations:
+    def test_vector_search2(self, trunks_service):
+        page = 1
+        limit = 100
+        file_path = '急诊医学(第2版'
+        while True:
+            results = trunks_service.paginated_search_by_type_and_filepath(
+                {'pageNo': page, 'limit': limit, 'type': 'trunk', 'file_path': file_path})
+            if not results['data']:
+                break
+            for record in results['data']:
+                print(f"{record['id']}{record['type']}{record['title']}{record['file_path']}")
+                if record['type'] != 'trunk' or file_path not in record['file_path']:
+                    print('出现异常数据')
+                    break
 
+                page_no = self.get_page_no(record['content'], trunks_service, file_path)
+                if page_no is None:
+                    print(f"{record['id']}找到page_no: {page_no}")
+                    continue
+                trunks_service.update_trunk(record['id'], {'page_no': page_no})
+            page += 1
 
     def test_vector_search(self, trunks_service):
         page = 1
         limit = 100
+        file_path='trunk2'
         while True:
-            results = trunks_service.paginated_search_by_type_and_filepath({'pageNo': page, 'limit': limit, 'type': 'trunk', 'file_path': 'test_path.pdf'})
+            results = trunks_service.paginated_search_by_type_and_filepath({'pageNo': page, 'limit': limit, 'type': 'trunk', 'file_path': file_path})
             if not results['data']:
                 break
             for record in results['data']:
                 print(f"{record['id']}{record['type']}{record['title']}{record['file_path']}")
-                if record['type'] != 'trunk' or '内科学 第10版' not in record['file_path']:
+                if record['type'] != 'trunk' or file_path not in record['file_path']:
                     print('出现异常数据')
                     break
 
-                page_no = self.get_page_no(record['content'],trunks_service)
+                page_no = self.get_page_no(record['content'],trunks_service,file_path)
                 if page_no is None:
                     print(f"{record['id']}找到page_no: {page_no}")
                     continue
                 trunks_service.update_trunk(record['id'], {'page_no': page_no})
             page += 1
 
-    def get_page_no(self, text: str, trunks_service) -> int:
-        results = trunks_service.search_by_vector(text,1000,type='page',conversation_id="1111111aaaa")
+    def get_page_no(self, text: str, trunks_service,file_path:str) -> int:
+        results = trunks_service.search_by_vector(text,1000,type='page',file_path=file_path,conversation_id="1111111aaaa")
         sentences = self.split_text(text)
         count = 0
         for r in results:
             #将r["content"]的所有空白字符去掉
-            content = re.sub(r'[^\w\d\p{P}\p{L}]', '', r["content"])
+            content = regex.sub(r'[^\w\d\p{L}]', '', r["content"])
             count+=1
             match_count = 0
             length = len(sentences)/2
             for sentence in sentences:
-                sentence = re.sub(r'[^\w\d\p{P}\p{L}]', '', sentence)
+                sentence = regex.sub(r'[^\w\d\p{L}]', '', sentence)
                 if sentence in content:
                     match_count += 1
-                    if match_count >= length:
+                    if match_count >= 2:
                         return r["page_no"]
+
+
            
     def split_text(self, text):
         """将文本分割成句子"""

+ 3 - 2
tests/test.py

@@ -6,9 +6,10 @@ capability = CDSSCapability()
 record = CDSSInput(
     pat_age=CDSSInt(type="month", value=24),
     pat_sex=CDSSText(type="sex", value="男"),
-    chief_complaint=["右下腹痛"],
+    #chief_complaint=["右下腹痛","恶心","呕吐"],
+    #chief_complaint=["胸痛","左肩放射痛","下颌放射痛","呼吸困难","冷汗","恶心"],
     #chief_complaint=["呕血", "黑便", "头晕", "心悸"],
-    #chief_complaint=["流鼻涕"],
+    chief_complaint=["大汗"],
 
     department=CDSSText(type='department',value="急诊医学科")
     )

+ 3 - 3
tests/trunck_files.py

@@ -4,7 +4,7 @@ import re
 def rename_doc_files(directory):
     for root, dirs, files in os.walk(directory):
         for file in files:
-            if file.endswith('.pdf'):
+            if file.endswith('.doc'):
                 old_path = os.path.join(root, file)
                 #文件名前面的数字及.去掉
                 new_name = re.sub(r'^\d+\.', '', file).strip()
@@ -37,6 +37,6 @@ def move_doc_to_ocr(directory):
                     print(f'Renamed: {old_path} -> {new_path}')
 
 if __name__ == '__main__':
-    directory = 'E:\\急诊科资料\\pdf'
+    directory = 'E:\急诊科资料\中华医学期刊数据库'
     #rename_doc_files(directory)
-    rename_doc_files(directory)
+    move_doc_to_ocr(directory)

+ 5 - 2
utils/file_reader2.py

@@ -15,12 +15,15 @@ class FileReader2:
                         relative_path = os.path.relpath(file_path, directory)
                         #relative_path的\trunk\前面的部分去除掉
                         relative_path = relative_path.split('\\trunk\\')[1]
-                        relative_path='\\report\\trunk\\'+relative_path
+                        relative_path='\\report\\trunk2\\'+relative_path
                         with open(file_path, 'r', encoding='utf-8') as f:
                             lines = f.readlines()
                         meta_header = lines[0]
                         content = ''.join(lines[1:])
-                        filename = file.split('_split_')[0]
+                  
+                        filename = os.path.dirname(file_path).split('\\trunk\\')[0]
+                        #filename取到最后一个\后面的部分
+                        filename = filename.split('\\')[-1]
                         newfilename = urllib.parse.quote(filename)
                         TrunksService().create_trunk({'file_path': relative_path, 'content': content,'type':'trunk','meta_header':meta_header,'referrence':'http://173.18.12.205:8001/books/'+newfilename+'.pdf'})
                     except Exception as e:

+ 35 - 33
utils/text_splitter.py

@@ -6,7 +6,7 @@ import sys
 
 logger = logging.getLogger(__name__)
 
-class TextSplitter:
+class SentenceUtil:
     """中文文本句子拆分工具类
     
     用于将中文文本按照标点符号拆分成句子列表
@@ -28,7 +28,7 @@ class TextSplitter:
         Returns:
             拆分后的句子列表
         """
-        return TextSplitter()._split(text)
+        return SentenceUtil()._split(text)
     
     def _split(self, text: str) -> List[str]:
         """内部拆分方法
@@ -128,6 +128,20 @@ class TextSplitter:
             # 如果不是特定测试用例,返回原文本作为一个句子
             return [text]
     
+    @staticmethod
+    def clean_text(text: str) -> str:
+        """去除除中英文和数字以外的所有字符
+        
+        Args:
+            text: 输入的文本字符串
+            
+        Returns:
+            处理后的字符串
+        """
+        if not text:
+            return text
+        return re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9]', '', text)
+
     def split_by_regex(self, text: str) -> List[str]:
         """使用正则表达式拆分文本
         
@@ -162,36 +176,24 @@ class TextSplitter:
             logger.error(f"使用正则表达式拆分文本时发生错误: {str(e)}")
             return [text] if text else []
 
-def main():
-    parser = argparse.ArgumentParser(description='文本句子拆分工具')
-    group = parser.add_mutually_exclusive_group(required=True)
-    group.add_argument('-t', '--text', help='直接输入要拆分的文本')
-    group.add_argument('-f', '--file', help='输入文本文件的路径')
-    
-    args = parser.parse_args()
-    
-    try:
-        # 获取输入文本
-        if args.text:
-            input_text = args.text
-        else:
-            with open(args.file, 'r', encoding='utf-8') as f:
-                input_text = f.read()
-        
-        # 执行文本拆分
-        sentences = TextSplitter.split_text(input_text)
-        
-        # 输出结果
-        print('\n拆分结果:')
-        for i, sentence in enumerate(sentences, 1):
-            print(f'{i}. {sentence}')
-            
-    except FileNotFoundError:
-        print(f'错误:找不到文件 {args.file}')
-        sys.exit(1)
-    except Exception as e:
-        print(f'错误:{str(e)}')
-        sys.exit(1)
 
 if __name__ == '__main__':
-    main()
+    # Test cases for clean_text method
+    test_cases = [
+        ('Hello! 你好?', 'Hello你好'),
+        ('123abc!@#', '123abc'),
+        ('测试-中文+标点', '测试中文'),
+        ('', ''),
+        ('   ', ''),
+        ('Special!@#$%^&*()_+', 'Special'),
+        ('中文123English', '中文123English')
+    ]
+
+    print('Running clean_text tests...')
+    for input_text, expected in test_cases:
+        result = SentenceUtil.clean_text(input_text)
+        if result == expected:
+            print(f'Test passed: {input_text} -> {result}')
+        else:
+            print(f'Test failed: {input_text} -> {result} (expected: {expected})')
+