weaviate_db.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. from typing import List, Dict, Any, Optional, Union
  2. import numpy as np
  3. import weaviate
  4. from weaviate import WeaviateClient
  5. from weaviate.collections import Collection
  6. import weaviate.classes.config as wc
  7. from weaviate.classes.config import Property, DataType
  8. from trustrag.modules.retrieval.embedding import EmbeddingGenerator
  9. from weaviate.classes.query import MetadataQuery
  10. class WeaviateEngine:
  11. def __init__(
  12. self,
  13. collection_name: str,
  14. embedding_generator: EmbeddingGenerator,
  15. client_params: Dict[str, Any] = {
  16. "http_host": "localhost",
  17. "http_port": 8080,
  18. "http_secure": False,
  19. "grpc_host": "localhost",
  20. "grpc_port": 50051,
  21. "grpc_secure": False,
  22. },
  23. ):
  24. """
  25. Initialize the Weaviate vector store.
  26. :param collection_name: Name of the Weaviate collection
  27. :param embedding_generator: An instance of EmbeddingGenerator to generate embeddings
  28. :param weaviate_client_params: Dictionary of parameters to pass to Weaviate client
  29. """
  30. self.collection_name = collection_name
  31. self.embedding_generator = embedding_generator
  32. # Initialize Weaviate client with provided parameters
  33. self.client = weaviate.connect_to_custom(
  34. skip_init_checks=False,
  35. **client_params
  36. )
  37. # Create collection if it doesn't exist
  38. if not self._collection_exists():
  39. self._create_collection()
  40. def _collection_exists(self) -> bool:
  41. """Check if collection exists in Weaviate."""
  42. try:
  43. collections = self.client.collections.list_all()
  44. collection_names = [c.lower() for c in collections]
  45. return self.collection_name in collection_names
  46. except Exception as e:
  47. print(f"Error checking collection existence: {e}")
  48. return False
  49. def _create_collection(self):
  50. """Create a new collection in Weaviate."""
  51. try:
  52. self.client.collections.create(
  53. name=self.collection_name,
  54. # Define properties of metadata
  55. properties=[
  56. wc.Property(
  57. name="text",
  58. data_type=wc.DataType.TEXT
  59. ),
  60. wc.Property(
  61. name="title",
  62. data_type=wc.DataType.TEXT,
  63. skip_vectorization=True
  64. ),
  65. ]
  66. )
  67. except Exception as e:
  68. print(f"Error creating collection: {e}")
  69. raise
  70. def upload_vectors(
  71. self,
  72. vectors: Union[np.ndarray, List[List[float]]],
  73. payload: List[Dict[str, Any]],
  74. batch_size: int = 100
  75. ):
  76. """
  77. Upload vectors and payload to the Weaviate collection.
  78. :param vectors: A numpy array or list of vectors to upload
  79. :param payload: A list of dictionaries containing the payload for each vector
  80. :param batch_size: Number of vectors to upload in a single batch
  81. """
  82. if not isinstance(vectors, np.ndarray):
  83. vectors = np.array(vectors)
  84. if len(vectors) != len(payload):
  85. raise ValueError("Vectors and payload must have the same length.")
  86. collection = self.client.collections.get(self.collection_name)
  87. # Process in batches
  88. for i in range(0, len(vectors), batch_size):
  89. batch_vectors = vectors[i:i + batch_size]
  90. batch_payload = payload[i:i + batch_size]
  91. try:
  92. with collection.batch.dynamic() as batch:
  93. for idx, (properties, vector) in enumerate(zip(batch_payload, batch_vectors)):
  94. # Separate text content and other metadata
  95. text_content = properties.get('description',
  96. '') # Assuming 'description' is the main text field
  97. metadata = {k: v for k, v in properties.items() if k != 'description'}
  98. # Prepare the properties dictionary
  99. properties_dict = {
  100. "text": text_content,
  101. "title": metadata.get('title', f'Object {idx}') # Using title from metadata or default
  102. }
  103. # Add the object with properties and vector
  104. batch.add_object(
  105. properties=properties_dict,
  106. vector=vector
  107. )
  108. except Exception as e:
  109. print(f"Error uploading batch: {e}")
  110. raise
  111. def search(
  112. self,
  113. text: str,
  114. query_filter: Optional[Dict[str, Any]] = None,
  115. limit: int = 5
  116. ) -> List[Dict[str, Any]]:
  117. """
  118. Search for the closest vectors in the collection based on the input text.
  119. :param text: The text query to search for
  120. :param query_filter: Optional filter to apply to the search
  121. :param limit: Number of closest results to return
  122. :return: List of payloads from the closest vectors
  123. """
  124. # Generate embedding for the query text
  125. vector = self.embedding_generator.generate_embedding(text)
  126. print(vector.shape)
  127. collection = self.client.collections.get(self.collection_name)
  128. # Prepare query arguments
  129. query_args = {
  130. "near_vector": vector,
  131. "limit": limit,
  132. "return_metadata": MetadataQuery(distance=True)
  133. }
  134. # Add filter if provided
  135. if query_filter:
  136. query_args["filter"] = query_filter
  137. results = collection.query.near_vector(**query_args)
  138. # Convert results to the same format as QdrantEngine
  139. payloads = []
  140. for obj in results.objects:
  141. payload = obj.properties.get('metadata', {})
  142. payload['text'] = obj.properties.get('text', '')
  143. payload['_distance'] = obj.metadata.distance
  144. payloads.append(payload)
  145. return payloads
  146. def build_filter(self, conditions: List[Dict[str, Any]]) -> Dict[str, Any]:
  147. """
  148. Build a Weaviate filter from a list of conditions.
  149. :param conditions: A list of conditions, where each condition is a dictionary with:
  150. - key: The field name to filter on
  151. - match: The value to match
  152. :return: A Weaviate filter object
  153. """
  154. filter_dict = {
  155. "operator": "And",
  156. "operands": []
  157. }
  158. for condition in conditions:
  159. key = condition.get("key")
  160. match_value = condition.get("match")
  161. if key and match_value is not None:
  162. filter_dict["operands"].append({
  163. "path": [f"metadata.{key}"],
  164. "operator": "Equal",
  165. "valueString": str(match_value)
  166. })
  167. return filter_dict if filter_dict["operands"] else None