Browse Source

智能体查询接口

yuchengwei 5 ngày trước cách đây
mục cha
commit
87c2a12423

+ 246 - 0
src/knowledge/router/medical_knowledge_api.py

@@ -0,0 +1,246 @@
+from fastapi import APIRouter, Depends, HTTPException
+from pydantic import BaseModel
+from typing import List, Optional
+from ..model.response import StandardResponse
+from ..db.session import get_db
+from sqlalchemy.orm import Session
+import logging
+from ..service.search_service import SearchBusiness
+
+router = APIRouter(prefix="/medical", tags=["Medical Knowledge API"])
+logger = logging.getLogger(__name__)
+
+# 1. 疾病与症状相关接口
+class SymptomDiseasesRequest(BaseModel):
+    symptoms: List[str]
+
+class DiseaseSymptomsRequest(BaseModel):
+    disease_id: str
+
+class DiseaseInfoRequest(BaseModel):
+    query: str
+    type: Optional[str] = None
+
+@router.post("/symptom_diseases", response_model=StandardResponse)
+async def get_symptom_diseases(
+    request: SymptomDiseasesRequest,
+    db: Session = Depends(get_db)
+):
+    try:
+        # 实现获取症状相关疾病的逻辑
+        return StandardResponse(success=True, data=[])
+    except Exception as e:
+        logger.error(f"获取症状相关疾病失败: {str(e)}")
+        raise HTTPException(500, detail=StandardResponse.error(str(e)))
+
+@router.post("/disease_symptoms", response_model=StandardResponse)
+async def get_disease_symptoms(
+    request: DiseaseSymptomsRequest,
+    db: Session = Depends(get_db)
+):
+    try:
+        # 实现获取疾病症状的逻辑
+        return StandardResponse(success=True, data=[])
+    except Exception as e:
+        logger.error(f"获取疾病症状失败: {str(e)}")
+        raise HTTPException(500, detail=StandardResponse.error(str(e)))
+
+@router.post("/disease_info", response_model=StandardResponse)
+async def get_disease_info(
+    request: DiseaseInfoRequest,
+    db: Session = Depends(get_db)
+):
+    try:
+        # 实现搜索医学概念的逻辑
+        return StandardResponse(success=True, data=[])
+    except Exception as e:
+        logger.error(f"搜索医学概念失败: {str(e)}")
+        raise HTTPException(500, detail=StandardResponse.error(str(e)))
+
+# 2. 药物相关接口
+class DrugRequest(BaseModel):
+    drug_id: str
+
+@router.post("/drug_indications", response_model=StandardResponse)
+async def get_drug_indications(
+    request: DrugRequest,
+    db: Session = Depends(get_db)
+):
+    try:
+        # 实现获取药物适应症的逻辑
+        return StandardResponse(success=True, data=[])
+    except Exception as e:
+        logger.error(f"获取药物适应症失败: {str(e)}")
+        raise HTTPException(500, detail=StandardResponse.error(str(e)))
+
+@router.post("/drug_contraindications", response_model=StandardResponse)
+async def get_drug_contraindications(
+    request: DrugRequest,
+    db: Session = Depends(get_db)
+):
+    try:
+        # 实现获取药物禁忌症的逻辑
+        return StandardResponse(success=True, data=[])
+    except Exception as e:
+        logger.error(f"获取药物禁忌症失败: {str(e)}")
+        raise HTTPException(500, detail=StandardResponse.error(str(e)))
+
+@router.post("/drug_interactions", response_model=StandardResponse)
+async def get_drug_interactions(
+    request: DrugRequest,
+    db: Session = Depends(get_db)
+):
+    try:
+        # 实现获取药物相互作用的逻辑
+        return StandardResponse(success=True, data=[])
+    except Exception as e:
+        logger.error(f"获取药物相互作用失败: {str(e)}")
+        raise HTTPException(500, detail=StandardResponse.error(str(e)))
+
+# 3. 关系与概念处理接口
+class ConceptSearchRequest(BaseModel):
+    query: str
+    type: Optional[str] = None
+
+class ConceptRelationsRequest(BaseModel):
+    concept_id: str
+
+class SimilarConceptsRequest(BaseModel):
+    concept_id: str
+    top_k: int = 5
+
+@router.post("/search_concept", response_model=StandardResponse)
+async def search_concept(
+    request: ConceptSearchRequest
+):
+    try:
+        # 实现搜索医学概念的逻辑
+        search = SearchBusiness()
+        results = search.search_nodes(name=request.query, type=request.type)
+        return StandardResponse(success=True, data=results)
+    except Exception as e:
+        logger.error(f"搜索医学概念失败: {str(e)}")
+        raise HTTPException(500, detail=StandardResponse.error(str(e)))
+
+@router.post("/get_relations", response_model=StandardResponse)
+async def get_relations(
+    request: ConceptRelationsRequest
+):
+    try:
+        # 实现获取概念关系的逻辑
+        search = SearchBusiness()
+        results = search.search_edges(name=None,src_id=request.concept_id)
+        return StandardResponse(success=True, data=results)
+    except Exception as e:
+        logger.error(f"获取概念关系失败: {str(e)}")
+        raise HTTPException(500, detail=StandardResponse.error(str(e)))
+
+@router.post("/get_similar_concepts", response_model=StandardResponse)
+async def get_similar_concepts(
+    request: SimilarConceptsRequest
+):
+    try:
+        # 实现获取相似概念的逻辑
+        search = SearchBusiness()
+        results = search.search_nodes(id=request.concept_id, limit=request.top_k)
+        return StandardResponse(success=True, data=results)
+    except Exception as e:
+        logger.error(f"获取相似概念失败: {str(e)}")
+        raise HTTPException(500, detail=StandardResponse.error(str(e)))
+
+# 4. 临床辅助决策接口
+class DiagnosisCheckRequest(BaseModel):
+    diagnosis: str
+    symptoms: List[str]
+    lab_results: Optional[Dict[str, Any]] = None
+
+class NextQuestionsRequest(BaseModel):
+    current_input: Dict[str, Any]
+
+class TriageRequest(BaseModel):
+    symptoms: List[str]
+    vital_signs: Dict[str, Any]
+
+class CodingRequest(BaseModel):
+    clinical_term: str
+
+class DepartmentRequest(BaseModel):
+    diagnosis: str
+    symptoms: Optional[List[str]] = None
+
+@router.post("/medical/check_diagnosis", response_model=StandardResponse)
+async def check_diagnosis_consistency(
+    request: DiagnosisCheckRequest,
+    db: Session = Depends(get_db)
+):
+    try:
+        # 实现诊断一致性检查的逻辑
+        return StandardResponse(success=True, data=[])
+    except Exception as e:
+        logger.error(f"诊断一致性检查失败: {str(e)}")
+        raise HTTPException(500, detail=StandardResponse.error(str(e)))
+
+@router.post("/medical/suggest_questions", response_model=StandardResponse)
+async def suggest_next_questions(
+    request: NextQuestionsRequest,
+    db: Session = Depends(get_db)
+):
+    try:
+        # 实现建议下一步问题的逻辑
+        return StandardResponse(success=True, data=[])
+    except Exception as e:
+        logger.error(f"建议下一步问题失败: {str(e)}")
+        raise HTTPException(500, detail=StandardResponse.error(str(e)))
+
+@router.post("/medical/triage_guidance", response_model=StandardResponse)
+async def get_triage_guidance(
+    request: TriageRequest,
+    db: Session = Depends(get_db)
+):
+    try:
+        # 实现分诊建议的逻辑
+        return StandardResponse(success=True, data=[])
+    except Exception as e:
+        logger.error(f"获取分诊建议失败: {str(e)}")
+        raise HTTPException(500, detail=StandardResponse.error(str(e)))
+
+@router.post("/medical/coding_recommendation", response_model=StandardResponse)
+async def get_coding_recommendation(
+    request: CodingRequest,
+    db: Session = Depends(get_db)
+):
+    try:
+        # 实现编码推荐的逻辑
+        return StandardResponse(success=True, data=[])
+    except Exception as e:
+        logger.error(f"获取编码推荐失败: {str(e)}")
+        raise HTTPException(500, detail=StandardResponse.error(str(e)))
+
+@router.post("/medical/suggest_department", response_model=StandardResponse)
+async def suggest_appropriate_department(
+    request: DepartmentRequest,
+    db: Session = Depends(get_db)
+):
+    try:
+        # 实现科室推荐的逻辑
+        return StandardResponse(success=True, data=[])
+    except Exception as e:
+        logger.error(f"科室推荐失败: {str(e)}")
+        raise HTTPException(500, detail=StandardResponse.error(str(e)))
+
+# 5. 病历质控接口
+class MedicalRecordRequest(BaseModel):
+    medical_record: Dict[str, Any]
+    validate_type: str
+
+@router.post("/medical/validate_record", response_model=StandardResponse)
+async def validate_medical_record(
+    request: MedicalRecordRequest,
+    db: Session = Depends(get_db)
+):
+    try:
+        # 实现病历验证的逻辑
+        return StandardResponse(success=True, data=[])
+    except Exception as e:
+        logger.error(f"病历验证失败: {str(e)}")
+        raise HTTPException(500, detail=StandardResponse.error(str(e)))

