소스 검색

代码提交

SGTY 1 개월 전
부모
커밋
0b192936b7
6개의 변경된 파일159개의 추가작업 그리고 20개의 파일을 삭제
  1. 1 1
      .env
  2. 57 0
      config/site.py
  3. 19 0
      tests/download_bge_model.py
  4. 11 0
      utils/embed_helper.py
  5. 57 0
      utils/site.py
  6. 14 19
      utils/vectorizer.py

+ 1 - 1
.env

@@ -5,4 +5,4 @@ DB_USER = knowledge
 DB_PASSWORD = qwer1234.
 
 license=E:\project\knowledge\license_issued
-EMBEDDING_MODEL=C:\Users\jiyua\.cache\modelscope\hub\models\BAAI\bge-m3
+EMBEDDING_MODEL=E:\project\knowledge\bge-m3

+ 57 - 0
config/site.py

@@ -0,0 +1,57 @@
+import os
+from dotenv import load_dotenv
+from urllib.parse import quote
+
+load_dotenv()
+
+
+class SiteConfig:
+    def __init__(self):
+        self.load_config()
+    
+    def load_config(self):        
+        self.config = {
+            "SITE_NAME": os.getenv("SITE_NAME", "DEMO"),
+            "SITE_DESCRIPTION": os.getenv("SITE_DESCRIPTION", "ChatGPT"),
+            "SITE_URL": os.getenv("SITE_URL", ""),
+            "SITE_LOGO": os.getenv("SITE_LOGO", ""),
+            "SITE_FAVICON": os.getenv("SITE_FAVICON"),
+            'ELASTICSEARCH_HOST': os.getenv("ELASTICSEARCH_HOST"),
+            'ELASTICSEARCH_USER': os.getenv("ELASTICSEARCH_USER"),
+            'ELASTICSEARCH_PWD': os.getenv("ELASTICSEARCH_PWD"),
+            'WORD_INDEX': os.getenv("WORD_INDEX"),
+            'TITLE_INDEX': os.getenv("TITLE_INDEX"),
+            'CHUNC_INDEX': os.getenv("CHUNC_INDEX"),
+            'DEEPSEEK_API_URL': os.getenv("DEEPSEEK_API_URL"),
+            'DEEPSEEK_API_KEY': os.getenv("DEEPSEEK_API_KEY"),
+            'CACHED_DATA_PATH': os.getenv("CACHED_DATA_PATH"),
+            'UPDATE_DATA_PATH': os.getenv("UPDATE_DATA_PATH"),
+            'FACTOR_DATA_PATH': os.getenv("FACTOR_DATA_PATH"),
+            'GRAPH_API_URL': os.getenv("GRAPH_API_URL"),
+            'EMBEDDING_MODEL': os.getenv("EMBEDDING_MODEL"),
+            'DOC_PATH': os.getenv("DOC_PATH"),
+            'DOC_STORAGE_PATH': os.getenv("DOC_STORAGE_PATH"),
+            'TRUNC_OUTPUT_PATH': os.getenv("TRUNC_OUTPUT_PATH"),
+            'DOC_ABSTRACT_OUTPUT_PATH': os.getenv("DOC_ABSTRACT_OUTPUT_PATH"),
+            'JIEBA_USER_DICT': os.getenv("JIEBA_USER_DICT"),
+            'JIEBA_STOP_DICT': os.getenv("JIEBA_STOP_DICT"),
+            'POSTGRESQL_HOST':  os.getenv("POSTGRESQL_HOST","localhost"),
+            'POSTGRESQL_DATABASE':  os.getenv("POSTGRESQL_DATABASE","kg"),
+            'POSTGRESQL_USER':  os.getenv("POSTGRESQL_USER","dify"),
+            'POSTGRESQL_PASSWORD':  os.getenv("POSTGRESQL_PASSWORD",quote("difyai123456")),
+        }
+    def get_config(self, config_name, default=None): 
+        config_name = config_name.upper()     
+        value = os.getenv(config_name, None)  
+        if value:
+            return value
+        
+        if config_name in self.config:            
+            return self.config[config_name]
+        else:
+            return default
+    def check_config(self, config_list):
+        for item in config_list:
+            if not self.get_config(item):
+                raise ValueError(f"Configuration '{item}' is not set.")
+      

