gmf_forge_ai_data.vector_stores
Vector Stores for Platform AI Data Layer.
This module provides vector storage and retrieval capabilities for RAG pipelines. Includes production-ready Azure AI Search, Cosmos DB, and MongoDB integrations as well as an in-memory implementation for testing and development.
Available Vector Stores:
- InMemoryVectorStore: Fast in-memory storage for testing and prototyping
- AzureAISearchVectorStore: Production-ready Azure AI Search integration
- AzureCosmosDBVectorStore: Azure Cosmos DB NoSQL API (vector distance / cosine)
- MongoDBVectorStore: MongoDB Atlas Vector Search ($vectorSearch)
Core Classes:
- BaseVectorStore: Abstract interface for all vector store implementations
- Document: Container for document content, embeddings, and metadata
- SearchResult: Container for search results with relevance scores
1""" 2Vector Stores for Platform AI Data Layer. 3 4This module provides vector storage and retrieval capabilities for RAG pipelines. 5Includes production-ready Azure AI Search, Cosmos DB, and MongoDB integrations 6as well as an in-memory implementation for testing and development. 7 8Available Vector Stores: 9- InMemoryVectorStore: Fast in-memory storage for testing and prototyping 10- AzureAISearchVectorStore: Production-ready Azure AI Search integration 11- AzureCosmosDBVectorStore: Azure Cosmos DB NoSQL API (vector distance / cosine) 12- MongoDBVectorStore: MongoDB Atlas Vector Search ($vectorSearch) 13 14Core Classes: 15- BaseVectorStore: Abstract interface for all vector store implementations 16- Document: Container for document content, embeddings, and metadata 17- SearchResult: Container for search results with relevance scores 18""" 19 20from .base_vector_store import BaseVectorStore, Document, SearchResult 21from .in_memory_vector_store import InMemoryVectorStore 22from .azure_ai_search_vector_store import AzureAISearchVectorStore 23from .cosmos_db_vector_store import AzureCosmosDBVectorStore 24from .mongodb_vector_store import MongoDBVectorStore 25 26__all__ = [ 27 "BaseVectorStore", 28 "Document", 29 "SearchResult", 30 "InMemoryVectorStore", 31 "AzureAISearchVectorStore", 32 "AzureCosmosDBVectorStore", 33 "MongoDBVectorStore", 34] 35 36__version__ = "1.0.0"
181class BaseVectorStore(ABC): 182 """ 183 Abstract base class for vector store implementations. 184 185 Vector stores provide storage and retrieval of document embeddings, 186 supporting various search strategies (vector similarity, keyword, hybrid). 187 188 Implementations must provide: 189 - Document storage with embeddings 190 - Vector similarity search 191 - Optional keyword/hybrid search 192 - Document lifecycle management (add, update, delete) 193 """ 194 195 @abstractmethod 196 def add_documents( 197 self, 198 documents: List[Document], 199 generate_embeddings: bool = True 200 ) -> List[str]: 201 """ 202 Add documents to the vector store. 203 204 Args: 205 documents: List of Document objects to add 206 generate_embeddings: If True and documents lack embeddings, 207 generate them automatically 208 209 Returns: 210 List of document IDs that were added 211 212 Raises: 213 ValueError: If documents are invalid or embeddings cannot be generated 214 """ 215 pass 216 217 @abstractmethod 218 def search( 219 self, 220 query: Optional[str] = None, 221 query_embedding: Optional[List[float]] = None, 222 top_k: int = 5, 223 filters: Optional[Dict[str, Any]] = None, 224 search_type: str = "vector" 225 ) -> List[SearchResult]: 226 """ 227 Search the vector store. 228 229 Args: 230 query: Text query (required for keyword/hybrid search) 231 query_embedding: Vector embedding of the query (required for vector search) 232 top_k: Number of results to return 233 filters: Metadata filters to apply (e.g., {"source": "docs.pdf"}) 234 search_type: Type of search - "vector", "keyword", or "hybrid" 235 236 Returns: 237 List of SearchResult objects, ordered by relevance 238 239 Raises: 240 ValueError: If required parameters are missing for the search type 241 """ 242 pass 243 244 @abstractmethod 245 def delete_documents(self, document_ids: List[str]) -> int: 246 """ 247 Delete documents from the vector store. 248 249 Args: 250 document_ids: List of document IDs to delete 251 252 Returns: 253 Number of documents successfully deleted 254 """ 255 pass 256 257 @abstractmethod 258 def update_document( 259 self, 260 document_id: str, 261 content: Optional[str] = None, 262 embedding: Optional[List[float]] = None, 263 metadata: Optional[Dict[str, Any]] = None 264 ) -> bool: 265 """ 266 Update an existing document. 267 268 Args: 269 document_id: ID of the document to update 270 content: New content (if provided) 271 embedding: New embedding (if provided) 272 metadata: New metadata (merged with existing) 273 274 Returns: 275 True if update was successful, False if document not found 276 """ 277 pass 278 279 @abstractmethod 280 def get_document(self, document_id: str) -> Optional[Document]: 281 """ 282 Retrieve a document by ID. 283 284 Args: 285 document_id: ID of the document to retrieve 286 287 Returns: 288 Document if found, None otherwise 289 """ 290 pass 291 292 @abstractmethod 293 def count(self) -> int: 294 """ 295 Get the total number of documents in the vector store. 296 297 Returns: 298 Total document count 299 """ 300 pass 301 302 @abstractmethod 303 def clear(self) -> None: 304 """ 305 Remove all documents from the vector store. 306 307 Warning: This operation is irreversible. 308 """ 309 pass
Abstract base class for vector store implementations.
Vector stores provide storage and retrieval of document embeddings, supporting various search strategies (vector similarity, keyword, hybrid).
Implementations must provide:
- Document storage with embeddings
- Vector similarity search
- Optional keyword/hybrid search
- Document lifecycle management (add, update, delete)
195 @abstractmethod 196 def add_documents( 197 self, 198 documents: List[Document], 199 generate_embeddings: bool = True 200 ) -> List[str]: 201 """ 202 Add documents to the vector store. 203 204 Args: 205 documents: List of Document objects to add 206 generate_embeddings: If True and documents lack embeddings, 207 generate them automatically 208 209 Returns: 210 List of document IDs that were added 211 212 Raises: 213 ValueError: If documents are invalid or embeddings cannot be generated 214 """ 215 pass
Add documents to the vector store.
Args: documents: List of Document objects to add generate_embeddings: If True and documents lack embeddings, generate them automatically
Returns: List of document IDs that were added
Raises: ValueError: If documents are invalid or embeddings cannot be generated
217 @abstractmethod 218 def search( 219 self, 220 query: Optional[str] = None, 221 query_embedding: Optional[List[float]] = None, 222 top_k: int = 5, 223 filters: Optional[Dict[str, Any]] = None, 224 search_type: str = "vector" 225 ) -> List[SearchResult]: 226 """ 227 Search the vector store. 228 229 Args: 230 query: Text query (required for keyword/hybrid search) 231 query_embedding: Vector embedding of the query (required for vector search) 232 top_k: Number of results to return 233 filters: Metadata filters to apply (e.g., {"source": "docs.pdf"}) 234 search_type: Type of search - "vector", "keyword", or "hybrid" 235 236 Returns: 237 List of SearchResult objects, ordered by relevance 238 239 Raises: 240 ValueError: If required parameters are missing for the search type 241 """ 242 pass
Search the vector store.
Args: query: Text query (required for keyword/hybrid search) query_embedding: Vector embedding of the query (required for vector search) top_k: Number of results to return filters: Metadata filters to apply (e.g., {"source": "docs.pdf"}) search_type: Type of search - "vector", "keyword", or "hybrid"
Returns: List of SearchResult objects, ordered by relevance
Raises: ValueError: If required parameters are missing for the search type
244 @abstractmethod 245 def delete_documents(self, document_ids: List[str]) -> int: 246 """ 247 Delete documents from the vector store. 248 249 Args: 250 document_ids: List of document IDs to delete 251 252 Returns: 253 Number of documents successfully deleted 254 """ 255 pass
Delete documents from the vector store.
Args: document_ids: List of document IDs to delete
Returns: Number of documents successfully deleted
257 @abstractmethod 258 def update_document( 259 self, 260 document_id: str, 261 content: Optional[str] = None, 262 embedding: Optional[List[float]] = None, 263 metadata: Optional[Dict[str, Any]] = None 264 ) -> bool: 265 """ 266 Update an existing document. 267 268 Args: 269 document_id: ID of the document to update 270 content: New content (if provided) 271 embedding: New embedding (if provided) 272 metadata: New metadata (merged with existing) 273 274 Returns: 275 True if update was successful, False if document not found 276 """ 277 pass
Update an existing document.
Args: document_id: ID of the document to update content: New content (if provided) embedding: New embedding (if provided) metadata: New metadata (merged with existing)
Returns: True if update was successful, False if document not found
279 @abstractmethod 280 def get_document(self, document_id: str) -> Optional[Document]: 281 """ 282 Retrieve a document by ID. 283 284 Args: 285 document_id: ID of the document to retrieve 286 287 Returns: 288 Document if found, None otherwise 289 """ 290 pass
Retrieve a document by ID.
Args: document_id: ID of the document to retrieve
Returns: Document if found, None otherwise
19@dataclass 20class Document: 21 """ 22 Represents a document or chunk with its content and metadata. 23 24 This class is designed to be extensible - you can subclass it to add 25 custom fields specific to your application domain. 26 27 Base Attributes: 28 id: Unique identifier for the document 29 content: The text content of the document 30 embedding: Optional vector embedding (list of floats) 31 timestamp: Optional timestamp for time-weighted retrieval 32 metadata: Flexible dictionary for custom properties 33 34 Design Philosophy: 35 The base Document keeps only truly universal fields. Domain-specific 36 fields (source, page_number, category, etc.) should be added via 37 inheritance to match your application's data model. 38 39 Example - Using base Document with metadata: 40 ```python 41 doc = Document( 42 id="doc_1", 43 content="AI is transforming software development...", 44 timestamp=datetime.now(), 45 metadata={ 46 "source": "blog.pdf", 47 "page_number": 5, 48 "author": "John Doe", 49 "category": "AI" 50 } 51 ) 52 ``` 53 54 Example - Custom Document schema via inheritance: 55 ```python 56 from dataclasses import dataclass 57 from gmf_forge_ai_data.vector_stores import Document 58 59 @dataclass 60 class LegalDocument(Document): 61 case_number: str 62 court: str 63 decision_date: datetime 64 jurisdiction: str 65 source: str = "" # Add source as typed field 66 page_number: Optional[int] = None 67 68 legal_doc = LegalDocument( 69 id="case_123", 70 content="In the matter of...", 71 case_number="2024-CV-12345", 72 court="Supreme Court", 73 decision_date=datetime(2024, 1, 15), 74 jurisdiction="Federal", 75 source="court_records.pdf", 76 page_number=42 77 ) 78 ``` 79 80 Example - E-commerce Product Document: 81 ```python 82 @dataclass 83 class ProductDocument(Document): 84 sku: str 85 category: str 86 price: float 87 in_stock: bool 88 brand: str 89 # No source/page_number - not applicable to products 90 91 product = ProductDocument( 92 id="prod_456", 93 content="High-performance laptop with...", 94 sku="LAP-001", 95 category="Electronics", 96 price=1299.99, 97 in_stock=True, 98 brand="TechCorp" 99 ) 100 ``` 101 """ 102 id: str 103 content: str 104 embedding: Optional[List[float]] = None 105 timestamp: Optional[datetime] = None 106 metadata: Dict[str, Any] = field(default_factory=dict) 107 108 def to_dict(self) -> Dict[str, Any]: 109 """ 110 Convert Document to dictionary for serialization. 111 112 Returns: 113 Dictionary representation of the document 114 """ 115 doc_dict = asdict(self) 116 # Convert datetime to ISO string for serialization 117 if self.timestamp: 118 doc_dict['timestamp'] = self.timestamp.isoformat() 119 return doc_dict 120 121 @classmethod 122 def from_dict(cls: Type[T], data: Dict[str, Any]) -> T: 123 """ 124 Create Document from dictionary. 125 126 Args: 127 data: Dictionary containing document fields 128 129 Returns: 130 Document instance (or subclass instance) 131 """ 132 # Handle datetime deserialization 133 if 'timestamp' in data and isinstance(data['timestamp'], str): 134 data['timestamp'] = datetime.fromisoformat(data['timestamp']) 135 136 # Filter to only fields that exist in the dataclass 137 import inspect 138 valid_fields = set(inspect.signature(cls).parameters.keys()) 139 filtered_data = {k: v for k, v in data.items() if k in valid_fields} 140 141 return cls(**filtered_data) 142 143 def update_metadata(self, **kwargs) -> None: 144 """ 145 Update metadata fields. 146 147 Args: 148 **kwargs: Key-value pairs to add/update in metadata 149 """ 150 self.metadata.update(kwargs) 151 152 def get_metadata(self, key: str, default: Any = None) -> Any: 153 """ 154 Get a metadata value. 155 156 Args: 157 key: Metadata key to retrieve 158 default: Default value if key not found 159 160 Returns: 161 Metadata value or default 162 """ 163 return self.metadata.get(key, default)
Represents a document or chunk with its content and metadata.
This class is designed to be extensible - you can subclass it to add custom fields specific to your application domain.
Base Attributes: id: Unique identifier for the document content: The text content of the document embedding: Optional vector embedding (list of floats) timestamp: Optional timestamp for time-weighted retrieval metadata: Flexible dictionary for custom properties
Design Philosophy: The base Document keeps only truly universal fields. Domain-specific fields (source, page_number, category, etc.) should be added via inheritance to match your application's data model.
Example - Using base Document with metadata:
doc = Document(
id="doc_1",
content="AI is transforming software development...",
timestamp=datetime.now(),
metadata={
"source": "blog.pdf",
"page_number": 5,
"author": "John Doe",
"category": "AI"
}
)
Example - Custom Document schema via inheritance:
from dataclasses import dataclass
from gmf_forge_ai_data.vector_stores import Document
@dataclass
class LegalDocument(Document):
case_number: str
court: str
decision_date: datetime
jurisdiction: str
source: str = "" # Add source as typed field
page_number: Optional[int] = None
legal_doc = LegalDocument(
id="case_123",
content="In the matter of...",
case_number="2024-CV-12345",
court="Supreme Court",
decision_date=datetime(2024, 1, 15),
jurisdiction="Federal",
source="court_records.pdf",
page_number=42
)
Example - E-commerce Product Document:
@dataclass
class ProductDocument(Document):
sku: str
category: str
price: float
in_stock: bool
brand: str
# No source/page_number - not applicable to products
product = ProductDocument(
id="prod_456",
content="High-performance laptop with...",
sku="LAP-001",
category="Electronics",
price=1299.99,
in_stock=True,
brand="TechCorp"
)
108 def to_dict(self) -> Dict[str, Any]: 109 """ 110 Convert Document to dictionary for serialization. 111 112 Returns: 113 Dictionary representation of the document 114 """ 115 doc_dict = asdict(self) 116 # Convert datetime to ISO string for serialization 117 if self.timestamp: 118 doc_dict['timestamp'] = self.timestamp.isoformat() 119 return doc_dict
Convert Document to dictionary for serialization.
Returns: Dictionary representation of the document
121 @classmethod 122 def from_dict(cls: Type[T], data: Dict[str, Any]) -> T: 123 """ 124 Create Document from dictionary. 125 126 Args: 127 data: Dictionary containing document fields 128 129 Returns: 130 Document instance (or subclass instance) 131 """ 132 # Handle datetime deserialization 133 if 'timestamp' in data and isinstance(data['timestamp'], str): 134 data['timestamp'] = datetime.fromisoformat(data['timestamp']) 135 136 # Filter to only fields that exist in the dataclass 137 import inspect 138 valid_fields = set(inspect.signature(cls).parameters.keys()) 139 filtered_data = {k: v for k, v in data.items() if k in valid_fields} 140 141 return cls(**filtered_data)
Create Document from dictionary.
Args: data: Dictionary containing document fields
Returns: Document instance (or subclass instance)
143 def update_metadata(self, **kwargs) -> None: 144 """ 145 Update metadata fields. 146 147 Args: 148 **kwargs: Key-value pairs to add/update in metadata 149 """ 150 self.metadata.update(kwargs)
Update metadata fields.
Args: **kwargs: Key-value pairs to add/update in metadata
152 def get_metadata(self, key: str, default: Any = None) -> Any: 153 """ 154 Get a metadata value. 155 156 Args: 157 key: Metadata key to retrieve 158 default: Default value if key not found 159 160 Returns: 161 Metadata value or default 162 """ 163 return self.metadata.get(key, default)
Get a metadata value.
Args: key: Metadata key to retrieve default: Default value if key not found
Returns: Metadata value or default
166@dataclass 167class SearchResult: 168 """ 169 Represents a search result from a vector store. 170 171 Attributes: 172 document: The retrieved document 173 score: Similarity score (0.0 to 1.0, higher is better) 174 rank: Position in the result list (0-indexed) 175 """ 176 document: Document 177 score: float 178 rank: int
Represents a search result from a vector store.
Attributes: document: The retrieved document score: Similarity score (0.0 to 1.0, higher is better) rank: Position in the result list (0-indexed)
17class InMemoryVectorStore(BaseVectorStore): 18 """ 19 In-memory vector store using numpy for vector operations. 20 21 Features: 22 - Fast cosine similarity search using numpy 23 - Keyword search using simple text matching 24 - Hybrid search combining vector and keyword scores 25 - Metadata filtering 26 - No external dependencies (Azure, etc.) 27 28 Ideal for: 29 - Unit testing 30 - Development and prototyping 31 - Small datasets (< 10,000 documents) 32 - CI/CD pipelines 33 34 Note: All data is lost when the process terminates. 35 """ 36 37 def __init__(self, embedding_dimension: int = 1536): 38 """ 39 Initialize the in-memory vector store. 40 41 Args: 42 embedding_dimension: Expected dimension of embeddings (default 1536 for text-embedding-ada-002) 43 """ 44 self.embedding_dimension = embedding_dimension 45 self._documents: Dict[str, Document] = {} 46 self._embeddings: Optional[np.ndarray] = None 47 self._doc_ids: List[str] = [] 48 49 def add_documents( 50 self, 51 documents: List[Document], 52 generate_embeddings: bool = True 53 ) -> List[str]: 54 """ 55 Add documents to the in-memory store. 56 57 Args: 58 documents: List of Document objects to add 59 generate_embeddings: If True, validates embeddings are present 60 (actual generation should be done externally) 61 62 Returns: 63 List of document IDs that were added 64 """ 65 if not documents: 66 return [] 67 68 added_ids = [] 69 70 for doc in documents: 71 # Validate document 72 if not doc.id or not doc.content: 73 raise ValueError(f"Document must have id and content: {doc}") 74 75 if generate_embeddings and doc.embedding is None: 76 raise ValueError( 77 f"Document {doc.id} lacks embedding. " 78 "Generate embeddings externally before adding to store." 79 ) 80 81 if doc.embedding is not None and len(doc.embedding) != self.embedding_dimension: 82 raise ValueError( 83 f"Document {doc.id} has embedding dimension {len(doc.embedding)}, " 84 f"expected {self.embedding_dimension}" 85 ) 86 87 # Store document 88 self._documents[doc.id] = deepcopy(doc) 89 added_ids.append(doc.id) 90 91 # Rebuild embedding matrix 92 self._rebuild_embeddings() 93 94 return added_ids 95 96 def search( 97 self, 98 query: Optional[str] = None, 99 query_embedding: Optional[List[float]] = None, 100 top_k: int = 5, 101 filters: Optional[Dict[str, Any]] = None, 102 search_type: str = "vector" 103 ) -> List[SearchResult]: 104 """ 105 Search the in-memory vector store. 106 107 Args: 108 query: Text query (required for keyword/hybrid search) 109 query_embedding: Vector embedding of the query (required for vector/hybrid search) 110 top_k: Number of results to return 111 filters: Metadata filters (e.g., {"source": "doc.pdf"}) 112 search_type: "vector", "keyword", or "hybrid" 113 114 Returns: 115 List of SearchResult objects, ordered by relevance 116 """ 117 if search_type not in ["vector", "keyword", "hybrid"]: 118 raise ValueError(f"Invalid search_type: {search_type}. Must be 'vector', 'keyword', or 'hybrid'") 119 120 if search_type in ["vector", "hybrid"] and query_embedding is None: 121 raise ValueError(f"{search_type} search requires query_embedding") 122 123 if search_type in ["keyword", "hybrid"] and query is None: 124 raise ValueError(f"{search_type} search requires query text") 125 126 if not self._documents: 127 return [] 128 129 # Apply metadata filters 130 filtered_docs = self._apply_filters(filters) 131 if not filtered_docs: 132 return [] 133 134 # Calculate scores based on search type 135 if search_type == "vector": 136 scores = self._vector_search(query_embedding, filtered_docs) 137 elif search_type == "keyword": 138 scores = self._keyword_search(query, filtered_docs) 139 else: # hybrid 140 vector_scores = self._vector_search(query_embedding, filtered_docs) 141 keyword_scores = self._keyword_search(query, filtered_docs) 142 # Combine scores (50/50 weighting) 143 scores = { 144 doc_id: 0.5 * vector_scores.get(doc_id, 0.0) + 0.5 * keyword_scores.get(doc_id, 0.0) 145 for doc_id in filtered_docs 146 } 147 148 # Sort by score and take top_k 149 sorted_results = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:top_k] 150 151 # Build SearchResult objects 152 results = [] 153 for rank, (doc_id, score) in enumerate(sorted_results): 154 results.append(SearchResult( 155 document=deepcopy(self._documents[doc_id]), 156 score=float(score), 157 rank=rank 158 )) 159 160 return results 161 162 def delete_documents(self, document_ids: List[str]) -> int: 163 """ 164 Delete documents from the store. 165 166 Args: 167 document_ids: List of document IDs to delete 168 169 Returns: 170 Number of documents successfully deleted 171 """ 172 deleted_count = 0 173 for doc_id in document_ids: 174 if doc_id in self._documents: 175 del self._documents[doc_id] 176 deleted_count += 1 177 178 # Rebuild embedding matrix 179 self._rebuild_embeddings() 180 181 return deleted_count 182 183 def update_document( 184 self, 185 document_id: str, 186 content: Optional[str] = None, 187 embedding: Optional[List[float]] = None, 188 metadata: Optional[Dict[str, Any]] = None 189 ) -> bool: 190 """ 191 Update an existing document. 192 193 Args: 194 document_id: ID of the document to update 195 content: New content (if provided) 196 embedding: New embedding (if provided) 197 metadata: New metadata (merged with existing) 198 199 Returns: 200 True if update was successful, False if document not found 201 """ 202 if document_id not in self._documents: 203 return False 204 205 doc = self._documents[document_id] 206 207 if content is not None: 208 doc.content = content 209 210 if embedding is not None: 211 if len(embedding) != self.embedding_dimension: 212 raise ValueError( 213 f"Embedding dimension {len(embedding)} != expected {self.embedding_dimension}" 214 ) 215 doc.embedding = embedding 216 self._rebuild_embeddings() 217 218 if metadata is not None: 219 doc.metadata.update(metadata) 220 221 return True 222 223 def get_document(self, document_id: str) -> Optional[Document]: 224 """ 225 Retrieve a document by ID. 226 227 Args: 228 document_id: ID of the document to retrieve 229 230 Returns: 231 Document if found, None otherwise 232 """ 233 doc = self._documents.get(document_id) 234 return deepcopy(doc) if doc else None 235 236 def count(self) -> int: 237 """Get the total number of documents in the store.""" 238 return len(self._documents) 239 240 def clear(self) -> None: 241 """Remove all documents from the store.""" 242 self._documents.clear() 243 self._embeddings = None 244 self._doc_ids = [] 245 246 # Private helper methods 247 248 def _rebuild_embeddings(self) -> None: 249 """Rebuild the embedding matrix from stored documents.""" 250 docs_with_embeddings = [ 251 (doc_id, doc) for doc_id, doc in self._documents.items() 252 if doc.embedding is not None 253 ] 254 255 if not docs_with_embeddings: 256 self._embeddings = None 257 self._doc_ids = [] 258 return 259 260 self._doc_ids = [doc_id for doc_id, _ in docs_with_embeddings] 261 embeddings_list = [doc.embedding for _, doc in docs_with_embeddings] 262 self._embeddings = np.array(embeddings_list, dtype=np.float32) 263 264 def _apply_filters(self, filters: Optional[Dict[str, Any]]) -> List[str]: 265 """ 266 Apply metadata filters and return list of matching document IDs. 267 268 Supports filtering on: source, page_number, timestamp, and metadata fields 269 Special operators: use tuples for range queries 270 - (">=", value) for greater than or equal 271 - ("<=", value) for less than or equal 272 - ("range", min_val, max_val) for range queries 273 """ 274 if filters is None: 275 return list(self._documents.keys()) 276 277 matching_ids = [] 278 for doc_id, doc in self._documents.items(): 279 matches = True 280 281 for key, value in filters.items(): 282 # Check document attributes (source, page_number, timestamp) 283 doc_value = getattr(doc, key, None) 284 285 # If not a document attribute, check metadata 286 if doc_value is None and key in doc.metadata: 287 doc_value = doc.metadata[key] 288 289 # Skip if field doesn't exist 290 if doc_value is None: 291 matches = False 292 break 293 294 # Handle tuple operators for range queries 295 if isinstance(value, tuple): 296 operator = value[0] 297 if operator == "range" and len(value) == 3: 298 min_val, max_val = value[1], value[2] 299 if not (min_val <= doc_value <= max_val): 300 matches = False 301 break 302 elif operator in [">=", "gt"]: 303 if not (doc_value >= value[1]): 304 matches = False 305 break 306 elif operator in ["<=", "lt"]: 307 if not (doc_value <= value[1]): 308 matches = False 309 break 310 # Simple equality check 311 elif doc_value != value: 312 matches = False 313 break 314 315 if matches: 316 matching_ids.append(doc_id) 317 318 return matching_ids 319 320 def _vector_search( 321 self, 322 query_embedding: List[float], 323 candidate_ids: List[str] 324 ) -> Dict[str, float]: 325 """ 326 Perform vector similarity search using cosine similarity. 327 328 Returns: 329 Dict mapping document IDs to similarity scores (0.0 to 1.0) 330 """ 331 if self._embeddings is None or len(self._doc_ids) == 0: 332 return {} 333 334 query_vector = np.array(query_embedding, dtype=np.float32) 335 336 # Filter embeddings to only candidate documents 337 candidate_indices = [ 338 i for i, doc_id in enumerate(self._doc_ids) 339 if doc_id in candidate_ids 340 ] 341 342 if not candidate_indices: 343 return {} 344 345 candidate_embeddings = self._embeddings[candidate_indices] 346 candidate_doc_ids = [self._doc_ids[i] for i in candidate_indices] 347 348 # Compute cosine similarity 349 query_norm = np.linalg.norm(query_vector) 350 if query_norm == 0: 351 return {doc_id: 0.0 for doc_id in candidate_doc_ids} 352 353 doc_norms = np.linalg.norm(candidate_embeddings, axis=1) 354 dot_products = candidate_embeddings.dot(query_vector) 355 356 # Avoid division by zero 357 similarities = np.zeros(len(candidate_doc_ids)) 358 valid_mask = doc_norms > 0 359 similarities[valid_mask] = dot_products[valid_mask] / (doc_norms[valid_mask] * query_norm) 360 361 # Convert from [-1, 1] to [0, 1] 362 similarities = (similarities + 1) / 2 363 364 return {doc_id: float(sim) for doc_id, sim in zip(candidate_doc_ids, similarities)} 365 366 def _keyword_search( 367 self, 368 query: str, 369 candidate_ids: List[str] 370 ) -> Dict[str, float]: 371 """ 372 Perform simple keyword search using term matching. 373 374 Returns: 375 Dict mapping document IDs to relevance scores (0.0 to 1.0) 376 """ 377 query_lower = query.lower() 378 query_terms = set(query_lower.split()) 379 380 scores = {} 381 for doc_id in candidate_ids: 382 doc = self._documents[doc_id] 383 content_lower = doc.content.lower() 384 content_terms = set(content_lower.split()) 385 386 # Calculate Jaccard similarity as relevance score 387 if not query_terms: 388 scores[doc_id] = 0.0 389 else: 390 intersection = len(query_terms & content_terms) 391 union = len(query_terms | content_terms) 392 scores[doc_id] = intersection / union if union > 0 else 0.0 393 394 return scores
In-memory vector store using numpy for vector operations.
Features:
- Fast cosine similarity search using numpy
- Keyword search using simple text matching
- Hybrid search combining vector and keyword scores
- Metadata filtering
- No external dependencies (Azure, etc.)
Ideal for:
- Unit testing
- Development and prototyping
- Small datasets (< 10,000 documents)
- CI/CD pipelines
Note: All data is lost when the process terminates.
37 def __init__(self, embedding_dimension: int = 1536): 38 """ 39 Initialize the in-memory vector store. 40 41 Args: 42 embedding_dimension: Expected dimension of embeddings (default 1536 for text-embedding-ada-002) 43 """ 44 self.embedding_dimension = embedding_dimension 45 self._documents: Dict[str, Document] = {} 46 self._embeddings: Optional[np.ndarray] = None 47 self._doc_ids: List[str] = []
Initialize the in-memory vector store.
Args: embedding_dimension: Expected dimension of embeddings (default 1536 for text-embedding-ada-002)
49 def add_documents( 50 self, 51 documents: List[Document], 52 generate_embeddings: bool = True 53 ) -> List[str]: 54 """ 55 Add documents to the in-memory store. 56 57 Args: 58 documents: List of Document objects to add 59 generate_embeddings: If True, validates embeddings are present 60 (actual generation should be done externally) 61 62 Returns: 63 List of document IDs that were added 64 """ 65 if not documents: 66 return [] 67 68 added_ids = [] 69 70 for doc in documents: 71 # Validate document 72 if not doc.id or not doc.content: 73 raise ValueError(f"Document must have id and content: {doc}") 74 75 if generate_embeddings and doc.embedding is None: 76 raise ValueError( 77 f"Document {doc.id} lacks embedding. " 78 "Generate embeddings externally before adding to store." 79 ) 80 81 if doc.embedding is not None and len(doc.embedding) != self.embedding_dimension: 82 raise ValueError( 83 f"Document {doc.id} has embedding dimension {len(doc.embedding)}, " 84 f"expected {self.embedding_dimension}" 85 ) 86 87 # Store document 88 self._documents[doc.id] = deepcopy(doc) 89 added_ids.append(doc.id) 90 91 # Rebuild embedding matrix 92 self._rebuild_embeddings() 93 94 return added_ids
Add documents to the in-memory store.
Args: documents: List of Document objects to add generate_embeddings: If True, validates embeddings are present (actual generation should be done externally)
Returns: List of document IDs that were added
96 def search( 97 self, 98 query: Optional[str] = None, 99 query_embedding: Optional[List[float]] = None, 100 top_k: int = 5, 101 filters: Optional[Dict[str, Any]] = None, 102 search_type: str = "vector" 103 ) -> List[SearchResult]: 104 """ 105 Search the in-memory vector store. 106 107 Args: 108 query: Text query (required for keyword/hybrid search) 109 query_embedding: Vector embedding of the query (required for vector/hybrid search) 110 top_k: Number of results to return 111 filters: Metadata filters (e.g., {"source": "doc.pdf"}) 112 search_type: "vector", "keyword", or "hybrid" 113 114 Returns: 115 List of SearchResult objects, ordered by relevance 116 """ 117 if search_type not in ["vector", "keyword", "hybrid"]: 118 raise ValueError(f"Invalid search_type: {search_type}. Must be 'vector', 'keyword', or 'hybrid'") 119 120 if search_type in ["vector", "hybrid"] and query_embedding is None: 121 raise ValueError(f"{search_type} search requires query_embedding") 122 123 if search_type in ["keyword", "hybrid"] and query is None: 124 raise ValueError(f"{search_type} search requires query text") 125 126 if not self._documents: 127 return [] 128 129 # Apply metadata filters 130 filtered_docs = self._apply_filters(filters) 131 if not filtered_docs: 132 return [] 133 134 # Calculate scores based on search type 135 if search_type == "vector": 136 scores = self._vector_search(query_embedding, filtered_docs) 137 elif search_type == "keyword": 138 scores = self._keyword_search(query, filtered_docs) 139 else: # hybrid 140 vector_scores = self._vector_search(query_embedding, filtered_docs) 141 keyword_scores = self._keyword_search(query, filtered_docs) 142 # Combine scores (50/50 weighting) 143 scores = { 144 doc_id: 0.5 * vector_scores.get(doc_id, 0.0) + 0.5 * keyword_scores.get(doc_id, 0.0) 145 for doc_id in filtered_docs 146 } 147 148 # Sort by score and take top_k 149 sorted_results = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:top_k] 150 151 # Build SearchResult objects 152 results = [] 153 for rank, (doc_id, score) in enumerate(sorted_results): 154 results.append(SearchResult( 155 document=deepcopy(self._documents[doc_id]), 156 score=float(score), 157 rank=rank 158 )) 159 160 return results
Search the in-memory vector store.
Args: query: Text query (required for keyword/hybrid search) query_embedding: Vector embedding of the query (required for vector/hybrid search) top_k: Number of results to return filters: Metadata filters (e.g., {"source": "doc.pdf"}) search_type: "vector", "keyword", or "hybrid"
Returns: List of SearchResult objects, ordered by relevance
162 def delete_documents(self, document_ids: List[str]) -> int: 163 """ 164 Delete documents from the store. 165 166 Args: 167 document_ids: List of document IDs to delete 168 169 Returns: 170 Number of documents successfully deleted 171 """ 172 deleted_count = 0 173 for doc_id in document_ids: 174 if doc_id in self._documents: 175 del self._documents[doc_id] 176 deleted_count += 1 177 178 # Rebuild embedding matrix 179 self._rebuild_embeddings() 180 181 return deleted_count
Delete documents from the store.
Args: document_ids: List of document IDs to delete
Returns: Number of documents successfully deleted
183 def update_document( 184 self, 185 document_id: str, 186 content: Optional[str] = None, 187 embedding: Optional[List[float]] = None, 188 metadata: Optional[Dict[str, Any]] = None 189 ) -> bool: 190 """ 191 Update an existing document. 192 193 Args: 194 document_id: ID of the document to update 195 content: New content (if provided) 196 embedding: New embedding (if provided) 197 metadata: New metadata (merged with existing) 198 199 Returns: 200 True if update was successful, False if document not found 201 """ 202 if document_id not in self._documents: 203 return False 204 205 doc = self._documents[document_id] 206 207 if content is not None: 208 doc.content = content 209 210 if embedding is not None: 211 if len(embedding) != self.embedding_dimension: 212 raise ValueError( 213 f"Embedding dimension {len(embedding)} != expected {self.embedding_dimension}" 214 ) 215 doc.embedding = embedding 216 self._rebuild_embeddings() 217 218 if metadata is not None: 219 doc.metadata.update(metadata) 220 221 return True
Update an existing document.
Args: document_id: ID of the document to update content: New content (if provided) embedding: New embedding (if provided) metadata: New metadata (merged with existing)
Returns: True if update was successful, False if document not found
223 def get_document(self, document_id: str) -> Optional[Document]: 224 """ 225 Retrieve a document by ID. 226 227 Args: 228 document_id: ID of the document to retrieve 229 230 Returns: 231 Document if found, None otherwise 232 """ 233 doc = self._documents.get(document_id) 234 return deepcopy(doc) if doc else None
Retrieve a document by ID.
Args: document_id: ID of the document to retrieve
Returns: Document if found, None otherwise
43class AzureAISearchVectorStore(BaseVectorStore): 44 """ 45 Azure AI Search vector store for production RAG pipelines. 46 47 Features: 48 - Scalable vector search with HNSW (Hierarchical Navigable Small World) 49 - Hybrid search (vector + keyword BM25) 50 - Semantic ranking powered by Microsoft Bing 51 - Automatic indexing of all document fields (universal + custom) 52 - Efficient filtering on all indexed fields 53 - High availability and disaster recovery 54 - Managed service (no infrastructure management) 55 56 Prerequisites: 57 - Azure AI Search service (Basic tier or higher for vector search) 58 - Service endpoint and API key **or** a managed-identity token provider 59 - Index created with vector field configuration 60 61 Usage: 62 63 **Simple Mode** (base Document): 64 ```python 65 # API key auth 66 store = AzureAISearchVectorStore( 67 endpoint="https://my-search.search.windows.net", 68 index_name="documents", 69 api_key="...", 70 embedding_dimension=1536 71 ) 72 73 # Managed identity / token provider auth 74 from azure.identity import DefaultAzureCredential, get_bearer_token_provider 75 token_provider = get_bearer_token_provider( 76 DefaultAzureCredential(), 77 "https://search.azure.com/.default" # Azure AI Search scope 78 ) 79 store = AzureAISearchVectorStore( 80 endpoint="https://my-search.search.windows.net", 81 index_name="documents", 82 token_provider=token_provider, 83 embedding_dimension=1536 84 ) 85 # Uses base Document with universal fields only 86 ``` 87 88 **Custom Schema Mode** (domain-specific Document): 89 ```python 90 from dataclasses import dataclass 91 from gmf_forge_ai_data.vector_stores import Document 92 93 @dataclass 94 class LegalDocument(Document): 95 case_number: str = "" 96 court: str = "" 97 jurisdiction: str = "" 98 decision_date: Optional[datetime] = None 99 100 # Pass document_type - all fields automatically indexed. 101 # If the index was provisioned externally (portal, Bicep, etc.) the 102 # vector and content field names will likely differ from the defaults 103 # ("embedding" / "content"). Always set vector_field_name and 104 # content_field_name to match your actual index schema. 105 store = AzureAISearchVectorStore( 106 endpoint="https://my-search.search.windows.net", 107 index_name="legal_documents", 108 api_key="...", 109 embedding_dimension=1536, 110 document_type=LegalDocument, 111 vector_field_name="contentVector", # match your index schema 112 content_field_name="chunkContent", # match your index schema 113 ) 114 115 # All fields indexed - filter on any field 116 results = store.search( 117 query_embedding=vector, 118 filters={ 119 "jurisdiction": "Federal", 120 "court": "Supreme Court", 121 "decision_date": (">=", datetime(2024, 1, 1)) 122 } 123 ) 124 ``` 125 """ 126 127 def __init__( 128 self, 129 endpoint: str, 130 index_name: str, 131 api_key: Optional[str] = None, 132 token_provider: Optional[Callable[[], str]] = None, 133 embedding_dimension: int = 1536, 134 document_type: type = Document, 135 vector_field_name: str = "embedding", 136 content_field_name: str = "content", 137 ): 138 """ 139 Initialize Azure AI Search vector store. 140 141 The index must be pre-provisioned using ``AzureAISearchIndexBuilder`` 142 before first use. The constructor only establishes client connections 143 — it never creates or modifies indexes. 144 145 Exactly one of ``api_key`` or ``token_provider`` must be supplied. 146 147 Args: 148 endpoint: Azure AI Search service endpoint 149 index_name: Name of the search index 150 api_key: Azure AI Search API key. Use for local development or 151 when managed identity is not available. 152 token_provider: Zero-argument callable that returns a bearer token 153 string. Use for managed identity / workload identity scenarios. 154 The callable must request the **Azure AI Search** scope:: 155 156 from azure.identity import DefaultAzureCredential, get_bearer_token_provider 157 token_provider = get_bearer_token_provider( 158 DefaultAzureCredential(), 159 "https://search.azure.com/.default" 160 ) 161 162 Note: this scope is different from Azure OpenAI / Cognitive Services 163 (``https://cognitiveservices.azure.com/.default``) — each service 164 requires its own token_provider. 165 embedding_dimension: Dimension of vector embeddings (default 1536) 166 document_type: Document class (base or custom subclass). All fields will be indexed. 167 vector_field_name: Name of the vector field in the Azure Search index (default "embedding"). 168 Override this when the index was built externally with a different field name 169 (e.g. "contentVector", "chunkVector"). 170 content_field_name: Name of the text content field in the Azure Search index (default "content"). 171 Override when the index uses a different name (e.g. "chunkContent", "text"). 172 """ 173 if not api_key and not token_provider: 174 raise ValueError( 175 "Either api_key or token_provider must be supplied to AzureAISearchVectorStore." 176 ) 177 self.endpoint = endpoint 178 self.index_name = index_name 179 self.embedding_dimension = embedding_dimension 180 self.document_type = document_type 181 self.vector_field_name = vector_field_name 182 self.content_field_name = content_field_name 183 184 if token_provider: 185 credential = _TokenProviderCredential(token_provider) 186 else: 187 credential = AzureKeyCredential(api_key) 188 189 # Initialize clients 190 self.search_client = SearchClient( 191 endpoint=endpoint, 192 index_name=index_name, 193 credential=credential 194 ) 195 196 @staticmethod 197 def _format_datetime_for_azure(dt: datetime) -> str: 198 """ 199 Convert datetime to Azure Search DateTimeOffset format (ISO 8601 with timezone). 200 201 Args: 202 dt: datetime object (timezone-aware or naive) 203 204 Returns: 205 ISO 8601 string with timezone (e.g., '2024-01-15T10:30:00Z') 206 """ 207 if dt.tzinfo is None: 208 # Timezone-naive datetime: assume UTC and append 'Z' 209 return dt.isoformat() + 'Z' 210 else: 211 # Timezone-aware datetime: use standard ISO format 212 return dt.isoformat() 213 214 def add_documents( 215 self, 216 documents: List[Document], 217 generate_embeddings: bool = True 218 ) -> List[str]: 219 """ 220 Add documents to Azure AI Search. 221 222 Args: 223 documents: List of Document objects to add (can be base or custom subclasses) 224 generate_embeddings: If True, validates embeddings are present 225 226 Returns: 227 List of document IDs that were added 228 229 Note: 230 Custom document fields are automatically serialized and stored. 231 All fields remain accessible after retrieval. 232 """ 233 if not documents: 234 return [] 235 236 # Helper function to format datetime for Azure Search DateTimeOffset 237 format_datetime_for_azure = self._format_datetime_for_azure 238 239 # Helper function to serialize datetime objects in dictionaries 240 def serialize_for_json(obj): 241 """Recursively convert datetime objects to ISO format strings for JSON serialization.""" 242 if isinstance(obj, datetime): 243 return format_datetime_for_azure(obj) 244 elif isinstance(obj, dict): 245 return {k: serialize_for_json(v) for k, v in obj.items()} 246 elif isinstance(obj, (list, tuple)): 247 return [serialize_for_json(item) for item in obj] 248 else: 249 return obj 250 251 # Prepare documents for upload 252 search_documents = [] 253 added_ids = [] 254 255 for doc in documents: 256 if not doc.id or not doc.content: 257 raise ValueError(f"Document must have id and content: {doc}") 258 259 if generate_embeddings and doc.embedding is None: 260 raise ValueError( 261 f"Document {doc.id} lacks embedding. " 262 "Generate embeddings externally before adding." 263 ) 264 265 # Serialize entire document (including custom fields) to JSON 266 import json 267 import dataclasses 268 doc_dict = doc.to_dict() 269 270 # Convert datetime objects for JSON serialization 271 doc_dict_serializable = serialize_for_json(doc_dict) 272 273 # Convert to Azure Search document format 274 # Store base fields directly 275 search_doc = { 276 "id": doc.id, 277 "content": doc.content, 278 "embedding": doc.embedding if doc.embedding else [], 279 "timestamp": format_datetime_for_azure(doc.timestamp) if doc.timestamp else None, 280 "document_data": json.dumps(doc_dict_serializable) # Keep for metadata and backwards compatibility 281 } 282 283 # Add all custom fields as indexed fields 284 if dataclasses.is_dataclass(doc): 285 for field in dataclasses.fields(doc): 286 # Skip base Document fields (already added above) 287 if field.name in {'id', 'content', 'embedding', 'timestamp', 'metadata'}: 288 continue 289 290 field_value = getattr(doc, field.name, None) 291 292 if field_value is not None: 293 # Serialize datetime fields with Azure DateTimeOffset format 294 if isinstance(field_value, datetime): 295 search_doc[field.name] = format_datetime_for_azure(field_value) 296 # Handle lists/dicts by converting to JSON string 297 elif isinstance(field_value, (list, dict)): 298 search_doc[field.name] = json.dumps(field_value) 299 else: 300 search_doc[field.name] = field_value 301 302 search_documents.append(search_doc) 303 added_ids.append(doc.id) 304 305 # Upload to Azure Search 306 result = self.search_client.upload_documents(documents=search_documents) 307 308 # Check for failures 309 failed_count = sum(1 for r in result if not r.succeeded) 310 if failed_count > 0: 311 logger.warning(f"{failed_count} documents failed to upload") 312 313 return added_ids 314 315 def search( 316 self, 317 query: Optional[str] = None, 318 query_embedding: Optional[List[float]] = None, 319 top_k: int = 5, 320 filters: Optional[Dict[str, Any]] = None, 321 search_type: str = "vector" 322 ) -> List[SearchResult]: 323 """ 324 Search Azure AI Search index. 325 326 Args: 327 query: Text query (required for keyword/hybrid search) 328 query_embedding: Vector embedding (required for vector/hybrid search) 329 top_k: Number of results to return 330 filters: Metadata filters (e.g., {"source": "doc.pdf"}) 331 search_type: "vector", "keyword", or "hybrid" 332 333 Returns: 334 List of SearchResult objects, ordered by relevance 335 """ 336 if search_type not in ["vector", "keyword", "hybrid"]: 337 raise ValueError(f"Invalid search_type: {search_type}") 338 339 if search_type in ["vector", "hybrid"] and query_embedding is None: 340 raise ValueError(f"{search_type} search requires query_embedding") 341 342 if search_type in ["keyword", "hybrid"] and query is None: 343 raise ValueError(f"{search_type} search requires query text") 344 345 # Build filter expression 346 filter_expr = self._build_filter_expression(filters) 347 348 # Execute search based on type 349 if search_type == "vector": 350 # Pure vector search 351 vector_query = VectorizedQuery( 352 vector=query_embedding, 353 k_nearest_neighbors=top_k, 354 fields=self.vector_field_name 355 ) 356 357 results = self.search_client.search( 358 search_text=None, 359 vector_queries=[vector_query], 360 filter=filter_expr, 361 top=top_k 362 ) 363 364 elif search_type == "keyword": 365 # Pure keyword search (BM25) 366 results = self.search_client.search( 367 search_text=query, 368 filter=filter_expr, 369 top=top_k, 370 include_total_count=False 371 ) 372 373 else: # hybrid 374 # Hybrid search (vector + keyword) 375 vector_query = VectorizedQuery( 376 vector=query_embedding, 377 k_nearest_neighbors=top_k, 378 fields=self.vector_field_name 379 ) 380 381 results = self.search_client.search( 382 search_text=query, 383 vector_queries=[vector_query], 384 filter=filter_expr, 385 top=top_k 386 ) 387 388 # Convert to SearchResult objects 389 search_results = [] 390 for rank, result in enumerate(results): 391 import json 392 393 # Deserialize full document data 394 doc_data = json.loads(result.get("document_data", "{}") or "{}") 395 396 if not doc_data: 397 # Index was built externally — fields are stored directly on the result. 398 # Collect all non-Azure-internal fields, excluding the vector blob. 399 doc_data = { 400 k: v for k, v in result.items() 401 if not k.startswith("@") and k != self.vector_field_name 402 } 403 # Remap content field when the index uses a non-standard name. 404 if self.content_field_name != "content" and self.content_field_name in doc_data: 405 doc_data["content"] = doc_data.pop(self.content_field_name) 406 407 # Reconstruct document using the configured document_type 408 doc = self.document_type.from_dict(doc_data) 409 410 search_results.append(SearchResult( 411 document=doc, 412 score=result.get("@search.score", 0.0), 413 rank=rank 414 )) 415 416 return search_results 417 418 def delete_documents(self, document_ids: List[str]) -> int: 419 """ 420 Delete documents from Azure AI Search. 421 422 Args: 423 document_ids: List of document IDs to delete 424 425 Returns: 426 Number of documents successfully deleted 427 """ 428 if not document_ids: 429 return 0 430 431 # Prepare delete operations 432 delete_docs = [{"id": doc_id} for doc_id in document_ids] 433 434 # Execute delete 435 result = self.search_client.delete_documents(documents=delete_docs) 436 437 # Count successful deletions 438 deleted_count = sum(1 for r in result if r.succeeded) 439 return deleted_count 440 441 def update_document( 442 self, 443 document_id: str, 444 content: Optional[str] = None, 445 embedding: Optional[List[float]] = None, 446 metadata: Optional[Dict[str, Any]] = None, 447 timestamp: Optional[datetime] = None, 448 **custom_fields 449 ) -> bool: 450 """ 451 Update an existing document in Azure AI Search. 452 453 Args: 454 document_id: ID of the document to update 455 content: New content (if provided) 456 embedding: New embedding (if provided) 457 metadata: New metadata (merged with existing) 458 timestamp: New timestamp (if provided) 459 **custom_fields: Any custom fields to update 460 461 Returns: 462 True if update was successful, False if document not found 463 """ 464 try: 465 import json 466 467 # Retrieve existing document 468 existing_doc = self.search_client.get_document(key=document_id) 469 doc_data = json.loads(existing_doc.get("document_data", "{}")) 470 471 # Update fields in document data 472 if content is not None: 473 doc_data["content"] = content 474 existing_doc["content"] = content 475 476 if embedding is not None: 477 doc_data["embedding"] = embedding 478 existing_doc["embedding"] = embedding 479 480 if timestamp is not None: 481 doc_data["timestamp"] = self._format_datetime_for_azure(timestamp) 482 existing_doc["timestamp"] = self._format_datetime_for_azure(timestamp) 483 484 if metadata is not None: 485 if "metadata" not in doc_data: 486 doc_data["metadata"] = {} 487 doc_data["metadata"].update(metadata) 488 489 # Update custom fields 490 for key, value in custom_fields.items(): 491 doc_data[key] = value 492 493 # Update document_data 494 existing_doc["document_data"] = json.dumps(doc_data) 495 496 # Merge/upload updated document 497 result = self.search_client.merge_or_upload_documents(documents=[existing_doc]) 498 499 return result[0].succeeded 500 501 except Exception as e: 502 logger.error(f"Failed to update document {document_id}: {e}") 503 return False 504 505 def get_document(self, document_id: str) -> Optional[Document]: 506 """ 507 Retrieve a document by ID from Azure AI Search. 508 509 Args: 510 document_id: ID of the document to retrieve 511 512 Returns: 513 Document if found, None otherwise (returns configured document_type) 514 """ 515 try: 516 import json 517 result = self.search_client.get_document(key=document_id) 518 519 # Deserialize full document data using configured document_type 520 doc_data = json.loads(result.get("document_data", "{}")) 521 return self.document_type.from_dict(doc_data) 522 523 except Exception as e: 524 logger.debug(f"Document {document_id} not found: {e}") 525 return None 526 527 def get_document_by_parent(self, document_id: str) -> List[Document]: 528 """ 529 Retrieve all chunks belonging to a parent document. 530 531 Args: 532 document_id: The value of the ``documentId`` field shared across 533 all chunks of the same parent document. 534 535 Returns: 536 List of Document objects sorted by ``pageNumber``, or an empty 537 list if no matching chunks are found. 538 """ 539 results = self.search( 540 query="*", 541 search_type="keyword", 542 filters={"documentId": document_id}, 543 top_k=1000, 544 ) 545 chunks = [r.document for r in results] 546 chunks.sort(key=lambda c: getattr(c, "pageNumber", 0)) 547 return chunks 548 549 def count(self) -> int: 550 """ 551 Get the total number of documents in the index. 552 553 Returns: 554 Total document count 555 """ 556 # Execute a search to get total count 557 results = self.search_client.search( 558 search_text="*", 559 include_total_count=True, 560 top=0 561 ) 562 563 # Get count from results 564 count = getattr(results, 'get_count', lambda: 0)() 565 return count if count is not None else 0 566 567 def clear(self) -> None: 568 """ 569 Remove all documents from the index. 570 571 Warning: This operation is irreversible. 572 """ 573 # Search for all document IDs 574 results = self.search_client.search( 575 search_text="*", 576 select=["id"], 577 top=1000 # Adjust batch size as needed 578 ) 579 580 # Collect all IDs 581 doc_ids = [r["id"] for r in results] 582 583 # Delete in batches 584 if doc_ids: 585 self.delete_documents(doc_ids) 586 logger.info(f"Cleared {len(doc_ids)} documents from index '{self.index_name}'") 587 588 def _build_filter_expression(self, filters: Optional[Dict[str, Any]]) -> Optional[str]: 589 """ 590 Build OData filter expression from filter dictionary. 591 592 Supports filtering on all indexed fields (universal + custom fields from document_type). 593 594 Args: 595 filters: Dictionary of field: value pairs 596 Special operators: use tuples for range queries 597 - (">=", value) for greater than or equal 598 - ("<=", value) for less than or equal 599 - ("range", min_val, max_val) for range queries 600 601 Returns: 602 OData filter string or None 603 604 Example: 605 ```python 606 # Filter on universal field 607 filters = {"timestamp": (">=", datetime(2024, 1, 1))} 608 609 # Filter on custom fields (if document_type has them) 610 filters = { 611 "court": "Supreme Court", 612 "case_number": "2024-001", 613 "decision_date": ("range", start_date, end_date) 614 } 615 ``` 616 """ 617 if not filters: 618 return None 619 620 # Get indexed field names from document_type 621 import dataclasses 622 indexed_fields = {'id', 'content', 'timestamp'} # Universal fields 623 624 if dataclasses.is_dataclass(self.document_type): 625 for field in dataclasses.fields(self.document_type): 626 if field.name not in {'id', 'content', 'embedding', 'timestamp', 'metadata'}: 627 indexed_fields.add(field.name) 628 629 filter_parts = [] 630 631 for key, value in filters.items(): 632 # Check if field is indexed 633 if key not in indexed_fields: 634 logger.warning( 635 f"Filtering on '{key}' - field not found in document schema. " 636 f"Indexed fields: {indexed_fields}" 637 ) 638 continue 639 640 # Handle tuple operators for range queries 641 if isinstance(value, tuple): 642 operator = value[0] 643 if operator == "range" and len(value) == 3: 644 # Range query: field >= min and field <= max 645 min_val = value[1] 646 max_val = value[2] 647 if isinstance(min_val, datetime): 648 min_val = self._format_datetime_for_azure(min_val) 649 if isinstance(max_val, datetime): 650 max_val = self._format_datetime_for_azure(max_val) 651 filter_parts.append(f"{key} ge {min_val} and {key} le {max_val}") 652 elif operator in [">=", "gt", ">="]: 653 op_str = "ge" 654 val = value[1] 655 if isinstance(val, datetime): 656 val = self._format_datetime_for_azure(val) 657 filter_parts.append(f"{key} {op_str} {val}") 658 elif operator in ["<=", "lt", "<="]: 659 op_str = "le" 660 val = value[1] 661 if isinstance(val, datetime): 662 val = self._format_datetime_for_azure(val) 663 filter_parts.append(f"{key} {op_str} {val}") 664 # Handle datetime equality 665 elif isinstance(value, datetime): 666 filter_parts.append(f"{key} eq {self._format_datetime_for_azure(value)}") 667 # Handle string equality (needs quotes in OData) 668 elif isinstance(value, str): 669 # Escape single quotes in the value 670 escaped_value = value.replace("'", "''") 671 filter_parts.append(f"{key} eq '{escaped_value}'") 672 # Handle boolean 673 elif isinstance(value, bool): 674 filter_parts.append(f"{key} eq {str(value).lower()}") 675 # Handle numeric types 676 elif isinstance(value, (int, float)): 677 filter_parts.append(f"{key} eq {value}") 678 679 return " and ".join(filter_parts) if filter_parts else None
Azure AI Search vector store for production RAG pipelines.
Features:
- Scalable vector search with HNSW (Hierarchical Navigable Small World)
- Hybrid search (vector + keyword BM25)
- Semantic ranking powered by Microsoft Bing
- Automatic indexing of all document fields (universal + custom)
- Efficient filtering on all indexed fields
- High availability and disaster recovery
- Managed service (no infrastructure management)
Prerequisites:
- Azure AI Search service (Basic tier or higher for vector search)
- Service endpoint and API key or a managed-identity token provider
- Index created with vector field configuration
Usage:
Simple Mode (base Document):
# API key auth
store = AzureAISearchVectorStore(
endpoint="https://my-search.search.windows.net",
index_name="documents",
api_key="...",
embedding_dimension=1536
)
# Managed identity / token provider auth
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
token_provider = get_bearer_token_provider(
DefaultAzureCredential(),
"https://search.azure.com/.default" # Azure AI Search scope
)
store = AzureAISearchVectorStore(
endpoint="https://my-search.search.windows.net",
index_name="documents",
token_provider=token_provider,
embedding_dimension=1536
)
# Uses base Document with universal fields only
Custom Schema Mode (domain-specific Document):
from dataclasses import dataclass
from gmf_forge_ai_data.vector_stores import Document
@dataclass
class LegalDocument(Document):
case_number: str = ""
court: str = ""
jurisdiction: str = ""
decision_date: Optional[datetime] = None
# Pass document_type - all fields automatically indexed.
# If the index was provisioned externally (portal, Bicep, etc.) the
# vector and content field names will likely differ from the defaults
# ("embedding" / "content"). Always set vector_field_name and
# content_field_name to match your actual index schema.
store = AzureAISearchVectorStore(
endpoint="https://my-search.search.windows.net",
index_name="legal_documents",
api_key="...",
embedding_dimension=1536,
document_type=LegalDocument,
vector_field_name="contentVector", # match your index schema
content_field_name="chunkContent", # match your index schema
)
# All fields indexed - filter on any field
results = store.search(
query_embedding=vector,
filters={
"jurisdiction": "Federal",
"court": "Supreme Court",
"decision_date": (">=", datetime(2024, 1, 1))
}
)
127 def __init__( 128 self, 129 endpoint: str, 130 index_name: str, 131 api_key: Optional[str] = None, 132 token_provider: Optional[Callable[[], str]] = None, 133 embedding_dimension: int = 1536, 134 document_type: type = Document, 135 vector_field_name: str = "embedding", 136 content_field_name: str = "content", 137 ): 138 """ 139 Initialize Azure AI Search vector store. 140 141 The index must be pre-provisioned using ``AzureAISearchIndexBuilder`` 142 before first use. The constructor only establishes client connections 143 — it never creates or modifies indexes. 144 145 Exactly one of ``api_key`` or ``token_provider`` must be supplied. 146 147 Args: 148 endpoint: Azure AI Search service endpoint 149 index_name: Name of the search index 150 api_key: Azure AI Search API key. Use for local development or 151 when managed identity is not available. 152 token_provider: Zero-argument callable that returns a bearer token 153 string. Use for managed identity / workload identity scenarios. 154 The callable must request the **Azure AI Search** scope:: 155 156 from azure.identity import DefaultAzureCredential, get_bearer_token_provider 157 token_provider = get_bearer_token_provider( 158 DefaultAzureCredential(), 159 "https://search.azure.com/.default" 160 ) 161 162 Note: this scope is different from Azure OpenAI / Cognitive Services 163 (``https://cognitiveservices.azure.com/.default``) — each service 164 requires its own token_provider. 165 embedding_dimension: Dimension of vector embeddings (default 1536) 166 document_type: Document class (base or custom subclass). All fields will be indexed. 167 vector_field_name: Name of the vector field in the Azure Search index (default "embedding"). 168 Override this when the index was built externally with a different field name 169 (e.g. "contentVector", "chunkVector"). 170 content_field_name: Name of the text content field in the Azure Search index (default "content"). 171 Override when the index uses a different name (e.g. "chunkContent", "text"). 172 """ 173 if not api_key and not token_provider: 174 raise ValueError( 175 "Either api_key or token_provider must be supplied to AzureAISearchVectorStore." 176 ) 177 self.endpoint = endpoint 178 self.index_name = index_name 179 self.embedding_dimension = embedding_dimension 180 self.document_type = document_type 181 self.vector_field_name = vector_field_name 182 self.content_field_name = content_field_name 183 184 if token_provider: 185 credential = _TokenProviderCredential(token_provider) 186 else: 187 credential = AzureKeyCredential(api_key) 188 189 # Initialize clients 190 self.search_client = SearchClient( 191 endpoint=endpoint, 192 index_name=index_name, 193 credential=credential 194 )
Initialize Azure AI Search vector store.
The index must be pre-provisioned using AzureAISearchIndexBuilder
before first use. The constructor only establishes client connections
— it never creates or modifies indexes.
Exactly one of api_key or token_provider must be supplied.
Args: endpoint: Azure AI Search service endpoint index_name: Name of the search index api_key: Azure AI Search API key. Use for local development or when managed identity is not available. token_provider: Zero-argument callable that returns a bearer token string. Use for managed identity / workload identity scenarios. The callable must request the Azure AI Search scope::
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
token_provider = get_bearer_token_provider(
DefaultAzureCredential(),
"https://search.azure.com/.default"
)
Note: this scope is different from Azure OpenAI / Cognitive Services
(``https://cognitiveservices.azure.com/.default``) — each service
requires its own token_provider.
embedding_dimension: Dimension of vector embeddings (default 1536)
document_type: Document class (base or custom subclass). All fields will be indexed.
vector_field_name: Name of the vector field in the Azure Search index (default "embedding").
Override this when the index was built externally with a different field name
(e.g. "contentVector", "chunkVector").
content_field_name: Name of the text content field in the Azure Search index (default "content").
Override when the index uses a different name (e.g. "chunkContent", "text").
214 def add_documents( 215 self, 216 documents: List[Document], 217 generate_embeddings: bool = True 218 ) -> List[str]: 219 """ 220 Add documents to Azure AI Search. 221 222 Args: 223 documents: List of Document objects to add (can be base or custom subclasses) 224 generate_embeddings: If True, validates embeddings are present 225 226 Returns: 227 List of document IDs that were added 228 229 Note: 230 Custom document fields are automatically serialized and stored. 231 All fields remain accessible after retrieval. 232 """ 233 if not documents: 234 return [] 235 236 # Helper function to format datetime for Azure Search DateTimeOffset 237 format_datetime_for_azure = self._format_datetime_for_azure 238 239 # Helper function to serialize datetime objects in dictionaries 240 def serialize_for_json(obj): 241 """Recursively convert datetime objects to ISO format strings for JSON serialization.""" 242 if isinstance(obj, datetime): 243 return format_datetime_for_azure(obj) 244 elif isinstance(obj, dict): 245 return {k: serialize_for_json(v) for k, v in obj.items()} 246 elif isinstance(obj, (list, tuple)): 247 return [serialize_for_json(item) for item in obj] 248 else: 249 return obj 250 251 # Prepare documents for upload 252 search_documents = [] 253 added_ids = [] 254 255 for doc in documents: 256 if not doc.id or not doc.content: 257 raise ValueError(f"Document must have id and content: {doc}") 258 259 if generate_embeddings and doc.embedding is None: 260 raise ValueError( 261 f"Document {doc.id} lacks embedding. " 262 "Generate embeddings externally before adding." 263 ) 264 265 # Serialize entire document (including custom fields) to JSON 266 import json 267 import dataclasses 268 doc_dict = doc.to_dict() 269 270 # Convert datetime objects for JSON serialization 271 doc_dict_serializable = serialize_for_json(doc_dict) 272 273 # Convert to Azure Search document format 274 # Store base fields directly 275 search_doc = { 276 "id": doc.id, 277 "content": doc.content, 278 "embedding": doc.embedding if doc.embedding else [], 279 "timestamp": format_datetime_for_azure(doc.timestamp) if doc.timestamp else None, 280 "document_data": json.dumps(doc_dict_serializable) # Keep for metadata and backwards compatibility 281 } 282 283 # Add all custom fields as indexed fields 284 if dataclasses.is_dataclass(doc): 285 for field in dataclasses.fields(doc): 286 # Skip base Document fields (already added above) 287 if field.name in {'id', 'content', 'embedding', 'timestamp', 'metadata'}: 288 continue 289 290 field_value = getattr(doc, field.name, None) 291 292 if field_value is not None: 293 # Serialize datetime fields with Azure DateTimeOffset format 294 if isinstance(field_value, datetime): 295 search_doc[field.name] = format_datetime_for_azure(field_value) 296 # Handle lists/dicts by converting to JSON string 297 elif isinstance(field_value, (list, dict)): 298 search_doc[field.name] = json.dumps(field_value) 299 else: 300 search_doc[field.name] = field_value 301 302 search_documents.append(search_doc) 303 added_ids.append(doc.id) 304 305 # Upload to Azure Search 306 result = self.search_client.upload_documents(documents=search_documents) 307 308 # Check for failures 309 failed_count = sum(1 for r in result if not r.succeeded) 310 if failed_count > 0: 311 logger.warning(f"{failed_count} documents failed to upload") 312 313 return added_ids
Add documents to Azure AI Search.
Args: documents: List of Document objects to add (can be base or custom subclasses) generate_embeddings: If True, validates embeddings are present
Returns: List of document IDs that were added
Note: Custom document fields are automatically serialized and stored. All fields remain accessible after retrieval.
315 def search( 316 self, 317 query: Optional[str] = None, 318 query_embedding: Optional[List[float]] = None, 319 top_k: int = 5, 320 filters: Optional[Dict[str, Any]] = None, 321 search_type: str = "vector" 322 ) -> List[SearchResult]: 323 """ 324 Search Azure AI Search index. 325 326 Args: 327 query: Text query (required for keyword/hybrid search) 328 query_embedding: Vector embedding (required for vector/hybrid search) 329 top_k: Number of results to return 330 filters: Metadata filters (e.g., {"source": "doc.pdf"}) 331 search_type: "vector", "keyword", or "hybrid" 332 333 Returns: 334 List of SearchResult objects, ordered by relevance 335 """ 336 if search_type not in ["vector", "keyword", "hybrid"]: 337 raise ValueError(f"Invalid search_type: {search_type}") 338 339 if search_type in ["vector", "hybrid"] and query_embedding is None: 340 raise ValueError(f"{search_type} search requires query_embedding") 341 342 if search_type in ["keyword", "hybrid"] and query is None: 343 raise ValueError(f"{search_type} search requires query text") 344 345 # Build filter expression 346 filter_expr = self._build_filter_expression(filters) 347 348 # Execute search based on type 349 if search_type == "vector": 350 # Pure vector search 351 vector_query = VectorizedQuery( 352 vector=query_embedding, 353 k_nearest_neighbors=top_k, 354 fields=self.vector_field_name 355 ) 356 357 results = self.search_client.search( 358 search_text=None, 359 vector_queries=[vector_query], 360 filter=filter_expr, 361 top=top_k 362 ) 363 364 elif search_type == "keyword": 365 # Pure keyword search (BM25) 366 results = self.search_client.search( 367 search_text=query, 368 filter=filter_expr, 369 top=top_k, 370 include_total_count=False 371 ) 372 373 else: # hybrid 374 # Hybrid search (vector + keyword) 375 vector_query = VectorizedQuery( 376 vector=query_embedding, 377 k_nearest_neighbors=top_k, 378 fields=self.vector_field_name 379 ) 380 381 results = self.search_client.search( 382 search_text=query, 383 vector_queries=[vector_query], 384 filter=filter_expr, 385 top=top_k 386 ) 387 388 # Convert to SearchResult objects 389 search_results = [] 390 for rank, result in enumerate(results): 391 import json 392 393 # Deserialize full document data 394 doc_data = json.loads(result.get("document_data", "{}") or "{}") 395 396 if not doc_data: 397 # Index was built externally — fields are stored directly on the result. 398 # Collect all non-Azure-internal fields, excluding the vector blob. 399 doc_data = { 400 k: v for k, v in result.items() 401 if not k.startswith("@") and k != self.vector_field_name 402 } 403 # Remap content field when the index uses a non-standard name. 404 if self.content_field_name != "content" and self.content_field_name in doc_data: 405 doc_data["content"] = doc_data.pop(self.content_field_name) 406 407 # Reconstruct document using the configured document_type 408 doc = self.document_type.from_dict(doc_data) 409 410 search_results.append(SearchResult( 411 document=doc, 412 score=result.get("@search.score", 0.0), 413 rank=rank 414 )) 415 416 return search_results
Search Azure AI Search index.
Args: query: Text query (required for keyword/hybrid search) query_embedding: Vector embedding (required for vector/hybrid search) top_k: Number of results to return filters: Metadata filters (e.g., {"source": "doc.pdf"}) search_type: "vector", "keyword", or "hybrid"
Returns: List of SearchResult objects, ordered by relevance
418 def delete_documents(self, document_ids: List[str]) -> int: 419 """ 420 Delete documents from Azure AI Search. 421 422 Args: 423 document_ids: List of document IDs to delete 424 425 Returns: 426 Number of documents successfully deleted 427 """ 428 if not document_ids: 429 return 0 430 431 # Prepare delete operations 432 delete_docs = [{"id": doc_id} for doc_id in document_ids] 433 434 # Execute delete 435 result = self.search_client.delete_documents(documents=delete_docs) 436 437 # Count successful deletions 438 deleted_count = sum(1 for r in result if r.succeeded) 439 return deleted_count
Delete documents from Azure AI Search.
Args: document_ids: List of document IDs to delete
Returns: Number of documents successfully deleted
441 def update_document( 442 self, 443 document_id: str, 444 content: Optional[str] = None, 445 embedding: Optional[List[float]] = None, 446 metadata: Optional[Dict[str, Any]] = None, 447 timestamp: Optional[datetime] = None, 448 **custom_fields 449 ) -> bool: 450 """ 451 Update an existing document in Azure AI Search. 452 453 Args: 454 document_id: ID of the document to update 455 content: New content (if provided) 456 embedding: New embedding (if provided) 457 metadata: New metadata (merged with existing) 458 timestamp: New timestamp (if provided) 459 **custom_fields: Any custom fields to update 460 461 Returns: 462 True if update was successful, False if document not found 463 """ 464 try: 465 import json 466 467 # Retrieve existing document 468 existing_doc = self.search_client.get_document(key=document_id) 469 doc_data = json.loads(existing_doc.get("document_data", "{}")) 470 471 # Update fields in document data 472 if content is not None: 473 doc_data["content"] = content 474 existing_doc["content"] = content 475 476 if embedding is not None: 477 doc_data["embedding"] = embedding 478 existing_doc["embedding"] = embedding 479 480 if timestamp is not None: 481 doc_data["timestamp"] = self._format_datetime_for_azure(timestamp) 482 existing_doc["timestamp"] = self._format_datetime_for_azure(timestamp) 483 484 if metadata is not None: 485 if "metadata" not in doc_data: 486 doc_data["metadata"] = {} 487 doc_data["metadata"].update(metadata) 488 489 # Update custom fields 490 for key, value in custom_fields.items(): 491 doc_data[key] = value 492 493 # Update document_data 494 existing_doc["document_data"] = json.dumps(doc_data) 495 496 # Merge/upload updated document 497 result = self.search_client.merge_or_upload_documents(documents=[existing_doc]) 498 499 return result[0].succeeded 500 501 except Exception as e: 502 logger.error(f"Failed to update document {document_id}: {e}") 503 return False
Update an existing document in Azure AI Search.
Args: document_id: ID of the document to update content: New content (if provided) embedding: New embedding (if provided) metadata: New metadata (merged with existing) timestamp: New timestamp (if provided) **custom_fields: Any custom fields to update
Returns: True if update was successful, False if document not found
505 def get_document(self, document_id: str) -> Optional[Document]: 506 """ 507 Retrieve a document by ID from Azure AI Search. 508 509 Args: 510 document_id: ID of the document to retrieve 511 512 Returns: 513 Document if found, None otherwise (returns configured document_type) 514 """ 515 try: 516 import json 517 result = self.search_client.get_document(key=document_id) 518 519 # Deserialize full document data using configured document_type 520 doc_data = json.loads(result.get("document_data", "{}")) 521 return self.document_type.from_dict(doc_data) 522 523 except Exception as e: 524 logger.debug(f"Document {document_id} not found: {e}") 525 return None
Retrieve a document by ID from Azure AI Search.
Args: document_id: ID of the document to retrieve
Returns: Document if found, None otherwise (returns configured document_type)
527 def get_document_by_parent(self, document_id: str) -> List[Document]: 528 """ 529 Retrieve all chunks belonging to a parent document. 530 531 Args: 532 document_id: The value of the ``documentId`` field shared across 533 all chunks of the same parent document. 534 535 Returns: 536 List of Document objects sorted by ``pageNumber``, or an empty 537 list if no matching chunks are found. 538 """ 539 results = self.search( 540 query="*", 541 search_type="keyword", 542 filters={"documentId": document_id}, 543 top_k=1000, 544 ) 545 chunks = [r.document for r in results] 546 chunks.sort(key=lambda c: getattr(c, "pageNumber", 0)) 547 return chunks
Retrieve all chunks belonging to a parent document.
Args:
document_id: The value of the documentId field shared across
all chunks of the same parent document.
Returns:
List of Document objects sorted by pageNumber, or an empty
list if no matching chunks are found.
549 def count(self) -> int: 550 """ 551 Get the total number of documents in the index. 552 553 Returns: 554 Total document count 555 """ 556 # Execute a search to get total count 557 results = self.search_client.search( 558 search_text="*", 559 include_total_count=True, 560 top=0 561 ) 562 563 # Get count from results 564 count = getattr(results, 'get_count', lambda: 0)() 565 return count if count is not None else 0
Get the total number of documents in the index.
Returns: Total document count
567 def clear(self) -> None: 568 """ 569 Remove all documents from the index. 570 571 Warning: This operation is irreversible. 572 """ 573 # Search for all document IDs 574 results = self.search_client.search( 575 search_text="*", 576 select=["id"], 577 top=1000 # Adjust batch size as needed 578 ) 579 580 # Collect all IDs 581 doc_ids = [r["id"] for r in results] 582 583 # Delete in batches 584 if doc_ids: 585 self.delete_documents(doc_ids) 586 logger.info(f"Cleared {len(doc_ids)} documents from index '{self.index_name}'")
Remove all documents from the index.
Warning: This operation is irreversible.
32class AzureCosmosDBVectorStore(BaseVectorStore): 33 """ 34 Azure Cosmos DB NoSQL API vector store. 35 36 Stores documents in a Cosmos DB NoSQL container and queries them using 37 Cosmos DB's vector distance functions and SQL-like query language. 38 39 Features: 40 - Vector search using Cosmos DB vector indexing (cosine distance) 41 - Full-text keyword search via SQL CONTAINS / LIKE queries 42 - Hybrid search (vector + keyword, equal weighting) 43 - Metadata filtering via SQL WHERE clauses 44 - Custom ``document_type`` support (any Document subclass) 45 - Automatic container creation with vector embedding policy 46 47 Usage:: 48 49 from gmf_forge_ai_data.vector_stores import AzureCosmosDBVectorStore 50 51 store = AzureCosmosDBVectorStore( 52 endpoint="https://your-account.documents.azure.com:443/", 53 key="your-key", 54 database_name="rag_db", 55 container_name="documents", 56 embedding_dimension=1536, 57 ) 58 store.add_documents(docs_with_embeddings) 59 results = store.search(query_embedding=embedding, top_k=5) 60 """ 61 62 def __init__( 63 self, 64 endpoint: str, 65 key: str, 66 database_name: str, 67 container_name: str, 68 embedding_dimension: int = 1536, 69 document_type: type = Document, 70 ssl_cert_path: Optional[str] = None, 71 ) -> None: 72 """ 73 Initialise the Cosmos DB NoSQL vector store. 74 75 The database and container must be pre-provisioned using 76 ``CosmosDBIndexBuilder`` before first use. The constructor only 77 establishes client connections — it never creates or modifies 78 databases or containers. 79 80 Args: 81 endpoint: Cosmos DB account endpoint 82 (e.g. ``https://your-account.documents.azure.com:443/``). 83 key: Cosmos DB account key or resource token. 84 database_name: Name of the Cosmos DB database. 85 container_name: Name of the container. 86 embedding_dimension: Dimension of vector embeddings (default 1536). 87 document_type: Document class to use when deserialising results. 88 Pass a custom subclass to support additional typed fields. 89 ssl_cert_path: Path to a CA certificate bundle (PEM) for TLS 90 verification. Useful in corporate environments with custom 91 certificate authorities. 92 """ 93 try: 94 from azure.cosmos import CosmosClient # noqa: F401 95 except ImportError as exc: 96 raise ImportError( 97 "azure-cosmos is required for AzureCosmosDBVectorStore. " 98 "Install it with: pip install azure-cosmos" 99 ) from exc 100 101 self.embedding_dimension = embedding_dimension 102 self.document_type = document_type 103 self._container_name = container_name 104 105 # Build client with optional custom SSL cert 106 import ssl as _ssl 107 connection_kwargs: Dict[str, Any] = {} 108 if ssl_cert_path: 109 ssl_context = _ssl.create_default_context(cafile=ssl_cert_path) 110 connection_kwargs["connection_verify"] = ssl_cert_path 111 112 from azure.cosmos import CosmosClient 113 self._client = CosmosClient(endpoint, credential=key, **connection_kwargs) 114 115 # Database and container — pre-provisioned by developer 116 self._database = self._client.get_database_client(database_name) 117 self._container = self._database.get_container_client(container_name) 118 119 # ------------------------------------------------------------------ 120 # Serialisation helpers 121 # ------------------------------------------------------------------ 122 123 @staticmethod 124 def _serialize_value(obj: Any) -> Any: 125 """Recursively convert datetime objects to ISO strings.""" 126 if isinstance(obj, datetime): 127 return obj.isoformat() 128 if isinstance(obj, dict): 129 return {k: AzureCosmosDBVectorStore._serialize_value(v) for k, v in obj.items()} 130 if isinstance(obj, (list, tuple)): 131 return [AzureCosmosDBVectorStore._serialize_value(i) for i in obj] 132 return obj 133 134 def _to_cosmos(self, doc: Document) -> Dict[str, Any]: 135 """Convert a Document to a Cosmos DB item dict.""" 136 doc_dict = doc.to_dict() 137 doc_dict_serialisable = self._serialize_value(doc_dict) 138 item = { 139 "id": doc.id, 140 "content": doc.content, 141 "embedding": doc.embedding or [], 142 "timestamp": doc.timestamp.isoformat() if doc.timestamp else None, 143 "metadata": doc.metadata, 144 "document_data": json.dumps(doc_dict_serialisable), 145 } 146 # Promote custom fields (e.g. category, author, year) to top-level 147 # so they are filterable via SQL WHERE clauses 148 base_keys = {"id", "content", "embedding", "timestamp", "metadata"} 149 for key, value in doc_dict_serialisable.items(): 150 if key not in base_keys: 151 item[key] = value 152 return item 153 154 def _from_cosmos(self, cosmos_item: Dict[str, Any]) -> Document: 155 """Reconstruct a Document (or subclass) from a Cosmos DB item.""" 156 doc_data = json.loads(cosmos_item.get("document_data", "{}")) 157 return self.document_type.from_dict(doc_data) 158 159 # ------------------------------------------------------------------ 160 # Filter helpers 161 # ------------------------------------------------------------------ 162 163 @staticmethod 164 def _build_where_clause(filters: Optional[Dict[str, Any]]) -> str: 165 """Translate a user-facing filter dict to a SQL WHERE clause.""" 166 if not filters: 167 return "" 168 op_map = {">=": ">=", "<=": "<=", ">": ">", "<": "<", "!=": "!="} 169 conditions: List[str] = [] 170 for key, value in filters.items(): 171 if isinstance(value, tuple) and len(value) == 2: 172 op, val = value 173 sql_op = op_map.get(op, "=") 174 if isinstance(val, str): 175 conditions.append(f"c.{key} {sql_op} '{val}'") 176 elif isinstance(val, datetime): 177 conditions.append(f"c.{key} {sql_op} '{val.isoformat()}'") 178 else: 179 conditions.append(f"c.{key} {sql_op} {val}") 180 else: 181 if isinstance(value, str): 182 conditions.append(f"c.{key} = '{value}'") 183 elif isinstance(value, bool): 184 conditions.append(f"c.{key} = {'true' if value else 'false'}") 185 elif isinstance(value, datetime): 186 conditions.append(f"c.{key} = '{value.isoformat()}'") 187 else: 188 conditions.append(f"c.{key} = {value}") 189 return " AND ".join(conditions) 190 191 # ------------------------------------------------------------------ 192 # BaseVectorStore interface 193 # ------------------------------------------------------------------ 194 195 def add_documents( 196 self, 197 documents: List[Document], 198 generate_embeddings: bool = True, 199 ) -> List[str]: 200 """ 201 Upsert documents into the Cosmos DB container. 202 203 Args: 204 documents: Documents to add. Each must have ``id`` and ``content`` 205 set. If *generate_embeddings* is ``True`` (default) every 206 document must also carry a pre-computed ``embedding``. 207 generate_embeddings: When ``True`` the method validates that all 208 documents already have embeddings rather than generating them 209 itself (generation is the caller's responsibility). 210 211 Returns: 212 List of document IDs that were successfully upserted. 213 """ 214 if not documents: 215 return [] 216 217 added_ids: List[str] = [] 218 219 for doc in documents: 220 if not doc.id or not doc.content: 221 raise ValueError(f"Document must have id and content: {doc}") 222 223 if generate_embeddings and doc.embedding is None: 224 raise ValueError( 225 f"Document '{doc.id}' has no embedding. " 226 "Generate embeddings before calling add_documents()." 227 ) 228 229 cosmos_item = self._to_cosmos(doc) 230 self._container.upsert_item(cosmos_item) 231 added_ids.append(doc.id) 232 233 logger.info( 234 "Upserted %d document(s) into Cosmos DB container '%s'", 235 len(added_ids), 236 self._container_name, 237 ) 238 return added_ids 239 240 def search( 241 self, 242 query: Optional[str] = None, 243 query_embedding: Optional[List[float]] = None, 244 top_k: int = 5, 245 filters: Optional[Dict[str, Any]] = None, 246 search_type: str = "vector", 247 ) -> List[SearchResult]: 248 """ 249 Search the Cosmos DB container. 250 251 Args: 252 query: Plain-text query string. Required for ``keyword`` and 253 ``hybrid`` search types. 254 query_embedding: Pre-computed query vector. Required for 255 ``vector`` and ``hybrid`` search types. 256 top_k: Maximum number of results to return. 257 filters: Optional filter dict. Supports equality 258 (``{"key": value}``) and range (``{"key": (">=", value)}``) 259 operators. 260 search_type: One of ``"vector"``, ``"keyword"``, or ``"hybrid"``. 261 262 Returns: 263 Ranked list of :class:`SearchResult` objects. 264 """ 265 if search_type not in ("vector", "keyword", "hybrid"): 266 raise ValueError( 267 f"Invalid search_type '{search_type}'. " 268 "Must be 'vector', 'keyword', or 'hybrid'." 269 ) 270 if search_type in ("vector", "hybrid") and query_embedding is None: 271 raise ValueError(f"search_type='{search_type}' requires query_embedding") 272 if search_type in ("keyword", "hybrid") and query is None: 273 raise ValueError(f"search_type='{search_type}' requires query text") 274 275 if search_type == "vector": 276 return self._vector_search(query_embedding, top_k, filters) 277 if search_type == "keyword": 278 return self._keyword_search(query, top_k, filters) 279 return self._hybrid_search(query, query_embedding, top_k, filters) 280 281 # --- search helpers ----------------------------------------------- 282 283 def _vector_search( 284 self, 285 query_embedding: List[float], 286 top_k: int, 287 filters: Optional[Dict[str, Any]], 288 ) -> List[SearchResult]: 289 where_clause = self._build_where_clause(filters) 290 where_sql = f"WHERE {where_clause}" if where_clause else "" 291 292 query_text = ( 293 f"SELECT TOP @top_k c.id, c.document_data, " 294 f"VectorDistance(c.embedding, @embedding) AS score " 295 f"FROM c {where_sql} " 296 f"ORDER BY VectorDistance(c.embedding, @embedding)" 297 ) 298 parameters = [ 299 {"name": "@top_k", "value": top_k}, 300 {"name": "@embedding", "value": query_embedding}, 301 ] 302 303 items = list(self._container.query_items( 304 query=query_text, 305 parameters=parameters, 306 enable_cross_partition_query=True, 307 )) 308 309 results = [] 310 for rank, item in enumerate(items): 311 doc = self._from_cosmos(item) 312 # VectorDistance with cosine returns distance (0 = identical); 313 # convert to similarity score (1 = identical) 314 distance = float(item.get("score", 1.0)) 315 similarity = 1.0 - distance 316 results.append(SearchResult(document=doc, score=similarity, rank=rank)) 317 return results 318 319 def _keyword_search( 320 self, 321 query: str, 322 top_k: int, 323 filters: Optional[Dict[str, Any]], 324 ) -> List[SearchResult]: 325 # Build keyword conditions using CONTAINS for each query term 326 terms = query.split() 327 keyword_conditions = " OR ".join( 328 f"CONTAINS(LOWER(c.content), '{term.lower()}')" for term in terms 329 ) 330 where_clause = self._build_where_clause(filters) 331 filter_sql = f" AND {where_clause}" if where_clause else "" 332 333 # Fetch ALL matching documents — scoring and top_k slicing happen in Python. 334 # Applying SELECT TOP before Python scoring would miss high-scoring documents 335 # if Cosmos DB returns lower-scoring ones first (order without ORDER BY is undefined). 336 query_text = ( 337 f"SELECT c.id, c.document_data, c.content " 338 f"FROM c WHERE ({keyword_conditions}){filter_sql}" 339 ) 340 341 items = list(self._container.query_items( 342 query=query_text, 343 parameters=[], 344 enable_cross_partition_query=True, 345 )) 346 347 # Score by number of matching terms (simple BM25-like relevance) 348 results = [] 349 for item in items: 350 content_lower = item.get("content", "").lower() 351 match_count = sum(1 for t in terms if t.lower() in content_lower) 352 score = match_count / len(terms) if terms else 0.0 353 results.append(SearchResult( 354 document=self._from_cosmos(item), 355 score=score, 356 rank=0, 357 )) 358 359 results.sort(key=lambda r: r.score, reverse=True) 360 for rank, r in enumerate(results): 361 r.rank = rank 362 return results[:top_k] 363 364 def _hybrid_search( 365 self, 366 query: str, 367 query_embedding: List[float], 368 top_k: int, 369 filters: Optional[Dict[str, Any]], 370 ) -> List[SearchResult]: 371 """Combine vector and keyword results with equal (50/50) weighting.""" 372 vector_results = self._vector_search(query_embedding, top_k, filters) 373 keyword_results = self._keyword_search(query, top_k, filters) 374 375 def _max(results: List[SearchResult]) -> float: 376 return max((r.score for r in results), default=1.0) or 1.0 377 378 v_max = _max(vector_results) 379 k_max = _max(keyword_results) 380 381 combined: Dict[str, float] = {} 382 doc_map: Dict[str, Document] = {} 383 384 for r in vector_results: 385 combined[r.document.id] = 0.5 * (r.score / v_max) 386 doc_map[r.document.id] = r.document 387 388 for r in keyword_results: 389 combined[r.document.id] = ( 390 combined.get(r.document.id, 0.0) + 0.5 * (r.score / k_max) 391 ) 392 doc_map.setdefault(r.document.id, r.document) 393 394 sorted_ids = sorted(combined, key=lambda x: combined[x], reverse=True)[:top_k] 395 return [ 396 SearchResult( 397 document=doc_map[doc_id], 398 score=combined[doc_id], 399 rank=rank, 400 ) 401 for rank, doc_id in enumerate(sorted_ids) 402 ] 403 404 # ------------------------------------------------------------------ 405 # Document lifecycle 406 # ------------------------------------------------------------------ 407 408 def delete_documents(self, document_ids: List[str]) -> int: 409 """ 410 Delete documents by ID. 411 412 Args: 413 document_ids: IDs of the documents to remove. 414 415 Returns: 416 Number of documents actually deleted. 417 """ 418 deleted = 0 419 for doc_id in document_ids: 420 try: 421 self._container.delete_item(item=doc_id, partition_key=doc_id) 422 deleted += 1 423 except Exception: 424 logger.debug("Document '%s' not found for deletion", doc_id) 425 return deleted 426 427 def update_document( 428 self, 429 document_id: str, 430 content: Optional[str] = None, 431 embedding: Optional[List[float]] = None, 432 metadata: Optional[Dict[str, Any]] = None, 433 ) -> bool: 434 """ 435 Update an existing document. 436 437 Args: 438 document_id: ID of the document to update. 439 content: New text content (optional). 440 embedding: New embedding vector (optional). 441 metadata: Metadata key-value pairs to *merge* into existing metadata 442 (optional). 443 444 Returns: 445 ``True`` if the document was found and updated, ``False`` otherwise. 446 """ 447 try: 448 item = self._container.read_item(item=document_id, partition_key=document_id) 449 except Exception: 450 return False 451 452 doc = self._from_cosmos(item) 453 454 if content is not None: 455 doc.content = content 456 if embedding is not None: 457 doc.embedding = embedding 458 if metadata is not None: 459 doc.metadata.update(metadata) 460 461 self._container.upsert_item(self._to_cosmos(doc)) 462 return True 463 464 def get_document(self, document_id: str) -> Optional[Document]: 465 """ 466 Retrieve a single document by ID. 467 468 Args: 469 document_id: The document's unique identifier. 470 471 Returns: 472 The :class:`Document` (or subclass) if found, otherwise ``None``. 473 """ 474 try: 475 item = self._container.read_item(item=document_id, partition_key=document_id) 476 return self._from_cosmos(item) 477 except Exception: 478 return None 479 480 def count(self) -> int: 481 """Return the total number of documents in the container.""" 482 query_text = "SELECT VALUE COUNT(1) FROM c" 483 items = list(self._container.query_items( 484 query=query_text, 485 enable_cross_partition_query=True, 486 )) 487 return items[0] if items else 0 488 489 def clear(self) -> None: 490 """ 491 Remove **all** documents from the container. 492 493 Warning: 494 This operation is irreversible. 495 """ 496 query_text = "SELECT c.id FROM c" 497 items = list(self._container.query_items( 498 query=query_text, 499 enable_cross_partition_query=True, 500 )) 501 for item in items: 502 try: 503 self._container.delete_item(item=item["id"], partition_key=item["id"]) 504 except Exception: 505 pass 506 logger.warning( 507 "Cleared all documents from Cosmos DB container '%s'", 508 self._container_name, 509 ) 510 511 def close(self) -> None: 512 """Close the underlying Cosmos DB client.""" 513 # azure-cosmos CosmosClient doesn't require explicit close, 514 # but we provide the method for interface consistency 515 pass
Azure Cosmos DB NoSQL API vector store.
Stores documents in a Cosmos DB NoSQL container and queries them using Cosmos DB's vector distance functions and SQL-like query language.
Features:
- Vector search using Cosmos DB vector indexing (cosine distance)
- Full-text keyword search via SQL CONTAINS / LIKE queries
- Hybrid search (vector + keyword, equal weighting)
- Metadata filtering via SQL WHERE clauses
- Custom
document_typesupport (any Document subclass) - Automatic container creation with vector embedding policy
Usage::
from gmf_forge_ai_data.vector_stores import AzureCosmosDBVectorStore
store = AzureCosmosDBVectorStore(
endpoint="https://your-account.documents.azure.com:443/",
key="your-key",
database_name="rag_db",
container_name="documents",
embedding_dimension=1536,
)
store.add_documents(docs_with_embeddings)
results = store.search(query_embedding=embedding, top_k=5)
62 def __init__( 63 self, 64 endpoint: str, 65 key: str, 66 database_name: str, 67 container_name: str, 68 embedding_dimension: int = 1536, 69 document_type: type = Document, 70 ssl_cert_path: Optional[str] = None, 71 ) -> None: 72 """ 73 Initialise the Cosmos DB NoSQL vector store. 74 75 The database and container must be pre-provisioned using 76 ``CosmosDBIndexBuilder`` before first use. The constructor only 77 establishes client connections — it never creates or modifies 78 databases or containers. 79 80 Args: 81 endpoint: Cosmos DB account endpoint 82 (e.g. ``https://your-account.documents.azure.com:443/``). 83 key: Cosmos DB account key or resource token. 84 database_name: Name of the Cosmos DB database. 85 container_name: Name of the container. 86 embedding_dimension: Dimension of vector embeddings (default 1536). 87 document_type: Document class to use when deserialising results. 88 Pass a custom subclass to support additional typed fields. 89 ssl_cert_path: Path to a CA certificate bundle (PEM) for TLS 90 verification. Useful in corporate environments with custom 91 certificate authorities. 92 """ 93 try: 94 from azure.cosmos import CosmosClient # noqa: F401 95 except ImportError as exc: 96 raise ImportError( 97 "azure-cosmos is required for AzureCosmosDBVectorStore. " 98 "Install it with: pip install azure-cosmos" 99 ) from exc 100 101 self.embedding_dimension = embedding_dimension 102 self.document_type = document_type 103 self._container_name = container_name 104 105 # Build client with optional custom SSL cert 106 import ssl as _ssl 107 connection_kwargs: Dict[str, Any] = {} 108 if ssl_cert_path: 109 ssl_context = _ssl.create_default_context(cafile=ssl_cert_path) 110 connection_kwargs["connection_verify"] = ssl_cert_path 111 112 from azure.cosmos import CosmosClient 113 self._client = CosmosClient(endpoint, credential=key, **connection_kwargs) 114 115 # Database and container — pre-provisioned by developer 116 self._database = self._client.get_database_client(database_name) 117 self._container = self._database.get_container_client(container_name)
Initialise the Cosmos DB NoSQL vector store.
The database and container must be pre-provisioned using
CosmosDBIndexBuilder before first use. The constructor only
establishes client connections — it never creates or modifies
databases or containers.
Args:
endpoint: Cosmos DB account endpoint
(e.g. https://your-account.documents.azure.com:443/).
key: Cosmos DB account key or resource token.
database_name: Name of the Cosmos DB database.
container_name: Name of the container.
embedding_dimension: Dimension of vector embeddings (default 1536).
document_type: Document class to use when deserialising results.
Pass a custom subclass to support additional typed fields.
ssl_cert_path: Path to a CA certificate bundle (PEM) for TLS
verification. Useful in corporate environments with custom
certificate authorities.
195 def add_documents( 196 self, 197 documents: List[Document], 198 generate_embeddings: bool = True, 199 ) -> List[str]: 200 """ 201 Upsert documents into the Cosmos DB container. 202 203 Args: 204 documents: Documents to add. Each must have ``id`` and ``content`` 205 set. If *generate_embeddings* is ``True`` (default) every 206 document must also carry a pre-computed ``embedding``. 207 generate_embeddings: When ``True`` the method validates that all 208 documents already have embeddings rather than generating them 209 itself (generation is the caller's responsibility). 210 211 Returns: 212 List of document IDs that were successfully upserted. 213 """ 214 if not documents: 215 return [] 216 217 added_ids: List[str] = [] 218 219 for doc in documents: 220 if not doc.id or not doc.content: 221 raise ValueError(f"Document must have id and content: {doc}") 222 223 if generate_embeddings and doc.embedding is None: 224 raise ValueError( 225 f"Document '{doc.id}' has no embedding. " 226 "Generate embeddings before calling add_documents()." 227 ) 228 229 cosmos_item = self._to_cosmos(doc) 230 self._container.upsert_item(cosmos_item) 231 added_ids.append(doc.id) 232 233 logger.info( 234 "Upserted %d document(s) into Cosmos DB container '%s'", 235 len(added_ids), 236 self._container_name, 237 ) 238 return added_ids
Upsert documents into the Cosmos DB container.
Args:
documents: Documents to add. Each must have id and content
set. If generate_embeddings is True (default) every
document must also carry a pre-computed embedding.
generate_embeddings: When True the method validates that all
documents already have embeddings rather than generating them
itself (generation is the caller's responsibility).
Returns: List of document IDs that were successfully upserted.
240 def search( 241 self, 242 query: Optional[str] = None, 243 query_embedding: Optional[List[float]] = None, 244 top_k: int = 5, 245 filters: Optional[Dict[str, Any]] = None, 246 search_type: str = "vector", 247 ) -> List[SearchResult]: 248 """ 249 Search the Cosmos DB container. 250 251 Args: 252 query: Plain-text query string. Required for ``keyword`` and 253 ``hybrid`` search types. 254 query_embedding: Pre-computed query vector. Required for 255 ``vector`` and ``hybrid`` search types. 256 top_k: Maximum number of results to return. 257 filters: Optional filter dict. Supports equality 258 (``{"key": value}``) and range (``{"key": (">=", value)}``) 259 operators. 260 search_type: One of ``"vector"``, ``"keyword"``, or ``"hybrid"``. 261 262 Returns: 263 Ranked list of :class:`SearchResult` objects. 264 """ 265 if search_type not in ("vector", "keyword", "hybrid"): 266 raise ValueError( 267 f"Invalid search_type '{search_type}'. " 268 "Must be 'vector', 'keyword', or 'hybrid'." 269 ) 270 if search_type in ("vector", "hybrid") and query_embedding is None: 271 raise ValueError(f"search_type='{search_type}' requires query_embedding") 272 if search_type in ("keyword", "hybrid") and query is None: 273 raise ValueError(f"search_type='{search_type}' requires query text") 274 275 if search_type == "vector": 276 return self._vector_search(query_embedding, top_k, filters) 277 if search_type == "keyword": 278 return self._keyword_search(query, top_k, filters) 279 return self._hybrid_search(query, query_embedding, top_k, filters)
Search the Cosmos DB container.
Args:
query: Plain-text query string. Required for keyword and
hybrid search types.
query_embedding: Pre-computed query vector. Required for
vector and hybrid search types.
top_k: Maximum number of results to return.
filters: Optional filter dict. Supports equality
({"key": value}) and range ({"key": (">=", value)})
operators.
search_type: One of "vector", "keyword", or "hybrid".
Returns:
Ranked list of SearchResult objects.
408 def delete_documents(self, document_ids: List[str]) -> int: 409 """ 410 Delete documents by ID. 411 412 Args: 413 document_ids: IDs of the documents to remove. 414 415 Returns: 416 Number of documents actually deleted. 417 """ 418 deleted = 0 419 for doc_id in document_ids: 420 try: 421 self._container.delete_item(item=doc_id, partition_key=doc_id) 422 deleted += 1 423 except Exception: 424 logger.debug("Document '%s' not found for deletion", doc_id) 425 return deleted
Delete documents by ID.
Args: document_ids: IDs of the documents to remove.
Returns: Number of documents actually deleted.
427 def update_document( 428 self, 429 document_id: str, 430 content: Optional[str] = None, 431 embedding: Optional[List[float]] = None, 432 metadata: Optional[Dict[str, Any]] = None, 433 ) -> bool: 434 """ 435 Update an existing document. 436 437 Args: 438 document_id: ID of the document to update. 439 content: New text content (optional). 440 embedding: New embedding vector (optional). 441 metadata: Metadata key-value pairs to *merge* into existing metadata 442 (optional). 443 444 Returns: 445 ``True`` if the document was found and updated, ``False`` otherwise. 446 """ 447 try: 448 item = self._container.read_item(item=document_id, partition_key=document_id) 449 except Exception: 450 return False 451 452 doc = self._from_cosmos(item) 453 454 if content is not None: 455 doc.content = content 456 if embedding is not None: 457 doc.embedding = embedding 458 if metadata is not None: 459 doc.metadata.update(metadata) 460 461 self._container.upsert_item(self._to_cosmos(doc)) 462 return True
Update an existing document.
Args: document_id: ID of the document to update. content: New text content (optional). embedding: New embedding vector (optional). metadata: Metadata key-value pairs to merge into existing metadata (optional).
Returns:
True if the document was found and updated, False otherwise.
464 def get_document(self, document_id: str) -> Optional[Document]: 465 """ 466 Retrieve a single document by ID. 467 468 Args: 469 document_id: The document's unique identifier. 470 471 Returns: 472 The :class:`Document` (or subclass) if found, otherwise ``None``. 473 """ 474 try: 475 item = self._container.read_item(item=document_id, partition_key=document_id) 476 return self._from_cosmos(item) 477 except Exception: 478 return None
Retrieve a single document by ID.
Args: document_id: The document's unique identifier.
Returns:
The Document (or subclass) if found, otherwise None.
480 def count(self) -> int: 481 """Return the total number of documents in the container.""" 482 query_text = "SELECT VALUE COUNT(1) FROM c" 483 items = list(self._container.query_items( 484 query=query_text, 485 enable_cross_partition_query=True, 486 )) 487 return items[0] if items else 0
Return the total number of documents in the container.
489 def clear(self) -> None: 490 """ 491 Remove **all** documents from the container. 492 493 Warning: 494 This operation is irreversible. 495 """ 496 query_text = "SELECT c.id FROM c" 497 items = list(self._container.query_items( 498 query=query_text, 499 enable_cross_partition_query=True, 500 )) 501 for item in items: 502 try: 503 self._container.delete_item(item=item["id"], partition_key=item["id"]) 504 except Exception: 505 pass 506 logger.warning( 507 "Cleared all documents from Cosmos DB container '%s'", 508 self._container_name, 509 )
Remove all documents from the container.
Warning: This operation is irreversible.
35class MongoDBVectorStore(BaseVectorStore): 36 """ 37 MongoDB Atlas Vector Search vector store. 38 39 Stores documents in a MongoDB collection and queries them using the 40 ``$vectorSearch`` aggregation stage. Hybrid search is implemented by 41 running vector and text searches separately and combining their normalised 42 scores with a 50 / 50 weight. 43 44 Features: 45 - Atlas Vector Search (approximate KNN, configurable candidates) 46 - Full-text keyword search via ``$text`` index 47 - Hybrid search (vector + keyword, equal weighting) 48 - Metadata pre-filtering on vector search 49 - Custom ``document_type`` support (any Document subclass) 50 51 Note: 52 The Atlas Vector Search index (``$vectorSearch``) must be created 53 outside of this class — through the MongoDB Atlas UI, Atlas CLI, or 54 Atlas Admin API. The ``vector_index_name`` parameter must match the 55 name you gave that index. 56 57 Usage:: 58 59 from gmf_forge_ai_data.vector_stores import MongoDBVectorStore 60 61 store = MongoDBVectorStore( 62 connection_string="mongodb+srv://...", 63 database_name="rag_db", 64 collection_name="documents", 65 vector_index_name="vector_index", 66 embedding_dimension=1536, 67 ) 68 store.add_documents(docs_with_embeddings) 69 results = store.search(query_embedding=embedding, top_k=5) 70 """ 71 72 def __init__( 73 self, 74 connection_string: str, 75 database_name: str, 76 collection_name: str, 77 vector_index_name: str = "vector_index", 78 embedding_dimension: int = 1536, 79 document_type: type = Document, 80 ssl_cert_path: Optional[str] = None, 81 ) -> None: 82 """ 83 Initialise the MongoDB Atlas vector store. 84 85 Args: 86 connection_string: MongoDB Atlas connection string 87 (e.g. ``mongodb+srv://user:pass@cluster.mongodb.net/``). 88 database_name: Name of the MongoDB database. 89 collection_name: Name of the collection. 90 vector_index_name: Name of the Atlas Vector Search index defined 91 on this collection. Defaults to ``"vector_index"``. 92 embedding_dimension: Dimension of vector embeddings. Must match 93 the dimension configured in the Atlas Vector Search index 94 (default 1536). 95 document_type: Document class to use when deserialising results. 96 Pass a custom subclass to support additional typed fields. 97 ssl_cert_path: Path to a CA certificate bundle (PEM) for TLS 98 verification. Useful in corporate environments with custom 99 certificate authorities. 100 """ 101 try: 102 import pymongo # noqa: F401 — checked at construction time 103 except ImportError as exc: 104 raise ImportError( 105 "pymongo is required for MongoDBVectorStore. " 106 "Install it with: pip install pymongo" 107 ) from exc 108 109 self.embedding_dimension = embedding_dimension 110 self.document_type = document_type 111 self.vector_index_name = vector_index_name 112 self._collection_name = collection_name 113 114 client_kwargs: Dict[str, Any] = {} 115 if ssl_cert_path: 116 client_kwargs["tlsCAFile"] = ssl_cert_path 117 118 import pymongo 119 120 self._client: pymongo.MongoClient = pymongo.MongoClient( 121 connection_string, **client_kwargs 122 ) 123 self._db = self._client[database_name] 124 self._collection = self._db[collection_name] 125 126 # ------------------------------------------------------------------ 127 # Serialisation helpers 128 # ------------------------------------------------------------------ 129 130 @staticmethod 131 def _serialize_value(obj: Any) -> Any: 132 """Recursively convert datetime objects to ISO strings.""" 133 if isinstance(obj, datetime): 134 return obj.isoformat() 135 if isinstance(obj, dict): 136 return {k: MongoDBVectorStore._serialize_value(v) for k, v in obj.items()} 137 if isinstance(obj, (list, tuple)): 138 return [MongoDBVectorStore._serialize_value(i) for i in obj] 139 return obj 140 141 def _to_mongo(self, doc: Document) -> Dict[str, Any]: 142 """Convert a Document to a MongoDB document dict.""" 143 doc_dict = doc.to_dict() 144 doc_dict_serialisable = self._serialize_value(doc_dict) 145 mongo_doc = { 146 "_id": doc.id, 147 "id": doc.id, 148 "content": doc.content, 149 "embedding": doc.embedding or [], 150 "timestamp": doc.timestamp.isoformat() if doc.timestamp else None, 151 "metadata": doc.metadata, 152 "document_data": json.dumps(doc_dict_serialisable), 153 } 154 # Promote custom fields (e.g. category, published_year, field, institution) 155 # to top level so they are filterable via $vectorSearch filter and $match. 156 base_keys = {"id", "content", "embedding", "timestamp", "metadata"} 157 for key, value in doc_dict_serialisable.items(): 158 if key not in base_keys: 159 mongo_doc[key] = value 160 return mongo_doc 161 162 def _from_mongo(self, mongo_doc: Dict[str, Any]) -> Document: 163 """Reconstruct a Document (or subclass) from a MongoDB document.""" 164 doc_data = json.loads(mongo_doc.get("document_data", "{}")) 165 return self.document_type.from_dict(doc_data) 166 167 # ------------------------------------------------------------------ 168 # Filter helpers 169 # ------------------------------------------------------------------ 170 171 @staticmethod 172 def _build_filter(filters: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: 173 """Translate a user-facing filter dict to a MongoDB query expression.""" 174 if not filters: 175 return None 176 op_map = {">=": "$gte", "<=": "$lte", ">": "$gt", "<": "$lt", "!=": "$ne"} 177 match: Dict[str, Any] = {} 178 for key, value in filters.items(): 179 if isinstance(value, tuple) and len(value) == 2: 180 op, val = value 181 match[key] = {op_map[op]: val} if op in op_map else val 182 else: 183 match[key] = value 184 return match 185 186 # ------------------------------------------------------------------ 187 # BaseVectorStore interface 188 # ------------------------------------------------------------------ 189 190 def add_documents( 191 self, 192 documents: List[Document], 193 generate_embeddings: bool = True, 194 ) -> List[str]: 195 """ 196 Upsert documents into the MongoDB collection. 197 198 Args: 199 documents: Documents to add. Each must have ``id`` and ``content`` 200 set. If *generate_embeddings* is ``True`` (default) every 201 document must also carry a pre-computed ``embedding``. 202 generate_embeddings: When ``True`` the method validates that all 203 documents already have embeddings rather than generating them 204 itself (generation is the caller's responsibility). 205 206 Returns: 207 List of document IDs that were successfully upserted. 208 """ 209 if not documents: 210 return [] 211 212 added_ids: List[str] = [] 213 214 for doc in documents: 215 if not doc.id or not doc.content: 216 raise ValueError(f"Document must have id and content: {doc}") 217 218 if generate_embeddings and doc.embedding is None: 219 raise ValueError( 220 f"Document '{doc.id}' has no embedding. " 221 "Generate embeddings before calling add_documents()." 222 ) 223 224 mongo_doc = self._to_mongo(doc) 225 self._collection.replace_one( 226 {"_id": mongo_doc["_id"]}, mongo_doc, upsert=True 227 ) 228 added_ids.append(doc.id) 229 230 logger.info( 231 "Upserted %d document(s) into MongoDB collection '%s'", 232 len(added_ids), 233 self._collection_name, 234 ) 235 return added_ids 236 237 def search( 238 self, 239 query: Optional[str] = None, 240 query_embedding: Optional[List[float]] = None, 241 top_k: int = 5, 242 filters: Optional[Dict[str, Any]] = None, 243 search_type: str = "vector", 244 ) -> List[SearchResult]: 245 """ 246 Search the MongoDB collection. 247 248 Args: 249 query: Plain-text query string. Required for ``keyword`` and 250 ``hybrid`` search types. 251 query_embedding: Pre-computed query vector. Required for 252 ``vector`` and ``hybrid`` search types. 253 top_k: Maximum number of results to return. 254 filters: Optional metadata filter dict. Supports equality 255 (``{"key": value}``) and range (``{"key": (">=", value)}``) 256 operators. 257 search_type: One of ``"vector"``, ``"keyword"``, or ``"hybrid"``. 258 259 Returns: 260 Ranked list of :class:`SearchResult` objects. 261 """ 262 if search_type not in ("vector", "keyword", "hybrid"): 263 raise ValueError( 264 f"Invalid search_type '{search_type}'. " 265 "Must be 'vector', 'keyword', or 'hybrid'." 266 ) 267 if search_type in ("vector", "hybrid") and query_embedding is None: 268 raise ValueError(f"search_type='{search_type}' requires query_embedding") 269 if search_type in ("keyword", "hybrid") and query is None: 270 raise ValueError(f"search_type='{search_type}' requires query text") 271 272 if search_type == "vector": 273 return self._vector_search(query_embedding, top_k, filters) 274 if search_type == "keyword": 275 return self._keyword_search(query, top_k, filters) 276 return self._hybrid_search(query, query_embedding, top_k, filters) 277 278 # --- search helpers ----------------------------------------------- 279 280 def _vector_search( 281 self, 282 query_embedding: List[float], 283 top_k: int, 284 filters: Optional[Dict[str, Any]], 285 ) -> List[SearchResult]: 286 """Execute an Atlas ``$vectorSearch`` aggregation pipeline.""" 287 vector_stage: Dict[str, Any] = { 288 "index": self.vector_index_name, 289 "path": "embedding", 290 "queryVector": query_embedding, 291 # numCandidates controls recall vs. latency trade-off. 292 # A value of top_k * 10 (min 100) is a sensible default. 293 "numCandidates": max(top_k * 10, 100), 294 "limit": top_k, 295 } 296 297 pre_filter = self._build_filter(filters) 298 if pre_filter: 299 vector_stage["filter"] = pre_filter 300 301 pipeline = [ 302 {"$vectorSearch": vector_stage}, 303 { 304 "$project": { 305 "_id": 1, 306 "document_data": 1, 307 "score": {"$meta": "vectorSearchScore"}, 308 } 309 }, 310 ] 311 312 try: 313 return [ 314 SearchResult( 315 document=self._from_mongo(raw), 316 score=float(raw.get("score", 0.0)), 317 rank=rank, 318 ) 319 for rank, raw in enumerate(self._collection.aggregate(pipeline)) 320 ] 321 except Exception as exc: 322 msg = str(exc) 323 if "localhost:28000" in msg or "HostUnreachable" in msg or "PlanExecutor" in msg: 324 raise RuntimeError( 325 f"Atlas Vector Search index '{self.vector_index_name}' is not available. " 326 "Create it in the Atlas UI (collection → Search Indexes → Create Search Index → " 327 "Atlas Vector Search) with field='embedding', " 328 f"numDimensions={self.embedding_dimension}, similarity=cosine, " 329 f"indexName='{self.vector_index_name}'." 330 ) from exc 331 raise 332 333 def _keyword_search( 334 self, 335 query: str, 336 top_k: int, 337 filters: Optional[Dict[str, Any]], 338 ) -> List[SearchResult]: 339 """Execute a full-text search via the MongoDB ``$text`` operator. 340 341 Requires a text index on the ``content`` field to be created by the 342 developer before use:: 343 344 db.collection.createIndex({ "content": "text" }) 345 346 or via the Atlas UI / CLI / Admin API. 347 """ 348 mongo_filter: Dict[str, Any] = {"$text": {"$search": query}} 349 pre_filter = self._build_filter(filters) 350 if pre_filter: 351 mongo_filter.update(pre_filter) 352 353 cursor = ( 354 self._collection.find( 355 mongo_filter, 356 {"document_data": 1, "textScore": {"$meta": "textScore"}}, 357 ) 358 .sort([("textScore", {"$meta": "textScore"})]) 359 .limit(top_k) 360 ) 361 362 try: 363 return [ 364 SearchResult( 365 document=self._from_mongo(raw), 366 score=float(raw.get("textScore", 0.0)), 367 rank=rank, 368 ) 369 for rank, raw in enumerate(cursor) 370 ] 371 except Exception as exc: 372 msg = str(exc) 373 if "text index required" in msg.lower() or "$text" in msg: 374 raise RuntimeError( 375 "No text index found on the collection. " 376 "Create one before using keyword or hybrid search: " 377 f"db.{self._collection.name}.createIndex({{ 'content': 'text' }})" 378 ) from exc 379 raise 380 381 def _hybrid_search( 382 self, 383 query: str, 384 query_embedding: List[float], 385 top_k: int, 386 filters: Optional[Dict[str, Any]], 387 ) -> List[SearchResult]: 388 """Combine vector and keyword results with equal (50/50) weighting.""" 389 vector_results = self._vector_search(query_embedding, top_k, filters) 390 keyword_results = self._keyword_search(query, top_k, filters) 391 392 def _max(results: List[SearchResult]) -> float: 393 return max((r.score for r in results), default=1.0) or 1.0 394 395 v_max = _max(vector_results) 396 k_max = _max(keyword_results) 397 398 combined: Dict[str, float] = {} 399 doc_map: Dict[str, Document] = {} 400 401 for r in vector_results: 402 combined[r.document.id] = 0.5 * (r.score / v_max) 403 doc_map[r.document.id] = r.document 404 405 for r in keyword_results: 406 combined[r.document.id] = ( 407 combined.get(r.document.id, 0.0) + 0.5 * (r.score / k_max) 408 ) 409 doc_map.setdefault(r.document.id, r.document) 410 411 sorted_ids = sorted(combined, key=lambda x: combined[x], reverse=True)[:top_k] 412 return [ 413 SearchResult( 414 document=doc_map[doc_id], 415 score=combined[doc_id], 416 rank=rank, 417 ) 418 for rank, doc_id in enumerate(sorted_ids) 419 ] 420 421 # ------------------------------------------------------------------ 422 # Document lifecycle 423 # ------------------------------------------------------------------ 424 425 def delete_documents(self, document_ids: List[str]) -> int: 426 """ 427 Delete documents by ID. 428 429 Args: 430 document_ids: IDs of the documents to remove. 431 432 Returns: 433 Number of documents actually deleted. 434 """ 435 result = self._collection.delete_many({"_id": {"$in": document_ids}}) 436 return result.deleted_count 437 438 def update_document( 439 self, 440 document_id: str, 441 content: Optional[str] = None, 442 embedding: Optional[List[float]] = None, 443 metadata: Optional[Dict[str, Any]] = None, 444 ) -> bool: 445 """ 446 Update an existing document. 447 448 Args: 449 document_id: ID of the document to update. 450 content: New text content (optional). 451 embedding: New embedding vector (optional). 452 metadata: Metadata key-value pairs to *merge* into existing metadata 453 (optional). 454 455 Returns: 456 ``True`` if the document was found and updated, ``False`` otherwise. 457 """ 458 raw = self._collection.find_one({"_id": document_id}) 459 if raw is None: 460 return False 461 462 doc = self._from_mongo(raw) 463 464 if content is not None: 465 doc.content = content 466 if embedding is not None: 467 doc.embedding = embedding 468 if metadata is not None: 469 doc.metadata.update(metadata) 470 471 self._collection.replace_one({"_id": document_id}, self._to_mongo(doc)) 472 return True 473 474 def get_document(self, document_id: str) -> Optional[Document]: 475 """ 476 Retrieve a single document by ID. 477 478 Args: 479 document_id: The document's unique identifier. 480 481 Returns: 482 The :class:`Document` (or subclass) if found, otherwise ``None``. 483 """ 484 raw = self._collection.find_one({"_id": document_id}) 485 return None if raw is None else self._from_mongo(raw) 486 487 def count(self) -> int: 488 """Return the total number of documents in the collection.""" 489 return self._collection.count_documents({}) 490 491 def clear(self) -> None: 492 """ 493 Remove **all** documents from the collection. 494 495 Warning: 496 This operation is irreversible. 497 """ 498 self._collection.delete_many({}) 499 logger.warning( 500 "Cleared all documents from MongoDB collection '%s'", 501 self._collection_name, 502 ) 503 504 def close(self) -> None: 505 """Close the underlying MongoDB client connection.""" 506 self._client.close()
MongoDB Atlas Vector Search vector store.
Stores documents in a MongoDB collection and queries them using the
$vectorSearch aggregation stage. Hybrid search is implemented by
running vector and text searches separately and combining their normalised
scores with a 50 / 50 weight.
Features:
- Atlas Vector Search (approximate KNN, configurable candidates)
- Full-text keyword search via
$textindex - Hybrid search (vector + keyword, equal weighting)
- Metadata pre-filtering on vector search
- Custom
document_typesupport (any Document subclass)
Note:
The Atlas Vector Search index ($vectorSearch) must be created
outside of this class — through the MongoDB Atlas UI, Atlas CLI, or
Atlas Admin API. The vector_index_name parameter must match the
name you gave that index.
Usage::
from gmf_forge_ai_data.vector_stores import MongoDBVectorStore
store = MongoDBVectorStore(
connection_string="mongodb+srv://...",
database_name="rag_db",
collection_name="documents",
vector_index_name="vector_index",
embedding_dimension=1536,
)
store.add_documents(docs_with_embeddings)
results = store.search(query_embedding=embedding, top_k=5)
72 def __init__( 73 self, 74 connection_string: str, 75 database_name: str, 76 collection_name: str, 77 vector_index_name: str = "vector_index", 78 embedding_dimension: int = 1536, 79 document_type: type = Document, 80 ssl_cert_path: Optional[str] = None, 81 ) -> None: 82 """ 83 Initialise the MongoDB Atlas vector store. 84 85 Args: 86 connection_string: MongoDB Atlas connection string 87 (e.g. ``mongodb+srv://user:pass@cluster.mongodb.net/``). 88 database_name: Name of the MongoDB database. 89 collection_name: Name of the collection. 90 vector_index_name: Name of the Atlas Vector Search index defined 91 on this collection. Defaults to ``"vector_index"``. 92 embedding_dimension: Dimension of vector embeddings. Must match 93 the dimension configured in the Atlas Vector Search index 94 (default 1536). 95 document_type: Document class to use when deserialising results. 96 Pass a custom subclass to support additional typed fields. 97 ssl_cert_path: Path to a CA certificate bundle (PEM) for TLS 98 verification. Useful in corporate environments with custom 99 certificate authorities. 100 """ 101 try: 102 import pymongo # noqa: F401 — checked at construction time 103 except ImportError as exc: 104 raise ImportError( 105 "pymongo is required for MongoDBVectorStore. " 106 "Install it with: pip install pymongo" 107 ) from exc 108 109 self.embedding_dimension = embedding_dimension 110 self.document_type = document_type 111 self.vector_index_name = vector_index_name 112 self._collection_name = collection_name 113 114 client_kwargs: Dict[str, Any] = {} 115 if ssl_cert_path: 116 client_kwargs["tlsCAFile"] = ssl_cert_path 117 118 import pymongo 119 120 self._client: pymongo.MongoClient = pymongo.MongoClient( 121 connection_string, **client_kwargs 122 ) 123 self._db = self._client[database_name] 124 self._collection = self._db[collection_name]
Initialise the MongoDB Atlas vector store.
Args:
connection_string: MongoDB Atlas connection string
(e.g. mongodb+srv://user:pass@cluster.mongodb.net/).
database_name: Name of the MongoDB database.
collection_name: Name of the collection.
vector_index_name: Name of the Atlas Vector Search index defined
on this collection. Defaults to "vector_index".
embedding_dimension: Dimension of vector embeddings. Must match
the dimension configured in the Atlas Vector Search index
(default 1536).
document_type: Document class to use when deserialising results.
Pass a custom subclass to support additional typed fields.
ssl_cert_path: Path to a CA certificate bundle (PEM) for TLS
verification. Useful in corporate environments with custom
certificate authorities.
190 def add_documents( 191 self, 192 documents: List[Document], 193 generate_embeddings: bool = True, 194 ) -> List[str]: 195 """ 196 Upsert documents into the MongoDB collection. 197 198 Args: 199 documents: Documents to add. Each must have ``id`` and ``content`` 200 set. If *generate_embeddings* is ``True`` (default) every 201 document must also carry a pre-computed ``embedding``. 202 generate_embeddings: When ``True`` the method validates that all 203 documents already have embeddings rather than generating them 204 itself (generation is the caller's responsibility). 205 206 Returns: 207 List of document IDs that were successfully upserted. 208 """ 209 if not documents: 210 return [] 211 212 added_ids: List[str] = [] 213 214 for doc in documents: 215 if not doc.id or not doc.content: 216 raise ValueError(f"Document must have id and content: {doc}") 217 218 if generate_embeddings and doc.embedding is None: 219 raise ValueError( 220 f"Document '{doc.id}' has no embedding. " 221 "Generate embeddings before calling add_documents()." 222 ) 223 224 mongo_doc = self._to_mongo(doc) 225 self._collection.replace_one( 226 {"_id": mongo_doc["_id"]}, mongo_doc, upsert=True 227 ) 228 added_ids.append(doc.id) 229 230 logger.info( 231 "Upserted %d document(s) into MongoDB collection '%s'", 232 len(added_ids), 233 self._collection_name, 234 ) 235 return added_ids
Upsert documents into the MongoDB collection.
Args:
documents: Documents to add. Each must have id and content
set. If generate_embeddings is True (default) every
document must also carry a pre-computed embedding.
generate_embeddings: When True the method validates that all
documents already have embeddings rather than generating them
itself (generation is the caller's responsibility).
Returns: List of document IDs that were successfully upserted.
237 def search( 238 self, 239 query: Optional[str] = None, 240 query_embedding: Optional[List[float]] = None, 241 top_k: int = 5, 242 filters: Optional[Dict[str, Any]] = None, 243 search_type: str = "vector", 244 ) -> List[SearchResult]: 245 """ 246 Search the MongoDB collection. 247 248 Args: 249 query: Plain-text query string. Required for ``keyword`` and 250 ``hybrid`` search types. 251 query_embedding: Pre-computed query vector. Required for 252 ``vector`` and ``hybrid`` search types. 253 top_k: Maximum number of results to return. 254 filters: Optional metadata filter dict. Supports equality 255 (``{"key": value}``) and range (``{"key": (">=", value)}``) 256 operators. 257 search_type: One of ``"vector"``, ``"keyword"``, or ``"hybrid"``. 258 259 Returns: 260 Ranked list of :class:`SearchResult` objects. 261 """ 262 if search_type not in ("vector", "keyword", "hybrid"): 263 raise ValueError( 264 f"Invalid search_type '{search_type}'. " 265 "Must be 'vector', 'keyword', or 'hybrid'." 266 ) 267 if search_type in ("vector", "hybrid") and query_embedding is None: 268 raise ValueError(f"search_type='{search_type}' requires query_embedding") 269 if search_type in ("keyword", "hybrid") and query is None: 270 raise ValueError(f"search_type='{search_type}' requires query text") 271 272 if search_type == "vector": 273 return self._vector_search(query_embedding, top_k, filters) 274 if search_type == "keyword": 275 return self._keyword_search(query, top_k, filters) 276 return self._hybrid_search(query, query_embedding, top_k, filters)
Search the MongoDB collection.
Args:
query: Plain-text query string. Required for keyword and
hybrid search types.
query_embedding: Pre-computed query vector. Required for
vector and hybrid search types.
top_k: Maximum number of results to return.
filters: Optional metadata filter dict. Supports equality
({"key": value}) and range ({"key": (">=", value)})
operators.
search_type: One of "vector", "keyword", or "hybrid".
Returns:
Ranked list of SearchResult objects.
425 def delete_documents(self, document_ids: List[str]) -> int: 426 """ 427 Delete documents by ID. 428 429 Args: 430 document_ids: IDs of the documents to remove. 431 432 Returns: 433 Number of documents actually deleted. 434 """ 435 result = self._collection.delete_many({"_id": {"$in": document_ids}}) 436 return result.deleted_count
Delete documents by ID.
Args: document_ids: IDs of the documents to remove.
Returns: Number of documents actually deleted.
438 def update_document( 439 self, 440 document_id: str, 441 content: Optional[str] = None, 442 embedding: Optional[List[float]] = None, 443 metadata: Optional[Dict[str, Any]] = None, 444 ) -> bool: 445 """ 446 Update an existing document. 447 448 Args: 449 document_id: ID of the document to update. 450 content: New text content (optional). 451 embedding: New embedding vector (optional). 452 metadata: Metadata key-value pairs to *merge* into existing metadata 453 (optional). 454 455 Returns: 456 ``True`` if the document was found and updated, ``False`` otherwise. 457 """ 458 raw = self._collection.find_one({"_id": document_id}) 459 if raw is None: 460 return False 461 462 doc = self._from_mongo(raw) 463 464 if content is not None: 465 doc.content = content 466 if embedding is not None: 467 doc.embedding = embedding 468 if metadata is not None: 469 doc.metadata.update(metadata) 470 471 self._collection.replace_one({"_id": document_id}, self._to_mongo(doc)) 472 return True
Update an existing document.
Args: document_id: ID of the document to update. content: New text content (optional). embedding: New embedding vector (optional). metadata: Metadata key-value pairs to merge into existing metadata (optional).
Returns:
True if the document was found and updated, False otherwise.
474 def get_document(self, document_id: str) -> Optional[Document]: 475 """ 476 Retrieve a single document by ID. 477 478 Args: 479 document_id: The document's unique identifier. 480 481 Returns: 482 The :class:`Document` (or subclass) if found, otherwise ``None``. 483 """ 484 raw = self._collection.find_one({"_id": document_id}) 485 return None if raw is None else self._from_mongo(raw)
Retrieve a single document by ID.
Args: document_id: The document's unique identifier.
Returns:
The Document (or subclass) if found, otherwise None.
487 def count(self) -> int: 488 """Return the total number of documents in the collection.""" 489 return self._collection.count_documents({})
Return the total number of documents in the collection.
491 def clear(self) -> None: 492 """ 493 Remove **all** documents from the collection. 494 495 Warning: 496 This operation is irreversible. 497 """ 498 self._collection.delete_many({}) 499 logger.warning( 500 "Cleared all documents from MongoDB collection '%s'", 501 self._collection_name, 502 )
Remove all documents from the collection.
Warning: This operation is irreversible.