gmf_forge_ai_data.vector_stores

Vector Stores for Platform AI Data Layer.

This module provides vector storage and retrieval capabilities for RAG pipelines. Includes production-ready Azure AI Search, Cosmos DB, and MongoDB integrations as well as an in-memory implementation for testing and development.

Available Vector Stores:

  • InMemoryVectorStore: Fast in-memory storage for testing and prototyping
  • AzureAISearchVectorStore: Production-ready Azure AI Search integration
  • AzureCosmosDBVectorStore: Azure Cosmos DB NoSQL API (vector distance / cosine)
  • MongoDBVectorStore: MongoDB Atlas Vector Search ($vectorSearch)

Core Classes:

  • BaseVectorStore: Abstract interface for all vector store implementations
  • Document: Container for document content, embeddings, and metadata
  • SearchResult: Container for search results with relevance scores
 1"""
 2Vector Stores for Platform AI Data Layer.
 3
 4This module provides vector storage and retrieval capabilities for RAG pipelines.
 5Includes production-ready Azure AI Search, Cosmos DB, and MongoDB integrations
 6as well as an in-memory implementation for testing and development.
 7
 8Available Vector Stores:
 9- InMemoryVectorStore: Fast in-memory storage for testing and prototyping
10- AzureAISearchVectorStore: Production-ready Azure AI Search integration
11- AzureCosmosDBVectorStore: Azure Cosmos DB NoSQL API (vector distance / cosine)
12- MongoDBVectorStore: MongoDB Atlas Vector Search ($vectorSearch)
13
14Core Classes:
15- BaseVectorStore: Abstract interface for all vector store implementations
16- Document: Container for document content, embeddings, and metadata
17- SearchResult: Container for search results with relevance scores
18"""
19
20from .base_vector_store import BaseVectorStore, Document, SearchResult
21from .in_memory_vector_store import InMemoryVectorStore
22from .azure_ai_search_vector_store import AzureAISearchVectorStore
23from .cosmos_db_vector_store import AzureCosmosDBVectorStore
24from .mongodb_vector_store import MongoDBVectorStore
25
26__all__ = [
27    "BaseVectorStore",
28    "Document",
29    "SearchResult",
30    "InMemoryVectorStore",
31    "AzureAISearchVectorStore",
32    "AzureCosmosDBVectorStore",
33    "MongoDBVectorStore",
34]
35
36__version__ = "1.0.0"
class BaseVectorStore(abc.ABC):
181class BaseVectorStore(ABC):
182    """
183    Abstract base class for vector store implementations.
184    
185    Vector stores provide storage and retrieval of document embeddings,
186    supporting various search strategies (vector similarity, keyword, hybrid).
187    
188    Implementations must provide:
189    - Document storage with embeddings
190    - Vector similarity search
191    - Optional keyword/hybrid search
192    - Document lifecycle management (add, update, delete)
193    """
194    
195    @abstractmethod
196    def add_documents(
197        self,
198        documents: List[Document],
199        generate_embeddings: bool = True
200    ) -> List[str]:
201        """
202        Add documents to the vector store.
203        
204        Args:
205            documents: List of Document objects to add
206            generate_embeddings: If True and documents lack embeddings,
207                                generate them automatically
208        
209        Returns:
210            List of document IDs that were added
211            
212        Raises:
213            ValueError: If documents are invalid or embeddings cannot be generated
214        """
215        pass
216    
217    @abstractmethod
218    def search(
219        self,
220        query: Optional[str] = None,
221        query_embedding: Optional[List[float]] = None,
222        top_k: int = 5,
223        filters: Optional[Dict[str, Any]] = None,
224        search_type: str = "vector"
225    ) -> List[SearchResult]:
226        """
227        Search the vector store.
228        
229        Args:
230            query: Text query (required for keyword/hybrid search)
231            query_embedding: Vector embedding of the query (required for vector search)
232            top_k: Number of results to return
233            filters: Metadata filters to apply (e.g., {"source": "docs.pdf"})
234            search_type: Type of search - "vector", "keyword", or "hybrid"
235        
236        Returns:
237            List of SearchResult objects, ordered by relevance
238            
239        Raises:
240            ValueError: If required parameters are missing for the search type
241        """
242        pass
243    
244    @abstractmethod
245    def delete_documents(self, document_ids: List[str]) -> int:
246        """
247        Delete documents from the vector store.
248        
249        Args:
250            document_ids: List of document IDs to delete
251        
252        Returns:
253            Number of documents successfully deleted
254        """
255        pass
256    
257    @abstractmethod
258    def update_document(
259        self,
260        document_id: str,
261        content: Optional[str] = None,
262        embedding: Optional[List[float]] = None,
263        metadata: Optional[Dict[str, Any]] = None
264    ) -> bool:
265        """
266        Update an existing document.
267        
268        Args:
269            document_id: ID of the document to update
270            content: New content (if provided)
271            embedding: New embedding (if provided)
272            metadata: New metadata (merged with existing)
273        
274        Returns:
275            True if update was successful, False if document not found
276        """
277        pass
278    
279    @abstractmethod
280    def get_document(self, document_id: str) -> Optional[Document]:
281        """
282        Retrieve a document by ID.
283        
284        Args:
285            document_id: ID of the document to retrieve
286        
287        Returns:
288            Document if found, None otherwise
289        """
290        pass
291    
292    @abstractmethod
293    def count(self) -> int:
294        """
295        Get the total number of documents in the vector store.
296        
297        Returns:
298            Total document count
299        """
300        pass
301    
302    @abstractmethod
303    def clear(self) -> None:
304        """
305        Remove all documents from the vector store.
306        
307        Warning: This operation is irreversible.
308        """
309        pass

Abstract base class for vector store implementations.

Vector stores provide storage and retrieval of document embeddings, supporting various search strategies (vector similarity, keyword, hybrid).

Implementations must provide:

  • Document storage with embeddings
  • Vector similarity search
  • Optional keyword/hybrid search
  • Document lifecycle management (add, update, delete)
@abstractmethod
def add_documents( self, documents: List[Document], generate_embeddings: bool = True) -> List[str]:
195    @abstractmethod
196    def add_documents(
197        self,
198        documents: List[Document],
199        generate_embeddings: bool = True
200    ) -> List[str]:
201        """
202        Add documents to the vector store.
203        
204        Args:
205            documents: List of Document objects to add
206            generate_embeddings: If True and documents lack embeddings,
207                                generate them automatically
208        
209        Returns:
210            List of document IDs that were added
211            
212        Raises:
213            ValueError: If documents are invalid or embeddings cannot be generated
214        """
215        pass

Add documents to the vector store.

Args: documents: List of Document objects to add generate_embeddings: If True and documents lack embeddings, generate them automatically

Returns: List of document IDs that were added

Raises: ValueError: If documents are invalid or embeddings cannot be generated

@abstractmethod
def search( self, query: Optional[str] = None, query_embedding: Optional[List[float]] = None, top_k: int = 5, filters: Optional[Dict[str, Any]] = None, search_type: str = 'vector') -> List[SearchResult]:
217    @abstractmethod
218    def search(
219        self,
220        query: Optional[str] = None,
221        query_embedding: Optional[List[float]] = None,
222        top_k: int = 5,
223        filters: Optional[Dict[str, Any]] = None,
224        search_type: str = "vector"
225    ) -> List[SearchResult]:
226        """
227        Search the vector store.
228        
229        Args:
230            query: Text query (required for keyword/hybrid search)
231            query_embedding: Vector embedding of the query (required for vector search)
232            top_k: Number of results to return
233            filters: Metadata filters to apply (e.g., {"source": "docs.pdf"})
234            search_type: Type of search - "vector", "keyword", or "hybrid"
235        
236        Returns:
237            List of SearchResult objects, ordered by relevance
238            
239        Raises:
240            ValueError: If required parameters are missing for the search type
241        """
242        pass

Search the vector store.

Args: query: Text query (required for keyword/hybrid search) query_embedding: Vector embedding of the query (required for vector search) top_k: Number of results to return filters: Metadata filters to apply (e.g., {"source": "docs.pdf"}) search_type: Type of search - "vector", "keyword", or "hybrid"

Returns: List of SearchResult objects, ordered by relevance

Raises: ValueError: If required parameters are missing for the search type

@abstractmethod
def delete_documents(self, document_ids: List[str]) -> int:
244    @abstractmethod
245    def delete_documents(self, document_ids: List[str]) -> int:
246        """
247        Delete documents from the vector store.
248        
249        Args:
250            document_ids: List of document IDs to delete
251        
252        Returns:
253            Number of documents successfully deleted
254        """
255        pass

Delete documents from the vector store.

Args: document_ids: List of document IDs to delete

Returns: Number of documents successfully deleted

@abstractmethod
def update_document( self, document_id: str, content: Optional[str] = None, embedding: Optional[List[float]] = None, metadata: Optional[Dict[str, Any]] = None) -> bool:
257    @abstractmethod
258    def update_document(
259        self,
260        document_id: str,
261        content: Optional[str] = None,
262        embedding: Optional[List[float]] = None,
263        metadata: Optional[Dict[str, Any]] = None
264    ) -> bool:
265        """
266        Update an existing document.
267        
268        Args:
269            document_id: ID of the document to update
270            content: New content (if provided)
271            embedding: New embedding (if provided)
272            metadata: New metadata (merged with existing)
273        
274        Returns:
275            True if update was successful, False if document not found
276        """
277        pass

Update an existing document.

Args: document_id: ID of the document to update content: New content (if provided) embedding: New embedding (if provided) metadata: New metadata (merged with existing)

Returns: True if update was successful, False if document not found

@abstractmethod
def get_document( self, document_id: str) -> Optional[Document]:
279    @abstractmethod
280    def get_document(self, document_id: str) -> Optional[Document]:
281        """
282        Retrieve a document by ID.
283        
284        Args:
285            document_id: ID of the document to retrieve
286        
287        Returns:
288            Document if found, None otherwise
289        """
290        pass

Retrieve a document by ID.

Args: document_id: ID of the document to retrieve

Returns: Document if found, None otherwise

@abstractmethod
def count(self) -> int:
292    @abstractmethod
293    def count(self) -> int:
294        """
295        Get the total number of documents in the vector store.
296        
297        Returns:
298            Total document count
299        """
300        pass

Get the total number of documents in the vector store.

Returns: Total document count

@abstractmethod
def clear(self) -> None:
302    @abstractmethod
303    def clear(self) -> None:
304        """
305        Remove all documents from the vector store.
306        
307        Warning: This operation is irreversible.
308        """
309        pass

Remove all documents from the vector store.

Warning: This operation is irreversible.

@dataclass
class Document:
 19@dataclass
 20class Document:
 21    """
 22    Represents a document or chunk with its content and metadata.
 23    
 24    This class is designed to be extensible - you can subclass it to add
 25    custom fields specific to your application domain.
 26    
 27    Base Attributes:
 28        id: Unique identifier for the document
 29        content: The text content of the document
 30        embedding: Optional vector embedding (list of floats)
 31        timestamp: Optional timestamp for time-weighted retrieval
 32        metadata: Flexible dictionary for custom properties
 33    
 34    Design Philosophy:
 35        The base Document keeps only truly universal fields. Domain-specific
 36        fields (source, page_number, category, etc.) should be added via
 37        inheritance to match your application's data model.
 38    
 39    Example - Using base Document with metadata:
 40        ```python
 41        doc = Document(
 42            id="doc_1",
 43            content="AI is transforming software development...",
 44            timestamp=datetime.now(),
 45            metadata={
 46                "source": "blog.pdf",
 47                "page_number": 5,
 48                "author": "John Doe",
 49                "category": "AI"
 50            }
 51        )
 52        ```
 53    
 54    Example - Custom Document schema via inheritance:
 55        ```python
 56        from dataclasses import dataclass
 57        from gmf_forge_ai_data.vector_stores import Document
 58        
 59        @dataclass
 60        class LegalDocument(Document):
 61            case_number: str
 62            court: str
 63            decision_date: datetime
 64            jurisdiction: str
 65            source: str = ""  # Add source as typed field
 66            page_number: Optional[int] = None
 67            
 68        legal_doc = LegalDocument(
 69            id="case_123",
 70            content="In the matter of...",
 71            case_number="2024-CV-12345",
 72            court="Supreme Court",
 73            decision_date=datetime(2024, 1, 15),
 74            jurisdiction="Federal",
 75            source="court_records.pdf",
 76            page_number=42
 77        )
 78        ```
 79    
 80    Example - E-commerce Product Document:
 81        ```python
 82        @dataclass
 83        class ProductDocument(Document):
 84            sku: str
 85            category: str
 86            price: float
 87            in_stock: bool
 88            brand: str
 89            # No source/page_number - not applicable to products
 90            
 91        product = ProductDocument(
 92            id="prod_456",
 93            content="High-performance laptop with...",
 94            sku="LAP-001",
 95            category="Electronics",
 96            price=1299.99,
 97            in_stock=True,
 98            brand="TechCorp"
 99        )
100        ```
101    """
102    id: str
103    content: str
104    embedding: Optional[List[float]] = None
105    timestamp: Optional[datetime] = None
106    metadata: Dict[str, Any] = field(default_factory=dict)
107    
108    def to_dict(self) -> Dict[str, Any]:
109        """
110        Convert Document to dictionary for serialization.
111        
112        Returns:
113            Dictionary representation of the document
114        """
115        doc_dict = asdict(self)
116        # Convert datetime to ISO string for serialization
117        if self.timestamp:
118            doc_dict['timestamp'] = self.timestamp.isoformat()
119        return doc_dict
120    
121    @classmethod
122    def from_dict(cls: Type[T], data: Dict[str, Any]) -> T:
123        """
124        Create Document from dictionary.
125        
126        Args:
127            data: Dictionary containing document fields
128        
129        Returns:
130            Document instance (or subclass instance)
131        """
132        # Handle datetime deserialization
133        if 'timestamp' in data and isinstance(data['timestamp'], str):
134            data['timestamp'] = datetime.fromisoformat(data['timestamp'])
135        
136        # Filter to only fields that exist in the dataclass
137        import inspect
138        valid_fields = set(inspect.signature(cls).parameters.keys())
139        filtered_data = {k: v for k, v in data.items() if k in valid_fields}
140        
141        return cls(**filtered_data)
142    
143    def update_metadata(self, **kwargs) -> None:
144        """
145        Update metadata fields.
146        
147        Args:
148            **kwargs: Key-value pairs to add/update in metadata
149        """
150        self.metadata.update(kwargs)
151    
152    def get_metadata(self, key: str, default: Any = None) -> Any:
153        """
154        Get a metadata value.
155        
156        Args:
157            key: Metadata key to retrieve
158            default: Default value if key not found
159        
160        Returns:
161            Metadata value or default
162        """
163        return self.metadata.get(key, default)

Represents a document or chunk with its content and metadata.

This class is designed to be extensible - you can subclass it to add custom fields specific to your application domain.

Base Attributes: id: Unique identifier for the document content: The text content of the document embedding: Optional vector embedding (list of floats) timestamp: Optional timestamp for time-weighted retrieval metadata: Flexible dictionary for custom properties

Design Philosophy: The base Document keeps only truly universal fields. Domain-specific fields (source, page_number, category, etc.) should be added via inheritance to match your application's data model.

Example - Using base Document with metadata:

doc = Document(
    id="doc_1",
    content="AI is transforming software development...",
    timestamp=datetime.now(),
    metadata={
        "source": "blog.pdf",
        "page_number": 5,
        "author": "John Doe",
        "category": "AI"
    }
)

Example - Custom Document schema via inheritance:

from dataclasses import dataclass
from gmf_forge_ai_data.vector_stores import Document

@dataclass
class LegalDocument(Document):
    case_number: str
    court: str
    decision_date: datetime
    jurisdiction: str
    source: str = ""  # Add source as typed field
    page_number: Optional[int] = None

legal_doc = LegalDocument(
    id="case_123",
    content="In the matter of...",
    case_number="2024-CV-12345",
    court="Supreme Court",
    decision_date=datetime(2024, 1, 15),
    jurisdiction="Federal",
    source="court_records.pdf",
    page_number=42
)

Example - E-commerce Product Document:

@dataclass
class ProductDocument(Document):
    sku: str
    category: str
    price: float
    in_stock: bool
    brand: str
    # No source/page_number - not applicable to products

product = ProductDocument(
    id="prod_456",
    content="High-performance laptop with...",
    sku="LAP-001",
    category="Electronics",
    price=1299.99,
    in_stock=True,
    brand="TechCorp"
)
Document( id: str, content: str, embedding: Optional[List[float]] = None, timestamp: Optional[datetime.datetime] = None, metadata: Dict[str, Any] = <factory>)
id: str
content: str
embedding: Optional[List[float]] = None
timestamp: Optional[datetime.datetime] = None
metadata: Dict[str, Any]
def to_dict(self) -> Dict[str, Any]:
108    def to_dict(self) -> Dict[str, Any]:
109        """
110        Convert Document to dictionary for serialization.
111        
112        Returns:
113            Dictionary representation of the document
114        """
115        doc_dict = asdict(self)
116        # Convert datetime to ISO string for serialization
117        if self.timestamp:
118            doc_dict['timestamp'] = self.timestamp.isoformat()
119        return doc_dict