+ 19 - 0
tests/download_bge_model.py

@@ -0,0 +1,19 @@
+
+from modelscope import snapshot_download
+import shutil
+import os
+if __name__ == '__main__':
+
+    # 1. 先下载到默认缓存位置
+    model_dir = snapshot_download('BAAI/bge-m3')
+    print(model_dir)
+
+    # 2. 定义目标目录
+    target_dir = 'E:/project/knowledge/bge_m3'
+
+    # 3. 复制到目标位置
+    if not os.path.exists(target_dir):
+        os.makedirs(target_dir)
+    shutil.copytree(model_dir, os.path.join(target_dir, 'bge-m3'))
+
+    print(f"模型已复制到:{os.path.join(target_dir, 'bge-m3')}")

+ 11 - 0
utils/embed_helper.py

@@ -0,0 +1,11 @@
+#load enviroment variable
+from config.site import SiteConfig
+from sentence_transformers import SentenceTransformer
+config = SiteConfig()
+
+class EmbedHelper:
+    def __init__(self):
+        self.embedding_model_name = config.get_config("EMBEDDING_MODEL")
+        self.embedding_model = SentenceTransformer(model_name_or_path=self.embedding_model_name) 
+    def embed_text(self, text):
+        return self.embedding_model.encode(text).tolist()

+ 57 - 0
utils/site.py

@@ -0,0 +1,57 @@
+import os
+from dotenv import load_dotenv
+from urllib.parse import quote
+
+load_dotenv()
+
+
+class SiteConfig:
+    def __init__(self):
+        self.load_config()
+    
+    def load_config(self):        
+        self.config = {
+            "SITE_NAME": os.getenv("SITE_NAME", "DEMO"),
+            "SITE_DESCRIPTION": os.getenv("SITE_DESCRIPTION", "ChatGPT"),
+            "SITE_URL": os.getenv("SITE_URL", ""),
+            "SITE_LOGO": os.getenv("SITE_LOGO", ""),
+            "SITE_FAVICON": os.getenv("SITE_FAVICON"),
+            'ELASTICSEARCH_HOST': os.getenv("ELASTICSEARCH_HOST"),
+            'ELASTICSEARCH_USER': os.getenv("ELASTICSEARCH_USER"),
+            'ELASTICSEARCH_PWD': os.getenv("ELASTICSEARCH_PWD"),
+            'WORD_INDEX': os.getenv("WORD_INDEX"),
+            'TITLE_INDEX': os.getenv("TITLE_INDEX"),
+            'CHUNC_INDEX': os.getenv("CHUNC_INDEX"),
+            'DEEPSEEK_API_URL': os.getenv("DEEPSEEK_API_URL"),
+            'DEEPSEEK_API_KEY': os.getenv("DEEPSEEK_API_KEY"),
+            'CACHED_DATA_PATH': os.getenv("CACHED_DATA_PATH"),
+            'UPDATE_DATA_PATH': os.getenv("UPDATE_DATA_PATH"),
+            'FACTOR_DATA_PATH': os.getenv("FACTOR_DATA_PATH"),
+            'GRAPH_API_URL': os.getenv("GRAPH_API_URL"),
+            'EMBEDDING_MODEL': os.getenv("EMBEDDING_MODEL"),
+            'DOC_PATH': os.getenv("DOC_PATH"),
+            'DOC_STORAGE_PATH': os.getenv("DOC_STORAGE_PATH"),
+            'TRUNC_OUTPUT_PATH': os.getenv("TRUNC_OUTPUT_PATH"),
+            'DOC_ABSTRACT_OUTPUT_PATH': os.getenv("DOC_ABSTRACT_OUTPUT_PATH"),
+            'JIEBA_USER_DICT': os.getenv("JIEBA_USER_DICT"),
+            'JIEBA_STOP_DICT': os.getenv("JIEBA_STOP_DICT"),
+            'POSTGRESQL_HOST':  os.getenv("POSTGRESQL_HOST","localhost"),
+            'POSTGRESQL_DATABASE':  os.getenv("POSTGRESQL_DATABASE","kg"),
+            'POSTGRESQL_USER':  os.getenv("POSTGRESQL_USER","dify"),
+            'POSTGRESQL_PASSWORD':  os.getenv("POSTGRESQL_PASSWORD",quote("difyai123456")),
+        }
+    def get_config(self, config_name, default=None): 
+        config_name = config_name.upper()     
+        value = os.getenv(config_name, None)  
+        if value:
+            return value
+        
+        if config_name in self.config:            
+            return self.config[config_name]
+        else:
+            return default
+    def check_config(self, config_list):
+        for item in config_list:
+            if not self.get_config(item):
+                raise ValueError(f"Configuration '{item}' is not set.")
+      

