浏览代码

智能体查询接口1

yuchengwei 4 天之前
父节点
当前提交
ac54246d25

+ 3 - 0
src/knowledge/.env

@@ -6,3 +6,6 @@ DB_PASSWORD=qwer1234.
 LICENSE_PATH=E:\project\knowledge2\src\knowledge\utils\license_issued
 LICENSE_PATH=E:\project\knowledge2\src\knowledge\utils\license_issued
 EMBEDDING_MODEL=E:\project\bge-m3
 EMBEDDING_MODEL=E:\project\bge-m3
 BOOKS=E:\project\books
 BOOKS=E:\project\books
+ELASTICSEARCH_HOST=http://173.18.12.203:9200
+ELASTICSEARCH_USER=elastic
+ELASTICSEARCH_PWD=cRiX2SRrSMPkMe0gQJUn

+ 3 - 0
src/knowledge/config/site.py

@@ -19,6 +19,9 @@ class SiteConfig:
             'DB_USER': os.getenv("DB_USER", ""),
             'DB_USER': os.getenv("DB_USER", ""),
             'DB_PASSWORD': os.getenv("DB_PASSWORD", ""),
             'DB_PASSWORD': os.getenv("DB_PASSWORD", ""),
             'BOOKS': os.getenv("BOOKS", ""),
             'BOOKS': os.getenv("BOOKS", ""),
+            'ELASTICSEARCH_HOST': os.getenv("ELASTICSEARCH_HOST"),
+            'ELASTICSEARCH_USER': os.getenv("ELASTICSEARCH_USER"),
+            'ELASTICSEARCH_PWD': os.getenv("ELASTICSEARCH_PWD", quote("qwer1234.")),
         }
         }
     def get_config(self, config_name, default=None): 
     def get_config(self, config_name, default=None): 
         config_name = config_name.upper()     
         config_name = config_name.upper()     

+ 7 - 6
src/knowledge/router/medical_knowledge_api.py

@@ -1,6 +1,6 @@
 from fastapi import APIRouter, Depends, HTTPException
 from fastapi import APIRouter, Depends, HTTPException
 from pydantic import BaseModel
 from pydantic import BaseModel
-from typing import List, Optional
+from typing import List, Optional, Dict, Any
 from ..model.response import StandardResponse
 from ..model.response import StandardResponse
 from ..db.session import get_db
 from ..db.session import get_db
 from sqlalchemy.orm import Session
 from sqlalchemy.orm import Session