Convert Document to dictionary for serialization.

Returns: Dictionary representation of the document

@classmethod
def from_dict(cls: Type[~T], data: Dict[str, Any]) -> ~T:
121    @classmethod
122    def from_dict(cls: Type[T], data: Dict[str, Any]) -> T:
123        """
124        Create Document from dictionary.
125        
126        Args:
127            data: Dictionary containing document fields
128        
129        Returns:
130            Document instance (or subclass instance)
131        """
132        # Handle datetime deserialization
133        if 'timestamp' in data and isinstance(data['timestamp'], str):
134            data['timestamp'] = datetime.fromisoformat(data['timestamp'])
135        
136        # Filter to only fields that exist in the dataclass
137        import inspect
138        valid_fields = set(inspect.signature(cls).parameters.keys())
139        filtered_data = {k: v for k, v in data.items() if k in valid_fields}
140        
141        return cls(**filtered_data)

Create Document from dictionary.

Args: data: Dictionary containing document fields

Returns: Document instance (or subclass instance)

def update_metadata(self, **kwargs) -> None:
143    def update_metadata(self, **kwargs) -> None:
144        """
145        Update metadata fields.
146        
147        Args:
148            **kwargs: Key-value pairs to add/update in metadata
149        """
150        self.metadata.update(kwargs)

Update metadata fields.

Args: **kwargs: Key-value pairs to add/update in metadata

def get_metadata(self, key: str, default: Any = None) -> Any:
152    def get_metadata(self, key: str, default: Any = None) -> Any:
153        """
154        Get a metadata value.
155        
156        Args:
157            key: Metadata key to retrieve
158            default: Default value if key not found
159        
160        Returns:
161            Metadata value or default
162        """
163        return self.metadata.get(key, default)

Get a metadata value.

Args: key: Metadata key to retrieve default: Default value if key not found

Returns: Metadata value or default

@dataclass
class SearchResult:
166@dataclass
167class SearchResult:
168    """
169    Represents a search result from a vector store.
170    
171    Attributes:
172        document: The retrieved document
173        score: Similarity score (0.0 to 1.0, higher is better)
174        rank: Position in the result list (0-indexed)
175    """
176    document: Document
177    score: float
178    rank: int

Represents a search result from a vector store.

Attributes: document: The retrieved document score: Similarity score (0.0 to 1.0, higher is better) rank: Position in the result list (0-indexed)

SearchResult( document: Document, score: float, rank: int)
document: Document
score: float
rank: int
class InMemoryVectorStore(gmf_forge_ai_data.vector_stores.BaseVectorStore):
 17class InMemoryVectorStore(BaseVectorStore):
 18    """
 19    In-memory vector store using numpy for vector operations.
 20    
 21    Features:
 22    - Fast cosine similarity search using numpy
 23    - Keyword search using simple text matching
 24    - Hybrid search combining vector and keyword scores
 25    - Metadata filtering
 26    - No external dependencies (Azure, etc.)
 27    
 28    Ideal for:
 29    - Unit testing
 30    - Development and prototyping
 31    - Small datasets (< 10,000 documents)
 32    - CI/CD pipelines
 33    
 34    Note: All data is lost when the process terminates.
 35    """
 36    
 37    def __init__(self, embedding_dimension: int = 1536):
 38        """
 39        Initialize the in-memory vector store.
 40        
 41        Args:
 42            embedding_dimension: Expected dimension of embeddings (default 1536 for text-embedding-ada-002)
 43        """
 44        self.embedding_dimension = embedding_dimension
 45        self._documents: Dict[str, Document] = {}
 46        self._embeddings: Optional[np.ndarray] = None
 47        self._doc_ids: List[str] = []
 48    
 49    def add_documents(
 50        self,
 51        documents: List[Document],
 52        generate_embeddings: bool = True
 53    ) -> List[str]:
 54        """
 55        Add documents to the in-memory store.
 56        
 57        Args:
 58            documents: List of Document objects to add
 59            generate_embeddings: If True, validates embeddings are present
 60                                (actual generation should be done externally)
 61        
 62        Returns:
 63            List of document IDs that were added
 64        """
 65        if not documents:
 66            return []
 67        
 68        added_ids = []
 69        
 70        for doc in documents:
 71            # Validate document
 72            if not doc.id or not doc.content:
 73                raise ValueError(f"Document must have id and content: {doc}")
 74            
 75            if generate_embeddings and doc.embedding is None:
 76                raise ValueError(
 77                    f"Document {doc.id} lacks embedding. "
 78                    "Generate embeddings externally before adding to store."
 79                )
 80            
 81            if doc.embedding is not None and len(doc.embedding) != self.embedding_dimension:
 82                raise ValueError(
 83                    f"Document {doc.id} has embedding dimension {len(doc.embedding)}, "
 84                    f"expected {self.embedding_dimension}"
 85                )
 86            
 87            # Store document
 88            self._documents[doc.id] = deepcopy(doc)
 89            added_ids.append(doc.id)
 90        
 91        # Rebuild embedding matrix
 92        self._rebuild_embeddings()
 93        
 94        return added_ids
 95    
 96    def search(
 97        self,
 98        query: Optional[str] = None,
 99        query_embedding: Optional[List[float]] = None,
100        top_k: int = 5,
101        filters: Optional[Dict[str, Any]] = None,
102        search_type: str = "vector"
103    ) -> List[SearchResult]:
104        """
105        Search the in-memory vector store.
106        
107        Args:
108            query: Text query (required for keyword/hybrid search)
109            query_embedding: Vector embedding of the query (required for vector/hybrid search)
110            top_k: Number of results to return
111            filters: Metadata filters (e.g., {"source": "doc.pdf"})
112            search_type: "vector", "keyword", or "hybrid"
113        
114        Returns:
115            List of SearchResult objects, ordered by relevance
116        """
117        if search_type not in ["vector", "keyword", "hybrid"]:
118            raise ValueError(f"Invalid search_type: {search_type}. Must be 'vector', 'keyword', or 'hybrid'")
119        
120        if search_type in ["vector", "hybrid"] and query_embedding is None:
121            raise ValueError(f"{search_type} search requires query_embedding")
122        
123        if search_type in ["keyword", "hybrid"] and query is None:
124            raise ValueError(f"{search_type} search requires query text")
125        
126        if not self._documents:
127            return []
128        
129        # Apply metadata filters
130        filtered_docs = self._apply_filters(filters)
131        if not filtered_docs:
132            return []
133        
134        # Calculate scores based on search type
135        if search_type == "vector":
136            scores = self._vector_search(query_embedding, filtered_docs)
137        elif search_type == "keyword":
138            scores = self._keyword_search(query, filtered_docs)
139        else:  # hybrid
140            vector_scores = self._vector_search(query_embedding, filtered_docs)
141            keyword_scores = self._keyword_search(query, filtered_docs)
142            # Combine scores (50/50 weighting)
143            scores = {
144                doc_id: 0.5 * vector_scores.get(doc_id, 0.0) + 0.5 * keyword_scores.get(doc_id, 0.0)
145                for doc_id in filtered_docs
146            }
147        
148        # Sort by score and take top_k
149        sorted_results = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:top_k]
150        
151        # Build SearchResult objects
152        results = []
153        for rank, (doc_id, score) in enumerate(sorted_results):
154            results.append(SearchResult(
155                document=deepcopy(self._documents[doc_id]),
156                score=float(score),
157                rank=rank
158            ))
159        
160        return results
161    
162    def delete_documents(self, document_ids: List[str]) -> int:
163        """
164        Delete documents from the store.
165        
166        Args:
167            document_ids: List of document IDs to delete
168        
169        Returns:
170            Number of documents successfully deleted
171        """
172        deleted_count = 0
173        for doc_id in document_ids:
174            if doc_id in self._documents:
175                del self._documents[doc_id]
176                deleted_count += 1
177        
178        # Rebuild embedding matrix
179        self._rebuild_embeddings()
180        
181        return deleted_count
182    
183    def update_document(
184        self,
185        document_id: str,
186        content: Optional[str] = None,
187        embedding: Optional[List[float]] = None,
188        metadata: Optional[Dict[str, Any]] = None
189    ) -> bool:
190        """
191        Update an existing document.
192        
193        Args:
194            document_id: ID of the document to update
195            content: New content (if provided)
196            embedding: New embedding (if provided)
197            metadata: New metadata (merged with existing)
198        
199        Returns:
200            True if update was successful, False if document not found
201        """
202        if document_id not in self._documents:
203            return False
204        
205        doc = self._documents[document_id]
206        
207        if content is not None:
208            doc.content = content
209        
210        if embedding is not None:
211            if len(embedding) != self.embedding_dimension:
212                raise ValueError(
213                    f"Embedding dimension {len(embedding)} != expected {self.embedding_dimension}"
214                )
215            doc.embedding = embedding
216            self._rebuild_embeddings()
217        
218        if metadata is not None:
219            doc.metadata.update(metadata)
220        
221        return True
222    
223    def get_document(self, document_id: str) -> Optional[Document]:
224        """
225        Retrieve a document by ID.
226        
227        Args:
228            document_id: ID of the document to retrieve
229        
230        Returns:
231            Document if found, None otherwise
232        """
233        doc = self._documents.get(document_id)
234        return deepcopy(doc) if doc else None
235    
236    def count(self) -> int:
237        """Get the total number of documents in the store."""
238        return len(self._documents)
239    
240    def clear(self) -> None:
241        """Remove all documents from the store."""
242        self._documents.clear()
243        self._embeddings = None
244        self._doc_ids = []
245    
246    # Private helper methods
247    
248    def _rebuild_embeddings(self) -> None:
249        """Rebuild the embedding matrix from stored documents."""
250        docs_with_embeddings = [
251            (doc_id, doc) for doc_id, doc in self._documents.items()
252            if doc.embedding is not None
253        ]
254        
255        if not docs_with_embeddings:
256            self._embeddings = None
257            self._doc_ids = []
258            return
259        
260        self._doc_ids = [doc_id for doc_id, _ in docs_with_embeddings]
261        embeddings_list = [doc.embedding for _, doc in docs_with_embeddings]
262        self._embeddings = np.array(embeddings_list, dtype=np.float32)
263    
264    def _apply_filters(self, filters: Optional[Dict[str, Any]]) -> List[str]:
265        """
266        Apply metadata filters and return list of matching document IDs.
267        
268        Supports filtering on: source, page_number, timestamp, and metadata fields
269        Special operators: use tuples for range queries
270        - (">=", value) for greater than or equal
271        - ("<=", value) for less than or equal  
272        - ("range", min_val, max_val) for range queries
273        """
274        if filters is None:
275            return list(self._documents.keys())
276        
277        matching_ids = []
278        for doc_id, doc in self._documents.items():
279            matches = True
280            
281            for key, value in filters.items():
282                # Check document attributes (source, page_number, timestamp)
283                doc_value = getattr(doc, key, None)
284                
285                # If not a document attribute, check metadata
286                if doc_value is None and key in doc.metadata:
287                    doc_value = doc.metadata[key]
288                
289                # Skip if field doesn't exist
290                if doc_value is None:
291                    matches = False
292                    break
293                
294                # Handle tuple operators for range queries
295                if isinstance(value, tuple):
296                    operator = value[0]
297                    if operator == "range" and len(value) == 3:
298                        min_val, max_val = value[1], value[2]
299                        if not (min_val <= doc_value <= max_val):
300                            matches = False
301                            break
302                    elif operator in [">=", "gt"]:
303                        if not (doc_value >= value[1]):
304                            matches = False
305                            break
306                    elif operator in ["<=", "lt"]:
307                        if not (doc_value <= value[1]):
308                            matches = False
309                            break
310                # Simple equality check
311                elif doc_value != value:
312                    matches = False
313                    break
314            
315            if matches:
316                matching_ids.append(doc_id)
317        
318        return matching_ids
319    
320    def _vector_search(
321        self,
322        query_embedding: List[float],
323        candidate_ids: List[str]
324    ) -> Dict[str, float]:
325        """
326        Perform vector similarity search using cosine similarity.
327        
328        Returns:
329            Dict mapping document IDs to similarity scores (0.0 to 1.0)
330        """
331        if self._embeddings is None or len(self._doc_ids) == 0:
332            return {}
333        
334        query_vector = np.array(query_embedding, dtype=np.float32)
335        
336        # Filter embeddings to only candidate documents
337        candidate_indices = [
338            i for i, doc_id in enumerate(self._doc_ids)
339            if doc_id in candidate_ids
340        ]
341        
342        if not candidate_indices:
343            return {}
344        
345        candidate_embeddings = self._embeddings[candidate_indices]
346        candidate_doc_ids = [self._doc_ids[i] for i in candidate_indices]
347        
348        # Compute cosine similarity
349        query_norm = np.linalg.norm(query_vector)
350        if query_norm == 0:
351            return {doc_id: 0.0 for doc_id in candidate_doc_ids}
352        
353        doc_norms = np.linalg.norm(candidate_embeddings, axis=1)
354        dot_products = candidate_embeddings.dot(query_vector)
355        
356        # Avoid division by zero
357        similarities = np.zeros(len(candidate_doc_ids))
358        valid_mask = doc_norms > 0
359        similarities[valid_mask] = dot_products[valid_mask] / (doc_norms[valid_mask] * query_norm)
360        
361        # Convert from [-1, 1] to [0, 1]
362        similarities = (similarities + 1) / 2
363        
364        return {doc_id: float(sim) for doc_id, sim in zip(candidate_doc_ids, similarities)}
365    
366    def _keyword_search(
367        self,
368        query: str,
369        candidate_ids: List[str]
370    ) -> Dict[str, float]:
371        """
372        Perform simple keyword search using term matching.
373        
374        Returns:
375            Dict mapping document IDs to relevance scores (0.0 to 1.0)
376        """
377        query_lower = query.lower()
378        query_terms = set(query_lower.split())
379        
380        scores = {}
381        for doc_id in candidate_ids:
382            doc = self._documents[doc_id]
383            content_lower = doc.content.lower()
384            content_terms = set(content_lower.split())
385            
386            # Calculate Jaccard similarity as relevance score
387            if not query_terms:
388                scores[doc_id] = 0.0
389            else:
390                intersection = len(query_terms & content_terms)
391                union = len(query_terms | content_terms)
392                scores[doc_id] = intersection / union if union > 0 else 0.0
393        
394        return scores

In-memory vector store using numpy for vector operations.

Features:

  • Fast cosine similarity search using numpy
  • Keyword search using simple text matching
  • Hybrid search combining vector and keyword scores
  • Metadata filtering
  • No external dependencies (Azure, etc.)

Ideal for:

  • Unit testing
  • Development and prototyping
  • Small datasets (< 10,000 documents)
  • CI/CD pipelines

Note: All data is lost when the process terminates.

InMemoryVectorStore(embedding_dimension: int = 1536)
37    def __init__(self, embedding_dimension: int = 1536):
38        """
39        Initialize the in-memory vector store.
40        
41        Args:
42            embedding_dimension: Expected dimension of embeddings (default 1536 for text-embedding-ada-002)
43        """
44        self.embedding_dimension = embedding_dimension
45        self._documents: Dict[str, Document] = {}
46        self._embeddings: Optional[np.ndarray] = None
47        self._doc_ids: List[str] = []

Initialize the in-memory vector store.

Args: embedding_dimension: Expected dimension of embeddings (default 1536 for text-embedding-ada-002)

embedding_dimension
def add_documents( self, documents: List[Document], generate_embeddings: bool = True) -> List[str]:
49    def add_documents(
50        self,
51        documents: List[Document],
52        generate_embeddings: bool = True
53    ) -> List[str]:
54        """
55        Add documents to the in-memory store.
56        
57        Args:
58            documents: List of Document objects to add
59            generate_embeddings: If True, validates embeddings are present
60                                (actual generation should be done externally)
61        
62        Returns:
63            List of document IDs that were added
64        """
65        if not documents:
66            return []
67        
68        added_ids = []
69        
70        for doc in documents:
71            # Validate document
72            if not doc.id or not doc.content:
73                raise ValueError(f"Document must have id and content: {doc}")
74            
75            if generate_embeddings and doc.embedding is None:
76                raise ValueError(
77                    f"Document {doc.id} lacks embedding. "
78                    "Generate embeddings externally before adding to store."
79                )
80            
81            if doc.embedding is not None and len(doc.embedding) != self.embedding_dimension:
82                raise ValueError(
83                    f"Document {doc.id} has embedding dimension {len(doc.embedding)}, "
84                    f"expected {self.embedding_dimension}"
85                )
86            
87            # Store document
88            self._documents[doc.id] = deepcopy(doc)
89            added_ids.append(doc.id)
90        
91        # Rebuild embedding matrix
92        self._rebuild_embeddings()
93        
94        return added_ids

Add documents to the in-memory store.

Args: documents: List of Document objects to add generate_embeddings: If True, validates embeddings are present (actual generation should be done externally)

Returns: List of document IDs that were added

