Browse Source

代码提交

SGTY 1 month ago
parent
commit
000fadb354

+ 4 - 2
agent/cdss/libs/cdss_helper2.py

@@ -737,8 +737,10 @@ class CDSSHelper(GraphHelper):
             # symtoms中matched=true的排在前面,matched=false的排在后面
             symptoms = sorted(symptoms, key=lambda x: x["matched"], reverse=True)
 
-
-            new_item = {"old_score": item[1]["score"],"count": item[1]["count"], "score": float(item[1]["count"])*0.1,"symptoms":symptoms}
+            start_nodes_size = len(start_nodes)
+            # if start_nodes_size > 1:
+            #     start_nodes_size = start_nodes_size*0.5
+            new_item = {"old_score": item[1]["score"],"count": item[1]["count"], "score": float(item[1]["count"])/start_nodes_size/2*0.1,"symptoms":symptoms}
             diags[disease] = new_item
         sorted_score_diags = sorted(diags.items(), key=lambda x: x[1]["score"], reverse=True)
 

+ 47 - 1
service/trunks_service.py

@@ -117,7 +117,7 @@ class TrunksService:
                 update_data['type'] = 'default'
             logger.debug(f"更新生成的embedding长度: {len(update_data['embedding'])}, 内容摘要: {content[:20]}")
             # update_data['content_tsvector'] = func.to_tsvector('chinese', content)
-        
+
         db = SessionLocal()
         try:
             trunk = db.query(Trunks).filter(Trunks.id == trunk_id).first()