@@ -37,8 +37,7 @@ async def get_symptom_diseases(
 
 
 @router.post("/disease_symptoms", response_model=StandardResponse)
 @router.post("/disease_symptoms", response_model=StandardResponse)
 async def get_disease_symptoms(
 async def get_disease_symptoms(
-    request: DiseaseSymptomsRequest,
-    db: Session = Depends(get_db)
+    request: DiseaseSymptomsRequest
 ):
 ):
     try:
     try:
         # 实现获取疾病症状的逻辑
         # 实现获取疾病症状的逻辑
@@ -118,7 +117,7 @@ async def search_concept(
     try:
     try:
         # 实现搜索医学概念的逻辑
         # 实现搜索医学概念的逻辑
         search = SearchBusiness()
         search = SearchBusiness()
-        results = search.search_nodes(name=request.query, type=request.type)
+        results = search.search_nodes(name=request.query, type=request.type,id=None)
         return StandardResponse(success=True, data=results)
         return StandardResponse(success=True, data=results)
     except Exception as e:
     except Exception as e:
         logger.error(f"搜索医学概念失败: {str(e)}")
         logger.error(f"搜索医学概念失败: {str(e)}")
@@ -144,7 +143,7 @@ async def get_similar_concepts(
     try:
     try:
         # 实现获取相似概念的逻辑
         # 实现获取相似概念的逻辑
         search = SearchBusiness()
         search = SearchBusiness()
-        results = search.search_nodes(id=request.concept_id, limit=request.top_k)
+        results = search.search_nodes(name=None,type=None,id=request.concept_id, limit=request.top_k)
         return StandardResponse(success=True, data=results)
         return StandardResponse(success=True, data=results)
     except Exception as e:
     except Exception as e:
         logger.error(f"获取相似概念失败: {str(e)}")
         logger.error(f"获取相似概念失败: {str(e)}")
@@ -245,4 +244,6 @@ async def validate_medical_record(
         return StandardResponse(success=True, data=[])
         return StandardResponse(success=True, data=[])
     except Exception as e:
     except Exception as e:
         logger.error(f"病历验证失败: {str(e)}")
         logger.error(f"病历验证失败: {str(e)}")
-        raise HTTPException(500, detail=StandardResponse.error(str(e)))
+        raise HTTPException(500, detail=StandardResponse.error(str(e)))
+
+medical_knowledge_router=router

+ 36 - 0
src/knowledge/server.py

@@ -15,9 +15,16 @@ from .router.graph_api import graph_router
 from .router.knowledge_nodes_api import knowledge_nodes_api_router
 from .router.knowledge_nodes_api import knowledge_nodes_api_router
 from .router.knowledge_saas import saas_kb_router
 from .router.knowledge_saas import saas_kb_router
 from .router.text_search import text_search_router
 from .router.text_search import text_search_router
+from .router.medical_knowledge_api import medical_knowledge_router
 
 
 from .utils import log_util
 from .utils import log_util
 
 
+from fastapi.openapi.docs import (
+    get_redoc_html,
+    get_swagger_ui_html,
+    get_swagger_ui_oauth2_redirect_html,
+)
+
 
 
 
 
 @asynccontextmanager
 @asynccontextmanager
@@ -31,6 +38,8 @@ app = FastAPI(
     description="知识图谱开放平台",
     description="知识图谱开放平台",
     lifespan=lifespan,
     lifespan=lifespan,
     middleware=register_middlewares(),  # 注册web中间件
     middleware=register_middlewares(),  # 注册web中间件
+    docs_url=None,
+    redoc_url=None
 )
 )
 @app.get("/health")
 @app.get("/health")
 async def health_check():
 async def health_check():
@@ -58,6 +67,7 @@ async def startup():
     app.include_router(text_search_router)
     app.include_router(text_search_router)
     app.include_router(graph_router)
     app.include_router(graph_router)
     app.include_router(saas_kb_router)
     app.include_router(saas_kb_router)
+    app.include_router(medical_knowledge_router)
 
 
     # 挂载静态文件目录,将/books路径映射到本地books文件夹
     # 挂载静态文件目录,将/books路径映射到本地books文件夹
 
 
@@ -66,6 +76,32 @@ async def startup():
 
 
     app.mount("/books", StaticFiles(directory=books_path), name="books")
     app.mount("/books", StaticFiles(directory=books_path), name="books")
 
 
+    app.mount("/static", StaticFiles(directory=Path("/Users/ycw/PycharmProjects/knowledge/src/knowledge/static")), name="static")
+
+    # 定义一个异步函数,用于返回自定义的Swagger UI HTML页面
+    @app.get("/docs", include_in_schema=False)
+    async def custom_swagger_ui_html():
+        # 调用get_swagger_ui_html函数,传入openapi_url、title、oauth2_redirect_url、swagger_js_url和swagger_css_url参数
+        return get_swagger_ui_html(
+            openapi_url=app.openapi_url,
+            title=app.title + " - Swagger UI",
+            oauth2_redirect_url=app.swagger_ui_oauth2_redirect_url,
+            swagger_js_url="/static/swagger-ui-bundle.js",
+            swagger_css_url="/static/swagger-ui.css",
+        )
+
+    @app.get(app.swagger_ui_oauth2_redirect_url, include_in_schema=False)
+    async def swagger_ui_redirect():
+        return get_swagger_ui_oauth2_redirect_html()
+
+    @app.get("/redoc", include_in_schema=False)
+    async def redoc_html():
+        return get_redoc_html(
+            openapi_url=app.openapi_url,
+            title=app.title + " - ReDoc",
+            redoc_js_url="/static/redoc.standalone.js",
+        )
+
     # 需要拦截的 URL 列表(支持通配符)
     # 需要拦截的 URL 列表(支持通配符)
     INTERCEPT_URLS = {
     INTERCEPT_URLS = {
         "/v1/knowledge/*"
         "/v1/knowledge/*"

+ 45 - 14
src/knowledge/service/search_service.py

@@ -51,13 +51,13 @@ class SearchBusiness:
                 query["from"] = from_
                 query["from"] = from_
 
 
             if name:
             if name:
-                query["query"]["bool"]["must"].append({"match": {"public_kg_nodes_category": name}})
+                query["query"]["bool"]["must"].append({"match": {"public_kg_nodes_name": name}})
             if id:
             if id:
-                query["query"]["bool"]["must"].append({"match": {"public_kg_nodes_id": id}})
+                query["query"]["bool"]["must"].append({"term": {"public_kg_nodes_id": id}})
             if type:
             if type:
                 query["query"]["bool"]["must"].append({"match": {"public_kg_nodes_category": type}})
                 query["query"]["bool"]["must"].append({"match": {"public_kg_nodes_category": type}})
 
 
-            results = self.search_nodes_and_edges(index="connector-postgres-test", query=query)
+            results = self.search_nodes_and_edges(index="connector-postgresql-all", query=query)
             return results
             return results
         except exceptions.NotFoundError as e:
         except exceptions.NotFoundError as e:
             logger.error(f"Index not found: {e}")
             logger.error(f"Index not found: {e}")
@@ -81,9 +81,9 @@ class SearchBusiness:
             if name:
             if name:
                 query["query"]["bool"]["must"].append({"match": {"public_kg_edges_category": name}})
                 query["query"]["bool"]["must"].append({"match": {"public_kg_edges_category": name}})
             if src_id:
             if src_id:
-                query["query"]["bool"]["must"].append({"match": {"public_kg_edges_src_id": src_id}})
+                query["query"]["bool"]["must"].append({"term": {"public_kg_edges_src_id": src_id}})
 
 
-            results = self.search_nodes_and_edges(index="connector-postgres-test", query=query)
+            results = self.search_nodes_and_edges(index="connector-postgresql-all", query=query)
             return results
             return results
         except exceptions.NotFoundError as e:
         except exceptions.NotFoundError as e:
             logger.error(f"Index not found: {e}")
             logger.error(f"Index not found: {e}")
@@ -91,7 +91,7 @@ class SearchBusiness:
         except Exception as e:
         except Exception as e:
             logger.error(f"Search error: {e}")
             logger.error(f"Search error: {e}")
             return None
             return None
-            
+
     def get_symptom_diseases(self, symptom_names):
     def get_symptom_diseases(self, symptom_names):
         """
         """
         根据症状名称列表获取相关疾病列表
         根据症状名称列表获取相关疾病列表
@@ -99,25 +99,25 @@ class SearchBusiness:
         :return: 疾病节点列表(按命中症状次数降序排列)
         :return: 疾病节点列表(按命中症状次数降序排列)
         """
         """
         disease_dict = {}
         disease_dict = {}
-        
+
         # 对症状名称去重
         # 对症状名称去重
         unique_symptoms = list(set(symptom_names))
         unique_symptoms = list(set(symptom_names))
         threshold = len(unique_symptoms) / 2
         threshold = len(unique_symptoms) / 2
-        
+
         # 存储已处理的疾病ID,避免同义词重复计算
         # 存储已处理的疾病ID,避免同义词重复计算
         processed_disease_ids = set()
         processed_disease_ids = set()
-        
+
         for symptom_name in unique_symptoms:
         for symptom_name in unique_symptoms:
             # 获取症状节点ID
             # 获取症状节点ID
             symptom_nodes = self.search_nodes(name=symptom_name, type="症状")
             symptom_nodes = self.search_nodes(name=symptom_name, type="症状")
             if not symptom_nodes:
             if not symptom_nodes:
                 continue
                 continue
-                
+
             for symptom_node in symptom_nodes:
             for symptom_node in symptom_nodes:
                 symptom_id = symptom_node.get("public_kg_nodes_id")
                 symptom_id = symptom_node.get("public_kg_nodes_id")
                 if not symptom_id:
                 if not symptom_id:
                     continue
                     continue
-                    
+
                 # 查询该症状的'症状同义词'关系获取同义词节点
                 # 查询该症状的'症状同义词'关系获取同义词节点
                 synonym_edges = self.search_edges(name="症状同义词", src_id=symptom_id)
                 synonym_edges = self.search_edges(name="症状同义词", src_id=symptom_id)
                 synonym_ids = {symptom_id}  # 包含原始症状ID
                 synonym_ids = {symptom_id}  # 包含原始症状ID
@@ -126,13 +126,13 @@ class SearchBusiness:
                         synonym_id = edge.get("public_kg_edges_dest_id")
                         synonym_id = edge.get("public_kg_edges_dest_id")
                         if synonym_id:
                         if synonym_id:
                             synonym_ids.add(synonym_id)
                             synonym_ids.add(synonym_id)
-                
+
                 # 对每个症状ID(包括同义词)查询'常见疾病'关系
                 # 对每个症状ID(包括同义词)查询'常见疾病'关系
                 for symptom_id in synonym_ids:
                 for symptom_id in synonym_ids:
                     disease_edges = self.search_edges(name="常见疾病", src_id=symptom_id)
                     disease_edges = self.search_edges(name="常见疾病", src_id=symptom_id)
                     if not disease_edges:
                     if not disease_edges:
                         continue
                         continue
-                        
+
                     # 收集所有疾病节点
                     # 收集所有疾病节点
                     for edge in disease_edges:
                     for edge in disease_edges:
                         disease_id = edge.get("public_kg_edges_dest_id")
                         disease_id = edge.get("public_kg_edges_dest_id")
@@ -147,12 +147,43 @@ class SearchBusiness:
                                     else:
                                     else:
                                         node["count"] = 1
                                         node["count"] = 1
                                         disease_dict[node_id] = node
                                         disease_dict[node_id] = node
-                            
+
         # 按count降序排列并过滤掉count小于阈值的疾病
         # 按count降序排列并过滤掉count小于阈值的疾病
         sorted_diseases = sorted(disease_dict.values(), key=lambda x: x["count"], reverse=True)
         sorted_diseases = sorted(disease_dict.values(), key=lambda x: x["count"], reverse=True)
         filtered_diseases = [disease for disease in sorted_diseases if disease["count"] >= threshold]
         filtered_diseases = [disease for disease in sorted_diseases if disease["count"] >= threshold]
         return filtered_diseases
         return filtered_diseases
 
 
+    def search_related_nodes(self, src_id, relation_name):
+        """
+        查询与指定节点有特定关系的所有节点
+        :param src_id: 源节点ID
+        :param relation_name: 关系名称
+        :return: 相关节点列表
+        """
+        try:
+            # 1. 查询关系表获取所有dest_id
+            edges = self.search_edges(name=relation_name, src_id=src_id)
+            if not edges:
+                return []
+
+            # 2. 收集所有目标节点ID
+            dest_ids = [edge["public_kg_edges_dest_id"] for edge in edges if "public_kg_edges_dest_id" in edge]
+            if not dest_ids:
+                return []
+
+            # 3. 查询节点表获取所有目标节点信息
+            nodes = []
+            for dest_id in dest_ids:
+                node_results = self.search_nodes(id=dest_id)
+                if node_results:
+                    nodes.extend(node_results)
+
+            return nodes
+
+        except Exception as e:
+            logger.error(f"Search related nodes error: {e}")
+            return None
+
 if __name__ == "__main__":
 if __name__ == "__main__":
     search_biz = SearchBusiness()
     search_biz = SearchBusiness()
     index=""
     index=""

+ 0 - 0
src/knowledge/static/__init__.py


文件差异内容过多而无法显示
+ 1782 - 0
src/knowledge/static/redoc.standalone.js


文件差异内容过多而无法显示
+ 2 - 0
src/knowledge/static/swagger-ui-bundle.js


文件差异内容过多而无法显示
+ 3 - 0
src/knowledge/static/swagger-ui.css


文件差异内容过多而无法显示
+ 1 - 0
src/knowledge/static/swagger-ui.css.map