def search( self, query: Optional[str] = None, query_embedding: Optional[List[float]] = None, top_k: int = 5, filters: Optional[Dict[str, Any]] = None, search_type: str = 'vector') -> List[SearchResult]:
 96    def search(
 97        self,
 98        query: Optional[str] = None,
 99        query_embedding: Optional[List[float]] = None,
100        top_k: int = 5,
101        filters: Optional[Dict[str, Any]] = None,
102        search_type: str = "vector"
103    ) -> List[SearchResult]:
104        """
105        Search the in-memory vector store.
106        
107        Args:
108            query: Text query (required for keyword/hybrid search)
109            query_embedding: Vector embedding of the query (required for vector/hybrid search)
110            top_k: Number of results to return
111            filters: Metadata filters (e.g., {"source": "doc.pdf"})
112            search_type: "vector", "keyword", or "hybrid"
113        
114        Returns:
115            List of SearchResult objects, ordered by relevance
116        """
117        if search_type not in ["vector", "keyword", "hybrid"]:
118            raise ValueError(f"Invalid search_type: {search_type}. Must be 'vector', 'keyword', or 'hybrid'")
119        
120        if search_type in ["vector", "hybrid"] and query_embedding is None:
121            raise ValueError(f"{search_type} search requires query_embedding")
122        
123        if search_type in ["keyword", "hybrid"] and query is None:
124            raise ValueError(f"{search_type} search requires query text")
125        
126        if not self._documents:
127            return []
128        
129        # Apply metadata filters
130        filtered_docs = self._apply_filters(filters)
131        if not filtered_docs:
132            return []
133        
134        # Calculate scores based on search type
135        if search_type == "vector":
136            scores = self._vector_search(query_embedding, filtered_docs)
137        elif search_type == "keyword":
138            scores = self._keyword_search(query, filtered_docs)
139        else:  # hybrid
140            vector_scores = self._vector_search(query_embedding, filtered_docs)
141            keyword_scores = self._keyword_search(query, filtered_docs)
142            # Combine scores (50/50 weighting)
143            scores = {
144                doc_id: 0.5 * vector_scores.get(doc_id, 0.0) + 0.5 * keyword_scores.get(doc_id, 0.0)
145                for doc_id in filtered_docs
146            }
147        
148        # Sort by score and take top_k
149        sorted_results = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:top_k]
150        
151        # Build SearchResult objects
152        results = []
153        for rank, (doc_id, score) in enumerate(sorted_results):
154            results.append(SearchResult(
155                document=deepcopy(self._documents[doc_id]),
156                score=float(score),
157                rank=rank
158            ))
159        
160        return results

Search the in-memory vector store.

Args: query: Text query (required for keyword/hybrid search) query_embedding: Vector embedding of the query (required for vector/hybrid search) top_k: Number of results to return filters: Metadata filters (e.g., {"source": "doc.pdf"}) search_type: "vector", "keyword", or "hybrid"

Returns: List of SearchResult objects, ordered by relevance

def delete_documents(self, document_ids: List[str]) -> int:
162    def delete_documents(self, document_ids: List[str]) -> int:
163        """
164        Delete documents from the store.
165        
166        Args:
167            document_ids: List of document IDs to delete
168        
169        Returns:
170            Number of documents successfully deleted
171        """
172        deleted_count = 0
173        for doc_id in document_ids:
174            if doc_id in self._documents:
175                del self._documents[doc_id]
176                deleted_count += 1
177        
178        # Rebuild embedding matrix
179        self._rebuild_embeddings()
180        
181        return deleted_count

Delete documents from the store.

Args: document_ids: List of document IDs to delete

Returns: Number of documents successfully deleted

def update_document( self, document_id: str, content: Optional[str] = None, embedding: Optional[List[float]] = None, metadata: Optional[Dict[str, Any]] = None) -> bool:
183    def update_document(
184        self,
185        document_id: str,
186        content: Optional[str] = None,
187        embedding: Optional[List[float]] = None,
188        metadata: Optional[Dict[str, Any]] = None
189    ) -> bool:
190        """
191        Update an existing document.
192        
193        Args:
194            document_id: ID of the document to update
195            content: New content (if provided)
196            embedding: New embedding (if provided)
197            metadata: New metadata (merged with existing)
198        
199        Returns:
200            True if update was successful, False if document not found
201        """
202        if document_id not in self._documents:
203            return False
204        
205        doc = self._documents[document_id]
206        
207        if content is not None:
208            doc.content = content
209        
210        if embedding is not None:
211            if len(embedding) != self.embedding_dimension:
212                raise ValueError(
213                    f"Embedding dimension {len(embedding)} != expected {self.embedding_dimension}"
214                )
215            doc.embedding = embedding
216            self._rebuild_embeddings()
217        
218        if metadata is not None:
219            doc.metadata.update(metadata)
220        
221        return True

Update an existing document.

Args: document_id: ID of the document to update content: New content (if provided) embedding: New embedding (if provided) metadata: New metadata (merged with existing)

Returns: True if update was successful, False if document not found

def get_document( self, document_id: str) -> Optional[Document]:
223    def get_document(self, document_id: str) -> Optional[Document]:
224        """
225        Retrieve a document by ID.
226        
227        Args:
228            document_id: ID of the document to retrieve
229        
230        Returns:
231            Document if found, None otherwise
232        """
233        doc = self._documents.get(document_id)
234        return deepcopy(doc) if doc else None

Retrieve a document by ID.

Args: document_id: ID of the document to retrieve

Returns: Document if found, None otherwise

def count(self) -> int:
236    def count(self) -> int:
237        """Get the total number of documents in the store."""
238        return len(self._documents)

Get the total number of documents in the store.

def clear(self) -> None:
240    def clear(self) -> None:
241        """Remove all documents from the store."""
242        self._documents.clear()
243        self._embeddings = None
244        self._doc_ids = []

Remove all documents from the store.

class AzureAISearchVectorStore(gmf_forge_ai_data.vector_stores.BaseVectorStore):
 43class AzureAISearchVectorStore(BaseVectorStore):
 44    """
 45    Azure AI Search vector store for production RAG pipelines.
 46    
 47    Features:
 48    - Scalable vector search with HNSW (Hierarchical Navigable Small World)
 49    - Hybrid search (vector + keyword BM25)
 50    - Semantic ranking powered by Microsoft Bing
 51    - Automatic indexing of all document fields (universal + custom)
 52    - Efficient filtering on all indexed fields
 53    - High availability and disaster recovery
 54    - Managed service (no infrastructure management)
 55    
 56    Prerequisites:
 57    - Azure AI Search service (Basic tier or higher for vector search)
 58    - Service endpoint and API key **or** a managed-identity token provider
 59    - Index created with vector field configuration
 60    
 61    Usage:
 62    
 63    **Simple Mode** (base Document):
 64        ```python
 65        # API key auth
 66        store = AzureAISearchVectorStore(
 67            endpoint="https://my-search.search.windows.net",
 68            index_name="documents",
 69            api_key="...",
 70            embedding_dimension=1536
 71        )
 72
 73        # Managed identity / token provider auth
 74        from azure.identity import DefaultAzureCredential, get_bearer_token_provider
 75        token_provider = get_bearer_token_provider(
 76            DefaultAzureCredential(),
 77            "https://search.azure.com/.default"  # Azure AI Search scope
 78        )
 79        store = AzureAISearchVectorStore(
 80            endpoint="https://my-search.search.windows.net",
 81            index_name="documents",
 82            token_provider=token_provider,
 83            embedding_dimension=1536
 84        )
 85        # Uses base Document with universal fields only
 86        ```
 87    
 88    **Custom Schema Mode** (domain-specific Document):
 89        ```python
 90        from dataclasses import dataclass
 91        from gmf_forge_ai_data.vector_stores import Document
 92        
 93        @dataclass
 94        class LegalDocument(Document):
 95            case_number: str = ""
 96            court: str = ""
 97            jurisdiction: str = ""
 98            decision_date: Optional[datetime] = None
 99        
100        # Pass document_type - all fields automatically indexed.
101        # If the index was provisioned externally (portal, Bicep, etc.) the
102        # vector and content field names will likely differ from the defaults
103        # ("embedding" / "content"). Always set vector_field_name and
104        # content_field_name to match your actual index schema.
105        store = AzureAISearchVectorStore(
106            endpoint="https://my-search.search.windows.net",
107            index_name="legal_documents",
108            api_key="...",
109            embedding_dimension=1536,
110            document_type=LegalDocument,
111            vector_field_name="contentVector",   # match your index schema
112            content_field_name="chunkContent",   # match your index schema
113        )
114        
115        # All fields indexed - filter on any field
116        results = store.search(
117            query_embedding=vector,
118            filters={
119                "jurisdiction": "Federal",
120                "court": "Supreme Court",
121                "decision_date": (">=", datetime(2024, 1, 1))
122            }
123        )
124        ```
125    """
126    
127    def __init__(
128        self,
129        endpoint: str,
130        index_name: str,
131        api_key: Optional[str] = None,
132        token_provider: Optional[Callable[[], str]] = None,
133        embedding_dimension: int = 1536,
134        document_type: type = Document,
135        vector_field_name: str = "embedding",
136        content_field_name: str = "content",
137    ):
138        """
139        Initialize Azure AI Search vector store.
140
141        The index must be pre-provisioned using ``AzureAISearchIndexBuilder``
142        before first use.  The constructor only establishes client connections
143        — it never creates or modifies indexes.
144
145        Exactly one of ``api_key`` or ``token_provider`` must be supplied.
146
147        Args:
148            endpoint: Azure AI Search service endpoint
149            index_name: Name of the search index
150            api_key: Azure AI Search API key. Use for local development or
151                when managed identity is not available.
152            token_provider: Zero-argument callable that returns a bearer token
153                string. Use for managed identity / workload identity scenarios.
154                The callable must request the **Azure AI Search** scope::
155
156                    from azure.identity import DefaultAzureCredential, get_bearer_token_provider
157                    token_provider = get_bearer_token_provider(
158                        DefaultAzureCredential(),
159                        "https://search.azure.com/.default"
160                    )
161
162                Note: this scope is different from Azure OpenAI / Cognitive Services
163                (``https://cognitiveservices.azure.com/.default``) — each service
164                requires its own token_provider.
165            embedding_dimension: Dimension of vector embeddings (default 1536)
166            document_type: Document class (base or custom subclass). All fields will be indexed.
167            vector_field_name: Name of the vector field in the Azure Search index (default "embedding").
168                Override this when the index was built externally with a different field name
169                (e.g. "contentVector", "chunkVector").
170            content_field_name: Name of the text content field in the Azure Search index (default "content").
171                Override when the index uses a different name (e.g. "chunkContent", "text").
172        """
173        if not api_key and not token_provider:
174            raise ValueError(
175                "Either api_key or token_provider must be supplied to AzureAISearchVectorStore."
176            )
177        self.endpoint = endpoint
178        self.index_name = index_name
179        self.embedding_dimension = embedding_dimension
180        self.document_type = document_type
181        self.vector_field_name = vector_field_name
182        self.content_field_name = content_field_name
183
184        if token_provider:
185            credential = _TokenProviderCredential(token_provider)
186        else:
187            credential = AzureKeyCredential(api_key)
188
189        # Initialize clients
190        self.search_client = SearchClient(
191            endpoint=endpoint,
192            index_name=index_name,
193            credential=credential
194        )
195    
196    @staticmethod
197    def _format_datetime_for_azure(dt: datetime) -> str:
198        """
199        Convert datetime to Azure Search DateTimeOffset format (ISO 8601 with timezone).
200        
201        Args:
202            dt: datetime object (timezone-aware or naive)
203            
204        Returns:
205            ISO 8601 string with timezone (e.g., '2024-01-15T10:30:00Z')
206        """
207        if dt.tzinfo is None:
208            # Timezone-naive datetime: assume UTC and append 'Z'
209            return dt.isoformat() + 'Z'
210        else:
211            # Timezone-aware datetime: use standard ISO format
212            return dt.isoformat()
213    
214    def add_documents(
215        self,
216        documents: List[Document],
217        generate_embeddings: bool = True
218    ) -> List[str]:
219        """
220        Add documents to Azure AI Search.
221        
222        Args:
223            documents: List of Document objects to add (can be base or custom subclasses)
224            generate_embeddings: If True, validates embeddings are present
225        
226        Returns:
227            List of document IDs that were added
228            
229        Note:
230            Custom document fields are automatically serialized and stored.
231            All fields remain accessible after retrieval.
232        """
233        if not documents:
234            return []
235        
236        # Helper function to format datetime for Azure Search DateTimeOffset
237        format_datetime_for_azure = self._format_datetime_for_azure
238        
239        # Helper function to serialize datetime objects in dictionaries
240        def serialize_for_json(obj):
241            """Recursively convert datetime objects to ISO format strings for JSON serialization."""
242            if isinstance(obj, datetime):
243                return format_datetime_for_azure(obj)
244            elif isinstance(obj, dict):
245                return {k: serialize_for_json(v) for k, v in obj.items()}
246            elif isinstance(obj, (list, tuple)):
247                return [serialize_for_json(item) for item in obj]
248            else:
249                return obj
250        
251        # Prepare documents for upload
252        search_documents = []
253        added_ids = []
254        
255        for doc in documents:
256            if not doc.id or not doc.content:
257                raise ValueError(f"Document must have id and content: {doc}")
258            
259            if generate_embeddings and doc.embedding is None:
260                raise ValueError(
261                    f"Document {doc.id} lacks embedding. "
262                    "Generate embeddings externally before adding."
263                )
264            
265            # Serialize entire document (including custom fields) to JSON
266            import json
267            import dataclasses
268            doc_dict = doc.to_dict()
269            
270            # Convert datetime objects for JSON serialization
271            doc_dict_serializable = serialize_for_json(doc_dict)
272            
273            # Convert to Azure Search document format
274            # Store base fields directly
275            search_doc = {
276                "id": doc.id,
277                "content": doc.content,
278                "embedding": doc.embedding if doc.embedding else [],
279                "timestamp": format_datetime_for_azure(doc.timestamp) if doc.timestamp else None,
280                "document_data": json.dumps(doc_dict_serializable)  # Keep for metadata and backwards compatibility
281            }
282            
283            # Add all custom fields as indexed fields
284            if dataclasses.is_dataclass(doc):
285                for field in dataclasses.fields(doc):
286                    # Skip base Document fields (already added above)
287                    if field.name in {'id', 'content', 'embedding', 'timestamp', 'metadata'}:
288                        continue
289                    
290                    field_value = getattr(doc, field.name, None)
291                    
292                    if field_value is not None:
293                        # Serialize datetime fields with Azure DateTimeOffset format
294                        if isinstance(field_value, datetime):
295                            search_doc[field.name] = format_datetime_for_azure(field_value)
296                        # Handle lists/dicts by converting to JSON string
297                        elif isinstance(field_value, (list, dict)):
298                            search_doc[field.name] = json.dumps(field_value)
299                        else:
300                            search_doc[field.name] = field_value
301            
302            search_documents.append(search_doc)
303            added_ids.append(doc.id)
304        
305        # Upload to Azure Search
306        result = self.search_client.upload_documents(documents=search_documents)
307        
308        # Check for failures
309        failed_count = sum(1 for r in result if not r.succeeded)
310        if failed_count > 0:
311            logger.warning(f"{failed_count} documents failed to upload")
312        
313        return added_ids
314    
315    def search(
316        self,
317        query: Optional[str] = None,
318        query_embedding: Optional[List[float]] = None,
319        top_k: int = 5,
320        filters: Optional[Dict[str, Any]] = None,
321        search_type: str = "vector"
322    ) -> List[SearchResult]:
323        """
324        Search Azure AI Search index.
325        
326        Args:
327            query: Text query (required for keyword/hybrid search)
328            query_embedding: Vector embedding (required for vector/hybrid search)
329            top_k: Number of results to return
330            filters: Metadata filters (e.g., {"source": "doc.pdf"})
331            search_type: "vector", "keyword", or "hybrid"
332        
333        Returns:
334            List of SearchResult objects, ordered by relevance
335        """
336        if search_type not in ["vector", "keyword", "hybrid"]:
337            raise ValueError(f"Invalid search_type: {search_type}")
338        
339        if search_type in ["vector", "hybrid"] and query_embedding is None:
340            raise ValueError(f"{search_type} search requires query_embedding")
341        
342        if search_type in ["keyword", "hybrid"] and query is None:
343            raise ValueError(f"{search_type} search requires query text")
344        
345        # Build filter expression
346        filter_expr = self._build_filter_expression(filters)
347        
348        # Execute search based on type
349        if search_type == "vector":
350            # Pure vector search
351            vector_query = VectorizedQuery(
352                vector=query_embedding,
353                k_nearest_neighbors=top_k,
354                fields=self.vector_field_name
355            )
356            
357            results = self.search_client.search(
358                search_text=None,
359                vector_queries=[vector_query],
360                filter=filter_expr,
361                top=top_k
362            )
363        
364        elif search_type == "keyword":
365            # Pure keyword search (BM25)
366            results = self.search_client.search(
367                search_text=query,
368                filter=filter_expr,
369                top=top_k,
370                include_total_count=False
371            )
372        
373        else:  # hybrid
374            # Hybrid search (vector + keyword)
375            vector_query = VectorizedQuery(
376                vector=query_embedding,
377                k_nearest_neighbors=top_k,
378                fields=self.vector_field_name
379            )
380            
381            results = self.search_client.search(
382                search_text=query,
383                vector_queries=[vector_query],
384                filter=filter_expr,
385                top=top_k
386            )
387        
388        # Convert to SearchResult objects
389        search_results = []
390        for rank, result in enumerate(results):
391            import json
392            
393            # Deserialize full document data
394            doc_data = json.loads(result.get("document_data", "{}") or "{}")
395            
396            if not doc_data:
397                # Index was built externally — fields are stored directly on the result.
398                # Collect all non-Azure-internal fields, excluding the vector blob.
399                doc_data = {
400                    k: v for k, v in result.items()
401                    if not k.startswith("@") and k != self.vector_field_name
402                }
403                # Remap content field when the index uses a non-standard name.
404                if self.content_field_name != "content" and self.content_field_name in doc_data:
405                    doc_data["content"] = doc_data.pop(self.content_field_name)
406            
407            # Reconstruct document using the configured document_type
408            doc = self.document_type.from_dict(doc_data)
409            
410            search_results.append(SearchResult(
411                document=doc,
412                score=result.get("@search.score", 0.0),
413                rank=rank
414            ))
415        
416        return search_results
417    
418    def delete_documents(self, document_ids: List[str]) -> int:
419        """
420        Delete documents from Azure AI Search.
421        
422        Args:
423            document_ids: List of document IDs to delete
424        
425        Returns:
426            Number of documents successfully deleted
427        """
428        if not document_ids:
429            return 0
430        
431        # Prepare delete operations
432        delete_docs = [{"id": doc_id} for doc_id in document_ids]
433        
434        # Execute delete
435        result = self.search_client.delete_documents(documents=delete_docs)
436        
437        # Count successful deletions
438        deleted_count = sum(1 for r in result if r.succeeded)
439        return deleted_count
440    
441    def update_document(
442        self,
443        document_id: str,
444        content: Optional[str] = None,
445        embedding: Optional[List[float]] = None,
446        metadata: Optional[Dict[str, Any]] = None,
447        timestamp: Optional[datetime] = None,
448        **custom_fields
449    ) -> bool:
450        """
451        Update an existing document in Azure AI Search.
452        
453        Args:
454            document_id: ID of the document to update
455            content: New content (if provided)
456            embedding: New embedding (if provided)
457            metadata: New metadata (merged with existing)
458            timestamp: New timestamp (if provided)
459            **custom_fields: Any custom fields to update
460        
461        Returns:
462            True if update was successful, False if document not found
463        """
464        try:
465            import json
466            
467            # Retrieve existing document
468            existing_doc = self.search_client.get_document(key=document_id)
469            doc_data = json.loads(existing_doc.get("document_data", "{}"))
470            
471            # Update fields in document data
472            if content is not None:
473                doc_data["content"] = content
474                existing_doc["content"] = content
475            
476            if embedding is not None:
477                doc_data["embedding"] = embedding
478                existing_doc["embedding"] = embedding
479            
480            if timestamp is not None:
481                doc_data["timestamp"] = self._format_datetime_for_azure(timestamp)
482                existing_doc["timestamp"] = self._format_datetime_for_azure(timestamp)
483            
484            if metadata is not None:
485                if "metadata" not in doc_data:
486                    doc_data["metadata"] = {}
487                doc_data["metadata"].update(metadata)
488            
489            # Update custom fields
490            for key, value in custom_fields.items():
491                doc_data[key] = value
492            
493            # Update document_data
494            existing_doc["document_data"] = json.dumps(doc_data)
495            
496            # Merge/upload updated document
497            result = self.search_client.merge_or_upload_documents(documents=[existing_doc])
498            
499            return result[0].succeeded
500        
501        except Exception as e:
502            logger.error(f"Failed to update document {document_id}: {e}")
503            return False
504    
505    def get_document(self, document_id: str) -> Optional[Document]:
506        """
507        Retrieve a document by ID from Azure AI Search.
508        
509        Args:
510            document_id: ID of the document to retrieve
511        
512        Returns:
513            Document if found, None otherwise (returns configured document_type)
514        """
515        try:
516            import json
517            result = self.search_client.get_document(key=document_id)
518            
519            # Deserialize full document data using configured document_type
520            doc_data = json.loads(result.get("document_data", "{}"))
521            return self.document_type.from_dict(doc_data)
522        
523        except Exception as e:
524            logger.debug(f"Document {document_id} not found: {e}")
525            return None
526
527    def get_document_by_parent(self, document_id: str) -> List[Document]:
528        """
529        Retrieve all chunks belonging to a parent document.
530
531        Args:
532            document_id: The value of the ``documentId`` field shared across
533                        all chunks of the same parent document.
534
535        Returns:
536            List of Document objects sorted by ``pageNumber``, or an empty
537            list if no matching chunks are found.
538        """
539        results = self.search(
540            query="*",
541            search_type="keyword",
542            filters={"documentId": document_id},
543            top_k=1000,
544        )
545        chunks = [r.document for r in results]
546        chunks.sort(key=lambda c: getattr(c, "pageNumber", 0))
547        return chunks
548
549    def count(self) -> int:
550        """
551        Get the total number of documents in the index.
552        
553        Returns:
554            Total document count
555        """
556        # Execute a search to get total count
557        results = self.search_client.search(
558            search_text="*",
559            include_total_count=True,
560            top=0
561        )
562        
563        # Get count from results
564        count = getattr(results, 'get_count', lambda: 0)()
565        return count if count is not None else 0
566    
567    def clear(self) -> None:
568        """
569        Remove all documents from the index.
570        
571        Warning: This operation is irreversible.
572        """
573        # Search for all document IDs
574        results = self.search_client.search(
575            search_text="*",
576            select=["id"],
577            top=1000  # Adjust batch size as needed
578        )
579        
580        # Collect all IDs
581        doc_ids = [r["id"] for r in results]
582        
583        # Delete in batches
584        if doc_ids:
585            self.delete_documents(doc_ids)
586            logger.info(f"Cleared {len(doc_ids)} documents from index '{self.index_name}'")
587    
588    def _build_filter_expression(self, filters: Optional[Dict[str, Any]]) -> Optional[str]:
589        """
590        Build OData filter expression from filter dictionary.
591        
592        Supports filtering on all indexed fields (universal + custom fields from document_type).
593        
594        Args:
595            filters: Dictionary of field: value pairs
596                    Special operators: use tuples for range queries
597                    - (">=", value) for greater than or equal
598                    - ("<=", value) for less than or equal
599                    - ("range", min_val, max_val) for range queries
600        
601        Returns:
602            OData filter string or None
603            
604        Example:
605            ```python
606            # Filter on universal field
607            filters = {"timestamp": (">=", datetime(2024, 1, 1))}
608            
609            # Filter on custom fields (if document_type has them)
610            filters = {
611                "court": "Supreme Court",
612                "case_number": "2024-001",
613                "decision_date": ("range", start_date, end_date)
614            }
615            ```
616        """
617        if not filters:
618            return None
619        
620        # Get indexed field names from document_type
621        import dataclasses
622        indexed_fields = {'id', 'content', 'timestamp'}  # Universal fields
623        
624        if dataclasses.is_dataclass(self.document_type):
625            for field in dataclasses.fields(self.document_type):
626                if field.name not in {'id', 'content', 'embedding', 'timestamp', 'metadata'}:
627                    indexed_fields.add(field.name)
628        
629        filter_parts = []
630        
631        for key, value in filters.items():
632            # Check if field is indexed
633            if key not in indexed_fields:
634                logger.warning(
635                    f"Filtering on '{key}' - field not found in document schema. "
636                    f"Indexed fields: {indexed_fields}"
637                )
638                continue
639            
640            # Handle tuple operators for range queries
641            if isinstance(value, tuple):
642                operator = value[0]
643                if operator == "range" and len(value) == 3:
644                    # Range query: field >= min and field <= max
645                    min_val = value[1]
646                    max_val = value[2]
647                    if isinstance(min_val, datetime):
648                        min_val = self._format_datetime_for_azure(min_val)
649                    if isinstance(max_val, datetime):
650                        max_val = self._format_datetime_for_azure(max_val)
651                    filter_parts.append(f"{key} ge {min_val} and {key} le {max_val}")
652                elif operator in [">=", "gt", ">="]:
653                    op_str = "ge"
654                    val = value[1]
655                    if isinstance(val, datetime):
656                        val = self._format_datetime_for_azure(val)
657                    filter_parts.append(f"{key} {op_str} {val}")
658                elif operator in ["<=", "lt", "<="]:
659                    op_str = "le"
660                    val = value[1]
661                    if isinstance(val, datetime):
662                        val = self._format_datetime_for_azure(val)
663                    filter_parts.append(f"{key} {op_str} {val}")
664            # Handle datetime equality
665            elif isinstance(value, datetime):
666                filter_parts.append(f"{key} eq {self._format_datetime_for_azure(value)}")
667            # Handle string equality (needs quotes in OData)
668            elif isinstance(value, str):
669                # Escape single quotes in the value
670                escaped_value = value.replace("'", "''")
671                filter_parts.append(f"{key} eq '{escaped_value}'")
672            # Handle boolean
673            elif isinstance(value, bool):
674                filter_parts.append(f"{key} eq {str(value).lower()}")
675            # Handle numeric types
676            elif isinstance(value, (int, float)):
677                filter_parts.append(f"{key} eq {value}")
678        
679        return " and ".join(filter_parts) if filter_parts else None