+ 14 - 19
utils/vectorizer.py

@@ -5,6 +5,9 @@ import requests
 from requests.adapters import HTTPAdapter
 from urllib3.util.retry import Retry
 
+from utils.embed_helper import EmbedHelper
+from utils.vector_distance import VectorDistance
+
 logger = logging.getLogger(__name__)
 
 class Vectorizer:
@@ -64,22 +67,14 @@ class Vectorizer:
             logger.error(f"API请求失败: {str(e)}")
             raise
 
-    if __name__ == '__main__':
-        text ='''姓名:李XX  
-            性别:女  
-            年龄:55岁  
-            住院号:NJZY20231102  
-            主诉:突发胸痛伴呼吸困难2小时  
-            现病史:患者于下午3时许突发胸痛,位于心前区,呈压榨性疼痛,伴呼吸困难,持续不缓解,无恶心呕吐及二便失禁。急诊测血压150/90mmHg,心率100次/分。心电图示II、III、aVF导联ST段弓背向上抬高0.3-0.5mV。发病前2周曾诉间断性胸闷,每次持续数分钟自行缓解。  
-            既往史:高血压8年(间断服用降压药),无糖尿病史,无手术史。长期吸烟(20包/年),饮酒史5年(红酒约150g/日)。  
-            体格检查:BP 155/92mmHg,心率102次/分。神志清楚,痛苦面容,双肺呼吸音清,未闻及干湿性啰音。心界不大,心率102次/分,律齐,心尖区可闻及2/6级收缩期杂音。  
-            辅助检查:心电图示II、III、aVF导联ST段弓背向上抬高0.3-0.5mV;心肌酶谱:肌酸激酶同工酶(CK-MB)100U/L,肌钙蛋白T(cTnT)0.4ng/ml;心脏超声示左室下壁运动减弱。  
-            诊断:急性下壁心肌梗死  
-            治疗计划:  
-            1. 抗血小板治疗:阿司匹林300mg嚼服,氯吡格雷300mg负荷剂量后75mg qd。  
-            2. 抗凝治疗:低分子量肝素4000U皮下注射,每12小时1次。  
-            3. 冠状动脉介入治疗(PCI):急诊行冠状动脉造影,必要时行支架植入术。  
-            4. 调脂治疗:阿托伐他汀20mg qn。  
-            5. 血管扩张剂:硝酸甘油静脉泵入。'''
-        embedding = get_embedding(text)
-        print(f'生成的embedding向量:\n{embedding}')
+if __name__ == '__main__':
+
+    text = '你好'
+    print(text)
+    
+    embedding1 = EmbedHelper().embed_text(text)
+    print(len(embedding1))
+
+    embedding2 = Vectorizer.get_embedding(text)
+    print(len(embedding2))
+    print(VectorDistance().calculate_distance(embedding1, embedding2))