@@ -175,6 +175,52 @@ class TrunksService:
         :return: 结果列表
         """
         return self.get_cache(conversation_id)
+
+    def paginated_search_by_type_and_filepath(self, search_params: dict) -> dict:
+        """
+        根据type和file_path进行分页查询
+        :param search_params: 包含pageNo, limit的字典
+        :return: 包含结果列表和分页信息的字典
+        """
+        page_no = search_params.get('pageNo', 1)
+        limit = search_params.get('limit', 10)
+        
+        if page_no < 1:
+            page_no = 1
+        if limit < 1:
+            limit = 10
+            
+        offset = (page_no - 1) * limit
+        
+        db = SessionLocal()
+        try:
+
+            # 执行查询
+            results = db.query(
+                Trunks.id,
+                Trunks.file_path,
+                Trunks.content,
+                Trunks.type,
+                Trunks.title
+            )\
+            .filter(Trunks.type == 'trunk')\
+            .filter(Trunks.file_path.like('%内科学 第10版%'))\
+            .filter(Trunks.page_no == None)\
+            .offset(offset)\
+            .limit(limit)\
+            .all()
+            
+            return {
+                'data': [{
+                    'id': r.id,
+                    'file_path': r.file_path,
+                    'content': r.content,
+                    'type': r.type,
+                    'title': r.title
+                } for r in results]
+            }
+        finally:
+            db.close()
         
         
 

+ 67 - 11
tests/service/test_trunks_service.py

@@ -34,18 +34,74 @@ class TestTrunksServiceCRUD:
         assert trunks_service.get_trunk_by_id(trunk.id) is None
 
 class TestSearchOperations:
-    def test_vector_search(self, trunks_service, test_trunk_data):
-        results = trunks_service.search_by_vector("急性胰腺炎是常见的急腹症之一,以突发上腹部剧痛伴恶心呕吐为特征。轻症预后良好,重症可并发多器官衰竭,死亡率高达20-30%。",10,conversation_id="1111111aaaa")
-        print("搜索结果:", results[0])
-        results = trunks_service.get_cache("1111111aaaa")
-        print("搜索结果:", results)
-        assert len(results) > 0
-
-    # def test_fulltext_search(self, trunks_service, test_trunk_data):
-    #     trunks_service.create_trunk(test_trunk_data)
-    #     results = trunks_service.fulltext_search("测试")
-    #     assert len(results) > 0
 
+
+    def test_vector_search(self, trunks_service):
+        page = 1
+        limit = 100
+        while True:
+            results = trunks_service.paginated_search_by_type_and_filepath({'pageNo': page, 'limit': limit, 'type': 'trunk', 'file_path': 'test_path.pdf'})
+            if not results['data']:
+                break
+            for record in results['data']:
+                print(f"{record['id']}{record['type']}{record['title']}{record['file_path']}")
+                if record['type'] != 'trunk' or '内科学 第10版' not in record['file_path']:
+                    print('出现异常数据')
+                    break
+
+                page_no = self.get_page_no(record['content'],trunks_service)
+                if page_no is None:
+                    print(f"{record['id']}找到page_no: {page_no}")
+                    continue
+                trunks_service.update_trunk(record['id'], {'page_no': page_no})
+            page += 1
+
+    def get_page_no(self, text: str, trunks_service) -> int:
+        results = trunks_service.search_by_vector(text,1000,type='page',conversation_id="1111111aaaa")
+        sentences = self.split_text(text)
+        count = 0
+        for r in results:
+            #将r["content"]的所有空白字符去掉
+            content = re.sub(r'[^\w\d\p{P}\p{L}]', '', r["content"])
+            count+=1
+            match_count = 0
+            length = len(sentences)/2
+            for sentence in sentences:
+                sentence = re.sub(r'[^\w\d\p{P}\p{L}]', '', sentence)
+                if sentence in content:
+                    match_count += 1
+                    if match_count >= length:
+                        return r["page_no"]
+           
+    def split_text(self, text):
+        """将文本分割成句子"""
+        print(text)
+        # 使用常见的标点符号作为分隔符
+        delimiters = ['!', '?', '。', '!', '?', '\n', ';', '。', ';']
+        sentences = [text]
+        for delimiter in delimiters:
+            new_sentences = []
+            for sentence in sentences:
+                parts = sentence.split(delimiter)
+                new_sentences.extend([part + delimiter if i < len(parts) - 1 else part for i, part in enumerate(parts)])
+            sentences = [s.strip() for s in new_sentences if s.strip()]
+        
+        # 合并短句子
+        merged_sentences = []
+        buffer = ""
+        for sentence in sentences:
+            buffer += " " + sentence if buffer else sentence
+            if len(buffer) >= 10:
+                merged_sentences.append(buffer)
+                buffer = ""
+        if buffer:
+            merged_sentences.append(buffer)
+        
+        # 打印最终句子
+        for i, sentence in enumerate(merged_sentences):
+            print(f"句子{i+1}: {sentence.replace(" ","").replace("\u2003", "").replace("\u2002", "").replace("\u2009", "").replace("\n", "").replace("\r", "")}")
+        
+        return merged_sentences
 class TestExceptionCases:
     def test_duplicate_id(self, trunks_service, test_trunk_data):
         with pytest.raises(IntegrityError):

+ 1 - 1
tests/test.py

@@ -6,7 +6,7 @@ capability = CDSSCapability()
 record = CDSSInput(
     pat_age=CDSSInt(type="month", value=24),
     pat_sex=CDSSText(type="sex", value="男"),
-    chief_complaint=["右下腹痛","恶心","呕吐"],
+    chief_complaint=["右下腹痛"],
     #chief_complaint=["呕血", "黑便", "头晕", "心悸"],
     #chief_complaint=["流鼻涕"],
 

+ 5 - 4
tests/trunck_files.py

@@ -4,9 +4,10 @@ import re
 def rename_doc_files(directory):
     for root, dirs, files in os.walk(directory):
         for file in files:
-            if file.endswith('.doc'):
+            if file.endswith('.pdf'):
                 old_path = os.path.join(root, file)
-                new_name = re.sub(r'^d+.', '', file)
+                #文件名前面的数字及.去掉
+                new_name = re.sub(r'^\d+\.', '', file).strip()
                 if new_name != file:
                     new_path = os.path.join(root, new_name)
                     os.rename(old_path, new_path)
@@ -36,6 +37,6 @@ def move_doc_to_ocr(directory):
                     print(f'Renamed: {old_path} -> {new_path}')
 
 if __name__ == '__main__':
-    directory = 'E:\\医学所有资料\\中华医学期刊数据库'
+    directory = 'E:\\急诊科资料\\pdf'
     #rename_doc_files(directory)
-    move_doc_to_ocr(directory)
+    rename_doc_files(directory)

+ 10 - 3
utils/file_reader2.py

@@ -1,4 +1,6 @@
 import os
+import urllib
+
 from service.trunks_service import TrunksService
 
 class FileReader2:
@@ -10,12 +12,17 @@ class FileReader2:
                 #if file.endswith('.md'):
                     try:
                         file_path = os.path.join(root, file)
-                        relative_path = '\\report\\' + os.path.relpath(file_path, directory)
+                        relative_path = os.path.relpath(file_path, directory)
+                        #relative_path的\trunk\前面的部分去除掉
+                        relative_path = relative_path.split('\\trunk\\')[1]
+                        relative_path='\\report\\trunk\\'+relative_path
                         with open(file_path, 'r', encoding='utf-8') as f:
                             lines = f.readlines()
                         meta_header = lines[0]
                         content = ''.join(lines[1:])
-                        TrunksService().create_trunk({'file_path': relative_path, 'content': content,'type':'trunk','meta_header':meta_header,'referrence':'http://173.18.12.205:8001/books/%E5%86%85%E7%A7%91%E5%AD%A6%20%E7%AC%AC10%E7%89%88.pdf'})
+                        filename = file.split('_split_')[0]
+                        newfilename = urllib.parse.quote(filename)
+                        TrunksService().create_trunk({'file_path': relative_path, 'content': content,'type':'trunk','meta_header':meta_header,'referrence':'http://173.18.12.205:8001/books/'+newfilename+'.pdf'})
                     except Exception as e:
                         print(f'Error processing file {file_path}: {str(e)}')
     @staticmethod
@@ -30,5 +37,5 @@ class FileReader2:
                     TrunksService().create_trunk({'file_path': file_path, 'content': content, 'type': 'trunk', 'title': title})
 
 if __name__ == '__main__':
-    directory = 'E:\\project\\vscode\\《内科学 第10版》'
+    directory = 'E:\急诊科资料\中华医学期刊数据库'
     FileReader2.find_and_print_split_files(directory)