Azure AI Search vector store for production RAG pipelines.

Features:

  • Scalable vector search with HNSW (Hierarchical Navigable Small World)
  • Hybrid search (vector + keyword BM25)
  • Semantic ranking powered by Microsoft Bing
  • Automatic indexing of all document fields (universal + custom)
  • Efficient filtering on all indexed fields
  • High availability and disaster recovery
  • Managed service (no infrastructure management)

Prerequisites:

  • Azure AI Search service (Basic tier or higher for vector search)
  • Service endpoint and API key or a managed-identity token provider
  • Index created with vector field configuration

Usage:

Simple Mode (base Document):

# API key auth
store = AzureAISearchVectorStore(
    endpoint="https://my-search.search.windows.net",
    index_name="documents",
    api_key="...",
    embedding_dimension=1536
)

# Managed identity / token provider auth
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
token_provider = get_bearer_token_provider(
    DefaultAzureCredential(),
    "https://search.azure.com/.default"  # Azure AI Search scope
)
store = AzureAISearchVectorStore(
    endpoint="https://my-search.search.windows.net",
    index_name="documents",
    token_provider=token_provider,
    embedding_dimension=1536
)
# Uses base Document with universal fields only

Custom Schema Mode (domain-specific Document):

from dataclasses import dataclass
from gmf_forge_ai_data.vector_stores import Document

@dataclass
class LegalDocument(Document):
    case_number: str = ""
    court: str = ""
    jurisdiction: str = ""
    decision_date: Optional[datetime] = None

# Pass document_type - all fields automatically indexed.
# If the index was provisioned externally (portal, Bicep, etc.) the
# vector and content field names will likely differ from the defaults
# ("embedding" / "content"). Always set vector_field_name and
# content_field_name to match your actual index schema.
store = AzureAISearchVectorStore(
    endpoint="https://my-search.search.windows.net",
    index_name="legal_documents",
    api_key="...",
    embedding_dimension=1536,
    document_type=LegalDocument,
    vector_field_name="contentVector",   # match your index schema
    content_field_name="chunkContent",   # match your index schema
)

# All fields indexed - filter on any field
results = store.search(
    query_embedding=vector,
    filters={
        "jurisdiction": "Federal",
        "court": "Supreme Court",
        "decision_date": (">=", datetime(2024, 1, 1))
    }
)
AzureAISearchVectorStore( endpoint: str, index_name: str, api_key: Optional[str] = None, token_provider: Optional[Callable[[], str]] = None, embedding_dimension: int = 1536, document_type: type = <class 'Document'>, vector_field_name: str = 'embedding', content_field_name: str = 'content')
127    def __init__(
128        self,
129        endpoint: str,
130        index_name: str,
131        api_key: Optional[str] = None,
132        token_provider: Optional[Callable[[], str]] = None,
133        embedding_dimension: int = 1536,
134        document_type: type = Document,
135        vector_field_name: str = "embedding",
136        content_field_name: str = "content",
137    ):
138        """
139        Initialize Azure AI Search vector store.
140
141        The index must be pre-provisioned using ``AzureAISearchIndexBuilder``
142        before first use.  The constructor only establishes client connections
143        — it never creates or modifies indexes.
144
145        Exactly one of ``api_key`` or ``token_provider`` must be supplied.
146
147        Args:
148            endpoint: Azure AI Search service endpoint
149            index_name: Name of the search index
150            api_key: Azure AI Search API key. Use for local development or
151                when managed identity is not available.
152            token_provider: Zero-argument callable that returns a bearer token
153                string. Use for managed identity / workload identity scenarios.
154                The callable must request the **Azure AI Search** scope::
155
156                    from azure.identity import DefaultAzureCredential, get_bearer_token_provider
157                    token_provider = get_bearer_token_provider(
158                        DefaultAzureCredential(),
159                        "https://search.azure.com/.default"
160                    )
161
162                Note: this scope is different from Azure OpenAI / Cognitive Services
163                (``https://cognitiveservices.azure.com/.default``) — each service
164                requires its own token_provider.
165            embedding_dimension: Dimension of vector embeddings (default 1536)
166            document_type: Document class (base or custom subclass). All fields will be indexed.
167            vector_field_name: Name of the vector field in the Azure Search index (default "embedding").
168                Override this when the index was built externally with a different field name
169                (e.g. "contentVector", "chunkVector").
170            content_field_name: Name of the text content field in the Azure Search index (default "content").
171                Override when the index uses a different name (e.g. "chunkContent", "text").
172        """
173        if not api_key and not token_provider:
174            raise ValueError(
175                "Either api_key or token_provider must be supplied to AzureAISearchVectorStore."
176            )
177        self.endpoint = endpoint
178        self.index_name = index_name
179        self.embedding_dimension = embedding_dimension
180        self.document_type = document_type
181        self.vector_field_name = vector_field_name
182        self.content_field_name = content_field_name
183
184        if token_provider:
185            credential = _TokenProviderCredential(token_provider)
186        else:
187            credential = AzureKeyCredential(api_key)
188
189        # Initialize clients
190        self.search_client = SearchClient(
191            endpoint=endpoint,
192            index_name=index_name,
193            credential=credential
194        )

Initialize Azure AI Search vector store.

The index must be pre-provisioned using AzureAISearchIndexBuilder before first use. The constructor only establishes client connections — it never creates or modifies indexes.

Exactly one of api_key or token_provider must be supplied.