+ 101 - 0
src/knowledge/service/search_service.py

@@ -0,0 +1,101 @@
+import sys,os
+current_path = os.getcwd()
+sys.path.append(current_path)
+import logging
+logger = logging.getLogger(__name__)
+
+from agent.libs.schema import SchemaContent,SchemaData,SchemaDataItem
+from config.site import SiteConfig
+from elasticsearch import Elasticsearch, helpers, exceptions
+config = SiteConfig()
+ELASTICSEARCH_USER = config.get_config("ELASTICSEARCH_USER", "/tmp")
+ELASTICSEARCH_PWD = config.get_config("ELASTICSEARCH_PWD", "/tmp")
+ELASTICSEARCH_HOST = config.get_config("ELASTICSEARCH_HOST", "/tmp")
+
+class SearchBusiness:
+    def __init__(self):
+        self.es = Elasticsearch(hosts=[ELASTICSEARCH_HOST], verify_certs=False, http_auth=(ELASTICSEARCH_USER, ELASTICSEARCH_PWD))
+        pass
+    def search_nodes_and_edges(self, index, query):
+        try:
+            response = self.es.search(index=index, body=query)
+           
+            hits = response["hits"]["hits"]
+            results = []
+            for hit in hits:
+                source = hit["_source"]
+                source["score"] = hit["_score"]
+                results.append(source)
+
+            return results
+        except exceptions.NotFoundError as e:
+            logger.error(f"Index '{index}' not found: {e}")
+            return []
+
+    def search_nodes(self,name,type,id,limit=10,from_=0):
+
+        try:
+            query = {
+                "explain": "true",
+                "query": {
+                    "bool": {
+                        "must": [
+                            {"term": {"table": "kg_nodes"}}
+                        ]
+                    }
+                },
+                "sort": [{"_score": {"order": "desc"}}]
+            }
+            if limit:
+                query["size"] = limit
+            if from_:
+                query["from"] = from_
+
+            if name:
+                query["query"]["bool"]["must"].append({"match": {"public_kg_nodes_category": name}})
+            if id:
+                query["query"]["bool"]["must"].append({"match": {"public_kg_nodes_id": id}})
+            if type:
+                query["query"]["bool"]["must"].append({"match": {"public_kg_nodes_category": type}})
+
+            results = self.search_nodes_and_edges(index="connector-postgres-test", query=query)
+            return results
+        except exceptions.NotFoundError as e:
+            logger.error(f"Index not found: {e}")
+            return None
+        except Exception as e:
+            logger.error(f"Search error: {e}")
+            return None
+
+    def search_edges(self,name,src_id):
+        try:
+            query = {
+                "query": {
+                    "bool": {
+                        "must": [
+                            {"term": {"table": "kg_edges"}}
+                        ]
+                    }
+                },
+                "sort": [{"_score": {"order": "desc"}}]
+            }
+            if name:
+                query["query"]["bool"]["must"].append({"match": {"public_kg_edges_category": name}})
+            if src_id:
+                query["query"]["bool"]["must"].append({"match": {"public_kg_edges_src_id": src_id}})
+
+            results = self.search_nodes_and_edges(index="connector-postgres-test", query=query)
+            return results
+        except exceptions.NotFoundError as e:
+            logger.error(f"Index not found: {e}")
+            return None
+        except Exception as e:
+            logger.error(f"Search error: {e}")
+            return None
+
+if __name__ == "__main__":
+    search_biz = SearchBusiness()
+    index=""
+    query=""
+    results = search_biz.search_nodes_and_edges(index=index, query=query)
+    print(results)