123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193 |
- from typing import List, Dict, Any, Optional, Union
- import numpy as np
- import weaviate
- from weaviate import WeaviateClient
- from weaviate.collections import Collection
- import weaviate.classes.config as wc
- from weaviate.classes.config import Property, DataType
- from trustrag.modules.retrieval.embedding import EmbeddingGenerator
- from weaviate.classes.query import MetadataQuery
- class WeaviateEngine:
- def __init__(
- self,
- collection_name: str,
- embedding_generator: EmbeddingGenerator,
- client_params: Dict[str, Any] = {
- "http_host": "localhost",
- "http_port": 8080,
- "http_secure": False,
- "grpc_host": "localhost",
- "grpc_port": 50051,
- "grpc_secure": False,
- },
- ):
- """
- Initialize the Weaviate vector store.
- :param collection_name: Name of the Weaviate collection
- :param embedding_generator: An instance of EmbeddingGenerator to generate embeddings
- :param weaviate_client_params: Dictionary of parameters to pass to Weaviate client
- """
- self.collection_name = collection_name
- self.embedding_generator = embedding_generator
- # Initialize Weaviate client with provided parameters
- self.client = weaviate.connect_to_custom(
- skip_init_checks=False,
- **client_params
- )
- # Create collection if it doesn't exist
- if not self._collection_exists():
- self._create_collection()
- def _collection_exists(self) -> bool:
- """Check if collection exists in Weaviate."""
- try:
- collections = self.client.collections.list_all()
- collection_names = [c.lower() for c in collections]
- return self.collection_name in collection_names
- except Exception as e:
- print(f"Error checking collection existence: {e}")
- return False
- def _create_collection(self):
- """Create a new collection in Weaviate."""
- try:
- self.client.collections.create(
- name=self.collection_name,
- # Define properties of metadata
- properties=[
- wc.Property(
- name="text",
- data_type=wc.DataType.TEXT
- ),
- wc.Property(
- name="title",
- data_type=wc.DataType.TEXT,
- skip_vectorization=True
- ),
- ]
- )
- except Exception as e:
- print(f"Error creating collection: {e}")
- raise
- def upload_vectors(
- self,
- vectors: Union[np.ndarray, List[List[float]]],
- payload: List[Dict[str, Any]],
- batch_size: int = 100
- ):
- """
- Upload vectors and payload to the Weaviate collection.
- :param vectors: A numpy array or list of vectors to upload
- :param payload: A list of dictionaries containing the payload for each vector
- :param batch_size: Number of vectors to upload in a single batch
- """
- if not isinstance(vectors, np.ndarray):
- vectors = np.array(vectors)
- if len(vectors) != len(payload):
- raise ValueError("Vectors and payload must have the same length.")
- collection = self.client.collections.get(self.collection_name)
- # Process in batches
- for i in range(0, len(vectors), batch_size):
- batch_vectors = vectors[i:i + batch_size]
- batch_payload = payload[i:i + batch_size]
- try:
- with collection.batch.dynamic() as batch:
- for idx, (properties, vector) in enumerate(zip(batch_payload, batch_vectors)):
- # Separate text content and other metadata
- text_content = properties.get('description',
- '') # Assuming 'description' is the main text field
- metadata = {k: v for k, v in properties.items() if k != 'description'}
- # Prepare the properties dictionary
- properties_dict = {
- "text": text_content,
- "title": metadata.get('title', f'Object {idx}') # Using title from metadata or default
- }
- # Add the object with properties and vector
- batch.add_object(
- properties=properties_dict,
- vector=vector
- )
- except Exception as e:
- print(f"Error uploading batch: {e}")
- raise
- def search(
- self,
- text: str,
- query_filter: Optional[Dict[str, Any]] = None,
- limit: int = 5
- ) -> List[Dict[str, Any]]:
- """
- Search for the closest vectors in the collection based on the input text.
- :param text: The text query to search for
- :param query_filter: Optional filter to apply to the search
- :param limit: Number of closest results to return
- :return: List of payloads from the closest vectors
- """
- # Generate embedding for the query text
- vector = self.embedding_generator.generate_embedding(text)
- print(vector.shape)
- collection = self.client.collections.get(self.collection_name)
- # Prepare query arguments
- query_args = {
- "near_vector": vector,
- "limit": limit,
- "return_metadata": MetadataQuery(distance=True)
- }
- # Add filter if provided
- if query_filter:
- query_args["filter"] = query_filter
- results = collection.query.near_vector(**query_args)
- # Convert results to the same format as QdrantEngine
- payloads = []
- for obj in results.objects:
- payload = obj.properties.get('metadata', {})
- payload['text'] = obj.properties.get('text', '')
- payload['_distance'] = obj.metadata.distance
- payloads.append(payload)
- return payloads
- def build_filter(self, conditions: List[Dict[str, Any]]) -> Dict[str, Any]:
- """
- Build a Weaviate filter from a list of conditions.
- :param conditions: A list of conditions, where each condition is a dictionary with:
- - key: The field name to filter on
- - match: The value to match
- :return: A Weaviate filter object
- """
- filter_dict = {
- "operator": "And",
- "operands": []
- }
- for condition in conditions:
- key = condition.get("key")
- match_value = condition.get("match")
- if key and match_value is not None:
- filter_dict["operands"].append({
- "path": [f"metadata.{key}"],
- "operator": "Equal",
- "valueString": str(match_value)
- })
- return filter_dict if filter_dict["operands"] else None
|