Args: endpoint: Azure AI Search service endpoint index_name: Name of the search index api_key: Azure AI Search API key. Use for local development or when managed identity is not available. token_provider: Zero-argument callable that returns a bearer token string. Use for managed identity / workload identity scenarios. The callable must request the Azure AI Search scope::

        from azure.identity import DefaultAzureCredential, get_bearer_token_provider
        token_provider = get_bearer_token_provider(
            DefaultAzureCredential(),
            "https://search.azure.com/.default"
        )

    Note: this scope is different from Azure OpenAI / Cognitive Services
    (``https://cognitiveservices.azure.com/.default``) — each service
    requires its own token_provider.
embedding_dimension: Dimension of vector embeddings (default 1536)
document_type: Document class (base or custom subclass). All fields will be indexed.
vector_field_name: Name of the vector field in the Azure Search index (default "embedding").
    Override this when the index was built externally with a different field name
    (e.g. "contentVector", "chunkVector").
content_field_name: Name of the text content field in the Azure Search index (default "content").
    Override when the index uses a different name (e.g. "chunkContent", "text").
endpoint
index_name
embedding_dimension
document_type
vector_field_name
content_field_name
search_client
def add_documents( self, documents: List[Document], generate_embeddings: bool = True) -> List[str]:
214    def add_documents(
215        self,
216        documents: List[Document],
217        generate_embeddings: bool = True
218    ) -> List[str]:
219        """
220        Add documents to Azure AI Search.
221        
222        Args:
223            documents: List of Document objects to add (can be base or custom subclasses)
224            generate_embeddings: If True, validates embeddings are present
225        
226        Returns:
227            List of document IDs that were added
228            
229        Note:
230            Custom document fields are automatically serialized and stored.
231            All fields remain accessible after retrieval.
232        """
233        if not documents:
234            return []
235        
236        # Helper function to format datetime for Azure Search DateTimeOffset
237        format_datetime_for_azure = self._format_datetime_for_azure
238        
239        # Helper function to serialize datetime objects in dictionaries
240        def serialize_for_json(obj):
241            """Recursively convert datetime objects to ISO format strings for JSON serialization."""
242            if isinstance(obj, datetime):
243                return format_datetime_for_azure(obj)
244            elif isinstance(obj, dict):
245                return {k: serialize_for_json(v) for k, v in obj.items()}
246            elif isinstance(obj, (list, tuple)):
247                return [serialize_for_json(item) for item in obj]
248            else:
249                return obj
250        
251        # Prepare documents for upload
252        search_documents = []
253        added_ids = []
254        
255        for doc in documents:
256            if not doc.id or not doc.content:
257                raise ValueError(f"Document must have id and content: {doc}")
258            
259            if generate_embeddings and doc.embedding is None:
260                raise ValueError(
261                    f"Document {doc.id} lacks embedding. "
262                    "Generate embeddings externally before adding."
263                )
264            
265            # Serialize entire document (including custom fields) to JSON
266            import json
267            import dataclasses
268            doc_dict = doc.to_dict()
269            
270            # Convert datetime objects for JSON serialization
271            doc_dict_serializable = serialize_for_json(doc_dict)
272            
273            # Convert to Azure Search document format
274            # Store base fields directly
275            search_doc = {
276                "id": doc.id,
277                "content": doc.content,
278                "embedding": doc.embedding if doc.embedding else [],
279                "timestamp": format_datetime_for_azure(doc.timestamp) if doc.timestamp else None,
280                "document_data": json.dumps(doc_dict_serializable)  # Keep for metadata and backwards compatibility
281            }
282            
283            # Add all custom fields as indexed fields
284            if dataclasses.is_dataclass(doc):
285                for field in dataclasses.fields(doc):
286                    # Skip base Document fields (already added above)
287                    if field.name in {'id', 'content', 'embedding', 'timestamp', 'metadata'}:
288                        continue
289                    
290                    field_value = getattr(doc, field.name, None)
291                    
292                    if field_value is not None:
293                        # Serialize datetime fields with Azure DateTimeOffset format
294                        if isinstance(field_value, datetime):
295                            search_doc[field.name] = format_datetime_for_azure(field_value)
296                        # Handle lists/dicts by converting to JSON string
297                        elif isinstance(field_value, (list, dict)):
298                            search_doc[field.name] = json.dumps(field_value)
299                        else:
300                            search_doc[field.name] = field_value
301            
302            search_documents.append(search_doc)
303            added_ids.append(doc.id)
304        
305        # Upload to Azure Search
306        result = self.search_client.upload_documents(documents=search_documents)
307        
308        # Check for failures
309        failed_count = sum(1 for r in result if not r.succeeded)
310        if failed_count > 0:
311            logger.warning(f"{failed_count} documents failed to upload")
312        
313        return added_ids

Add documents to Azure AI Search.

Args: documents: List of Document objects to add (can be base or custom subclasses) generate_embeddings: If True, validates embeddings are present

Returns: List of document IDs that were added

Note: Custom document fields are automatically serialized and stored. All fields remain accessible after retrieval.

def search( self, query: Optional[str] = None, query_embedding: Optional[List[float]] = None, top_k: int = 5, filters: Optional[Dict[str, Any]] = None, search_type: str = 'vector') -> List[SearchResult]:
315    def search(
316        self,
317        query: Optional[str] = None,
318        query_embedding: Optional[List[float]] = None,
319        top_k: int = 5,
320        filters: Optional[Dict[str, Any]] = None,
321        search_type: str = "vector"
322    ) -> List[SearchResult]:
323        """
324        Search Azure AI Search index.
325        
326        Args:
327            query: Text query (required for keyword/hybrid search)
328            query_embedding: Vector embedding (required for vector/hybrid search)
329            top_k: Number of results to return
330            filters: Metadata filters (e.g., {"source": "doc.pdf"})
331            search_type: "vector", "keyword", or "hybrid"
332        
333        Returns:
334            List of SearchResult objects, ordered by relevance
335        """
336        if search_type not in ["vector", "keyword", "hybrid"]:
337            raise ValueError(f"Invalid search_type: {search_type}")
338        
339        if search_type in ["vector", "hybrid"] and query_embedding is None:
340            raise ValueError(f"{search_type} search requires query_embedding")
341        
342        if search_type in ["keyword", "hybrid"] and query is None:
343            raise ValueError(f"{search_type} search requires query text")
344        
345        # Build filter expression
346        filter_expr = self._build_filter_expression(filters)
347        
348        # Execute search based on type
349        if search_type == "vector":
350            # Pure vector search
351            vector_query = VectorizedQuery(
352                vector=query_embedding,
353                k_nearest_neighbors=top_k,
354                fields=self.vector_field_name
355            )
356            
357            results = self.search_client.search(
358                search_text=None,
359                vector_queries=[vector_query],
360                filter=filter_expr,
361                top=top_k
362            )
363        
364        elif search_type == "keyword":
365            # Pure keyword search (BM25)
366            results = self.search_client.search(
367                search_text=query,
368                filter=filter_expr,
369                top=top_k,
370                include_total_count=False
371            )
372        
373        else:  # hybrid
374            # Hybrid search (vector + keyword)
375            vector_query = VectorizedQuery(
376                vector=query_embedding,
377                k_nearest_neighbors=top_k,
378                fields=self.vector_field_name
379            )
380            
381            results = self.search_client.search(
382                search_text=query,
383                vector_queries=[vector_query],
384                filter=filter_expr,
385                top=top_k
386            )
387        
388        # Convert to SearchResult objects
389        search_results = []
390        for rank, result in enumerate(results):
391            import json
392            
393            # Deserialize full document data
394            doc_data = json.loads(result.get("document_data", "{}") or "{}")
395            
396            if not doc_data:
397                # Index was built externally — fields are stored directly on the result.
398                # Collect all non-Azure-internal fields, excluding the vector blob.
399                doc_data = {
400                    k: v for k, v in result.items()
401                    if not k.startswith("@") and k != self.vector_field_name
402                }
403                # Remap content field when the index uses a non-standard name.
404                if self.content_field_name != "content" and self.content_field_name in doc_data:
405                    doc_data["content"] = doc_data.pop(self.content_field_name)
406            
407            # Reconstruct document using the configured document_type
408            doc = self.document_type.from_dict(doc_data)
409            
410            search_results.append(SearchResult(
411                document=doc,
412                score=result.get("@search.score", 0.0),
413                rank=rank
414            ))
415        
416        return search_results

Search Azure AI Search index.

Args: query: Text query (required for keyword/hybrid search) query_embedding: Vector embedding (required for vector/hybrid search) top_k: Number of results to return filters: Metadata filters (e.g., {"source": "doc.pdf"}) search_type: "vector", "keyword", or "hybrid"

Returns: List of SearchResult objects, ordered by relevance

def delete_documents(self, document_ids: List[str]) -> int:
418    def delete_documents(self, document_ids: List[str]) -> int:
419        """
420        Delete documents from Azure AI Search.
421        
422        Args:
423            document_ids: List of document IDs to delete
424        
425        Returns:
426            Number of documents successfully deleted
427        """
428        if not document_ids:
429            return 0
430        
431        # Prepare delete operations
432        delete_docs = [{"id": doc_id} for doc_id in document_ids]
433        
434        # Execute delete
435        result = self.search_client.delete_documents(documents=delete_docs)
436        
437        # Count successful deletions
438        deleted_count = sum(1 for r in result if r.succeeded)
439        return deleted_count

Delete documents from Azure AI Search.

Args: document_ids: List of document IDs to delete

Returns: Number of documents successfully deleted

def update_document( self, document_id: str, content: Optional[str] = None, embedding: Optional[List[float]] = None, metadata: Optional[Dict[str, Any]] = None, timestamp: Optional[datetime.datetime] = None, **custom_fields) -> bool:
441    def update_document(
442        self,
443        document_id: str,
444        content: Optional[str] = None,
445        embedding: Optional[List[float]] = None,
446        metadata: Optional[Dict[str, Any]] = None,
447        timestamp: Optional[datetime] = None,
448        **custom_fields
449    ) -> bool:
450        """
451        Update an existing document in Azure AI Search.
452        
453        Args:
454            document_id: ID of the document to update
455            content: New content (if provided)
456            embedding: New embedding (if provided)
457            metadata: New metadata (merged with existing)
458            timestamp: New timestamp (if provided)
459            **custom_fields: Any custom fields to update
460        
461        Returns:
462            True if update was successful, False if document not found
463        """
464        try:
465            import json
466            
467            # Retrieve existing document
468            existing_doc = self.search_client.get_document(key=document_id)
469            doc_data = json.loads(existing_doc.get("document_data", "{}"))
470            
471            # Update fields in document data
472            if content is not None:
473                doc_data["content"] = content
474                existing_doc["content"] = content
475            
476            if embedding is not None:
477                doc_data["embedding"] = embedding
478                existing_doc["embedding"] = embedding
479            
480            if timestamp is not None:
481                doc_data["timestamp"] = self._format_datetime_for_azure(timestamp)
482                existing_doc["timestamp"] = self._format_datetime_for_azure(timestamp)
483            
484            if metadata is not None:
485                if "metadata" not in doc_data:
486                    doc_data["metadata"] = {}
487                doc_data["metadata"].update(metadata)
488            
489            # Update custom fields
490            for key, value in custom_fields.items():
491                doc_data[key] = value
492            
493            # Update document_data
494            existing_doc["document_data"] = json.dumps(doc_data)
495            
496            # Merge/upload updated document
497            result = self.search_client.merge_or_upload_documents(documents=[existing_doc])
498            
499            return result[0].succeeded
500        
501        except Exception as e:
502            logger.error(f"Failed to update document {document_id}: {e}")
503            return False

Update an existing document in Azure AI Search.

Args: document_id: ID of the document to update content: New content (if provided) embedding: New embedding (if provided) metadata: New metadata (merged with existing) timestamp: New timestamp (if provided) **custom_fields: Any custom fields to update

Returns: True if update was successful, False if document not found

def get_document( self, document_id: str) -> Optional[Document]:
505    def get_document(self, document_id: str) -> Optional[Document]:
506        """
507        Retrieve a document by ID from Azure AI Search.
508        
509        Args:
510            document_id: ID of the document to retrieve
511        
512        Returns:
513            Document if found, None otherwise (returns configured document_type)
514        """
515        try:
516            import json
517            result = self.search_client.get_document(key=document_id)
518            
519            # Deserialize full document data using configured document_type
520            doc_data = json.loads(result.get("document_data", "{}"))
521            return self.document_type.from_dict(doc_data)
522        
523        except Exception as e:
524            logger.debug(f"Document {document_id} not found: {e}")
525            return None

Retrieve a document by ID from Azure AI Search.

Args: document_id: ID of the document to retrieve

Returns: Document if found, None otherwise (returns configured document_type)

def get_document_by_parent( self, document_id: str) -> List[Document]:
527    def get_document_by_parent(self, document_id: str) -> List[Document]:
528        """
529        Retrieve all chunks belonging to a parent document.
530
531        Args:
532            document_id: The value of the ``documentId`` field shared across
533                        all chunks of the same parent document.
534
535        Returns:
536            List of Document objects sorted by ``pageNumber``, or an empty
537            list if no matching chunks are found.
538        """
539        results = self.search(
540            query="*",
541            search_type="keyword",
542            filters={"documentId": document_id},
543            top_k=1000,
544        )
545        chunks = [r.document for r in results]
546        chunks.sort(key=lambda c: getattr(c, "pageNumber", 0))
547        return chunks

Retrieve all chunks belonging to a parent document.

Args: document_id: The value of the documentId field shared across all chunks of the same parent document.

Returns: List of Document objects sorted by pageNumber, or an empty list if no matching chunks are found.

def count(self) -> int:
549    def count(self) -> int:
550        """
551        Get the total number of documents in the index.
552        
553        Returns:
554            Total document count
555        """
556        # Execute a search to get total count
557        results = self.search_client.search(
558            search_text="*",
559            include_total_count=True,
560            top=0
561        )
562        
563        # Get count from results
564        count = getattr(results, 'get_count', lambda: 0)()
565        return count if count is not None else 0

Get the total number of documents in the index.

Returns: Total document count

def clear(self) -> None:
567    def clear(self) -> None:
568        """
569        Remove all documents from the index.
570        
571        Warning: This operation is irreversible.
572        """
573        # Search for all document IDs
574        results = self.search_client.search(
575            search_text="*",
576            select=["id"],
577            top=1000  # Adjust batch size as needed
578        )
579        
580        # Collect all IDs
581        doc_ids = [r["id"] for r in results]
582        
583        # Delete in batches
584        if doc_ids:
585            self.delete_documents(doc_ids)
586            logger.info(f"Cleared {len(doc_ids)} documents from index '{self.index_name}'")

Remove all documents from the index.

Warning: This operation is irreversible.

class AzureCosmosDBVectorStore(gmf_forge_ai_data.vector_stores.BaseVectorStore):
 32class AzureCosmosDBVectorStore(BaseVectorStore):
 33    """
 34    Azure Cosmos DB NoSQL API vector store.
 35
 36    Stores documents in a Cosmos DB NoSQL container and queries them using
 37    Cosmos DB's vector distance functions and SQL-like query language.
 38
 39    Features:
 40    - Vector search using Cosmos DB vector indexing (cosine distance)
 41    - Full-text keyword search via SQL CONTAINS / LIKE queries
 42    - Hybrid search (vector + keyword, equal weighting)
 43    - Metadata filtering via SQL WHERE clauses
 44    - Custom ``document_type`` support (any Document subclass)
 45    - Automatic container creation with vector embedding policy
 46
 47    Usage::
 48
 49        from gmf_forge_ai_data.vector_stores import AzureCosmosDBVectorStore
 50
 51        store = AzureCosmosDBVectorStore(
 52            endpoint="https://your-account.documents.azure.com:443/",
 53            key="your-key",
 54            database_name="rag_db",
 55            container_name="documents",
 56            embedding_dimension=1536,
 57        )
 58        store.add_documents(docs_with_embeddings)
 59        results = store.search(query_embedding=embedding, top_k=5)
 60    """
 61
 62    def __init__(
 63        self,
 64        endpoint: str,
 65        key: str,
 66        database_name: str,
 67        container_name: str,
 68        embedding_dimension: int = 1536,
 69        document_type: type = Document,
 70        ssl_cert_path: Optional[str] = None,
 71    ) -> None:
 72        """
 73        Initialise the Cosmos DB NoSQL vector store.
 74
 75        The database and container must be pre-provisioned using
 76        ``CosmosDBIndexBuilder`` before first use.  The constructor only
 77        establishes client connections — it never creates or modifies
 78        databases or containers.
 79
 80        Args:
 81            endpoint: Cosmos DB account endpoint
 82                (e.g. ``https://your-account.documents.azure.com:443/``).
 83            key: Cosmos DB account key or resource token.
 84            database_name: Name of the Cosmos DB database.
 85            container_name: Name of the container.
 86            embedding_dimension: Dimension of vector embeddings (default 1536).
 87            document_type: Document class to use when deserialising results.
 88                Pass a custom subclass to support additional typed fields.
 89            ssl_cert_path: Path to a CA certificate bundle (PEM) for TLS
 90                verification.  Useful in corporate environments with custom
 91                certificate authorities.
 92        """
 93        try:
 94            from azure.cosmos import CosmosClient  # noqa: F401
 95        except ImportError as exc:
 96            raise ImportError(
 97                "azure-cosmos is required for AzureCosmosDBVectorStore. "
 98                "Install it with:  pip install azure-cosmos"
 99            ) from exc
100
101        self.embedding_dimension = embedding_dimension
102        self.document_type = document_type
103        self._container_name = container_name
104
105        # Build client with optional custom SSL cert
106        import ssl as _ssl
107        connection_kwargs: Dict[str, Any] = {}
108        if ssl_cert_path:
109            ssl_context = _ssl.create_default_context(cafile=ssl_cert_path)
110            connection_kwargs["connection_verify"] = ssl_cert_path
111
112        from azure.cosmos import CosmosClient
113        self._client = CosmosClient(endpoint, credential=key, **connection_kwargs)
114
115        # Database and container — pre-provisioned by developer
116        self._database = self._client.get_database_client(database_name)
117        self._container = self._database.get_container_client(container_name)
118
119    # ------------------------------------------------------------------
120    # Serialisation helpers
121    # ------------------------------------------------------------------
122
123    @staticmethod
124    def _serialize_value(obj: Any) -> Any:
125        """Recursively convert datetime objects to ISO strings."""
126        if isinstance(obj, datetime):
127            return obj.isoformat()
128        if isinstance(obj, dict):
129            return {k: AzureCosmosDBVectorStore._serialize_value(v) for k, v in obj.items()}
130        if isinstance(obj, (list, tuple)):
131            return [AzureCosmosDBVectorStore._serialize_value(i) for i in obj]
132        return obj
133
134    def _to_cosmos(self, doc: Document) -> Dict[str, Any]:
135        """Convert a Document to a Cosmos DB item dict."""
136        doc_dict = doc.to_dict()
137        doc_dict_serialisable = self._serialize_value(doc_dict)
138        item = {
139            "id": doc.id,
140            "content": doc.content,
141            "embedding": doc.embedding or [],
142            "timestamp": doc.timestamp.isoformat() if doc.timestamp else None,
143            "metadata": doc.metadata,
144            "document_data": json.dumps(doc_dict_serialisable),
145        }
146        # Promote custom fields (e.g. category, author, year) to top-level
147        # so they are filterable via SQL WHERE clauses
148        base_keys = {"id", "content", "embedding", "timestamp", "metadata"}
149        for key, value in doc_dict_serialisable.items():
150            if key not in base_keys:
151                item[key] = value
152        return item
153
154    def _from_cosmos(self, cosmos_item: Dict[str, Any]) -> Document:
155        """Reconstruct a Document (or subclass) from a Cosmos DB item."""
156        doc_data = json.loads(cosmos_item.get("document_data", "{}"))
157        return self.document_type.from_dict(doc_data)
158
159    # ------------------------------------------------------------------
160    # Filter helpers
161    # ------------------------------------------------------------------
162
163    @staticmethod
164    def _build_where_clause(filters: Optional[Dict[str, Any]]) -> str:
165        """Translate a user-facing filter dict to a SQL WHERE clause."""
166        if not filters:
167            return ""
168        op_map = {">=": ">=", "<=": "<=", ">": ">", "<": "<", "!=": "!="}
169        conditions: List[str] = []
170        for key, value in filters.items():
171            if isinstance(value, tuple) and len(value) == 2:
172                op, val = value
173                sql_op = op_map.get(op, "=")
174                if isinstance(val, str):
175                    conditions.append(f"c.{key} {sql_op} '{val}'")
176                elif isinstance(val, datetime):
177                    conditions.append(f"c.{key} {sql_op} '{val.isoformat()}'")
178                else:
179                    conditions.append(f"c.{key} {sql_op} {val}")
180            else:
181                if isinstance(value, str):
182                    conditions.append(f"c.{key} = '{value}'")
183                elif isinstance(value, bool):
184                    conditions.append(f"c.{key} = {'true' if value else 'false'}")
185                elif isinstance(value, datetime):
186                    conditions.append(f"c.{key} = '{value.isoformat()}'")
187                else:
188                    conditions.append(f"c.{key} = {value}")
189        return " AND ".join(conditions)
190
191    # ------------------------------------------------------------------
192    # BaseVectorStore interface
193    # ------------------------------------------------------------------
194
195    def add_documents(
196        self,
197        documents: List[Document],
198        generate_embeddings: bool = True,
199    ) -> List[str]:
200        """
201        Upsert documents into the Cosmos DB container.
202
203        Args:
204            documents: Documents to add.  Each must have ``id`` and ``content``
205                set.  If *generate_embeddings* is ``True`` (default) every
206                document must also carry a pre-computed ``embedding``.
207            generate_embeddings: When ``True`` the method validates that all
208                documents already have embeddings rather than generating them
209                itself (generation is the caller's responsibility).
210
211        Returns:
212            List of document IDs that were successfully upserted.
213        """
214        if not documents:
215            return []
216
217        added_ids: List[str] = []
218
219        for doc in documents:
220            if not doc.id or not doc.content:
221                raise ValueError(f"Document must have id and content: {doc}")
222
223            if generate_embeddings and doc.embedding is None:
224                raise ValueError(
225                    f"Document '{doc.id}' has no embedding.  "
226                    "Generate embeddings before calling add_documents()."
227                )
228
229            cosmos_item = self._to_cosmos(doc)
230            self._container.upsert_item(cosmos_item)
231            added_ids.append(doc.id)
232
233        logger.info(
234            "Upserted %d document(s) into Cosmos DB container '%s'",
235            len(added_ids),
236            self._container_name,
237        )
238        return added_ids
239
240    def search(
241        self,
242        query: Optional[str] = None,
243        query_embedding: Optional[List[float]] = None,
244        top_k: int = 5,
245        filters: Optional[Dict[str, Any]] = None,
246        search_type: str = "vector",
247    ) -> List[SearchResult]:
248        """
249        Search the Cosmos DB container.
250
251        Args:
252            query: Plain-text query string.  Required for ``keyword`` and
253                ``hybrid`` search types.
254            query_embedding: Pre-computed query vector.  Required for
255                ``vector`` and ``hybrid`` search types.
256            top_k: Maximum number of results to return.
257            filters: Optional filter dict.  Supports equality
258                (``{"key": value}``) and range (``{"key": (">=", value)}``)
259                operators.
260            search_type: One of ``"vector"``, ``"keyword"``, or ``"hybrid"``.
261
262        Returns:
263            Ranked list of :class:`SearchResult` objects.
264        """
265        if search_type not in ("vector", "keyword", "hybrid"):
266            raise ValueError(
267                f"Invalid search_type '{search_type}'. "
268                "Must be 'vector', 'keyword', or 'hybrid'."
269            )
270        if search_type in ("vector", "hybrid") and query_embedding is None:
271            raise ValueError(f"search_type='{search_type}' requires query_embedding")
272        if search_type in ("keyword", "hybrid") and query is None:
273            raise ValueError(f"search_type='{search_type}' requires query text")
274
275        if search_type == "vector":
276            return self._vector_search(query_embedding, top_k, filters)
277        if search_type == "keyword":
278            return self._keyword_search(query, top_k, filters)
279        return self._hybrid_search(query, query_embedding, top_k, filters)
280
281    # --- search helpers -----------------------------------------------
282
283    def _vector_search(
284        self,
285        query_embedding: List[float],
286        top_k: int,
287        filters: Optional[Dict[str, Any]],
288    ) -> List[SearchResult]:
289        where_clause = self._build_where_clause(filters)
290        where_sql = f"WHERE {where_clause}" if where_clause else ""
291
292        query_text = (
293            f"SELECT TOP @top_k c.id, c.document_data, "
294            f"VectorDistance(c.embedding, @embedding) AS score "
295            f"FROM c {where_sql} "
296            f"ORDER BY VectorDistance(c.embedding, @embedding)"
297        )
298        parameters = [
299            {"name": "@top_k", "value": top_k},
300            {"name": "@embedding", "value": query_embedding},
301        ]
302
303        items = list(self._container.query_items(
304            query=query_text,
305            parameters=parameters,
306            enable_cross_partition_query=True,
307        ))
308
309        results = []
310        for rank, item in enumerate(items):
311            doc = self._from_cosmos(item)
312            # VectorDistance with cosine returns distance (0 = identical);
313            # convert to similarity score (1 = identical)
314            distance = float(item.get("score", 1.0))
315            similarity = 1.0 - distance
316            results.append(SearchResult(document=doc, score=similarity, rank=rank))
317        return results
318
319    def _keyword_search(
320        self,
321        query: str,
322        top_k: int,
323        filters: Optional[Dict[str, Any]],
324    ) -> List[SearchResult]:
325        # Build keyword conditions using CONTAINS for each query term
326        terms = query.split()
327        keyword_conditions = " OR ".join(
328            f"CONTAINS(LOWER(c.content), '{term.lower()}')" for term in terms
329        )
330        where_clause = self._build_where_clause(filters)
331        filter_sql = f" AND {where_clause}" if where_clause else ""
332
333        # Fetch ALL matching documents — scoring and top_k slicing happen in Python.
334        # Applying SELECT TOP before Python scoring would miss high-scoring documents
335        # if Cosmos DB returns lower-scoring ones first (order without ORDER BY is undefined).
336        query_text = (
337            f"SELECT c.id, c.document_data, c.content "
338            f"FROM c WHERE ({keyword_conditions}){filter_sql}"
339        )
340
341        items = list(self._container.query_items(
342            query=query_text,
343            parameters=[],
344            enable_cross_partition_query=True,
345        ))
346
347        # Score by number of matching terms (simple BM25-like relevance)
348        results = []
349        for item in items:
350            content_lower = item.get("content", "").lower()
351            match_count = sum(1 for t in terms if t.lower() in content_lower)
352            score = match_count / len(terms) if terms else 0.0
353            results.append(SearchResult(
354                document=self._from_cosmos(item),
355                score=score,
356                rank=0,
357            ))
358
359        results.sort(key=lambda r: r.score, reverse=True)
360        for rank, r in enumerate(results):
361            r.rank = rank
362        return results[:top_k]
363
364    def _hybrid_search(
365        self,
366        query: str,
367        query_embedding: List[float],
368        top_k: int,
369        filters: Optional[Dict[str, Any]],
370    ) -> List[SearchResult]:
371        """Combine vector and keyword results with equal (50/50) weighting."""
372        vector_results = self._vector_search(query_embedding, top_k, filters)
373        keyword_results = self._keyword_search(query, top_k, filters)
374
375        def _max(results: List[SearchResult]) -> float:
376            return max((r.score for r in results), default=1.0) or 1.0
377
378        v_max = _max(vector_results)
379        k_max = _max(keyword_results)
380
381        combined: Dict[str, float] = {}
382        doc_map: Dict[str, Document] = {}
383
384        for r in vector_results:
385            combined[r.document.id] = 0.5 * (r.score / v_max)
386            doc_map[r.document.id] = r.document
387
388        for r in keyword_results:
389            combined[r.document.id] = (
390                combined.get(r.document.id, 0.0) + 0.5 * (r.score / k_max)
391            )
392            doc_map.setdefault(r.document.id, r.document)
393
394        sorted_ids = sorted(combined, key=lambda x: combined[x], reverse=True)[:top_k]
395        return [
396            SearchResult(
397                document=doc_map[doc_id],
398                score=combined[doc_id],
399                rank=rank,
400            )
401            for rank, doc_id in enumerate(sorted_ids)
402        ]
403
404    # ------------------------------------------------------------------
405    # Document lifecycle
406    # ------------------------------------------------------------------
407
408    def delete_documents(self, document_ids: List[str]) -> int:
409        """
410        Delete documents by ID.
411
412        Args:
413            document_ids: IDs of the documents to remove.
414
415        Returns:
416            Number of documents actually deleted.
417        """
418        deleted = 0
419        for doc_id in document_ids:
420            try:
421                self._container.delete_item(item=doc_id, partition_key=doc_id)
422                deleted += 1
423            except Exception:
424                logger.debug("Document '%s' not found for deletion", doc_id)
425        return deleted
426
427    def update_document(
428        self,
429        document_id: str,
430        content: Optional[str] = None,
431        embedding: Optional[List[float]] = None,
432        metadata: Optional[Dict[str, Any]] = None,
433    ) -> bool:
434        """
435        Update an existing document.
436
437        Args:
438            document_id: ID of the document to update.
439            content: New text content (optional).
440            embedding: New embedding vector (optional).
441            metadata: Metadata key-value pairs to *merge* into existing metadata
442                (optional).
443
444        Returns:
445            ``True`` if the document was found and updated, ``False`` otherwise.
446        """
447        try:
448            item = self._container.read_item(item=document_id, partition_key=document_id)
449        except Exception:
450            return False
451
452        doc = self._from_cosmos(item)
453
454        if content is not None:
455            doc.content = content
456        if embedding is not None:
457            doc.embedding = embedding
458        if metadata is not None:
459            doc.metadata.update(metadata)
460
461        self._container.upsert_item(self._to_cosmos(doc))
462        return True
463
464    def get_document(self, document_id: str) -> Optional[Document]:
465        """
466        Retrieve a single document by ID.
467
468        Args:
469            document_id: The document's unique identifier.
470
471        Returns:
472            The :class:`Document` (or subclass) if found, otherwise ``None``.
473        """
474        try:
475            item = self._container.read_item(item=document_id, partition_key=document_id)
476            return self._from_cosmos(item)
477        except Exception:
478            return None
479
480    def count(self) -> int:
481        """Return the total number of documents in the container."""
482        query_text = "SELECT VALUE COUNT(1) FROM c"
483        items = list(self._container.query_items(
484            query=query_text,
485            enable_cross_partition_query=True,
486        ))
487        return items[0] if items else 0
488
489    def clear(self) -> None:
490        """
491        Remove **all** documents from the container.
492
493        Warning:
494            This operation is irreversible.
495        """
496        query_text = "SELECT c.id FROM c"
497        items = list(self._container.query_items(
498            query=query_text,
499            enable_cross_partition_query=True,
500        ))
501        for item in items:
502            try:
503                self._container.delete_item(item=item["id"], partition_key=item["id"])
504            except Exception:
505                pass
506        logger.warning(
507            "Cleared all documents from Cosmos DB container '%s'",
508            self._container_name,
509        )
510
511    def close(self) -> None:
512        """Close the underlying Cosmos DB client."""
513        # azure-cosmos CosmosClient doesn't require explicit close,
514        # but we provide the method for interface consistency
515        pass

Azure Cosmos DB NoSQL API vector store.

Stores documents in a Cosmos DB NoSQL container and queries them using Cosmos DB's vector distance functions and SQL-like query language.

Features:

  • Vector search using Cosmos DB vector indexing (cosine distance)
  • Full-text keyword search via SQL CONTAINS / LIKE queries
  • Hybrid search (vector + keyword, equal weighting)
  • Metadata filtering via SQL WHERE clauses
  • Custom document_type support (any Document subclass)
  • Automatic container creation with vector embedding policy

Usage::

from gmf_forge_ai_data.vector_stores import AzureCosmosDBVectorStore

store = AzureCosmosDBVectorStore(
    endpoint="https://your-account.documents.azure.com:443/",
    key="your-key",
    database_name="rag_db",
    container_name="documents",
    embedding_dimension=1536,
)
store.add_documents(docs_with_embeddings)
results = store.search(query_embedding=embedding, top_k=5)
AzureCosmosDBVectorStore( endpoint: str, key: str, database_name: str, container_name: str, embedding_dimension: int = 1536, document_type: type = <class 'Document'>, ssl_cert_path: Optional[str] = None)
 62    def __init__(
 63        self,
 64        endpoint: str,
 65        key: str,
 66        database_name: str,
 67        container_name: str,
 68        embedding_dimension: int = 1536,
 69        document_type: type = Document,
 70        ssl_cert_path: Optional[str] = None,
 71    ) -> None:
 72        """
 73        Initialise the Cosmos DB NoSQL vector store.
 74
 75        The database and container must be pre-provisioned using
 76        ``CosmosDBIndexBuilder`` before first use.  The constructor only
 77        establishes client connections — it never creates or modifies
 78        databases or containers.
 79
 80        Args:
 81            endpoint: Cosmos DB account endpoint
 82                (e.g. ``https://your-account.documents.azure.com:443/``).
 83            key: Cosmos DB account key or resource token.
 84            database_name: Name of the Cosmos DB database.
 85            container_name: Name of the container.
 86            embedding_dimension: Dimension of vector embeddings (default 1536).
 87            document_type: Document class to use when deserialising results.
 88                Pass a custom subclass to support additional typed fields.
 89            ssl_cert_path: Path to a CA certificate bundle (PEM) for TLS
 90                verification.  Useful in corporate environments with custom
 91                certificate authorities.
 92        """
 93        try:
 94            from azure.cosmos import CosmosClient  # noqa: F401
 95        except ImportError as exc:
 96            raise ImportError(
 97                "azure-cosmos is required for AzureCosmosDBVectorStore. "
 98                "Install it with:  pip install azure-cosmos"
 99            ) from exc
100
101        self.embedding_dimension = embedding_dimension
102        self.document_type = document_type
103        self._container_name = container_name
104
105        # Build client with optional custom SSL cert
106        import ssl as _ssl
107        connection_kwargs: Dict[str, Any] = {}
108        if ssl_cert_path:
109            ssl_context = _ssl.create_default_context(cafile=ssl_cert_path)
110            connection_kwargs["connection_verify"] = ssl_cert_path
111
112        from azure.cosmos import CosmosClient
113        self._client = CosmosClient(endpoint, credential=key, **connection_kwargs)
114
115        # Database and container — pre-provisioned by developer
116        self._database = self._client.get_database_client(database_name)
117        self._container = self._database.get_container_client(container_name)

Initialise the Cosmos DB NoSQL vector store.

The database and container must be pre-provisioned using CosmosDBIndexBuilder before first use. The constructor only establishes client connections — it never creates or modifies databases or containers.

Args: endpoint: Cosmos DB account endpoint (e.g. https://your-account.documents.azure.com:443/). key: Cosmos DB account key or resource token. database_name: Name of the Cosmos DB database. container_name: Name of the container. embedding_dimension: Dimension of vector embeddings (default 1536). document_type: Document class to use when deserialising results. Pass a custom subclass to support additional typed fields. ssl_cert_path: Path to a CA certificate bundle (PEM) for TLS verification. Useful in corporate environments with custom certificate authorities.

embedding_dimension
document_type
def add_documents( self, documents: List[Document], generate_embeddings: bool = True) -> List[str]:
195    def add_documents(
196        self,
197        documents: List[Document],
198        generate_embeddings: bool = True,
199    ) -> List[str]:
200        """
201        Upsert documents into the Cosmos DB container.
202
203        Args:
204            documents: Documents to add.  Each must have ``id`` and ``content``
205                set.  If *generate_embeddings* is ``True`` (default) every
206                document must also carry a pre-computed ``embedding``.
207            generate_embeddings: When ``True`` the method validates that all
208                documents already have embeddings rather than generating them
209                itself (generation is the caller's responsibility).
210
211        Returns:
212            List of document IDs that were successfully upserted.
213        """
214        if not documents:
215            return []
216
217        added_ids: List[str] = []
218
219        for doc in documents:
220            if not doc.id or not doc.content:
221                raise ValueError(f"Document must have id and content: {doc}")
222
223            if generate_embeddings and doc.embedding is None:
224                raise ValueError(
225                    f"Document '{doc.id}' has no embedding.  "
226                    "Generate embeddings before calling add_documents()."
227                )
228
229            cosmos_item = self._to_cosmos(doc)
230            self._container.upsert_item(cosmos_item)
231            added_ids.append(doc.id)
232
233        logger.info(
234            "Upserted %d document(s) into Cosmos DB container '%s'",
235            len(added_ids),
236            self._container_name,
237        )
238        return added_ids

Upsert documents into the Cosmos DB container.

Args: documents: Documents to add. Each must have id and content set. If generate_embeddings is True (default) every document must also carry a pre-computed embedding. generate_embeddings: When True the method validates that all documents already have embeddings rather than generating them itself (generation is the caller's responsibility).

Returns: List of document IDs that were successfully upserted.

def search( self, query: Optional[str] = None, query_embedding: Optional[List[float]] = None, top_k: int = 5, filters: Optional[Dict[str, Any]] = None, search_type: str = 'vector') -> List[SearchResult]:
240    def search(
241        self,
242        query: Optional[str] = None,
243        query_embedding: Optional[List[float]] = None,
244        top_k: int = 5,
245        filters: Optional[Dict[str, Any]] = None,
246        search_type: str = "vector",
247    ) -> List[SearchResult]:
248        """
249        Search the Cosmos DB container.
250
251        Args:
252            query: Plain-text query string.  Required for ``keyword`` and
253                ``hybrid`` search types.
254            query_embedding: Pre-computed query vector.  Required for
255                ``vector`` and ``hybrid`` search types.
256            top_k: Maximum number of results to return.
257            filters: Optional filter dict.  Supports equality
258                (``{"key": value}``) and range (``{"key": (">=", value)}``)
259                operators.
260            search_type: One of ``"vector"``, ``"keyword"``, or ``"hybrid"``.
261
262        Returns:
263            Ranked list of :class:`SearchResult` objects.
264        """
265        if search_type not in ("vector", "keyword", "hybrid"):
266            raise ValueError(
267                f"Invalid search_type '{search_type}'. "
268                "Must be 'vector', 'keyword', or 'hybrid'."
269            )
270        if search_type in ("vector", "hybrid") and query_embedding is None:
271            raise ValueError(f"search_type='{search_type}' requires query_embedding")
272        if search_type in ("keyword", "hybrid") and query is None:
273            raise ValueError(f"search_type='{search_type}' requires query text")
274
275        if search_type == "vector":
276            return self._vector_search(query_embedding, top_k, filters)
277        if search_type == "keyword":
278            return self._keyword_search(query, top_k, filters)
279        return self._hybrid_search(query, query_embedding, top_k, filters)

Search the Cosmos DB container.

Args: query: Plain-text query string. Required for keyword and hybrid search types. query_embedding: Pre-computed query vector. Required for vector and hybrid search types. top_k: Maximum number of results to return. filters: Optional filter dict. Supports equality ({"key": value}) and range ({"key": (">=", value)}) operators. search_type: One of "vector", "keyword", or "hybrid".

Returns: Ranked list of SearchResult objects.

def delete_documents(self, document_ids: List[str]) -> int:
408    def delete_documents(self, document_ids: List[str]) -> int:
409        """
410        Delete documents by ID.
411
412        Args:
413            document_ids: IDs of the documents to remove.
414
415        Returns:
416            Number of documents actually deleted.
417        """
418        deleted = 0
419        for doc_id in document_ids:
420            try:
421                self._container.delete_item(item=doc_id, partition_key=doc_id)
422                deleted += 1
423            except Exception:
424                logger.debug("Document '%s' not found for deletion", doc_id)
425        return deleted

Delete documents by ID.

Args: document_ids: IDs of the documents to remove.

Returns: Number of documents actually deleted.

def update_document( self, document_id: str, content: Optional[str] = None, embedding: Optional[List[float]] = None, metadata: Optional[Dict[str, Any]] = None) -> bool:
427    def update_document(
428        self,
429        document_id: str,
430        content: Optional[str] = None,
431        embedding: Optional[List[float]] = None,
432        metadata: Optional[Dict[str, Any]] = None,
433    ) -> bool:
434        """
435        Update an existing document.
436
437        Args:
438            document_id: ID of the document to update.
439            content: New text content (optional).
440            embedding: New embedding vector (optional).
441            metadata: Metadata key-value pairs to *merge* into existing metadata
442                (optional).
443
444        Returns:
445            ``True`` if the document was found and updated, ``False`` otherwise.
446        """
447        try:
448            item = self._container.read_item(item=document_id, partition_key=document_id)
449        except Exception:
450            return False
451
452        doc = self._from_cosmos(item)
453
454        if content is not None:
455            doc.content = content
456        if embedding is not None:
457            doc.embedding = embedding
458        if metadata is not None:
459            doc.metadata.update(metadata)
460
461        self._container.upsert_item(self._to_cosmos(doc))
462        return True

Update an existing document.

Args: document_id: ID of the document to update. content: New text content (optional). embedding: New embedding vector (optional). metadata: Metadata key-value pairs to merge into existing metadata (optional).

Returns: True if the document was found and updated, False otherwise.

def get_document( self, document_id: str) -> Optional[Document]:
464    def get_document(self, document_id: str) -> Optional[Document]:
465        """
466        Retrieve a single document by ID.
467
468        Args:
469            document_id: The document's unique identifier.
470
471        Returns:
472            The :class:`Document` (or subclass) if found, otherwise ``None``.
473        """
474        try:
475            item = self._container.read_item(item=document_id, partition_key=document_id)
476            return self._from_cosmos(item)
477        except Exception:
478            return None

Retrieve a single document by ID.

Args: document_id: The document's unique identifier.

Returns: The Document (or subclass) if found, otherwise None.

def count(self) -> int:
480    def count(self) -> int:
481        """Return the total number of documents in the container."""
482        query_text = "SELECT VALUE COUNT(1) FROM c"
483        items = list(self._container.query_items(
484            query=query_text,
485            enable_cross_partition_query=True,
486        ))
487        return items[0] if items else 0

Return the total number of documents in the container.

def clear(self) -> None:
489    def clear(self) -> None:
490        """
491        Remove **all** documents from the container.
492
493        Warning:
494            This operation is irreversible.
495        """
496        query_text = "SELECT c.id FROM c"
497        items = list(self._container.query_items(
498            query=query_text,
499            enable_cross_partition_query=True,
500        ))
501        for item in items:
502            try:
503                self._container.delete_item(item=item["id"], partition_key=item["id"])
504            except Exception:
505                pass
506        logger.warning(
507            "Cleared all documents from Cosmos DB container '%s'",
508            self._container_name,
509        )

Remove all documents from the container.

Warning: This operation is irreversible.

def close(self) -> None:
511    def close(self) -> None:
512        """Close the underlying Cosmos DB client."""
513        # azure-cosmos CosmosClient doesn't require explicit close,
514        # but we provide the method for interface consistency
515        pass

Close the underlying Cosmos DB client.

class MongoDBVectorStore(gmf_forge_ai_data.vector_stores.BaseVectorStore):
 35class MongoDBVectorStore(BaseVectorStore):
 36    """
 37    MongoDB Atlas Vector Search vector store.
 38
 39    Stores documents in a MongoDB collection and queries them using the
 40    ``$vectorSearch`` aggregation stage.  Hybrid search is implemented by
 41    running vector and text searches separately and combining their normalised
 42    scores with a 50 / 50 weight.
 43
 44    Features:
 45    - Atlas Vector Search (approximate KNN, configurable candidates)
 46    - Full-text keyword search via ``$text`` index
 47    - Hybrid search (vector + keyword, equal weighting)
 48    - Metadata pre-filtering on vector search
 49    - Custom ``document_type`` support (any Document subclass)
 50
 51    Note:
 52        The Atlas Vector Search index (``$vectorSearch``) must be created
 53        outside of this class — through the MongoDB Atlas UI, Atlas CLI, or
 54        Atlas Admin API.  The ``vector_index_name`` parameter must match the
 55        name you gave that index.
 56
 57    Usage::
 58
 59        from gmf_forge_ai_data.vector_stores import MongoDBVectorStore
 60
 61        store = MongoDBVectorStore(
 62            connection_string="mongodb+srv://...",
 63            database_name="rag_db",
 64            collection_name="documents",
 65            vector_index_name="vector_index",
 66            embedding_dimension=1536,
 67        )
 68        store.add_documents(docs_with_embeddings)
 69        results = store.search(query_embedding=embedding, top_k=5)
 70    """
 71
 72    def __init__(
 73        self,
 74        connection_string: str,
 75        database_name: str,
 76        collection_name: str,
 77        vector_index_name: str = "vector_index",
 78        embedding_dimension: int = 1536,
 79        document_type: type = Document,
 80        ssl_cert_path: Optional[str] = None,
 81    ) -> None:
 82        """
 83        Initialise the MongoDB Atlas vector store.
 84
 85        Args:
 86            connection_string: MongoDB Atlas connection string
 87                (e.g. ``mongodb+srv://user:pass@cluster.mongodb.net/``).
 88            database_name: Name of the MongoDB database.
 89            collection_name: Name of the collection.
 90            vector_index_name: Name of the Atlas Vector Search index defined
 91                on this collection.  Defaults to ``"vector_index"``.
 92            embedding_dimension: Dimension of vector embeddings.  Must match
 93                the dimension configured in the Atlas Vector Search index
 94                (default 1536).
 95            document_type: Document class to use when deserialising results.
 96                Pass a custom subclass to support additional typed fields.
 97            ssl_cert_path: Path to a CA certificate bundle (PEM) for TLS
 98                verification.  Useful in corporate environments with custom
 99                certificate authorities.
100        """
101        try:
102            import pymongo  # noqa: F401 — checked at construction time
103        except ImportError as exc:
104            raise ImportError(
105                "pymongo is required for MongoDBVectorStore. "
106                "Install it with:  pip install pymongo"
107            ) from exc
108
109        self.embedding_dimension = embedding_dimension
110        self.document_type = document_type
111        self.vector_index_name = vector_index_name
112        self._collection_name = collection_name
113
114        client_kwargs: Dict[str, Any] = {}
115        if ssl_cert_path:
116            client_kwargs["tlsCAFile"] = ssl_cert_path
117
118        import pymongo
119
120        self._client: pymongo.MongoClient = pymongo.MongoClient(
121            connection_string, **client_kwargs
122        )
123        self._db = self._client[database_name]
124        self._collection = self._db[collection_name]
125
126    # ------------------------------------------------------------------
127    # Serialisation helpers
128    # ------------------------------------------------------------------
129
130    @staticmethod
131    def _serialize_value(obj: Any) -> Any:
132        """Recursively convert datetime objects to ISO strings."""
133        if isinstance(obj, datetime):
134            return obj.isoformat()
135        if isinstance(obj, dict):
136            return {k: MongoDBVectorStore._serialize_value(v) for k, v in obj.items()}
137        if isinstance(obj, (list, tuple)):
138            return [MongoDBVectorStore._serialize_value(i) for i in obj]
139        return obj
140
141    def _to_mongo(self, doc: Document) -> Dict[str, Any]:
142        """Convert a Document to a MongoDB document dict."""
143        doc_dict = doc.to_dict()
144        doc_dict_serialisable = self._serialize_value(doc_dict)
145        mongo_doc = {
146            "_id": doc.id,
147            "id": doc.id,
148            "content": doc.content,
149            "embedding": doc.embedding or [],
150            "timestamp": doc.timestamp.isoformat() if doc.timestamp else None,
151            "metadata": doc.metadata,
152            "document_data": json.dumps(doc_dict_serialisable),
153        }
154        # Promote custom fields (e.g. category, published_year, field, institution)
155        # to top level so they are filterable via $vectorSearch filter and $match.
156        base_keys = {"id", "content", "embedding", "timestamp", "metadata"}
157        for key, value in doc_dict_serialisable.items():
158            if key not in base_keys:
159                mongo_doc[key] = value
160        return mongo_doc
161
162    def _from_mongo(self, mongo_doc: Dict[str, Any]) -> Document:
163        """Reconstruct a Document (or subclass) from a MongoDB document."""
164        doc_data = json.loads(mongo_doc.get("document_data", "{}"))
165        return self.document_type.from_dict(doc_data)
166
167    # ------------------------------------------------------------------
168    # Filter helpers
169    # ------------------------------------------------------------------
170
171    @staticmethod
172    def _build_filter(filters: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
173        """Translate a user-facing filter dict to a MongoDB query expression."""
174        if not filters:
175            return None
176        op_map = {">=": "$gte", "<=": "$lte", ">": "$gt", "<": "$lt", "!=": "$ne"}
177        match: Dict[str, Any] = {}
178        for key, value in filters.items():
179            if isinstance(value, tuple) and len(value) == 2:
180                op, val = value
181                match[key] = {op_map[op]: val} if op in op_map else val
182            else:
183                match[key] = value
184        return match
185
186    # ------------------------------------------------------------------
187    # BaseVectorStore interface
188    # ------------------------------------------------------------------
189
190    def add_documents(
191        self,
192        documents: List[Document],
193        generate_embeddings: bool = True,
194    ) -> List[str]:
195        """
196        Upsert documents into the MongoDB collection.
197
198        Args:
199            documents: Documents to add.  Each must have ``id`` and ``content``
200                set.  If *generate_embeddings* is ``True`` (default) every
201                document must also carry a pre-computed ``embedding``.
202            generate_embeddings: When ``True`` the method validates that all
203                documents already have embeddings rather than generating them
204                itself (generation is the caller's responsibility).
205
206        Returns:
207            List of document IDs that were successfully upserted.
208        """
209        if not documents:
210            return []
211
212        added_ids: List[str] = []
213
214        for doc in documents:
215            if not doc.id or not doc.content:
216                raise ValueError(f"Document must have id and content: {doc}")
217
218            if generate_embeddings and doc.embedding is None:
219                raise ValueError(
220                    f"Document '{doc.id}' has no embedding.  "
221                    "Generate embeddings before calling add_documents()."
222                )
223
224            mongo_doc = self._to_mongo(doc)
225            self._collection.replace_one(
226                {"_id": mongo_doc["_id"]}, mongo_doc, upsert=True
227            )
228            added_ids.append(doc.id)
229
230        logger.info(
231            "Upserted %d document(s) into MongoDB collection '%s'",
232            len(added_ids),
233            self._collection_name,
234        )
235        return added_ids
236
237    def search(
238        self,
239        query: Optional[str] = None,
240        query_embedding: Optional[List[float]] = None,
241        top_k: int = 5,
242        filters: Optional[Dict[str, Any]] = None,
243        search_type: str = "vector",
244    ) -> List[SearchResult]:
245        """
246        Search the MongoDB collection.
247
248        Args:
249            query: Plain-text query string.  Required for ``keyword`` and
250                ``hybrid`` search types.
251            query_embedding: Pre-computed query vector.  Required for
252                ``vector`` and ``hybrid`` search types.
253            top_k: Maximum number of results to return.
254            filters: Optional metadata filter dict.  Supports equality
255                (``{"key": value}``) and range (``{"key": (">=", value)}``)
256                operators.
257            search_type: One of ``"vector"``, ``"keyword"``, or ``"hybrid"``.
258
259        Returns:
260            Ranked list of :class:`SearchResult` objects.
261        """
262        if search_type not in ("vector", "keyword", "hybrid"):
263            raise ValueError(
264                f"Invalid search_type '{search_type}'. "
265                "Must be 'vector', 'keyword', or 'hybrid'."
266            )
267        if search_type in ("vector", "hybrid") and query_embedding is None:
268            raise ValueError(f"search_type='{search_type}' requires query_embedding")
269        if search_type in ("keyword", "hybrid") and query is None:
270            raise ValueError(f"search_type='{search_type}' requires query text")
271
272        if search_type == "vector":
273            return self._vector_search(query_embedding, top_k, filters)
274        if search_type == "keyword":
275            return self._keyword_search(query, top_k, filters)
276        return self._hybrid_search(query, query_embedding, top_k, filters)
277
278    # --- search helpers -----------------------------------------------
279
280    def _vector_search(
281        self,
282        query_embedding: List[float],
283        top_k: int,
284        filters: Optional[Dict[str, Any]],
285    ) -> List[SearchResult]:
286        """Execute an Atlas ``$vectorSearch`` aggregation pipeline."""
287        vector_stage: Dict[str, Any] = {
288            "index": self.vector_index_name,
289            "path": "embedding",
290            "queryVector": query_embedding,
291            # numCandidates controls recall vs. latency trade-off.
292            # A value of top_k * 10 (min 100) is a sensible default.
293            "numCandidates": max(top_k * 10, 100),
294            "limit": top_k,
295        }
296
297        pre_filter = self._build_filter(filters)
298        if pre_filter:
299            vector_stage["filter"] = pre_filter
300
301        pipeline = [
302            {"$vectorSearch": vector_stage},
303            {
304                "$project": {
305                    "_id": 1,
306                    "document_data": 1,
307                    "score": {"$meta": "vectorSearchScore"},
308                }
309            },
310        ]
311
312        try:
313            return [
314                SearchResult(
315                    document=self._from_mongo(raw),
316                    score=float(raw.get("score", 0.0)),
317                    rank=rank,
318                )
319                for rank, raw in enumerate(self._collection.aggregate(pipeline))
320            ]
321        except Exception as exc:
322            msg = str(exc)
323            if "localhost:28000" in msg or "HostUnreachable" in msg or "PlanExecutor" in msg:
324                raise RuntimeError(
325                    f"Atlas Vector Search index '{self.vector_index_name}' is not available. "
326                    "Create it in the Atlas UI (collection → Search Indexes → Create Search Index → "
327                    "Atlas Vector Search) with field='embedding', "
328                    f"numDimensions={self.embedding_dimension}, similarity=cosine, "
329                    f"indexName='{self.vector_index_name}'."
330                ) from exc
331            raise
332
333    def _keyword_search(
334        self,
335        query: str,
336        top_k: int,
337        filters: Optional[Dict[str, Any]],
338    ) -> List[SearchResult]:
339        """Execute a full-text search via the MongoDB ``$text`` operator.
340
341        Requires a text index on the ``content`` field to be created by the
342        developer before use::
343
344            db.collection.createIndex({ "content": "text" })
345
346        or via the Atlas UI / CLI / Admin API.
347        """
348        mongo_filter: Dict[str, Any] = {"$text": {"$search": query}}
349        pre_filter = self._build_filter(filters)
350        if pre_filter:
351            mongo_filter.update(pre_filter)
352
353        cursor = (
354            self._collection.find(
355                mongo_filter,
356                {"document_data": 1, "textScore": {"$meta": "textScore"}},
357            )
358            .sort([("textScore", {"$meta": "textScore"})])
359            .limit(top_k)
360        )
361
362        try:
363            return [
364                SearchResult(
365                    document=self._from_mongo(raw),
366                    score=float(raw.get("textScore", 0.0)),
367                    rank=rank,
368                )
369                for rank, raw in enumerate(cursor)
370            ]
371        except Exception as exc:
372            msg = str(exc)
373            if "text index required" in msg.lower() or "$text" in msg:
374                raise RuntimeError(
375                    "No text index found on the collection. "
376                    "Create one before using keyword or hybrid search: "
377                    f"db.{self._collection.name}.createIndex({{ 'content': 'text' }})"
378                ) from exc
379            raise
380
381    def _hybrid_search(
382        self,
383        query: str,
384        query_embedding: List[float],
385        top_k: int,
386        filters: Optional[Dict[str, Any]],
387    ) -> List[SearchResult]:
388        """Combine vector and keyword results with equal (50/50) weighting."""
389        vector_results = self._vector_search(query_embedding, top_k, filters)
390        keyword_results = self._keyword_search(query, top_k, filters)
391
392        def _max(results: List[SearchResult]) -> float:
393            return max((r.score for r in results), default=1.0) or 1.0
394
395        v_max = _max(vector_results)
396        k_max = _max(keyword_results)
397
398        combined: Dict[str, float] = {}
399        doc_map: Dict[str, Document] = {}
400
401        for r in vector_results:
402            combined[r.document.id] = 0.5 * (r.score / v_max)
403            doc_map[r.document.id] = r.document
404
405        for r in keyword_results:
406            combined[r.document.id] = (
407                combined.get(r.document.id, 0.0) + 0.5 * (r.score / k_max)
408            )
409            doc_map.setdefault(r.document.id, r.document)
410
411        sorted_ids = sorted(combined, key=lambda x: combined[x], reverse=True)[:top_k]
412        return [
413            SearchResult(
414                document=doc_map[doc_id],
415                score=combined[doc_id],
416                rank=rank,
417            )
418            for rank, doc_id in enumerate(sorted_ids)
419        ]
420
421    # ------------------------------------------------------------------
422    # Document lifecycle
423    # ------------------------------------------------------------------
424
425    def delete_documents(self, document_ids: List[str]) -> int:
426        """
427        Delete documents by ID.
428
429        Args:
430            document_ids: IDs of the documents to remove.
431
432        Returns:
433            Number of documents actually deleted.
434        """
435        result = self._collection.delete_many({"_id": {"$in": document_ids}})
436        return result.deleted_count
437
438    def update_document(
439        self,
440        document_id: str,
441        content: Optional[str] = None,
442        embedding: Optional[List[float]] = None,
443        metadata: Optional[Dict[str, Any]] = None,
444    ) -> bool:
445        """
446        Update an existing document.
447
448        Args:
449            document_id: ID of the document to update.
450            content: New text content (optional).
451            embedding: New embedding vector (optional).
452            metadata: Metadata key-value pairs to *merge* into existing metadata
453                (optional).
454
455        Returns:
456            ``True`` if the document was found and updated, ``False`` otherwise.
457        """
458        raw = self._collection.find_one({"_id": document_id})
459        if raw is None:
460            return False
461
462        doc = self._from_mongo(raw)
463
464        if content is not None:
465            doc.content = content
466        if embedding is not None:
467            doc.embedding = embedding
468        if metadata is not None:
469            doc.metadata.update(metadata)
470
471        self._collection.replace_one({"_id": document_id}, self._to_mongo(doc))
472        return True
473
474    def get_document(self, document_id: str) -> Optional[Document]:
475        """
476        Retrieve a single document by ID.
477
478        Args:
479            document_id: The document's unique identifier.
480
481        Returns:
482            The :class:`Document` (or subclass) if found, otherwise ``None``.
483        """
484        raw = self._collection.find_one({"_id": document_id})
485        return None if raw is None else self._from_mongo(raw)
486
487    def count(self) -> int:
488        """Return the total number of documents in the collection."""
489        return self._collection.count_documents({})
490
491    def clear(self) -> None:
492        """
493        Remove **all** documents from the collection.
494
495        Warning:
496            This operation is irreversible.
497        """
498        self._collection.delete_many({})
499        logger.warning(
500            "Cleared all documents from MongoDB collection '%s'",
501            self._collection_name,
502        )
503
504    def close(self) -> None:
505        """Close the underlying MongoDB client connection."""
506        self._client.close()

MongoDB Atlas Vector Search vector store.

Stores documents in a MongoDB collection and queries them using the $vectorSearch aggregation stage. Hybrid search is implemented by running vector and text searches separately and combining their normalised scores with a 50 / 50 weight.

Features:

  • Atlas Vector Search (approximate KNN, configurable candidates)
  • Full-text keyword search via $text index
  • Hybrid search (vector + keyword, equal weighting)
  • Metadata pre-filtering on vector search
  • Custom document_type support (any Document subclass)

Note: The Atlas Vector Search index ($vectorSearch) must be created outside of this class — through the MongoDB Atlas UI, Atlas CLI, or Atlas Admin API. The vector_index_name parameter must match the name you gave that index.

Usage::

from gmf_forge_ai_data.vector_stores import MongoDBVectorStore

store = MongoDBVectorStore(
    connection_string="mongodb+srv://...",
    database_name="rag_db",
    collection_name="documents",
    vector_index_name="vector_index",
    embedding_dimension=1536,
)
store.add_documents(docs_with_embeddings)
results = store.search(query_embedding=embedding, top_k=5)
MongoDBVectorStore( connection_string: str, database_name: str, collection_name: str, vector_index_name: str = 'vector_index', embedding_dimension: int = 1536, document_type: type = <class 'Document'>, ssl_cert_path: Optional[str] = None)
 72    def __init__(
 73        self,
 74        connection_string: str,
 75        database_name: str,
 76        collection_name: str,
 77        vector_index_name: str = "vector_index",
 78        embedding_dimension: int = 1536,
 79        document_type: type = Document,
 80        ssl_cert_path: Optional[str] = None,
 81    ) -> None:
 82        """
 83        Initialise the MongoDB Atlas vector store.
 84
 85        Args:
 86            connection_string: MongoDB Atlas connection string
 87                (e.g. ``mongodb+srv://user:pass@cluster.mongodb.net/``).
 88            database_name: Name of the MongoDB database.
 89            collection_name: Name of the collection.
 90            vector_index_name: Name of the Atlas Vector Search index defined
 91                on this collection.  Defaults to ``"vector_index"``.
 92            embedding_dimension: Dimension of vector embeddings.  Must match
 93                the dimension configured in the Atlas Vector Search index
 94                (default 1536).
 95            document_type: Document class to use when deserialising results.
 96                Pass a custom subclass to support additional typed fields.
 97            ssl_cert_path: Path to a CA certificate bundle (PEM) for TLS
 98                verification.  Useful in corporate environments with custom
 99                certificate authorities.
100        """
101        try:
102            import pymongo  # noqa: F401 — checked at construction time
103        except ImportError as exc:
104            raise ImportError(
105                "pymongo is required for MongoDBVectorStore. "
106                "Install it with:  pip install pymongo"
107            ) from exc
108
109        self.embedding_dimension = embedding_dimension
110        self.document_type = document_type
111        self.vector_index_name = vector_index_name
112        self._collection_name = collection_name
113
114        client_kwargs: Dict[str, Any] = {}
115        if ssl_cert_path:
116            client_kwargs["tlsCAFile"] = ssl_cert_path
117
118        import pymongo
119
120        self._client: pymongo.MongoClient = pymongo.MongoClient(
121            connection_string, **client_kwargs
122        )
123        self._db = self._client[database_name]
124        self._collection = self._db[collection_name]

Initialise the MongoDB Atlas vector store.

Args: connection_string: MongoDB Atlas connection string (e.g. mongodb+srv://user:pass@cluster.mongodb.net/). database_name: Name of the MongoDB database. collection_name: Name of the collection. vector_index_name: Name of the Atlas Vector Search index defined on this collection. Defaults to "vector_index". embedding_dimension: Dimension of vector embeddings. Must match the dimension configured in the Atlas Vector Search index (default 1536). document_type: Document class to use when deserialising results. Pass a custom subclass to support additional typed fields. ssl_cert_path: Path to a CA certificate bundle (PEM) for TLS verification. Useful in corporate environments with custom certificate authorities.

embedding_dimension
document_type
vector_index_name
def add_documents( self, documents: List[Document], generate_embeddings: bool = True) -> List[str]:
190    def add_documents(
191        self,
192        documents: List[Document],
193        generate_embeddings: bool = True,
194    ) -> List[str]:
195        """
196        Upsert documents into the MongoDB collection.
197
198        Args:
199            documents: Documents to add.  Each must have ``id`` and ``content``
200                set.  If *generate_embeddings* is ``True`` (default) every
201                document must also carry a pre-computed ``embedding``.
202            generate_embeddings: When ``True`` the method validates that all
203                documents already have embeddings rather than generating them
204                itself (generation is the caller's responsibility).
205
206        Returns:
207            List of document IDs that were successfully upserted.
208        """
209        if not documents:
210            return []
211
212        added_ids: List[str] = []
213
214        for doc in documents:
215            if not doc.id or not doc.content:
216                raise ValueError(f"Document must have id and content: {doc}")
217
218            if generate_embeddings and doc.embedding is None:
219                raise ValueError(
220                    f"Document '{doc.id}' has no embedding.  "
221                    "Generate embeddings before calling add_documents()."
222                )
223
224            mongo_doc = self._to_mongo(doc)
225            self._collection.replace_one(
226                {"_id": mongo_doc["_id"]}, mongo_doc, upsert=True
227            )
228            added_ids.append(doc.id)
229
230        logger.info(
231            "Upserted %d document(s) into MongoDB collection '%s'",
232            len(added_ids),
233            self._collection_name,
234        )
235        return added_ids

Upsert documents into the MongoDB collection.

Args: documents: Documents to add. Each must have id and content set. If generate_embeddings is True (default) every document must also carry a pre-computed embedding. generate_embeddings: When True the method validates that all documents already have embeddings rather than generating them itself (generation is the caller's responsibility).

Returns: List of document IDs that were successfully upserted.

def search( self, query: Optional[str] = None, query_embedding: Optional[List[float]] = None, top_k: int = 5, filters: Optional[Dict[str, Any]] = None, search_type: str = 'vector') -> List[SearchResult]:
237    def search(
238        self,
239        query: Optional[str] = None,
240        query_embedding: Optional[List[float]] = None,
241        top_k: int = 5,
242        filters: Optional[Dict[str, Any]] = None,
243        search_type: str = "vector",
244    ) -> List[SearchResult]:
245        """
246        Search the MongoDB collection.
247
248        Args:
249            query: Plain-text query string.  Required for ``keyword`` and
250                ``hybrid`` search types.
251            query_embedding: Pre-computed query vector.  Required for
252                ``vector`` and ``hybrid`` search types.
253            top_k: Maximum number of results to return.
254            filters: Optional metadata filter dict.  Supports equality
255                (``{"key": value}``) and range (``{"key": (">=", value)}``)
256                operators.
257            search_type: One of ``"vector"``, ``"keyword"``, or ``"hybrid"``.
258
259        Returns:
260            Ranked list of :class:`SearchResult` objects.
261        """
262        if search_type not in ("vector", "keyword", "hybrid"):
263            raise ValueError(
264                f"Invalid search_type '{search_type}'. "
265                "Must be 'vector', 'keyword', or 'hybrid'."
266            )
267        if search_type in ("vector", "hybrid") and query_embedding is None:
268            raise ValueError(f"search_type='{search_type}' requires query_embedding")
269        if search_type in ("keyword", "hybrid") and query is None:
270            raise ValueError(f"search_type='{search_type}' requires query text")
271
272        if search_type == "vector":
273            return self._vector_search(query_embedding, top_k, filters)
274        if search_type == "keyword":
275            return self._keyword_search(query, top_k, filters)
276        return self._hybrid_search(query, query_embedding, top_k, filters)

Search the MongoDB collection.

Args: query: Plain-text query string. Required for keyword and hybrid search types. query_embedding: Pre-computed query vector. Required for vector and hybrid search types. top_k: Maximum number of results to return. filters: Optional metadata filter dict. Supports equality ({"key": value}) and range ({"key": (">=", value)}) operators. search_type: One of "vector", "keyword", or "hybrid".

Returns: Ranked list of SearchResult objects.

def delete_documents(self, document_ids: List[str]) -> int:
425    def delete_documents(self, document_ids: List[str]) -> int:
426        """
427        Delete documents by ID.
428
429        Args:
430            document_ids: IDs of the documents to remove.
431
432        Returns:
433            Number of documents actually deleted.
434        """
435        result = self._collection.delete_many({"_id": {"$in": document_ids}})
436        return result.deleted_count

Delete documents by ID.

Args: document_ids: IDs of the documents to remove.

Returns: Number of documents actually deleted.

def update_document( self, document_id: str, content: Optional[str] = None, embedding: Optional[List[float]] = None, metadata: Optional[Dict[str, Any]] = None) -> bool:
438    def update_document(
439        self,
440        document_id: str,
441        content: Optional[str] = None,
442        embedding: Optional[List[float]] = None,
443        metadata: Optional[Dict[str, Any]] = None,
444    ) -> bool:
445        """
446        Update an existing document.
447
448        Args:
449            document_id: ID of the document to update.
450            content: New text content (optional).
451            embedding: New embedding vector (optional).
452            metadata: Metadata key-value pairs to *merge* into existing metadata
453                (optional).
454
455        Returns:
456            ``True`` if the document was found and updated, ``False`` otherwise.
457        """
458        raw = self._collection.find_one({"_id": document_id})
459        if raw is None:
460            return False
461
462        doc = self._from_mongo(raw)
463
464        if content is not None:
465            doc.content = content
466        if embedding is not None:
467            doc.embedding = embedding
468        if metadata is not None:
469            doc.metadata.update(metadata)
470
471        self._collection.replace_one({"_id": document_id}, self._to_mongo(doc))
472        return True

Update an existing document.

Args: document_id: ID of the document to update. content: New text content (optional). embedding: New embedding vector (optional). metadata: Metadata key-value pairs to merge into existing metadata (optional).

Returns: True if the document was found and updated, False otherwise.

def get_document( self, document_id: str) -> Optional[Document]:
474    def get_document(self, document_id: str) -> Optional[Document]:
475        """
476        Retrieve a single document by ID.
477
478        Args:
479            document_id: The document's unique identifier.
480
481        Returns:
482            The :class:`Document` (or subclass) if found, otherwise ``None``.
483        """
484        raw = self._collection.find_one({"_id": document_id})
485        return None if raw is None else self._from_mongo(raw)

Retrieve a single document by ID.

Args: document_id: The document's unique identifier.

Returns: The Document (or subclass) if found, otherwise None.

def count(self) -> int:
487    def count(self) -> int:
488        """Return the total number of documents in the collection."""
489        return self._collection.count_documents({})

Return the total number of documents in the collection.

def clear(self) -> None:
491    def clear(self) -> None:
492        """
493        Remove **all** documents from the collection.
494
495        Warning:
496            This operation is irreversible.
497        """
498        self._collection.delete_many({})
499        logger.warning(
500            "Cleared all documents from MongoDB collection '%s'",
501            self._collection_name,
502        )

Remove all documents from the collection.

Warning: This operation is irreversible.

def close(self) -> None:
504    def close(self) -> None:
505        """Close the underlying MongoDB client connection."""
506        self._client.close()

Close the underlying MongoDB client connection.