faiss_helper.py 1.0 KB

12345678910111213141516171819202122232425262728293031323334
  1. import faiss
  2. import os
  3. from sentence_transformers import SentenceTransformer
  4. from dotenv import load_dotenv
  5. from typing import List, Dict, AsyncGenerator
  6. import numpy as np
  7. # 加载环境变量
  8. load_dotenv()
  9. # DeepSeek API配置
  10. EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
  11. class FaissHelper:
  12. def __init__(self, dimension: int):
  13. self.model = SentenceTransformer(EMBEDDING_MODEL)
  14. def generate_embeddings(self,texts: List[str]) -> np.ndarray:
  15. return self.model.encode(texts)
  16. def create_index(self,dimension: int):
  17. index = faiss.IndexFlatL2(dimension)
  18. return index
  19. def get_index_size(self,index):
  20. return index.ntotal
  21. def index_add(self,index, embeddings):
  22. index.add(embeddings)
  23. def load_faiss_index(self,index_path):
  24. index = faiss.read_index(index_path)
  25. return index
  26. def search_faiss_index(self, index, query, top_k):
  27. query_vector = query.reshape(1, -1)
  28. _, indices = index.search(query_vector, top_k)
  29. return indices[0]