gmf_forge_ai_data

GMF Forge AI Data Layer Package.

This package provides data layer components for RAG applications including:

  • Embeddings (Azure OpenAI)
  • Document processing and chunking
  • Retrieval strategies
  • Query optimization
  • Context management
  • Indexing utilities
  • Data connectors

For detailed documentation, see the README.md file.

  1"""
  2GMF Forge AI Data Layer Package.
  3
  4This package provides data layer components for RAG applications including:
  5- Embeddings (Azure OpenAI)
  6- Document processing and chunking
  7- Retrieval strategies
  8- Query optimization
  9- Context management
 10- Indexing utilities
 11- Data connectors
 12
 13For detailed documentation, see the README.md file.
 14"""
 15
 16from gmf_forge_ai_data.version import __version__
 17
 18# Embeddings module is available
 19from .embeddings import (
 20    EmbeddingProvider,
 21    AzureOpenAIEmbeddings,
 22    BatchEmbeddings,
 23)
 24
 25# Chunking module is available
 26from .chunkers import (
 27    Chunk,
 28    BaseChunker,
 29    FixedSizeChunker,
 30    SemanticChunker,
 31    RecursiveChunker,
 32    SentenceChunker,
 33    MarkdownChunker,
 34    CodeChunker,
 35)
 36
 37# Vector stores module is available
 38from .vector_stores import (
 39    BaseVectorStore,
 40    Document,
 41    SearchResult,
 42    InMemoryVectorStore,
 43    AzureAISearchVectorStore,
 44    AzureCosmosDBVectorStore,
 45    MongoDBVectorStore,
 46)
 47
 48# Retrieval strategies module is available
 49from .retrieval import (
 50    BaseRetriever,
 51    RetrievalQuery,
 52    VectorRetriever,
 53    KeywordRetriever,
 54    HybridRetriever,
 55    MMRRetriever,
 56    ParentDocumentRetriever,
 57    EnsembleRetriever,
 58    HierarchicalRetriever,
 59    GraphRetriever,
 60    SQLRetriever,
 61    SQLSchema,
 62    SQLQuery,
 63    MultiIndexRetriever,
 64    SourceConfig,
 65)
 66
 67# Data connectors module is available
 68from .connectors import (
 69    BaseConnector,
 70    FilesystemConnector,
 71    SharePointConnector,
 72    BlobStorageConnector,
 73)
 74
 75# Indexing builders module is available
 76from .indexing import (
 77    BaseIndexBuilder,
 78    AzureAISearchIndexBuilder,
 79    CosmosDBIndexBuilder,
 80    MongoDBIndexBuilder,
 81)
 82
 83# Context processing module is available
 84from .context import (
 85    RelevanceFilter,
 86    ContextDeduplicator,
 87    ContextReranker,
 88    ContextCompressor,
 89    ContextWindowManager,
 90    WindowedContext,
 91)
 92
 93# Query optimization module is available
 94from .query import (
 95    QueryDecomposer,
 96    DecomposedQuery,
 97    QueryRouter,
 98    RouteDecision,
 99    QueryExpander,
100    ExpandedQuery,
101    QueryRewriter,
102    RewrittenQuery,
103    HyDEGenerator,
104    HypotheticalDocument,
105)
106
107# Document Intelligence layout module is available
108from .layout import (
109    DocumentIntelligenceLayout,
110    LayoutResult,
111)
112
113__all__ = [
114    "__version__",
115    # Embeddings
116    "EmbeddingProvider",
117    "AzureOpenAIEmbeddings",
118    "BatchEmbeddings",
119    # Chunking
120    "Chunk",
121    "BaseChunker",
122    "FixedSizeChunker",
123    "SemanticChunker",
124    "RecursiveChunker",
125    "SentenceChunker",
126    "MarkdownChunker",
127    "CodeChunker",
128    # Vector Stores
129    "BaseVectorStore",
130    "Document",
131    "SearchResult",
132    "InMemoryVectorStore",
133    "AzureAISearchVectorStore",
134    "AzureCosmosDBVectorStore",
135    "MongoDBVectorStore",
136    # Retrieval Strategies
137    "BaseRetriever",
138    "RetrievalQuery",
139    "VectorRetriever",
140    "KeywordRetriever",
141    "HybridRetriever",
142    "MMRRetriever",
143    "ParentDocumentRetriever",
144    "EnsembleRetriever",
145    "HierarchicalRetriever",
146    "GraphRetriever",
147    "SQLRetriever",
148    "SQLSchema",
149    "SQLQuery",
150    "MultiIndexRetriever",
151    "SourceConfig",
152    # Data Connectors
153    "BaseConnector",
154    "FilesystemConnector",
155    "SharePointConnector",
156    "BlobStorageConnector",
157    # Indexing Builders
158    "BaseIndexBuilder",
159    "AzureAISearchIndexBuilder",
160    "CosmosDBIndexBuilder",
161    "MongoDBIndexBuilder",
162    # Context Processing
163    "RelevanceFilter",
164    "ContextDeduplicator",
165    "ContextReranker",
166    "ContextCompressor",
167    "ContextWindowManager",
168    "WindowedContext",
169    # Query Optimization
170    "QueryDecomposer",
171    "DecomposedQuery",
172    "QueryRouter",
173    "RouteDecision",
174    "QueryExpander",
175    "ExpandedQuery",
176    "QueryRewriter",
177    "RewrittenQuery",
178    "HyDEGenerator",
179    "HypotheticalDocument",
180    # Document Intelligence Layout
181    "DocumentIntelligenceLayout",
182    "LayoutResult",
183]
__version__ = '1.0.0'
class EmbeddingProvider(abc.ABC):
 13class EmbeddingProvider(ABC):
 14    """
 15    Abstract base class for embedding providers.
 16    
 17    All embedding implementations (Azure OpenAI, OpenAI, Cohere, etc.) should
 18    inherit from this class and implement its abstract methods.
 19    """
 20
 21    @abstractmethod
 22    def embed_text(self, text: str) -> List[float]:
 23        """
 24        Generate embeddings for a single text string.
 25        
 26        Args:
 27            text: The input text to embed
 28            
 29        Returns:
 30            A list of floats representing the embedding vector
 31            
 32        Raises:
 33            ValueError: If text is empty or None
 34            Exception: Provider-specific errors (rate limits, auth, etc.)
 35        """
 36        pass
 37
 38    @abstractmethod
 39    def embed_batch(self, texts: List[str]) -> List[List[float]]:
 40        """
 41        Generate embeddings for multiple texts in a batch.
 42        
 43        This method should be more efficient than calling embed_text() repeatedly
 44        as it can leverage the provider's batch API capabilities.
 45        
 46        Args:
 47            texts: List of text strings to embed
 48            
 49        Returns:
 50            List of embedding vectors, one for each input text
 51            
 52        Raises:
 53            ValueError: If texts is empty or contains invalid entries
 54            Exception: Provider-specific errors (rate limits, auth, etc.)
 55        """
 56        pass
 57
 58    @abstractmethod
 59    def get_embedding_dimension(self) -> int:
 60        """
 61        Get the dimensionality of the embedding vectors.
 62        
 63        Returns:
 64            Integer dimension of the embedding space
 65            (e.g., 1536 for text-embedding-3-large)
 66        """
 67        pass
 68
 69    @abstractmethod
 70    def get_model_name(self) -> str:
 71        """
 72        Get the name/identifier of the embedding model being used.
 73        
 74        Returns:
 75            String identifier for the model (e.g., "text-embedding-3-large")
 76        """
 77        pass
 78
 79    def validate_text(self, text: str) -> None:
 80        """
 81        Validate input text before embedding.
 82        
 83        Args:
 84            text: The text to validate
 85            
 86        Raises:
 87            ValueError: If text is None, empty, or not a string
 88        """
 89        if text is None:
 90            raise ValueError("Text cannot be None")
 91        if not isinstance(text, str):
 92            raise ValueError(f"Text must be a string, got {type(text)}")
 93        if not text.strip():
 94            raise ValueError("Text cannot be empty or whitespace only")
 95
 96    def validate_batch(self, texts: List[str]) -> None:
 97        """
 98        Validate a batch of texts before embedding.
 99        
100        Args:
101            texts: List of texts to validate
102            
103        Raises:
104            ValueError: If texts is invalid or contains invalid entries
105        """
106        if texts is None:
107            raise ValueError("Texts list cannot be None")
108        if not isinstance(texts, list):
109            raise ValueError(f"Texts must be a list, got {type(texts)}")
110        if len(texts) == 0:
111            raise ValueError("Texts list cannot be empty")
112        
113        for i, text in enumerate(texts):
114            try:
115                self.validate_text(text)
116            except ValueError as e:
117                raise ValueError(f"Invalid text at index {i}: {str(e)}")

Abstract base class for embedding providers.

All embedding implementations (Azure OpenAI, OpenAI, Cohere, etc.) should inherit from this class and implement its abstract methods.

@abstractmethod
def embed_text(self, text: str) -> List[float]:
21    @abstractmethod
22    def embed_text(self, text: str) -> List[float]:
23        """
24        Generate embeddings for a single text string.
25        
26        Args:
27            text: The input text to embed
28            
29        Returns:
30            A list of floats representing the embedding vector
31            
32        Raises:
33            ValueError: If text is empty or None
34            Exception: Provider-specific errors (rate limits, auth, etc.)
35        """
36        pass

Generate embeddings for a single text string.

Args: text: The input text to embed

Returns: A list of floats representing the embedding vector

Raises: ValueError: If text is empty or None Exception: Provider-specific errors (rate limits, auth, etc.)

@abstractmethod
def embed_batch(self, texts: List[str]) -> List[List[float]]:
38    @abstractmethod
39    def embed_batch(self, texts: List[str]) -> List[List[float]]:
40        """
41        Generate embeddings for multiple texts in a batch.
42        
43        This method should be more efficient than calling embed_text() repeatedly
44        as it can leverage the provider's batch API capabilities.
45        
46        Args:
47            texts: List of text strings to embed
48            
49        Returns:
50            List of embedding vectors, one for each input text
51            
52        Raises:
53            ValueError: If texts is empty or contains invalid entries
54            Exception: Provider-specific errors (rate limits, auth, etc.)
55        """
56        pass

Generate embeddings for multiple texts in a batch.

This method should be more efficient than calling embed_text() repeatedly as it can leverage the provider's batch API capabilities.

Args: texts: List of text strings to embed

Returns: List of embedding vectors, one for each input text

Raises: ValueError: If texts is empty or contains invalid entries Exception: Provider-specific errors (rate limits, auth, etc.)

@abstractmethod
def get_embedding_dimension(self) -> int:
58    @abstractmethod
59    def get_embedding_dimension(self) -> int:
60        """
61        Get the dimensionality of the embedding vectors.
62        
63        Returns:
64            Integer dimension of the embedding space
65            (e.g., 1536 for text-embedding-3-large)
66        """
67        pass

Get the dimensionality of the embedding vectors.

Returns: Integer dimension of the embedding space (e.g., 1536 for text-embedding-3-large)

@abstractmethod
def get_model_name(self) -> str:
69    @abstractmethod
70    def get_model_name(self) -> str:
71        """
72        Get the name/identifier of the embedding model being used.
73        
74        Returns:
75            String identifier for the model (e.g., "text-embedding-3-large")
76        """
77        pass

Get the name/identifier of the embedding model being used.

Returns: String identifier for the model (e.g., "text-embedding-3-large")

def validate_text(self, text: str) -> None:
79    def validate_text(self, text: str) -> None:
80        """
81        Validate input text before embedding.
82        
83        Args:
84            text: The text to validate
85            
86        Raises:
87            ValueError: If text is None, empty, or not a string
88        """
89        if text is None:
90            raise ValueError("Text cannot be None")
91        if not isinstance(text, str):
92            raise ValueError(f"Text must be a string, got {type(text)}")
93        if not text.strip():
94            raise ValueError("Text cannot be empty or whitespace only")

Validate input text before embedding.

Args: text: The text to validate

Raises: ValueError: If text is None, empty, or not a string

def validate_batch(self, texts: List[str]) -> None:
 96    def validate_batch(self, texts: List[str]) -> None:
 97        """
 98        Validate a batch of texts before embedding.
 99        
100        Args:
101            texts: List of texts to validate
102            
103        Raises:
104            ValueError: If texts is invalid or contains invalid entries
105        """
106        if texts is None:
107            raise ValueError("Texts list cannot be None")
108        if not isinstance(texts, list):
109            raise ValueError(f"Texts must be a list, got {type(texts)}")
110        if len(texts) == 0:
111            raise ValueError("Texts list cannot be empty")
112        
113        for i, text in enumerate(texts):
114            try:
115                self.validate_text(text)
116            except ValueError as e:
117                raise ValueError(f"Invalid text at index {i}: {str(e)}")

Validate a batch of texts before embedding.

Args: texts: List of texts to validate

Raises: ValueError: If texts is invalid or contains invalid entries

class AzureOpenAIEmbeddings(gmf_forge_ai_data.EmbeddingProvider):
 32class AzureOpenAIEmbeddings(EmbeddingProvider):
 33    """
 34    Azure OpenAI embedding provider using text-embedding-3-large.
 35    
 36    This implementation provides:
 37    - SSL certificate handling for internal Azure endpoints
 38    - Automatic retry logic with exponential backoff
 39    - Optional logging via shared-core BasicLogger
 40    - Simple developer-friendly API
 41    
 42    Example (API key):
 43        >>> embedder = AzureOpenAIEmbeddings(
 44        ...     endpoint="https://your-resource.openai.azure.com",
 45        ...     api_key="your-api-key",
 46        ...     deployment_name="text-embedding-3-large"
 47        ... )
 48        >>> vector = embedder.embed_text("Hello world")
 49        >>> vectors = embedder.embed_batch(["Text 1", "Text 2"])
 50
 51    Example (managed identity / token provider):
 52        >>> from azure.identity import DefaultAzureCredential, get_bearer_token_provider
 53        >>> token_provider = get_bearer_token_provider(
 54        ...     DefaultAzureCredential(),
 55        ...     "https://cognitiveservices.azure.com/.default"  # Cognitive Services scope
 56        ... )
 57        >>> embedder = AzureOpenAIEmbeddings(
 58        ...     endpoint="https://your-resource.openai.azure.com",
 59        ...     deployment_name="text-embedding-3-large",
 60        ...     token_provider=token_provider
 61        ... )
 62    """
 63
 64    def __init__(
 65        self,
 66        endpoint: str,
 67        deployment_name: str,
 68        api_key: Optional[str] = None,
 69        token_provider: Optional[Callable[[], str]] = None,
 70        model: str = "text-embedding-3-large",
 71        api_version: str = "2024-02-01",
 72        max_retries: int = 3,
 73        retry_delay: float = 1.0,
 74        ssl_cert_path: Optional[str] = None,
 75        logger: Optional['BasicLogger'] = None,
 76    ):
 77        """
 78        Initialize Azure OpenAI embeddings provider.
 79
 80        Exactly one of ``api_key`` or ``token_provider`` must be supplied.
 81
 82        Args:
 83            endpoint: Azure OpenAI endpoint URL
 84            deployment_name: Name of the deployment in Azure
 85            api_key: Azure OpenAI API key. Use for local development or
 86                when managed identity is not available.
 87            token_provider: Zero-argument callable that returns a bearer token
 88                string. Use for managed identity / workload identity scenarios.
 89                The callable must request the **Cognitive Services** scope::
 90
 91                    from azure.identity import DefaultAzureCredential, get_bearer_token_provider
 92                    token_provider = get_bearer_token_provider(
 93                        DefaultAzureCredential(),
 94                        "https://cognitiveservices.azure.com/.default"
 95                    )
 96
 97                Note: this scope is different from Azure AI Search
 98                (``https://search.azure.com/.default``) — each service
 99                requires its own token_provider.
100            model: Model identifier (default: text-embedding-3-large)
101            api_version: Azure OpenAI API version (default: 2024-02-01)
102            max_retries: Maximum number of retry attempts (default: 3)
103            retry_delay: Initial delay between retries in seconds (default: 1.0)
104            ssl_cert_path: Optional path to SSL certificate bundle for corporate networks.
105                          If None, uses default SSL verification.
106            logger: Optional BasicLogger instance for observability
107        """
108        if not api_key and not token_provider:
109            raise ValueError(
110                "Either api_key or token_provider must be supplied to AzureOpenAIEmbeddings."
111            )
112        self.endpoint = endpoint
113        self.deployment_name = deployment_name
114        self.model = model
115        self.max_retries = max_retries
116        self.retry_delay = retry_delay
117        self.logger = logger
118        self.ssl_cert_path = ssl_cert_path
119        
120        # Log initialization details
121        if self.logger:
122            self.logger.info(
123                "Initializing AzureOpenAIEmbeddings",
124                endpoint=endpoint,
125                deployment=deployment_name,
126                model=model,
127                ssl_cert_provided=ssl_cert_path is not None
128            )
129        
130        # Handle SSL certificate if provided (matches shared-core pattern)
131        http_client = None
132        if ssl_cert_path and os.path.exists(ssl_cert_path):
133            try:
134                # Use httpx Client with verify parameter (same as shared-core AzureOpenAIProvider)
135                http_client = httpx.Client(verify=str(ssl_cert_path), timeout=30.0)
136                
137                if self.logger:
138                    self.logger.info(f"✓ Using SSL certificate: {ssl_cert_path}")
139            except Exception as e:
140                if self.logger:
141                    self.logger.warning(f"Failed to configure SSL certificate: {e}")
142        elif ssl_cert_path:
143            if self.logger:
144                self.logger.warning(f"SSL certificate path provided but file not found: {ssl_cert_path}")
145        else:
146            if self.logger:
147                self.logger.debug("No SSL certificate provided, using default SSL verification")
148        
149        # Initialize Azure OpenAI client
150        if token_provider:
151            self.client = AzureOpenAI(
152                azure_ad_token_provider=token_provider,
153                api_version=api_version,
154                azure_endpoint=endpoint,
155                http_client=http_client,
156            )
157        else:
158            self.client = AzureOpenAI(
159                api_key=api_key,
160                api_version=api_version,
161                azure_endpoint=endpoint,
162                http_client=http_client,
163            )
164        
165        # Dimension is detected lazily from the first API response
166        self._dimension: Optional[int] = None
167        
168        if self.logger:
169            self.logger.info(
170                f"✓ Initialized AzureOpenAIEmbeddings with model={model}, "
171                f"deployment={deployment_name}"
172            )
173
174    def _call_with_retry(self, func, *args, **kwargs):
175        """
176        Execute a function with exponential backoff retry logic.
177        
178        Args:
179            func: Function to call
180            *args: Positional arguments for the function
181            **kwargs: Keyword arguments for the function
182            
183        Returns:
184            Result of the function call
185            
186        Raises:
187            Exception: If all retries are exhausted
188        """
189        last_exception = None
190        delay = self.retry_delay
191        
192        for attempt in range(self.max_retries):
193            try:
194                return func(*args, **kwargs)
195            except Exception as e:
196                last_exception = e
197                error_type = type(e).__name__
198                error_msg = str(e)
199                
200                if attempt < self.max_retries - 1:
201                    if self.logger:
202                        self.logger.warning(
203                            f"Attempt {attempt + 1} failed",
204                            error_type=error_type,
205                            error_message=error_msg,
206                            retry_in=f"{delay}s"
207                        )
208                    time.sleep(delay)
209                    delay *= 2  # Exponential backoff
210                else:
211                    if self.logger:
212                        self.logger.error(
213                            f"All {self.max_retries} attempts failed",
214                            error_type=error_type,
215                            error_message=error_msg
216                        )
217        
218        raise last_exception
219
220    def embed_text(self, text: str) -> List[float]:
221        """
222        Generate embeddings for a single text string.
223        
224        Args:
225            text: The input text to embed
226            
227        Returns:
228            A list of floats representing the embedding vector
229            
230        Raises:
231            ValueError: If text is invalid
232            Exception: Azure OpenAI API errors
233        """
234        self.validate_text(text)
235        
236        if self.logger:
237            self.logger.debug(f"Embedding single text (length={len(text)})")
238        
239        def _embed():
240            response: CreateEmbeddingResponse = self.client.embeddings.create(
241                input=text,
242                model=self.deployment_name,
243            )
244            return response.data[0].embedding
245        
246        embedding = self._call_with_retry(_embed)
247
248        if self._dimension is None:
249            self._dimension = len(embedding)
250
251        if self.logger:
252            self.logger.debug(f"Generated embedding with dimension={len(embedding)}")
253
254        return embedding
255
256    def embed_batch(self, texts: List[str]) -> List[List[float]]:
257        """
258        Generate embeddings for multiple texts in a batch.
259        
260        Azure OpenAI supports batch requests which are more efficient than
261        individual calls. This method automatically handles the batch API.
262        
263        Args:
264            texts: List of text strings to embed
265            
266        Returns:
267            List of embedding vectors, one for each input text
268            
269        Raises:
270            ValueError: If texts is invalid
271            Exception: Azure OpenAI API errors
272        """
273        self.validate_batch(texts)
274        
275        if self.logger:
276            self.logger.info(f"Embedding batch of {len(texts)} texts")
277        
278        def _embed():
279            response: CreateEmbeddingResponse = self.client.embeddings.create(
280                input=texts,
281                model=self.deployment_name,
282            )
283            # Ensure results are in the correct order
284            sorted_data = sorted(response.data, key=lambda x: x.index)
285            return [item.embedding for item in sorted_data]
286        
287        embeddings = self._call_with_retry(_embed)
288
289        if self._dimension is None and embeddings:
290            self._dimension = len(embeddings[0])
291
292        if self.logger:
293            self.logger.info(
294                f"Generated {len(embeddings)} embeddings successfully"
295            )
296
297        return embeddings
298
299    def get_embedding_dimension(self) -> Optional[int]:
300        """
301        Get the dimensionality of the embedding vectors.
302
303        Returns the actual dimension observed from the first API response.
304        Returns None if no embedding has been generated yet.
305        """
306        return self._dimension
307
308    def get_model_name(self) -> str:
309        """
310        Get the name of the embedding model.
311        
312        Returns:
313            The model identifier
314        """
315        return self.model

Azure OpenAI embedding provider using text-embedding-3-large.

This implementation provides:

  • SSL certificate handling for internal Azure endpoints
  • Automatic retry logic with exponential backoff
  • Optional logging via shared-core BasicLogger
  • Simple developer-friendly API

Example (API key):

embedder = AzureOpenAIEmbeddings( ... endpoint="https://your-resource.openai.azure.com", ... api_key="your-api-key", ... deployment_name="text-embedding-3-large" ... ) vector = embedder.embed_text("Hello world") vectors = embedder.embed_batch(["Text 1", "Text 2"])

Example (managed identity / token provider):

from azure.identity import DefaultAzureCredential, get_bearer_token_provider token_provider = get_bearer_token_provider( ... DefaultAzureCredential(), ... "https://cognitiveservices.azure.com/.default" # Cognitive Services scope ... ) embedder = AzureOpenAIEmbeddings( ... endpoint="https://your-resource.openai.azure.com", ... deployment_name="text-embedding-3-large", ... token_provider=token_provider ... )

AzureOpenAIEmbeddings( endpoint: str, deployment_name: str, api_key: Optional[str] = None, token_provider: Optional[Callable[[], str]] = None, model: str = 'text-embedding-3-large', api_version: str = '2024-02-01', max_retries: int = 3, retry_delay: float = 1.0, ssl_cert_path: Optional[str] = None, logger: Optional[gmf_forge_ai_shared_core.observability.BasicLogger] = None)
 64    def __init__(
 65        self,
 66        endpoint: str,
 67        deployment_name: str,
 68        api_key: Optional[str] = None,
 69        token_provider: Optional[Callable[[], str]] = None,
 70        model: str = "text-embedding-3-large",
 71        api_version: str = "2024-02-01",
 72        max_retries: int = 3,
 73        retry_delay: float = 1.0,
 74        ssl_cert_path: Optional[str] = None,
 75        logger: Optional['BasicLogger'] = None,
 76    ):
 77        """
 78        Initialize Azure OpenAI embeddings provider.
 79
 80        Exactly one of ``api_key`` or ``token_provider`` must be supplied.
 81
 82        Args:
 83            endpoint: Azure OpenAI endpoint URL
 84            deployment_name: Name of the deployment in Azure
 85            api_key: Azure OpenAI API key. Use for local development or
 86                when managed identity is not available.
 87            token_provider: Zero-argument callable that returns a bearer token
 88                string. Use for managed identity / workload identity scenarios.
 89                The callable must request the **Cognitive Services** scope::
 90
 91                    from azure.identity import DefaultAzureCredential, get_bearer_token_provider
 92                    token_provider = get_bearer_token_provider(
 93                        DefaultAzureCredential(),
 94                        "https://cognitiveservices.azure.com/.default"
 95                    )
 96
 97                Note: this scope is different from Azure AI Search
 98                (``https://search.azure.com/.default``) — each service
 99                requires its own token_provider.
100            model: Model identifier (default: text-embedding-3-large)
101            api_version: Azure OpenAI API version (default: 2024-02-01)
102            max_retries: Maximum number of retry attempts (default: 3)
103            retry_delay: Initial delay between retries in seconds (default: 1.0)
104            ssl_cert_path: Optional path to SSL certificate bundle for corporate networks.
105                          If None, uses default SSL verification.
106            logger: Optional BasicLogger instance for observability
107        """
108        if not api_key and not token_provider:
109            raise ValueError(
110                "Either api_key or token_provider must be supplied to AzureOpenAIEmbeddings."
111            )
112        self.endpoint = endpoint
113        self.deployment_name = deployment_name
114        self.model = model
115        self.max_retries = max_retries
116        self.retry_delay = retry_delay
117        self.logger = logger
118        self.ssl_cert_path = ssl_cert_path
119        
120        # Log initialization details
121        if self.logger:
122            self.logger.info(
123                "Initializing AzureOpenAIEmbeddings",
124                endpoint=endpoint,
125                deployment=deployment_name,
126                model=model,
127                ssl_cert_provided=ssl_cert_path is not None
128            )
129        
130        # Handle SSL certificate if provided (matches shared-core pattern)
131        http_client = None
132        if ssl_cert_path and os.path.exists(ssl_cert_path):
133            try:
134                # Use httpx Client with verify parameter (same as shared-core AzureOpenAIProvider)
135                http_client = httpx.Client(verify=str(ssl_cert_path), timeout=30.0)
136                
137                if self.logger:
138                    self.logger.info(f"✓ Using SSL certificate: {ssl_cert_path}")
139            except Exception as e:
140                if self.logger:
141                    self.logger.warning(f"Failed to configure SSL certificate: {e}")
142        elif ssl_cert_path:
143            if self.logger:
144                self.logger.warning(f"SSL certificate path provided but file not found: {ssl_cert_path}")
145        else:
146            if self.logger:
147                self.logger.debug("No SSL certificate provided, using default SSL verification")
148        
149        # Initialize Azure OpenAI client
150        if token_provider:
151            self.client = AzureOpenAI(
152                azure_ad_token_provider=token_provider,
153                api_version=api_version,
154                azure_endpoint=endpoint,
155                http_client=http_client,
156            )
157        else:
158            self.client = AzureOpenAI(
159                api_key=api_key,
160                api_version=api_version,
161                azure_endpoint=endpoint,
162                http_client=http_client,
163            )
164        
165        # Dimension is detected lazily from the first API response
166        self._dimension: Optional[int] = None
167        
168        if self.logger:
169            self.logger.info(
170                f"✓ Initialized AzureOpenAIEmbeddings with model={model}, "
171                f"deployment={deployment_name}"
172            )

Initialize Azure OpenAI embeddings provider.

Exactly one of api_key or token_provider must be supplied.

Args: endpoint: Azure OpenAI endpoint URL deployment_name: Name of the deployment in Azure api_key: Azure OpenAI API key. Use for local development or when managed identity is not available. token_provider: Zero-argument callable that returns a bearer token string. Use for managed identity / workload identity scenarios. The callable must request the Cognitive Services scope::

        from azure.identity import DefaultAzureCredential, get_bearer_token_provider
        token_provider = get_bearer_token_provider(
            DefaultAzureCredential(),
            "https://cognitiveservices.azure.com/.default"
        )

    Note: this scope is different from Azure AI Search
    (``https://search.azure.com/.default``) — each service
    requires its own token_provider.
model: Model identifier (default: text-embedding-3-large)
api_version: Azure OpenAI API version (default: 2024-02-01)
max_retries: Maximum number of retry attempts (default: 3)
retry_delay: Initial delay between retries in seconds (default: 1.0)
ssl_cert_path: Optional path to SSL certificate bundle for corporate networks.
              If None, uses default SSL verification.
logger: Optional BasicLogger instance for observability
endpoint
deployment_name
model
max_retries
retry_delay
logger
ssl_cert_path
def embed_text(self, text: str) -> List[float]:
220    def embed_text(self, text: str) -> List[float]:
221        """
222        Generate embeddings for a single text string.
223        
224        Args:
225            text: The input text to embed
226            
227        Returns:
228            A list of floats representing the embedding vector
229            
230        Raises:
231            ValueError: If text is invalid
232            Exception: Azure OpenAI API errors
233        """
234        self.validate_text(text)
235        
236        if self.logger:
237            self.logger.debug(f"Embedding single text (length={len(text)})")
238        
239        def _embed():
240            response: CreateEmbeddingResponse = self.client.embeddings.create(
241                input=text,
242                model=self.deployment_name,
243            )
244            return response.data[0].embedding
245        
246        embedding = self._call_with_retry(_embed)
247
248        if self._dimension is None:
249            self._dimension = len(embedding)
250
251        if self.logger:
252            self.logger.debug(f"Generated embedding with dimension={len(embedding)}")
253
254        return embedding

Generate embeddings for a single text string.

Args: text: The input text to embed

Returns: A list of floats representing the embedding vector

Raises: ValueError: If text is invalid Exception: Azure OpenAI API errors

def embed_batch(self, texts: List[str]) -> List[List[float]]:
256    def embed_batch(self, texts: List[str]) -> List[List[float]]:
257        """
258        Generate embeddings for multiple texts in a batch.
259        
260        Azure OpenAI supports batch requests which are more efficient than
261        individual calls. This method automatically handles the batch API.
262        
263        Args:
264            texts: List of text strings to embed
265            
266        Returns:
267            List of embedding vectors, one for each input text
268            
269        Raises:
270            ValueError: If texts is invalid
271            Exception: Azure OpenAI API errors
272        """
273        self.validate_batch(texts)
274        
275        if self.logger:
276            self.logger.info(f"Embedding batch of {len(texts)} texts")
277        
278        def _embed():
279            response: CreateEmbeddingResponse = self.client.embeddings.create(
280                input=texts,
281                model=self.deployment_name,
282            )
283            # Ensure results are in the correct order
284            sorted_data = sorted(response.data, key=lambda x: x.index)
285            return [item.embedding for item in sorted_data]
286        
287        embeddings = self._call_with_retry(_embed)
288
289        if self._dimension is None and embeddings:
290            self._dimension = len(embeddings[0])
291
292        if self.logger:
293            self.logger.info(
294                f"Generated {len(embeddings)} embeddings successfully"
295            )
296
297        return embeddings

Generate embeddings for multiple texts in a batch.

Azure OpenAI supports batch requests which are more efficient than individual calls. This method automatically handles the batch API.

Args: texts: List of text strings to embed

Returns: List of embedding vectors, one for each input text

Raises: ValueError: If texts is invalid Exception: Azure OpenAI API errors

def get_embedding_dimension(self) -> Optional[int]:
299    def get_embedding_dimension(self) -> Optional[int]:
300        """
301        Get the dimensionality of the embedding vectors.
302
303        Returns the actual dimension observed from the first API response.
304        Returns None if no embedding has been generated yet.
305        """
306        return self._dimension

Get the dimensionality of the embedding vectors.

Returns the actual dimension observed from the first API response. Returns None if no embedding has been generated yet.

def get_model_name(self) -> str:
308    def get_model_name(self) -> str:
309        """
310        Get the name of the embedding model.
311        
312        Returns:
313            The model identifier
314        """
315        return self.model

Get the name of the embedding model.

Returns: The model identifier

class BatchEmbeddings(gmf_forge_ai_data.EmbeddingProvider):
 25class BatchEmbeddings(EmbeddingProvider):
 26    """
 27    Wrapper for efficient batch processing of embeddings.
 28    
 29    This class wraps any EmbeddingProvider and adds:
 30    - Automatic batching with configurable batch size
 31    - Progress tracking for large document collections
 32    - Retry logic per batch
 33    - Memory-efficient processing
 34    
 35    Example:
 36        >>> base_embedder = AzureOpenAIEmbeddings(...)
 37        >>> batch_embedder = BatchEmbeddings(
 38        ...     provider=base_embedder,
 39        ...     batch_size=100,
 40        ...     show_progress=True
 41        ... )
 42        >>> # Process 10,000 documents efficiently
 43        >>> embeddings = batch_embedder.embed_batch(large_text_list)
 44    """
 45
 46    def __init__(
 47        self,
 48        provider: EmbeddingProvider,
 49        batch_size: int = 100,
 50        show_progress: bool = True,
 51        progress_callback: Optional[Callable[[int, int], None]] = None,
 52        logger: Optional['BasicLogger'] = None,
 53    ):
 54        """
 55        Initialize batch embedding wrapper.
 56        
 57        Args:
 58            provider: The underlying EmbeddingProvider to wrap
 59            batch_size: Number of texts to process per batch (default: 100)
 60            show_progress: Whether to print progress messages (default: True)
 61            progress_callback: Optional callback function(current, total) for
 62                             custom progress tracking
 63            logger: Optional BasicLogger instance for observability
 64        """
 65        self.provider = provider
 66        self.batch_size = batch_size
 67        self.show_progress = show_progress
 68        self.progress_callback = progress_callback
 69        self.logger = logger
 70        
 71        if self.logger:
 72            self.logger.info(
 73                f"Initialized BatchEmbeddings wrapper with batch_size={batch_size}"
 74            )
 75
 76    def embed_text(self, text: str) -> List[float]:
 77        """
 78        Generate embeddings for a single text (delegates to wrapped provider).
 79        
 80        Args:
 81            text: The input text to embed
 82            
 83        Returns:
 84            A list of floats representing the embedding vector
 85        """
 86        return self.provider.embed_text(text)
 87
 88    def embed_batch(self, texts: List[str]) -> List[List[float]]:
 89        """
 90        Generate embeddings for multiple texts with automatic batching.
 91        
 92        This method splits large text lists into smaller batches and processes
 93        them sequentially to avoid hitting provider limits and manage memory.
 94        
 95        Args:
 96            texts: List of text strings to embed
 97            
 98        Returns:
 99            List of embedding vectors, one for each input text
100            
101        Raises:
102            ValueError: If texts is invalid
103            Exception: Provider-specific errors
104        """
105        self.validate_batch(texts)
106        
107        total_texts = len(texts)
108        
109        if self.logger:
110            self.logger.info(
111                f"Processing {total_texts} texts in batches of {self.batch_size}"
112            )
113        
114        # If batch size is larger than input, just use the provider directly
115        if total_texts <= self.batch_size:
116            if self.logger:
117                self.logger.debug("Batch size larger than input, processing in one call")
118            return self.provider.embed_batch(texts)
119        
120        # Process in batches
121        all_embeddings = []
122        num_batches = (total_texts + self.batch_size - 1) // self.batch_size
123        
124        for batch_idx in range(num_batches):
125            start_idx = batch_idx * self.batch_size
126            end_idx = min(start_idx + self.batch_size, total_texts)
127            batch_texts = texts[start_idx:end_idx]
128            
129            if self.show_progress and self.logger:
130                progress = (batch_idx + 1) / num_batches * 100
131                self.logger.info(
132                    f"Processing batch {batch_idx + 1}/{num_batches}",
133                    progress_pct=round(progress, 1),
134                    texts_done=end_idx,
135                    texts_total=total_texts,
136                )
137            
138            if self.progress_callback:
139                self.progress_callback(end_idx, total_texts)
140            
141            if self.logger:
142                self.logger.debug(
143                    f"Processing batch {batch_idx + 1}/{num_batches}: "
144                    f"texts[{start_idx}:{end_idx}]"
145                )
146            
147            # Embed the batch
148            batch_embeddings = self.provider.embed_batch(batch_texts)
149            all_embeddings.extend(batch_embeddings)
150        
151        if self.show_progress and self.logger:
152            self.logger.info(f"Completed embedding {total_texts} texts")
153        
154        if self.logger:
155            self.logger.info(
156                f"Successfully processed all {total_texts} texts in {num_batches} batches"
157            )
158        
159        return all_embeddings
160
161    def get_embedding_dimension(self) -> int:
162        """
163        Get the dimensionality of the embedding vectors.
164        
165        Returns:
166            The dimension from the wrapped provider
167        """
168        return self.provider.get_embedding_dimension()
169
170    def get_model_name(self) -> str:
171        """
172        Get the name of the embedding model.
173        
174        Returns:
175            The model name from the wrapped provider
176        """
177        return self.provider.get_model_name()
178
179    def set_batch_size(self, batch_size: int) -> None:
180        """
181        Update the batch size for processing.
182        
183        Args:
184            batch_size: New batch size to use
185            
186        Raises:
187            ValueError: If batch_size is not positive
188        """
189        if batch_size <= 0:
190            raise ValueError(f"Batch size must be positive, got {batch_size}")
191        
192        self.batch_size = batch_size
193        
194        if self.logger:
195            self.logger.info(f"Updated batch size to {batch_size}")
196
197    def get_batch_size(self) -> int:
198        """
199        Get the current batch size.
200        
201        Returns:
202            Current batch size setting
203        """
204        return self.batch_size

Wrapper for efficient batch processing of embeddings.

This class wraps any EmbeddingProvider and adds:

  • Automatic batching with configurable batch size
  • Progress tracking for large document collections
  • Retry logic per batch
  • Memory-efficient processing

Example:

base_embedder = AzureOpenAIEmbeddings(...) batch_embedder = BatchEmbeddings( ... provider=base_embedder, ... batch_size=100, ... show_progress=True ... )

Process 10,000 documents efficiently

embeddings = batch_embedder.embed_batch(large_text_list)

BatchEmbeddings( provider: EmbeddingProvider, batch_size: int = 100, show_progress: bool = True, progress_callback: Optional[Callable[[int, int], NoneType]] = None, logger: Optional[gmf_forge_ai_shared_core.observability.BasicLogger] = None)
46    def __init__(
47        self,
48        provider: EmbeddingProvider,
49        batch_size: int = 100,
50        show_progress: bool = True,
51        progress_callback: Optional[Callable[[int, int], None]] = None,
52        logger: Optional['BasicLogger'] = None,
53    ):
54        """
55        Initialize batch embedding wrapper.
56        
57        Args:
58            provider: The underlying EmbeddingProvider to wrap
59            batch_size: Number of texts to process per batch (default: 100)
60            show_progress: Whether to print progress messages (default: True)
61            progress_callback: Optional callback function(current, total) for
62                             custom progress tracking
63            logger: Optional BasicLogger instance for observability
64        """
65        self.provider = provider
66        self.batch_size = batch_size
67        self.show_progress = show_progress
68        self.progress_callback = progress_callback
69        self.logger = logger
70        
71        if self.logger:
72            self.logger.info(
73                f"Initialized BatchEmbeddings wrapper with batch_size={batch_size}"
74            )

Initialize batch embedding wrapper.

Args: provider: The underlying EmbeddingProvider to wrap batch_size: Number of texts to process per batch (default: 100) show_progress: Whether to print progress messages (default: True) progress_callback: Optional callback function(current, total) for custom progress tracking logger: Optional BasicLogger instance for observability

provider
batch_size
show_progress
progress_callback
logger
def embed_text(self, text: str) -> List[float]:
76    def embed_text(self, text: str) -> List[float]:
77        """
78        Generate embeddings for a single text (delegates to wrapped provider).
79        
80        Args:
81            text: The input text to embed
82            
83        Returns:
84            A list of floats representing the embedding vector
85        """
86        return self.provider.embed_text(text)

Generate embeddings for a single text (delegates to wrapped provider).

Args: text: The input text to embed

Returns: A list of floats representing the embedding vector

def embed_batch(self, texts: List[str]) -> List[List[float]]:
 88    def embed_batch(self, texts: List[str]) -> List[List[float]]:
 89        """
 90        Generate embeddings for multiple texts with automatic batching.
 91        
 92        This method splits large text lists into smaller batches and processes
 93        them sequentially to avoid hitting provider limits and manage memory.
 94        
 95        Args:
 96            texts: List of text strings to embed
 97            
 98        Returns:
 99            List of embedding vectors, one for each input text
100            
101        Raises:
102            ValueError: If texts is invalid
103            Exception: Provider-specific errors
104        """
105        self.validate_batch(texts)
106        
107        total_texts = len(texts)
108        
109        if self.logger:
110            self.logger.info(
111                f"Processing {total_texts} texts in batches of {self.batch_size}"
112            )
113        
114        # If batch size is larger than input, just use the provider directly
115        if total_texts <= self.batch_size:
116            if self.logger:
117                self.logger.debug("Batch size larger than input, processing in one call")
118            return self.provider.embed_batch(texts)
119        
120        # Process in batches
121        all_embeddings = []
122        num_batches = (total_texts + self.batch_size - 1) // self.batch_size
123        
124        for batch_idx in range(num_batches):
125            start_idx = batch_idx * self.batch_size
126            end_idx = min(start_idx + self.batch_size, total_texts)
127            batch_texts = texts[start_idx:end_idx]
128            
129            if self.show_progress and self.logger:
130                progress = (batch_idx + 1) / num_batches * 100
131                self.logger.info(
132                    f"Processing batch {batch_idx + 1}/{num_batches}",
133                    progress_pct=round(progress, 1),
134                    texts_done=end_idx,
135                    texts_total=total_texts,
136                )
137            
138            if self.progress_callback:
139                self.progress_callback(end_idx, total_texts)
140            
141            if self.logger:
142                self.logger.debug(
143                    f"Processing batch {batch_idx + 1}/{num_batches}: "
144                    f"texts[{start_idx}:{end_idx}]"
145                )
146            
147            # Embed the batch
148            batch_embeddings = self.provider.embed_batch(batch_texts)
149            all_embeddings.extend(batch_embeddings)
150        
151        if self.show_progress and self.logger:
152            self.logger.info(f"Completed embedding {total_texts} texts")
153        
154        if self.logger:
155            self.logger.info(
156                f"Successfully processed all {total_texts} texts in {num_batches} batches"
157            )
158        
159        return all_embeddings

Generate embeddings for multiple texts with automatic batching.

This method splits large text lists into smaller batches and processes them sequentially to avoid hitting provider limits and manage memory.

Args: texts: List of text strings to embed

Returns: List of embedding vectors, one for each input text

Raises: ValueError: If texts is invalid Exception: Provider-specific errors

def get_embedding_dimension(self) -> int:
161    def get_embedding_dimension(self) -> int:
162        """
163        Get the dimensionality of the embedding vectors.
164        
165        Returns:
166            The dimension from the wrapped provider
167        """
168        return self.provider.get_embedding_dimension()

Get the dimensionality of the embedding vectors.

Returns: The dimension from the wrapped provider

def get_model_name(self) -> str:
170    def get_model_name(self) -> str:
171        """
172        Get the name of the embedding model.
173        
174        Returns:
175            The model name from the wrapped provider
176        """
177        return self.provider.get_model_name()

Get the name of the embedding model.

Returns: The model name from the wrapped provider

def set_batch_size(self, batch_size: int) -> None:
179    def set_batch_size(self, batch_size: int) -> None:
180        """
181        Update the batch size for processing.
182        
183        Args:
184            batch_size: New batch size to use
185            
186        Raises:
187            ValueError: If batch_size is not positive
188        """
189        if batch_size <= 0:
190            raise ValueError(f"Batch size must be positive, got {batch_size}")
191        
192        self.batch_size = batch_size
193        
194        if self.logger:
195            self.logger.info(f"Updated batch size to {batch_size}")

Update the batch size for processing.

Args: batch_size: New batch size to use

Raises: ValueError: If batch_size is not positive

def get_batch_size(self) -> int:
197    def get_batch_size(self) -> int:
198        """
199        Get the current batch size.
200        
201        Returns:
202            Current batch size setting
203        """
204        return self.batch_size

Get the current batch size.

Returns: Current batch size setting

@dataclass
class Chunk:
14@dataclass
15class Chunk:
16    """
17    Represents a chunk of text with associated metadata.
18    
19    Attributes:
20        text: The actual text content of the chunk
21        metadata: Dictionary of metadata (source, page number, etc.)
22        start_pos: Character position where chunk starts in original text
23        end_pos: Character position where chunk ends in original text
24        chunk_id: Unique identifier for the chunk
25    """
26    text: str
27    metadata: Dict[str, Any] = field(default_factory=dict)
28    start_pos: int = 0
29    end_pos: int = 0
30    chunk_id: str = ""
31    
32    def __post_init__(self):
33        """Validate chunk after initialization."""
34        if not self.text:
35            raise ValueError("Chunk text cannot be empty")
36        if self.end_pos < self.start_pos:
37            raise ValueError(f"end_pos ({self.end_pos}) cannot be less than start_pos ({self.start_pos})")
38    
39    def __len__(self) -> int:
40        """Return the length of the chunk text."""
41        return len(self.text)
42    
43    def __str__(self) -> str:
44        """Return a string representation of the chunk."""
45        preview = self.text[:50] + "..." if len(self.text) > 50 else self.text
46        return f"Chunk(id={self.chunk_id}, len={len(self.text)}, text='{preview}')"

Represents a chunk of text with associated metadata.

Attributes: text: The actual text content of the chunk metadata: Dictionary of metadata (source, page number, etc.) start_pos: Character position where chunk starts in original text end_pos: Character position where chunk ends in original text chunk_id: Unique identifier for the chunk

Chunk( text: str, metadata: Dict[str, Any] = <factory>, start_pos: int = 0, end_pos: int = 0, chunk_id: str = '')
text: str
metadata: Dict[str, Any]
start_pos: int = 0
end_pos: int = 0
chunk_id: str = ''
class BaseChunker(abc.ABC):
 49class BaseChunker(ABC):
 50    """
 51    Abstract base class for all text chunking strategies.
 52    
 53    All chunking implementations must inherit from this class and implement
 54    the chunk() method.
 55    """
 56    
 57    def __init__(self, logger=None):
 58        """
 59        Initialize the base chunker.
 60        
 61        Args:
 62            logger: Optional BasicLogger instance for structured logging
 63        """
 64        self.logger = logger
 65    
 66    @abstractmethod
 67    def chunk(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Chunk]:
 68        """
 69        Split text into chunks according to the chunking strategy.
 70        
 71        Args:
 72            text: The input text to chunk
 73            metadata: Optional metadata to attach to each chunk
 74            
 75        Returns:
 76            List of Chunk objects
 77            
 78        Raises:
 79            ValueError: If text is empty or None
 80        """
 81        pass
 82    
 83    def validate_text(self, text: str) -> None:
 84        """
 85        Validate input text before chunking.
 86        
 87        Args:
 88            text: The text to validate
 89            
 90        Raises:
 91            ValueError: If text is None, empty, or not a string
 92        """
 93        if text is None:
 94            raise ValueError("Text cannot be None")
 95        if not isinstance(text, str):
 96            raise ValueError(f"Text must be a string, got {type(text)}")
 97        if not text.strip():
 98            raise ValueError("Text cannot be empty or whitespace only")
 99    
100    def _generate_chunk_id(self, index: int, metadata: Optional[Dict[str, Any]] = None) -> str:
101        """
102        Generate a unique chunk ID.
103        
104        Args:
105            index: The index of the chunk in the sequence
106            metadata: Optional metadata that may contain source information
107            
108        Returns:
109            A unique chunk identifier
110        """
111        if metadata and "source" in metadata:
112            return f"{metadata['source']}_chunk_{index}"
113        return f"chunk_{index}"

Abstract base class for all text chunking strategies.

All chunking implementations must inherit from this class and implement the chunk() method.

BaseChunker(logger=None)
57    def __init__(self, logger=None):
58        """
59        Initialize the base chunker.
60        
61        Args:
62            logger: Optional BasicLogger instance for structured logging
63        """
64        self.logger = logger

Initialize the base chunker.

Args: logger: Optional BasicLogger instance for structured logging

logger
@abstractmethod
def chunk( self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Chunk]:
66    @abstractmethod
67    def chunk(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Chunk]:
68        """
69        Split text into chunks according to the chunking strategy.
70        
71        Args:
72            text: The input text to chunk
73            metadata: Optional metadata to attach to each chunk
74            
75        Returns:
76            List of Chunk objects
77            
78        Raises:
79            ValueError: If text is empty or None
80        """
81        pass

Split text into chunks according to the chunking strategy.

Args: text: The input text to chunk metadata: Optional metadata to attach to each chunk

Returns: List of Chunk objects

Raises: ValueError: If text is empty or None

def validate_text(self, text: str) -> None:
83    def validate_text(self, text: str) -> None:
84        """
85        Validate input text before chunking.
86        
87        Args:
88            text: The text to validate
89            
90        Raises:
91            ValueError: If text is None, empty, or not a string
92        """
93        if text is None:
94            raise ValueError("Text cannot be None")
95        if not isinstance(text, str):
96            raise ValueError(f"Text must be a string, got {type(text)}")
97        if not text.strip():
98            raise ValueError("Text cannot be empty or whitespace only")

Validate input text before chunking.

Args: text: The text to validate

Raises: ValueError: If text is None, empty, or not a string

class FixedSizeChunker(gmf_forge_ai_data.BaseChunker):
 15class FixedSizeChunker(BaseChunker):
 16    """
 17    Chunks text into fixed-size token-based segments with optional overlap.
 18    
 19    This is one of the most common chunking strategies for LLM applications,
 20    ensuring chunks don't exceed model context windows or embedding limits.
 21    """
 22    
 23    def __init__(
 24        self,
 25        chunk_size: int = 512,
 26        chunk_overlap: int = 50,
 27        encoding_name: str = "cl100k_base",
 28        logger=None
 29    ):
 30        """
 31        Initialize the fixed-size chunker.
 32        
 33        Args:
 34            chunk_size: Maximum number of tokens per chunk (default: 512)
 35            chunk_overlap: Number of tokens to overlap between chunks (default: 50)
 36            encoding_name: tiktoken encoding name (default: "cl100k_base" for GPT-3.5/4)
 37                          Options: "cl100k_base" (GPT-3.5, GPT-4)
 38                                   "p50k_base" (CodeX, GPT-3)
 39                                   "r50k_base" (GPT-2)
 40            logger: Optional BasicLogger instance
 41            
 42        Raises:
 43            ValueError: If chunk_size <= 0 or chunk_overlap >= chunk_size
 44        """
 45        super().__init__(logger)
 46        
 47        if chunk_size <= 0:
 48            raise ValueError("chunk_size must be greater than 0")
 49        if chunk_overlap < 0:
 50            raise ValueError("chunk_overlap must be non-negative")
 51        if chunk_overlap >= chunk_size:
 52            raise ValueError("chunk_overlap must be less than chunk_size")
 53        
 54        self.chunk_size = chunk_size
 55        self.chunk_overlap = chunk_overlap
 56        self.encoding_name = encoding_name
 57        
 58        try:
 59            self.encoding = tiktoken.get_encoding(encoding_name)
 60        except Exception as e:
 61            raise ValueError(f"Failed to load tiktoken encoding '{encoding_name}': {e}")
 62        
 63        if self.logger:
 64            self.logger.info(
 65                "Initialized FixedSizeChunker",
 66                chunk_size=chunk_size,
 67                chunk_overlap=chunk_overlap,
 68                encoding=encoding_name
 69            )
 70    
 71    def chunk(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Chunk]:
 72        """
 73        Split text into fixed-size token-based chunks with overlap.
 74        
 75        Args:
 76            text: The input text to chunk
 77            metadata: Optional metadata to attach to each chunk
 78            
 79        Returns:
 80            List of Chunk objects
 81            
 82        Raises:
 83            ValueError: If text is invalid
 84        """
 85        self.validate_text(text)
 86        
 87        if metadata is None:
 88            metadata = {}
 89        
 90        if self.logger:
 91            self.logger.debug(f"Chunking text of length {len(text)} characters")
 92        
 93        # Encode text to tokens
 94        tokens = self.encoding.encode(text)
 95        total_tokens = len(tokens)
 96        
 97        if self.logger:
 98            self.logger.debug(f"Text tokenized to {total_tokens} tokens")
 99        
100        chunks = []
101        start_idx = 0
102        chunk_index = 0
103        
104        while start_idx < total_tokens:
105            # Calculate end index for this chunk
106            end_idx = min(start_idx + self.chunk_size, total_tokens)
107            
108            # Extract token slice
109            chunk_tokens = tokens[start_idx:end_idx]
110            
111            # Decode tokens back to text
112            chunk_text = self.encoding.decode(chunk_tokens)
113            
114            # Find character positions in original text
115            # This is approximate since token boundaries don't always align with char boundaries
116            char_start = len(self.encoding.decode(tokens[:start_idx]))
117            char_end = char_start + len(chunk_text)
118            
119            # Create chunk
120            chunk = Chunk(
121                text=chunk_text,
122                metadata=metadata.copy(),
123                start_pos=char_start,
124                end_pos=char_end,
125                chunk_id=self._generate_chunk_id(chunk_index, metadata)
126            )
127            
128            # Add token count to metadata
129            chunk.metadata["token_count"] = len(chunk_tokens)
130            chunk.metadata["chunking_strategy"] = "fixed_size"
131            
132            chunks.append(chunk)
133            chunk_index += 1
134            
135            # Move start index forward, accounting for overlap
136            start_idx += self.chunk_size - self.chunk_overlap
137            
138            # Break if we've reached the end
139            if end_idx >= total_tokens:
140                break
141        
142        if self.logger:
143            self.logger.info(
144                f"Created {len(chunks)} chunks",
145                total_tokens=total_tokens,
146                avg_tokens_per_chunk=total_tokens / len(chunks) if chunks else 0
147            )
148        
149        return chunks
150    
151    def count_tokens(self, text: str) -> int:
152        """
153        Count the number of tokens in a text string.
154        
155        Args:
156            text: The text to count tokens for
157            
158        Returns:
159            Number of tokens
160        """
161        return len(self.encoding.encode(text))

Chunks text into fixed-size token-based segments with optional overlap.

This is one of the most common chunking strategies for LLM applications, ensuring chunks don't exceed model context windows or embedding limits.

FixedSizeChunker( chunk_size: int = 512, chunk_overlap: int = 50, encoding_name: str = 'cl100k_base', logger=None)
23    def __init__(
24        self,
25        chunk_size: int = 512,
26        chunk_overlap: int = 50,
27        encoding_name: str = "cl100k_base",
28        logger=None
29    ):
30        """
31        Initialize the fixed-size chunker.
32        
33        Args:
34            chunk_size: Maximum number of tokens per chunk (default: 512)
35            chunk_overlap: Number of tokens to overlap between chunks (default: 50)
36            encoding_name: tiktoken encoding name (default: "cl100k_base" for GPT-3.5/4)
37                          Options: "cl100k_base" (GPT-3.5, GPT-4)
38                                   "p50k_base" (CodeX, GPT-3)
39                                   "r50k_base" (GPT-2)
40            logger: Optional BasicLogger instance
41            
42        Raises:
43            ValueError: If chunk_size <= 0 or chunk_overlap >= chunk_size
44        """
45        super().__init__(logger)
46        
47        if chunk_size <= 0:
48            raise ValueError("chunk_size must be greater than 0")
49        if chunk_overlap < 0:
50            raise ValueError("chunk_overlap must be non-negative")
51        if chunk_overlap >= chunk_size:
52            raise ValueError("chunk_overlap must be less than chunk_size")
53        
54        self.chunk_size = chunk_size
55        self.chunk_overlap = chunk_overlap
56        self.encoding_name = encoding_name
57        
58        try:
59            self.encoding = tiktoken.get_encoding(encoding_name)
60        except Exception as e:
61            raise ValueError(f"Failed to load tiktoken encoding '{encoding_name}': {e}")
62        
63        if self.logger:
64            self.logger.info(
65                "Initialized FixedSizeChunker",
66                chunk_size=chunk_size,
67                chunk_overlap=chunk_overlap,
68                encoding=encoding_name
69            )

Initialize the fixed-size chunker.

Args: chunk_size: Maximum number of tokens per chunk (default: 512) chunk_overlap: Number of tokens to overlap between chunks (default: 50) encoding_name: tiktoken encoding name (default: "cl100k_base" for GPT-3.5/4) Options: "cl100k_base" (GPT-3.5, GPT-4) "p50k_base" (CodeX, GPT-3) "r50k_base" (GPT-2) logger: Optional BasicLogger instance

Raises: ValueError: If chunk_size <= 0 or chunk_overlap >= chunk_size

chunk_size
chunk_overlap
encoding_name
def chunk( self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Chunk]:
 71    def chunk(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Chunk]:
 72        """
 73        Split text into fixed-size token-based chunks with overlap.
 74        
 75        Args:
 76            text: The input text to chunk
 77            metadata: Optional metadata to attach to each chunk
 78            
 79        Returns:
 80            List of Chunk objects
 81            
 82        Raises:
 83            ValueError: If text is invalid
 84        """
 85        self.validate_text(text)
 86        
 87        if metadata is None:
 88            metadata = {}
 89        
 90        if self.logger:
 91            self.logger.debug(f"Chunking text of length {len(text)} characters")
 92        
 93        # Encode text to tokens
 94        tokens = self.encoding.encode(text)
 95        total_tokens = len(tokens)
 96        
 97        if self.logger:
 98            self.logger.debug(f"Text tokenized to {total_tokens} tokens")
 99        
100        chunks = []
101        start_idx = 0
102        chunk_index = 0
103        
104        while start_idx < total_tokens:
105            # Calculate end index for this chunk
106            end_idx = min(start_idx + self.chunk_size, total_tokens)
107            
108            # Extract token slice
109            chunk_tokens = tokens[start_idx:end_idx]
110            
111            # Decode tokens back to text
112            chunk_text = self.encoding.decode(chunk_tokens)
113            
114            # Find character positions in original text
115            # This is approximate since token boundaries don't always align with char boundaries
116            char_start = len(self.encoding.decode(tokens[:start_idx]))
117            char_end = char_start + len(chunk_text)
118            
119            # Create chunk
120            chunk = Chunk(
121                text=chunk_text,
122                metadata=metadata.copy(),
123                start_pos=char_start,
124                end_pos=char_end,
125                chunk_id=self._generate_chunk_id(chunk_index, metadata)
126            )
127            
128            # Add token count to metadata
129            chunk.metadata["token_count"] = len(chunk_tokens)
130            chunk.metadata["chunking_strategy"] = "fixed_size"
131            
132            chunks.append(chunk)
133            chunk_index += 1
134            
135            # Move start index forward, accounting for overlap
136            start_idx += self.chunk_size - self.chunk_overlap
137            
138            # Break if we've reached the end
139            if end_idx >= total_tokens:
140                break
141        
142        if self.logger:
143            self.logger.info(
144                f"Created {len(chunks)} chunks",
145                total_tokens=total_tokens,
146                avg_tokens_per_chunk=total_tokens / len(chunks) if chunks else 0
147            )
148        
149        return chunks

Split text into fixed-size token-based chunks with overlap.

Args: text: The input text to chunk metadata: Optional metadata to attach to each chunk

Returns: List of Chunk objects

Raises: ValueError: If text is invalid

def count_tokens(self, text: str) -> int:
151    def count_tokens(self, text: str) -> int:
152        """
153        Count the number of tokens in a text string.
154        
155        Args:
156            text: The text to count tokens for
157            
158        Returns:
159            Number of tokens
160        """
161        return len(self.encoding.encode(text))

Count the number of tokens in a text string.

Args: text: The text to count tokens for

Returns: Number of tokens

class SemanticChunker(gmf_forge_ai_data.BaseChunker):
 15class SemanticChunker(BaseChunker):
 16    """
 17    Chunks text based on semantic similarity and sentence boundaries.
 18    
 19    This chunker respects sentence boundaries and can optionally group
 20    semantically related sentences together using embeddings.
 21    
 22    By default uses simple regex for sentence detection. For better accuracy,
 23    pass a custom sentence_tokenizer (e.g., nltk.sent_tokenize).
 24    """
 25    
 26    def __init__(
 27        self,
 28        sentence_tokenizer: Callable[[str], List[str]],
 29        max_chunk_size: int = 1000,
 30        min_chunk_size: int = 100,
 31        similarity_threshold: float = 0.5,
 32        logger=None
 33    ):
 34        """
 35        Initialize the semantic chunker.
 36        
 37        Args:
 38            sentence_tokenizer: REQUIRED callable that splits text into sentences.
 39                              Use nltk.sent_tokenize (recommended), spacy, or custom function.
 40                              Example: nltk.sent_tokenize
 41                              Must return List[str] of sentences.
 42            max_chunk_size: Maximum characters per chunk (default: 1000)
 43            min_chunk_size: Minimum characters per chunk (default: 100)
 44            similarity_threshold: Threshold for semantic similarity (0.0-1.0, default: 0.5)
 45                                 Not used in basic implementation
 46            logger: Optional BasicLogger instance
 47            
 48        Raises:
 49            ValueError: If sentence_tokenizer is not provided
 50        """
 51        super().__init__(logger)
 52        
 53        if sentence_tokenizer is None:
 54            raise ValueError(
 55                "sentence_tokenizer is required. Please provide a sentence tokenization function.\n"
 56                "Recommended: import nltk; nltk.download('punkt'); use nltk.sent_tokenize\n"
 57                "Example: SemanticChunker(sentence_tokenizer=nltk.sent_tokenize)"
 58            )
 59        
 60        if max_chunk_size <= 0:
 61            raise ValueError("max_chunk_size must be greater than 0")
 62        if min_chunk_size < 0:
 63            raise ValueError("min_chunk_size must be non-negative")
 64        if min_chunk_size >= max_chunk_size:
 65            raise ValueError("min_chunk_size must be less than max_chunk_size")
 66        if not (0.0 <= similarity_threshold <= 1.0):
 67            raise ValueError("similarity_threshold must be between 0.0 and 1.0")
 68        
 69        self.sentence_tokenizer = sentence_tokenizer
 70        self.max_chunk_size = max_chunk_size
 71        self.min_chunk_size = min_chunk_size
 72        self.similarity_threshold = similarity_threshold
 73        
 74        if self.logger:
 75            self.logger.info(
 76                "Initialized SemanticChunker",
 77                max_chunk_size=max_chunk_size,
 78                min_chunk_size=min_chunk_size,
 79                tokenizer=sentence_tokenizer.__name__ if hasattr(sentence_tokenizer, '__name__') else 'custom'
 80            )
 81    
 82    def chunk(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Chunk]:
 83        """
 84        Split text into chunks based on sentence boundaries and semantic similarity.
 85        
 86        Args:
 87            text: The input text to chunk
 88            metadata: Optional metadata to attach to each chunk
 89            
 90        Returns:
 91            List of Chunk objects
 92            
 93        Raises:
 94            ValueError: If text is invalid
 95        """
 96        self.validate_text(text)
 97        
 98        if metadata is None:
 99            metadata = {}
100        
101        if self.logger:
102            self.logger.debug(f"Chunking text of length {len(text)} characters")
103        
104        # Split into sentences
105        sentences = self._split_into_sentences(text)
106        
107        if self.logger:
108            self.logger.debug(f"Split text into {len(sentences)} sentences")
109        
110        # Group sentences into chunks
111        chunks = []
112        current_chunk_sentences = []
113        current_chunk_size = 0
114        chunk_index = 0
115        char_position = 0
116        
117        for sentence_text, start, end in sentences:
118            sentence_len = len(sentence_text)
119            
120            # Check if adding this sentence would exceed max size
121            if current_chunk_size + sentence_len > self.max_chunk_size and current_chunk_sentences:
122                # Create chunk from accumulated sentences
123                chunk_text = " ".join(current_chunk_sentences)
124                chunk_start = char_position - current_chunk_size
125                
126                chunk = Chunk(
127                    text=chunk_text,
128                    metadata=metadata.copy(),
129                    start_pos=chunk_start,
130                    end_pos=char_position,
131                    chunk_id=self._generate_chunk_id(chunk_index, metadata)
132                )
133                chunk.metadata["sentence_count"] = len(current_chunk_sentences)
134                chunk.metadata["chunking_strategy"] = "semantic"
135                
136                chunks.append(chunk)
137                chunk_index += 1
138                
139                # Start new chunk
140                current_chunk_sentences = []
141                current_chunk_size = 0
142            
143            # Add sentence to current chunk
144            current_chunk_sentences.append(sentence_text)
145            current_chunk_size += sentence_len + 1  # +1 for space
146            char_position = end
147        
148       # Create final chunk if there are remaining sentences
149        if current_chunk_sentences:
150            chunk_text = " ".join(current_chunk_sentences)
151            chunk_start = char_position - current_chunk_size
152            
153            chunk = Chunk(
154                text=chunk_text,
155                metadata=metadata.copy(),
156                start_pos=chunk_start,
157                end_pos=char_position,
158                chunk_id=self._generate_chunk_id(chunk_index, metadata)
159            )
160            chunk.metadata["sentence_count"] = len(current_chunk_sentences)
161            chunk.metadata["chunking_strategy"] = "semantic"
162            
163            chunks.append(chunk)
164        
165        if self.logger:
166            self.logger.info(
167                f"Created {len(chunks)} semantic chunks",
168                total_sentences=len(sentences),
169                avg_sentences_per_chunk=len(sentences) / len(chunks) if chunks else 0
170            )
171        
172        return chunks
173    
174    def _split_into_sentences(self, text: str) -> List[tuple]:
175        """
176        Split text into sentences using provided tokenizer.
177        
178        Args:
179            text: The text to split
180            
181        Returns:
182            List of tuples (sentence_text, start_pos, end_pos)
183            
184        Raises:
185            RuntimeError: If sentence tokenizer fails
186        """
187        try:
188            sentence_texts = self.sentence_tokenizer(text)
189            
190            # Calculate positions for each sentence
191            sentences = []
192            pos = 0
193            for sent_text in sentence_texts:
194                # Find the sentence in the original text
195                idx = text.find(sent_text, pos)
196                if idx != -1:
197                    start = idx
198                    end = idx + len(sent_text)
199                    sentences.append((sent_text, start, end))
200                    pos = end
201                else:
202                    # Fallback: estimate position
203                    sentences.append((sent_text, pos, pos + len(sent_text)))
204                    pos += len(sent_text)
205            
206            return sentences if sentences else [(text.strip(), 0, len(text))]
207            
208        except Exception as e:
209            raise RuntimeError(
210                f"Sentence tokenizer failed: {e}\n"
211                f"Please ensure your tokenizer function is working correctly."
212            ) from e

Chunks text based on semantic similarity and sentence boundaries.

This chunker respects sentence boundaries and can optionally group semantically related sentences together using embeddings.

By default uses simple regex for sentence detection. For better accuracy, pass a custom sentence_tokenizer (e.g., nltk.sent_tokenize).

SemanticChunker( sentence_tokenizer: Callable[[str], List[str]], max_chunk_size: int = 1000, min_chunk_size: int = 100, similarity_threshold: float = 0.5, logger=None)
26    def __init__(
27        self,
28        sentence_tokenizer: Callable[[str], List[str]],
29        max_chunk_size: int = 1000,
30        min_chunk_size: int = 100,
31        similarity_threshold: float = 0.5,
32        logger=None
33    ):
34        """
35        Initialize the semantic chunker.
36        
37        Args:
38            sentence_tokenizer: REQUIRED callable that splits text into sentences.
39                              Use nltk.sent_tokenize (recommended), spacy, or custom function.
40                              Example: nltk.sent_tokenize
41                              Must return List[str] of sentences.
42            max_chunk_size: Maximum characters per chunk (default: 1000)
43            min_chunk_size: Minimum characters per chunk (default: 100)
44            similarity_threshold: Threshold for semantic similarity (0.0-1.0, default: 0.5)
45                                 Not used in basic implementation
46            logger: Optional BasicLogger instance
47            
48        Raises:
49            ValueError: If sentence_tokenizer is not provided
50        """
51        super().__init__(logger)
52        
53        if sentence_tokenizer is None:
54            raise ValueError(
55                "sentence_tokenizer is required. Please provide a sentence tokenization function.\n"
56                "Recommended: import nltk; nltk.download('punkt'); use nltk.sent_tokenize\n"
57                "Example: SemanticChunker(sentence_tokenizer=nltk.sent_tokenize)"
58            )
59        
60        if max_chunk_size <= 0:
61            raise ValueError("max_chunk_size must be greater than 0")
62        if min_chunk_size < 0:
63            raise ValueError("min_chunk_size must be non-negative")
64        if min_chunk_size >= max_chunk_size:
65            raise ValueError("min_chunk_size must be less than max_chunk_size")
66        if not (0.0 <= similarity_threshold <= 1.0):
67            raise ValueError("similarity_threshold must be between 0.0 and 1.0")
68        
69        self.sentence_tokenizer = sentence_tokenizer
70        self.max_chunk_size = max_chunk_size
71        self.min_chunk_size = min_chunk_size
72        self.similarity_threshold = similarity_threshold
73        
74        if self.logger:
75            self.logger.info(
76                "Initialized SemanticChunker",
77                max_chunk_size=max_chunk_size,
78                min_chunk_size=min_chunk_size,
79                tokenizer=sentence_tokenizer.__name__ if hasattr(sentence_tokenizer, '__name__') else 'custom'
80            )

Initialize the semantic chunker.

Args: sentence_tokenizer: REQUIRED callable that splits text into sentences. Use nltk.sent_tokenize (recommended), spacy, or custom function. Example: nltk.sent_tokenize Must return List[str] of sentences. max_chunk_size: Maximum characters per chunk (default: 1000) min_chunk_size: Minimum characters per chunk (default: 100) similarity_threshold: Threshold for semantic similarity (0.0-1.0, default: 0.5) Not used in basic implementation logger: Optional BasicLogger instance

Raises: ValueError: If sentence_tokenizer is not provided

sentence_tokenizer
max_chunk_size
min_chunk_size
similarity_threshold
def chunk( self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Chunk]:
 82    def chunk(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Chunk]:
 83        """
 84        Split text into chunks based on sentence boundaries and semantic similarity.
 85        
 86        Args:
 87            text: The input text to chunk
 88            metadata: Optional metadata to attach to each chunk
 89            
 90        Returns:
 91            List of Chunk objects
 92            
 93        Raises:
 94            ValueError: If text is invalid
 95        """
 96        self.validate_text(text)
 97        
 98        if metadata is None:
 99            metadata = {}
100        
101        if self.logger:
102            self.logger.debug(f"Chunking text of length {len(text)} characters")
103        
104        # Split into sentences
105        sentences = self._split_into_sentences(text)
106        
107        if self.logger:
108            self.logger.debug(f"Split text into {len(sentences)} sentences")
109        
110        # Group sentences into chunks
111        chunks = []
112        current_chunk_sentences = []
113        current_chunk_size = 0
114        chunk_index = 0
115        char_position = 0
116        
117        for sentence_text, start, end in sentences:
118            sentence_len = len(sentence_text)
119            
120            # Check if adding this sentence would exceed max size
121            if current_chunk_size + sentence_len > self.max_chunk_size and current_chunk_sentences:
122                # Create chunk from accumulated sentences
123                chunk_text = " ".join(current_chunk_sentences)
124                chunk_start = char_position - current_chunk_size
125                
126                chunk = Chunk(
127                    text=chunk_text,
128                    metadata=metadata.copy(),
129                    start_pos=chunk_start,
130                    end_pos=char_position,
131                    chunk_id=self._generate_chunk_id(chunk_index, metadata)
132                )
133                chunk.metadata["sentence_count"] = len(current_chunk_sentences)
134                chunk.metadata["chunking_strategy"] = "semantic"
135                
136                chunks.append(chunk)
137                chunk_index += 1
138                
139                # Start new chunk
140                current_chunk_sentences = []
141                current_chunk_size = 0
142            
143            # Add sentence to current chunk
144            current_chunk_sentences.append(sentence_text)
145            current_chunk_size += sentence_len + 1  # +1 for space
146            char_position = end
147        
148       # Create final chunk if there are remaining sentences
149        if current_chunk_sentences:
150            chunk_text = " ".join(current_chunk_sentences)
151            chunk_start = char_position - current_chunk_size
152            
153            chunk = Chunk(
154                text=chunk_text,
155                metadata=metadata.copy(),
156                start_pos=chunk_start,
157                end_pos=char_position,
158                chunk_id=self._generate_chunk_id(chunk_index, metadata)
159            )
160            chunk.metadata["sentence_count"] = len(current_chunk_sentences)
161            chunk.metadata["chunking_strategy"] = "semantic"
162            
163            chunks.append(chunk)
164        
165        if self.logger:
166            self.logger.info(
167                f"Created {len(chunks)} semantic chunks",
168                total_sentences=len(sentences),
169                avg_sentences_per_chunk=len(sentences) / len(chunks) if chunks else 0
170            )
171        
172        return chunks

Split text into chunks based on sentence boundaries and semantic similarity.

Args: text: The input text to chunk metadata: Optional metadata to attach to each chunk

Returns: List of Chunk objects

Raises: ValueError: If text is invalid

class RecursiveChunker(gmf_forge_ai_data.BaseChunker):
 16class RecursiveChunker(BaseChunker):
 17    """
 18    Recursively chunks text using a hierarchy of separators.
 19    
 20    This chunker attempts to split at natural boundaries in order of preference:
 21    1. Double newlines (paragraphs)
 22    2. Single newlines (lines)
 23    3. Sentence boundaries
 24    4. Word boundaries
 25    5. Character boundaries (last resort)
 26    
 27    This preserves document structure while ensuring chunks don't exceed the maximum size.
 28    """
 29    
 30    def __init__(
 31        self,
 32        chunk_size: int = 1000,
 33        chunk_overlap: int = 100,
 34        separators: Optional[List[str]] = None,
 35        logger=None
 36    ):
 37        """
 38        Initialize the recursive chunker.
 39        
 40        Args:
 41            chunk_size: Target maximum characters per chunk (default: 1000)
 42            chunk_overlap: Characters to overlap between chunks (default: 100)
 43            separators: List of separators in priority order (default: standard hierarchy)
 44            logger: Optional BasicLogger instance
 45        """
 46        super().__init__(logger)
 47        
 48        if chunk_size <= 0:
 49            raise ValueError("chunk_size must be greater than 0")
 50        if chunk_overlap < 0:
 51            raise ValueError("chunk_overlap must be non-negative")
 52        if chunk_overlap >= chunk_size:
 53            raise ValueError("chunk_overlap must be less than chunk_size")
 54        
 55        self.chunk_size = chunk_size
 56        self.chunk_overlap = chunk_overlap
 57        
 58        # Default separator hierarchy if not provided
 59        if separators is None:
 60            self.separators = [
 61                "\n\n",  # Paragraphs
 62                "\n",    # Lines
 63                ". ",    # Sentences
 64                " ",     # Words
 65                ""       # Characters (no separator)
 66            ]
 67        else:
 68            self.separators = separators
 69        
 70        if self.logger:
 71            self.logger.info(
 72                "Initialized RecursiveChunker",
 73                chunk_size=chunk_size,
 74                chunk_overlap=chunk_overlap,
 75                num_separators=len(self.separators)
 76            )
 77    
 78    def chunk(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Chunk]:
 79        """
 80        Recursively split text into chunks using hierarchical separators.
 81        
 82        Args:
 83            text: The input text to chunk
 84            metadata: Optional metadata to attach to each chunk
 85            
 86        Returns:
 87            List of Chunk objects
 88            
 89        Raises:
 90            ValueError: If text is invalid
 91        """
 92        self.validate_text(text)
 93        
 94        if metadata is None:
 95            metadata = {}
 96        
 97        if self.logger:
 98            self.logger.debug(f"Recursively chunking text of length {len(text)} characters")
 99        
100        # Perform recursive splitting
101        text_chunks = self._split_text_recursively(text)
102        
103        # Convert text chunks to Chunk objects with metadata
104        chunks = []
105        char_position = 0
106        
107        for i, chunk_text in enumerate(text_chunks):
108            chunk = Chunk(
109                text=chunk_text,
110                metadata=metadata.copy(),
111                start_pos=char_position,
112                end_pos=char_position + len(chunk_text),
113                chunk_id=self._generate_chunk_id(i, metadata)
114            )
115            chunk.metadata["chunking_strategy"] = "recursive"
116            chunks.append(chunk)
117            
118            # Update position accounting for overlap
119            char_position += len(chunk_text) - self.chunk_overlap
120        
121        if self.logger:
122            self.logger.info(
123                f"Created {len(chunks)} recursive chunks",
124                avg_chunk_size=sum(len(c.text) for c in chunks) / len(chunks) if chunks else 0
125            )
126        
127        return chunks
128    
129    def _split_text_recursively(self, text: str) -> List[str]:
130        """
131        Recursively split text using the separator hierarchy.
132        
133        Args:
134            text: Text to split
135            
136        Returns:
137            List of text chunks
138        """
139        return self._split_text(text, self.separators)
140    
141    def _split_text(self, text: str, separators: List[str]) -> List[str]:
142        """
143        Split text using the given separators recursively.
144        
145        Args:
146            text: Text to split
147            separators: Remaining separators to try
148            
149        Returns:
150            List of text chunks
151        """
152        final_chunks = []
153        
154        # Choose separator (last one if list is exhausted)
155        separator = separators[-1] if separators else ""
156        
157        # Split by current separator
158        if separator:
159            splits = text.split(separator)
160        else:
161            # No separator: split by characters
162            splits = list(text)
163        
164        # Process each split segment
165        current_chunk = []
166        for split in splits:
167            # Add separator back (except for character-level splitting)
168            if separator and current_chunk:
169                split = separator + split
170            
171            # If this split is small enough, accumulate it
172            current_size = sum(len(s) for s in current_chunk)
173            
174            if current_size + len(split) <= self.chunk_size:
175                current_chunk.append(split)
176            else:
177                # Current chunk is ready
178                if current_chunk:
179                    merged = "".join(current_chunk) if not separator else separator.join(current_chunk)
180                    if separator == " ":
181                        merged = " ".join(c.strip() for c in current_chunk if c.strip())
182                    
183                    if merged.strip():
184                        final_chunks.append(merged)
185                    current_chunk = []
186                
187                # Check if this split itself needs to be broken down
188                if len(split) > self.chunk_size:
189                    if len(separators) > 1:
190                        # Try next separator in hierarchy
191                        sub_chunks = self._split_text(split, separators[1:])
192                        final_chunks.extend(sub_chunks)
193                    else:
194                        # Force split at chunk_size boundaries
195                        for i in range(0, len(split), self.chunk_size):
196                            sub_chunk = split[i:i + self.chunk_size]
197                            if sub_chunk.strip():
198                                final_chunks.append(sub_chunk)
199                else:
200                    current_chunk.append(split)
201        
202        # Add remaining chunk
203        if current_chunk:
204            merged = "".join(current_chunk) if not separator else separator.join(current_chunk)
205            if separator == " ":
206                merged = " ".join(c.strip() for c in current_chunk if c.strip())
207            
208            if merged.strip():
209                final_chunks.append(merged)
210        
211        return final_chunks

Recursively chunks text using a hierarchy of separators.

This chunker attempts to split at natural boundaries in order of preference:

  1. Double newlines (paragraphs)
  2. Single newlines (lines)
  3. Sentence boundaries
  4. Word boundaries
  5. Character boundaries (last resort)

This preserves document structure while ensuring chunks don't exceed the maximum size.

RecursiveChunker( chunk_size: int = 1000, chunk_overlap: int = 100, separators: Optional[List[str]] = None, logger=None)
30    def __init__(
31        self,
32        chunk_size: int = 1000,
33        chunk_overlap: int = 100,
34        separators: Optional[List[str]] = None,
35        logger=None
36    ):
37        """
38        Initialize the recursive chunker.
39        
40        Args:
41            chunk_size: Target maximum characters per chunk (default: 1000)
42            chunk_overlap: Characters to overlap between chunks (default: 100)
43            separators: List of separators in priority order (default: standard hierarchy)
44            logger: Optional BasicLogger instance
45        """
46        super().__init__(logger)
47        
48        if chunk_size <= 0:
49            raise ValueError("chunk_size must be greater than 0")
50        if chunk_overlap < 0:
51            raise ValueError("chunk_overlap must be non-negative")
52        if chunk_overlap >= chunk_size:
53            raise ValueError("chunk_overlap must be less than chunk_size")
54        
55        self.chunk_size = chunk_size
56        self.chunk_overlap = chunk_overlap
57        
58        # Default separator hierarchy if not provided
59        if separators is None:
60            self.separators = [
61                "\n\n",  # Paragraphs
62                "\n",    # Lines
63                ". ",    # Sentences
64                " ",     # Words
65                ""       # Characters (no separator)
66            ]
67        else:
68            self.separators = separators
69        
70        if self.logger:
71            self.logger.info(
72                "Initialized RecursiveChunker",
73                chunk_size=chunk_size,
74                chunk_overlap=chunk_overlap,
75                num_separators=len(self.separators)
76            )

Initialize the recursive chunker.

Args: chunk_size: Target maximum characters per chunk (default: 1000) chunk_overlap: Characters to overlap between chunks (default: 100) separators: List of separators in priority order (default: standard hierarchy) logger: Optional BasicLogger instance

chunk_size
chunk_overlap
def chunk( self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Chunk]:
 78    def chunk(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Chunk]:
 79        """
 80        Recursively split text into chunks using hierarchical separators.
 81        
 82        Args:
 83            text: The input text to chunk
 84            metadata: Optional metadata to attach to each chunk
 85            
 86        Returns:
 87            List of Chunk objects
 88            
 89        Raises:
 90            ValueError: If text is invalid
 91        """
 92        self.validate_text(text)
 93        
 94        if metadata is None:
 95            metadata = {}
 96        
 97        if self.logger:
 98            self.logger.debug(f"Recursively chunking text of length {len(text)} characters")
 99        
100        # Perform recursive splitting
101        text_chunks = self._split_text_recursively(text)
102        
103        # Convert text chunks to Chunk objects with metadata
104        chunks = []
105        char_position = 0
106        
107        for i, chunk_text in enumerate(text_chunks):
108            chunk = Chunk(
109                text=chunk_text,
110                metadata=metadata.copy(),
111                start_pos=char_position,
112                end_pos=char_position + len(chunk_text),
113                chunk_id=self._generate_chunk_id(i, metadata)
114            )
115            chunk.metadata["chunking_strategy"] = "recursive"
116            chunks.append(chunk)
117            
118            # Update position accounting for overlap
119            char_position += len(chunk_text) - self.chunk_overlap
120        
121        if self.logger:
122            self.logger.info(
123                f"Created {len(chunks)} recursive chunks",
124                avg_chunk_size=sum(len(c.text) for c in chunks) / len(chunks) if chunks else 0
125            )
126        
127        return chunks

Recursively split text into chunks using hierarchical separators.

Args: text: The input text to chunk metadata: Optional metadata to attach to each chunk

Returns: List of Chunk objects

Raises: ValueError: If text is invalid

class SentenceChunker(gmf_forge_ai_data.BaseChunker):
 15class SentenceChunker(BaseChunker):
 16    """
 17    Chunks text by grouping sentences together.
 18    
 19    This chunker preserves sentence boundaries while grouping multiple
 20    sentences into chunks based on character count or sentence count limits.
 21    
 22    By default uses simple regex for sentence detection. For better accuracy,
 23    pass a custom sentence_tokenizer (e.g., nltk.sent_tokenize).
 24    """
 25    
 26    def __init__(
 27        self,
 28        sentence_tokenizer: Callable[[str], List[str]],
 29        max_chunk_size: int = 1000,
 30        sentences_per_chunk: Optional[int] = None,
 31        logger=None
 32    ):
 33        """
 34        Initialize the sentence chunker.
 35        
 36        Args:
 37            sentence_tokenizer: REQUIRED callable that splits text into sentences.
 38                              Use nltk.sent_tokenize (recommended), spacy, or custom function.
 39                              Example: nltk.sent_tokenize
 40                              Must return List[str] of sentences.
 41            max_chunk_size: Maximum characters per chunk (default: 1000)
 42            sentences_per_chunk: Optional fixed number of sentences per chunk
 43                                If set, this takes priority over max_chunk_size
 44            logger: Optional BasicLogger instance
 45            
 46        Raises:
 47            ValueError: If sentence_tokenizer is not provided
 48        """
 49        super().__init__(logger)
 50        
 51        if sentence_tokenizer is None:
 52            raise ValueError(
 53                "sentence_tokenizer is required. Please provide a sentence tokenization function.\n"
 54                "Recommended: import nltk; nltk.download('punkt'); use nltk.sent_tokenize\n"
 55                "Example: SentenceChunker(sentence_tokenizer=nltk.sent_tokenize)"
 56            )
 57        
 58        if max_chunk_size <= 0:
 59            raise ValueError("max_chunk_size must be greater than 0")
 60        if sentences_per_chunk is not None and sentences_per_chunk <= 0:
 61            raise ValueError("sentences_per_chunk must be greater than 0")
 62        
 63        self.sentence_tokenizer = sentence_tokenizer
 64        self.max_chunk_size = max_chunk_size
 65        self.sentences_per_chunk = sentences_per_chunk
 66        
 67        if self.logger:
 68            self.logger.info(
 69                "Initialized SentenceChunker",
 70                max_chunk_size=max_chunk_size,
 71                sentences_per_chunk=sentences_per_chunk,
 72                tokenizer=sentence_tokenizer.__name__ if hasattr(sentence_tokenizer, '__name__') else 'custom'
 73            )
 74    
 75    def chunk(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Chunk]:
 76        """
 77        Split text into chunks based on sentence boundaries.
 78        
 79        Args:
 80            text: The input text to chunk
 81            metadata: Optional metadata to attach to each chunk
 82            
 83        Returns:
 84            List of Chunk objects
 85            
 86        Raises:
 87            ValueError: If text is invalid
 88        """
 89        self.validate_text(text)
 90        
 91        if metadata is None:
 92            metadata = {}
 93        
 94        if self.logger:
 95            self.logger.debug(f"Chunking text of length {len(text)} characters by sentences")
 96        
 97        # Split text into sentences
 98        sentences = self._split_sentences(text)
 99        
100        if self.logger:
101            self.logger.debug(f"Found {len(sentences)} sentences")
102        
103        chunks = []
104        current_sentences = []
105        current_size = 0
106        chunk_index = 0
107        char_position = 0
108        
109        for sentence, start, end in sentences:
110            sentence_len = len(sentence)
111            
112            # Check if we should create a new chunk
113            should_chunk = False
114            
115            if self.sentences_per_chunk:
116                # Fixed sentence count mode
117                should_chunk = len(current_sentences) >= self.sentences_per_chunk
118            else:
119                # Size-based mode
120                should_chunk = (
121                    current_size + sentence_len > self.max_chunk_size 
122                    and current_sentences  # Don't create empty chunks
123                )
124            
125            if should_chunk:
126                # Create chunk from accumulated sentences
127                chunk_text = " ".join(current_sentences)
128                chunk_start = char_position - current_size
129                
130                chunk = Chunk(
131                    text=chunk_text,
132                    metadata=metadata.copy(),
133                    start_pos=chunk_start,
134                    end_pos=char_position,
135                    chunk_id=self._generate_chunk_id(chunk_index, metadata)
136                )
137                chunk.metadata["sentence_count"] = len(current_sentences)
138                chunk.metadata["chunking_strategy"] = "sentence"
139                
140                chunks.append(chunk)
141                chunk_index += 1
142                
143                # Reset for new chunk
144                current_sentences = []
145                current_size = 0
146            
147            # Add sentence to current chunk
148            current_sentences.append(sentence)
149            current_size += sentence_len + 1  # +1 for space
150            char_position = end
151        
152        # Create final chunk if there are remaining sentences
153        if current_sentences:
154            chunk_text = " ".join(current_sentences)
155            chunk_start = char_position - current_size
156            
157            chunk = Chunk(
158                text=chunk_text,
159                metadata=metadata.copy(),
160                start_pos=chunk_start,
161                end_pos=char_position,
162                chunk_id=self._generate_chunk_id(chunk_index, metadata)
163            )
164            chunk.metadata["sentence_count"] = len(current_sentences)
165            chunk.metadata["chunking_strategy"] = "sentence"
166            
167            chunks.append(chunk)
168        
169        if self.logger:
170            self.logger.info(
171                f"Created {len(chunks)} sentence-based chunks",
172                total_sentences=len(sentences),
173                avg_sentences_per_chunk=len(sentences) / len(chunks) if chunks else 0
174            )
175        
176        return chunks
177    
178    def _split_sentences(self, text: str) -> List[tuple]:
179        """
180        Split text into sentences with position tracking using provided tokenizer.
181        
182        Args:
183            text: Text to split into sentences
184            
185        Returns:
186            List of tuples (sentence_text, start_pos, end_pos)
187            
188        Raises:
189            RuntimeError: If sentence tokenizer fails
190        """
191        try:
192            sentence_texts = self.sentence_tokenizer(text)
193            
194            # Calculate positions for each sentence
195            sentences = []
196            pos = 0
197            for sent_text in sentence_texts:
198                # Find the sentence in the original text
199                idx = text.find(sent_text, pos)
200                if idx != -1:
201                    start = idx
202                    end = idx + len(sent_text)
203                    sentences.append((sent_text, start, end))
204                    pos = end
205                else:
206                    # Fallback: estimate position
207                    sentences.append((sent_text, pos, pos + len(sent_text)))
208                    pos += len(sent_text)
209            
210            return sentences if sentences else [(text.strip(), 0, len(text))]
211            
212        except Exception as e:
213            raise RuntimeError(
214                f"Sentence tokenizer failed: {e}\n"
215                f"Please ensure your tokenizer function is working correctly."
216            ) from e

Chunks text by grouping sentences together.

This chunker preserves sentence boundaries while grouping multiple sentences into chunks based on character count or sentence count limits.

By default uses simple regex for sentence detection. For better accuracy, pass a custom sentence_tokenizer (e.g., nltk.sent_tokenize).

SentenceChunker( sentence_tokenizer: Callable[[str], List[str]], max_chunk_size: int = 1000, sentences_per_chunk: Optional[int] = None, logger=None)
26    def __init__(
27        self,
28        sentence_tokenizer: Callable[[str], List[str]],
29        max_chunk_size: int = 1000,
30        sentences_per_chunk: Optional[int] = None,
31        logger=None
32    ):
33        """
34        Initialize the sentence chunker.
35        
36        Args:
37            sentence_tokenizer: REQUIRED callable that splits text into sentences.
38                              Use nltk.sent_tokenize (recommended), spacy, or custom function.
39                              Example: nltk.sent_tokenize
40                              Must return List[str] of sentences.
41            max_chunk_size: Maximum characters per chunk (default: 1000)
42            sentences_per_chunk: Optional fixed number of sentences per chunk
43                                If set, this takes priority over max_chunk_size
44            logger: Optional BasicLogger instance
45            
46        Raises:
47            ValueError: If sentence_tokenizer is not provided
48        """
49        super().__init__(logger)
50        
51        if sentence_tokenizer is None:
52            raise ValueError(
53                "sentence_tokenizer is required. Please provide a sentence tokenization function.\n"
54                "Recommended: import nltk; nltk.download('punkt'); use nltk.sent_tokenize\n"
55                "Example: SentenceChunker(sentence_tokenizer=nltk.sent_tokenize)"
56            )
57        
58        if max_chunk_size <= 0:
59            raise ValueError("max_chunk_size must be greater than 0")
60        if sentences_per_chunk is not None and sentences_per_chunk <= 0:
61            raise ValueError("sentences_per_chunk must be greater than 0")
62        
63        self.sentence_tokenizer = sentence_tokenizer
64        self.max_chunk_size = max_chunk_size
65        self.sentences_per_chunk = sentences_per_chunk
66        
67        if self.logger:
68            self.logger.info(
69                "Initialized SentenceChunker",
70                max_chunk_size=max_chunk_size,
71                sentences_per_chunk=sentences_per_chunk,
72                tokenizer=sentence_tokenizer.__name__ if hasattr(sentence_tokenizer, '__name__') else 'custom'
73            )

Initialize the sentence chunker.

Args: sentence_tokenizer: REQUIRED callable that splits text into sentences. Use nltk.sent_tokenize (recommended), spacy, or custom function. Example: nltk.sent_tokenize Must return List[str] of sentences. max_chunk_size: Maximum characters per chunk (default: 1000) sentences_per_chunk: Optional fixed number of sentences per chunk If set, this takes priority over max_chunk_size logger: Optional BasicLogger instance

Raises: ValueError: If sentence_tokenizer is not provided

sentence_tokenizer
max_chunk_size
sentences_per_chunk
def chunk( self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Chunk]:
 75    def chunk(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Chunk]:
 76        """
 77        Split text into chunks based on sentence boundaries.
 78        
 79        Args:
 80            text: The input text to chunk
 81            metadata: Optional metadata to attach to each chunk
 82            
 83        Returns:
 84            List of Chunk objects
 85            
 86        Raises:
 87            ValueError: If text is invalid
 88        """
 89        self.validate_text(text)
 90        
 91        if metadata is None:
 92            metadata = {}
 93        
 94        if self.logger:
 95            self.logger.debug(f"Chunking text of length {len(text)} characters by sentences")
 96        
 97        # Split text into sentences
 98        sentences = self._split_sentences(text)
 99        
100        if self.logger:
101            self.logger.debug(f"Found {len(sentences)} sentences")
102        
103        chunks = []
104        current_sentences = []
105        current_size = 0
106        chunk_index = 0
107        char_position = 0
108        
109        for sentence, start, end in sentences:
110            sentence_len = len(sentence)
111            
112            # Check if we should create a new chunk
113            should_chunk = False
114            
115            if self.sentences_per_chunk:
116                # Fixed sentence count mode
117                should_chunk = len(current_sentences) >= self.sentences_per_chunk
118            else:
119                # Size-based mode
120                should_chunk = (
121                    current_size + sentence_len > self.max_chunk_size 
122                    and current_sentences  # Don't create empty chunks
123                )
124            
125            if should_chunk:
126                # Create chunk from accumulated sentences
127                chunk_text = " ".join(current_sentences)
128                chunk_start = char_position - current_size
129                
130                chunk = Chunk(
131                    text=chunk_text,
132                    metadata=metadata.copy(),
133                    start_pos=chunk_start,
134                    end_pos=char_position,
135                    chunk_id=self._generate_chunk_id(chunk_index, metadata)
136                )
137                chunk.metadata["sentence_count"] = len(current_sentences)
138                chunk.metadata["chunking_strategy"] = "sentence"
139                
140                chunks.append(chunk)
141                chunk_index += 1
142                
143                # Reset for new chunk
144                current_sentences = []
145                current_size = 0
146            
147            # Add sentence to current chunk
148            current_sentences.append(sentence)
149            current_size += sentence_len + 1  # +1 for space
150            char_position = end
151        
152        # Create final chunk if there are remaining sentences
153        if current_sentences:
154            chunk_text = " ".join(current_sentences)
155            chunk_start = char_position - current_size
156            
157            chunk = Chunk(
158                text=chunk_text,
159                metadata=metadata.copy(),
160                start_pos=chunk_start,
161                end_pos=char_position,
162                chunk_id=self._generate_chunk_id(chunk_index, metadata)
163            )
164            chunk.metadata["sentence_count"] = len(current_sentences)
165            chunk.metadata["chunking_strategy"] = "sentence"
166            
167            chunks.append(chunk)
168        
169        if self.logger:
170            self.logger.info(
171                f"Created {len(chunks)} sentence-based chunks",
172                total_sentences=len(sentences),
173                avg_sentences_per_chunk=len(sentences) / len(chunks) if chunks else 0
174            )
175        
176        return chunks

Split text into chunks based on sentence boundaries.

Args: text: The input text to chunk metadata: Optional metadata to attach to each chunk

Returns: List of Chunk objects

Raises: ValueError: If text is invalid

class MarkdownChunker(gmf_forge_ai_data.BaseChunker):
 15class MarkdownChunker(BaseChunker):
 16    """
 17    Chunks markdown text while respecting document structure.
 18    
 19    This chunker identifies markdown headers (# ## ###, etc.) and uses them
 20    as natural boundaries for chunking, preserving the document hierarchy.
 21    
 22    Uses regex-based parsing which works well for standard markdown. For complex
 23    markdown with extensions, consider pre-processing with mistune or markdown-it-py.
 24    """
 25    
 26    def __init__(
 27        self,
 28        max_chunk_size: int = 1500,
 29        combine_headers: bool = True,
 30        min_header_level: int = 1,
 31        logger=None
 32    ):
 33        """
 34        Initialize the markdown chunker.
 35        
 36        Args:
 37            max_chunk_size: Maximum characters per chunk (default: 1500)
 38            combine_headers: Whether to combine small sections under headers (default: True)
 39            min_header_level: Minimum header level to split at (1-6, default: 1)
 40            logger: Optional BasicLogger instance
 41        """
 42        super().__init__(logger)
 43        
 44        if max_chunk_size <= 0:
 45            raise ValueError("max_chunk_size must be greater than 0")
 46        if not (1 <= min_header_level <= 6):
 47            raise ValueError("min_header_level must be between 1 and 6")
 48        
 49        self.max_chunk_size = max_chunk_size
 50        self.combine_headers = combine_headers
 51        self.min_header_level = min_header_level
 52        
 53        # Regex pattern for markdown headers (both ATX and Setext styles)
 54        self.header_pattern = re.compile(r'^(#{1,6})\s+(.+)$', re.MULTILINE)
 55        
 56        if self.logger:
 57            self.logger.info(
 58                "Initialized MarkdownChunker",
 59                max_chunk_size=max_chunk_size,
 60                combine_headers=combine_headers,
 61                min_header_level=min_header_level
 62            )
 63    
 64    def chunk(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Chunk]:
 65        """
 66        Split markdown text into chunks respecting header boundaries.
 67        
 68        Args:
 69            text: The markdown text to chunk
 70            metadata: Optional metadata to attach to each chunk
 71            
 72        Returns:
 73            List of Chunk objects
 74            
 75        Raises:
 76            ValueError: If text is invalid
 77        """
 78        self.validate_text(text)
 79        
 80        if metadata is None:
 81            metadata = {}
 82        
 83        if self.logger:
 84            self.logger.debug(f"Chunking markdown text of length {len(text)} characters")
 85        
 86        # Find all headers and their positions
 87        sections = self._parse_sections(text)
 88        
 89        if self.logger:
 90            self.logger.debug(f"Found {len(sections)} markdown sections")
 91        
 92        # Create chunks from sections
 93        chunks = []
 94        current_chunk_parts = []
 95        current_size = 0
 96        chunk_index = 0
 97        current_headers = []
 98        
 99        for level, header_text, content, start, end in sections:
100            section_text = content
101            section_size = len(section_text)
102            
103            # Check if we should start a new chunk
104            if (current_size + section_size > self.max_chunk_size 
105                and current_chunk_parts 
106                and level <= self.min_header_level):
107                
108                # Create chunk from accumulated sections
109                chunk_text = "\n\n".join(current_chunk_parts)
110                
111                chunk = Chunk(
112                    text=chunk_text,
113                    metadata=metadata.copy(),
114                    start_pos=start - current_size,
115                    end_pos=start,
116                    chunk_id=self._generate_chunk_id(chunk_index, metadata)
117                )
118                chunk.metadata["headers"] = current_headers.copy()
119                chunk.metadata["chunking_strategy"] = "markdown"
120                
121                chunks.append(chunk)
122                chunk_index += 1
123                
124                # Reset for new chunk
125                current_chunk_parts = []
126                current_size = 0
127                current_headers = []
128            
129            # Add section to current chunk
130            current_chunk_parts.append(section_text)
131            current_size += section_size + 2  # +2 for \n\n
132            
133            # Track header hierarchy
134            if header_text:
135                current_headers.append({
136                    "level": level,
137                    "text": header_text
138                })
139        
140        # Create final chunk
141        if current_chunk_parts:
142            chunk_text = "\n\n".join(current_chunk_parts)
143            
144            chunk = Chunk(
145                text=chunk_text,
146                metadata=metadata.copy(),
147                start_pos=len(text) - current_size,
148                end_pos=len(text),
149                chunk_id=self._generate_chunk_id(chunk_index, metadata)
150            )
151            chunk.metadata["headers"] = current_headers
152            chunk.metadata["chunking_strategy"] = "markdown"
153            
154            chunks.append(chunk)
155        
156        if self.logger:
157            self.logger.info(
158                f"Created {len(chunks)} markdown chunks",
159                total_sections=len(sections),
160                avg_sections_per_chunk=len(sections) / len(chunks) if chunks else 0
161            )
162        
163        return chunks
164    
165    def _parse_sections(self, text: str) -> List[Tuple[int, str, str, int, int]]:
166        """
167        Parse markdown text into sections based on headers.
168        
169        Args:
170            text: Markdown text to parse
171            
172        Returns:
173            List of tuples (header_level, header_text, content, start_pos, end_pos)
174        """
175        sections = []
176        lines = text.split('\n')
177        current_content = []
178        current_header_level = 0
179        current_header_text = ""
180        section_start = 0
181        char_position = 0
182        
183        for i, line in enumerate(lines):
184            # Check if line is a header
185            header_match = self.header_pattern.match(line)
186            
187            if header_match:
188                # Save previous section if it exists
189                if current_content or current_header_text:
190                    content = '\n'.join(current_content)
191                    sections.append((
192                        current_header_level,
193                        current_header_text,
194                        content,
195                        section_start,
196                        char_position
197                    ))
198                
199                # Start new section
200                current_header_level = len(header_match.group(1))
201                current_header_text = header_match.group(2).strip()
202                current_content = [line]  # Include header in content
203                section_start = char_position
204            else:
205                current_content.append(line)
206            
207            char_position += len(line) + 1  # +1 for newline
208        
209        # Add final section
210        if current_content:
211            content = '\n'.join(current_content)
212            sections.append((
213                current_header_level,
214                current_header_text,
215                content,
216                section_start,
217                char_position
218            ))
219        
220        # If no sections found, treat entire text as one section
221        if not sections:
222            sections = [(0, "", text, 0, len(text))]
223        
224        return sections

Chunks markdown text while respecting document structure.

This chunker identifies markdown headers (# ## ###, etc.) and uses them as natural boundaries for chunking, preserving the document hierarchy.

Uses regex-based parsing which works well for standard markdown. For complex markdown with extensions, consider pre-processing with mistune or markdown-it-py.

MarkdownChunker( max_chunk_size: int = 1500, combine_headers: bool = True, min_header_level: int = 1, logger=None)
26    def __init__(
27        self,
28        max_chunk_size: int = 1500,
29        combine_headers: bool = True,
30        min_header_level: int = 1,
31        logger=None
32    ):
33        """
34        Initialize the markdown chunker.
35        
36        Args:
37            max_chunk_size: Maximum characters per chunk (default: 1500)
38            combine_headers: Whether to combine small sections under headers (default: True)
39            min_header_level: Minimum header level to split at (1-6, default: 1)
40            logger: Optional BasicLogger instance
41        """
42        super().__init__(logger)
43        
44        if max_chunk_size <= 0:
45            raise ValueError("max_chunk_size must be greater than 0")
46        if not (1 <= min_header_level <= 6):
47            raise ValueError("min_header_level must be between 1 and 6")
48        
49        self.max_chunk_size = max_chunk_size
50        self.combine_headers = combine_headers
51        self.min_header_level = min_header_level
52        
53        # Regex pattern for markdown headers (both ATX and Setext styles)
54        self.header_pattern = re.compile(r'^(#{1,6})\s+(.+)$', re.MULTILINE)
55        
56        if self.logger:
57            self.logger.info(
58                "Initialized MarkdownChunker",
59                max_chunk_size=max_chunk_size,
60                combine_headers=combine_headers,
61                min_header_level=min_header_level
62            )

Initialize the markdown chunker.

Args: max_chunk_size: Maximum characters per chunk (default: 1500) combine_headers: Whether to combine small sections under headers (default: True) min_header_level: Minimum header level to split at (1-6, default: 1) logger: Optional BasicLogger instance

max_chunk_size
combine_headers
min_header_level
header_pattern
def chunk( self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Chunk]:
 64    def chunk(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Chunk]:
 65        """
 66        Split markdown text into chunks respecting header boundaries.
 67        
 68        Args:
 69            text: The markdown text to chunk
 70            metadata: Optional metadata to attach to each chunk
 71            
 72        Returns:
 73            List of Chunk objects
 74            
 75        Raises:
 76            ValueError: If text is invalid
 77        """
 78        self.validate_text(text)
 79        
 80        if metadata is None:
 81            metadata = {}
 82        
 83        if self.logger:
 84            self.logger.debug(f"Chunking markdown text of length {len(text)} characters")
 85        
 86        # Find all headers and their positions
 87        sections = self._parse_sections(text)
 88        
 89        if self.logger:
 90            self.logger.debug(f"Found {len(sections)} markdown sections")
 91        
 92        # Create chunks from sections
 93        chunks = []
 94        current_chunk_parts = []
 95        current_size = 0
 96        chunk_index = 0
 97        current_headers = []
 98        
 99        for level, header_text, content, start, end in sections:
100            section_text = content
101            section_size = len(section_text)
102            
103            # Check if we should start a new chunk
104            if (current_size + section_size > self.max_chunk_size 
105                and current_chunk_parts 
106                and level <= self.min_header_level):
107                
108                # Create chunk from accumulated sections
109                chunk_text = "\n\n".join(current_chunk_parts)
110                
111                chunk = Chunk(
112                    text=chunk_text,
113                    metadata=metadata.copy(),
114                    start_pos=start - current_size,
115                    end_pos=start,
116                    chunk_id=self._generate_chunk_id(chunk_index, metadata)
117                )
118                chunk.metadata["headers"] = current_headers.copy()
119                chunk.metadata["chunking_strategy"] = "markdown"
120                
121                chunks.append(chunk)
122                chunk_index += 1
123                
124                # Reset for new chunk
125                current_chunk_parts = []
126                current_size = 0
127                current_headers = []
128            
129            # Add section to current chunk
130            current_chunk_parts.append(section_text)
131            current_size += section_size + 2  # +2 for \n\n
132            
133            # Track header hierarchy
134            if header_text:
135                current_headers.append({
136                    "level": level,
137                    "text": header_text
138                })
139        
140        # Create final chunk
141        if current_chunk_parts:
142            chunk_text = "\n\n".join(current_chunk_parts)
143            
144            chunk = Chunk(
145                text=chunk_text,
146                metadata=metadata.copy(),
147                start_pos=len(text) - current_size,
148                end_pos=len(text),
149                chunk_id=self._generate_chunk_id(chunk_index, metadata)
150            )
151            chunk.metadata["headers"] = current_headers
152            chunk.metadata["chunking_strategy"] = "markdown"
153            
154            chunks.append(chunk)
155        
156        if self.logger:
157            self.logger.info(
158                f"Created {len(chunks)} markdown chunks",
159                total_sections=len(sections),
160                avg_sections_per_chunk=len(sections) / len(chunks) if chunks else 0
161            )
162        
163        return chunks

Split markdown text into chunks respecting header boundaries.

Args: text: The markdown text to chunk metadata: Optional metadata to attach to each chunk

Returns: List of Chunk objects

Raises: ValueError: If text is invalid

class CodeChunker(gmf_forge_ai_data.BaseChunker):
 15class CodeChunker(BaseChunker):
 16    """
 17    Chunks code while respecting function and class boundaries.
 18    
 19    This chunker identifies code structures (functions, classes, methods)
 20    and uses them as natural boundaries for chunking, preserving code context.
 21    Supports Python, JavaScript, TypeScript, Java, C#, and similar languages.
 22    """
 23    
 24    def __init__(
 25        self,
 26        max_chunk_size: int = 2000,
 27        language: str = "python",
 28        include_imports: bool = True,
 29        logger=None
 30    ):
 31        """
 32        Initialize the code chunker.
 33        
 34        Args:
 35            max_chunk_size: Maximum characters per chunk (default: 2000)
 36            language: Programming language ("python", "javascript", "java", etc.)
 37            include_imports: Whether to include imports/using statements in chunks
 38            logger: Optional BasicLogger instance
 39        """
 40        super().__init__(logger)
 41        
 42        if max_chunk_size <= 0:
 43            raise ValueError("max_chunk_size must be greater than 0")
 44        
 45        self.max_chunk_size = max_chunk_size
 46        self.language = language.lower()
 47        self.include_imports = include_imports
 48        
 49        # Define patterns for different languages
 50        self._setup_patterns()
 51        
 52        if self.logger:
 53            self.logger.info(
 54                "Initialized CodeChunker",
 55                max_chunk_size=max_chunk_size,
 56                language=language
 57            )
 58    
 59    def _setup_patterns(self):
 60        """Setup regex patterns based on language."""
 61        if self.language == "python":
 62            self.function_pattern = re.compile(
 63                r'^(async\s+)?def\s+\w+\s*\([^)]*\)\s*(->\s*[^:]+)?:',
 64                re.MULTILINE
 65            )
 66            self.class_pattern = re.compile(
 67                r'^class\s+\w+(\([^)]*\))?:\s*$',
 68                re.MULTILINE
 69            )
 70            self.import_pattern = re.compile(
 71                r'^(import\s+\S+|from\s+\S+\s+import\s+.+)$',
 72                re.MULTILINE
 73            )
 74        elif self.language in ["javascript", "typescript", "java", "csharp", "c++"]:
 75            self.function_pattern = re.compile(
 76                r'(public|private|protected|static|async)?\s*(function|void|int|string|bool|var|let|const)?\s+\w+\s*\([^)]*\)\s*\{',
 77                re.MULTILINE
 78            )
 79            self.class_pattern = re.compile(
 80                r'(export\s+)?(public\s+)?class\s+\w+(\s+extends\s+\w+)?(\s+implements\s+[\w,\s]+)?\s*\{',
 81                re.MULTILINE
 82            )
 83            self.import_pattern = re.compile(
 84                r'^(import\s+.*from\s+["\'].*["\'];?|using\s+\S+;|#include\s+<.*>)$',
 85                re.MULTILINE
 86            )
 87        else:
 88            # Generic patterns for unknown languages
 89            self.function_pattern = re.compile(r'^\s*(def|function|func|fun)\s+\w+', re.MULTILINE)
 90            self.class_pattern = re.compile(r'^\s*class\s+\w+', re.MULTILINE)
 91            self.import_pattern = re.compile(r'^(import|using|include)\s+', re.MULTILINE)
 92    
 93    def chunk(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Chunk]:
 94        """
 95        Split code into chunks respecting function and class boundaries.
 96        
 97        Args:
 98            text: The code text to chunk
 99            metadata: Optional metadata to attach to each chunk
100            
101        Returns:
102            List of Chunk objects
103            
104        Raises:
105            ValueError: If text is invalid
106        """
107        self.validate_text(text)
108        
109        if metadata is None:
110            metadata = {}
111        
112        if self.logger:
113            self.logger.debug(f"Chunking code of length {len(text)} characters")
114        
115        # Extract imports if needed
116        imports_text = ""
117        if self.include_imports:
118            imports_text = self._extract_imports(text)
119        
120        # Find all code blocks (functions and classes)
121        code_blocks = self._parse_code_blocks(text)
122        
123        if self.logger:
124            self.logger.debug(f"Found {len(code_blocks)} code blocks")
125        
126        # Create chunks from code blocks
127        chunks = []
128        current_chunk_parts = []
129        current_size = len(imports_text)
130        chunk_index = 0
131        
132        for block_type, block_name, block_content, start, end in code_blocks:
133            block_size = len(block_content)
134            
135            # Check if we should start a new chunk
136            if current_size + block_size > self.max_chunk_size and current_chunk_parts:
137                # Create chunk from accumulated blocks
138                chunk_text = imports_text + "\n\n" + "\n\n".join(current_chunk_parts)
139                
140                chunk = Chunk(
141                    text=chunk_text.strip(),
142                    metadata=metadata.copy(),
143                    start_pos=start - current_size,
144                    end_pos=start,
145                    chunk_id=self._generate_chunk_id(chunk_index, metadata)
146                )
147                chunk.metadata["code_blocks"] = len(current_chunk_parts)
148                chunk.metadata["chunking_strategy"] = "code"
149                chunk.metadata["language"] = self.language
150                
151                chunks.append(chunk)
152                chunk_index += 1
153                
154                # Reset for new chunk
155                current_chunk_parts = []
156                current_size = len(imports_text)
157            
158            # Add block to current chunk
159            current_chunk_parts.append(block_content)
160            current_size += block_size + 2  # +2 for \n\n
161        
162        # Create final chunk
163        if current_chunk_parts:
164            chunk_text = imports_text + "\n\n" + "\n\n".join(current_chunk_parts)
165            
166            chunk = Chunk(
167                text=chunk_text.strip(),
168                metadata=metadata.copy(),
169                start_pos=len(text) - current_size,
170                end_pos=len(text),
171                chunk_id=self._generate_chunk_id(chunk_index, metadata)
172            )
173            chunk.metadata["code_blocks"] = len(current_chunk_parts)
174            chunk.metadata["chunking_strategy"] = "code"
175            chunk.metadata["language"] = self.language
176            
177            chunks.append(chunk)
178        
179        if self.logger:
180            self.logger.info(
181                f"Created {len(chunks)} code chunks",
182                total_blocks=len(code_blocks),
183                avg_blocks_per_chunk=len(code_blocks) / len(chunks) if chunks else 0
184            )
185        
186        return chunks
187    
188    def _extract_imports(self, text: str) -> str:
189        """Extract import/using statements from the code."""
190        imports = []
191        for match in self.import_pattern.finditer(text):
192            imports.append(match.group().strip())
193       
194        return "\n".join(imports) if imports else ""
195    
196    def _parse_code_blocks(self, text: str) -> List[Tuple[str, str, str, int, int]]:
197        """
198        Parse code into blocks (functions, classes, etc.).
199        
200        Args:
201            text: Code text to parse
202            
203        Returns:
204            List of tuples (block_type, block_name, content, start_pos, end_pos)
205        """
206        blocks = []
207        lines = text.split('\n')
208        
209        # Find all function and class definitions
210        all_matches = []
211        
212        for match in self.function_pattern.finditer(text):
213            all_matches.append(('function', match.start(), match.group()))
214        
215        for match in self.class_pattern.finditer(text):
216            all_matches.append(('class', match.start(), match.group()))
217        
218        # Sort by position
219        all_matches.sort(key=lambda x: x[1])
220        
221        # Extract code blocks with their content
222        for i, (block_type, start, match_text) in enumerate(all_matches):
223            # Extract block name
224            block_name = self._extract_name(match_text)
225            
226            # Find end of block (next block start or end of file)
227            if i + 1 < len(all_matches):
228                end = all_matches[i + 1][1]
229            else:
230                end = len(text)
231            
232            # Extract block content
233            block_content = text[start:end].strip()
234            
235            blocks.append((block_type, block_name, block_content, start, end))
236        
237        # If no blocks found, treat entire text as one block
238        if not blocks:
239            blocks = [('code', 'main', text, 0, len(text))]
240        
241        return blocks
242    
243    def _extract_name(self, definition: str) -> str:
244        """Extract function or class name from definition."""
245        # Try to find the name after 'def', 'function', 'class', etc.
246        name_match = re.search(r'\b(def|function|class|func|fun)\s+(\w+)', definition)
247        if name_match:
248            return name_match.group(2)
249        
250        # Fallback: try to find any word after space
251        words = definition.split()
252        for word in words:
253            if re.match(r'^\w+$', word) and word not in ['def', 'function', 'class', 'public', 'private', 'static', 'async']:
254                return word
255        
256        return 'unknown'

Chunks code while respecting function and class boundaries.

This chunker identifies code structures (functions, classes, methods) and uses them as natural boundaries for chunking, preserving code context. Supports Python, JavaScript, TypeScript, Java, C#, and similar languages.

CodeChunker( max_chunk_size: int = 2000, language: str = 'python', include_imports: bool = True, logger=None)
24    def __init__(
25        self,
26        max_chunk_size: int = 2000,
27        language: str = "python",
28        include_imports: bool = True,
29        logger=None
30    ):
31        """
32        Initialize the code chunker.
33        
34        Args:
35            max_chunk_size: Maximum characters per chunk (default: 2000)
36            language: Programming language ("python", "javascript", "java", etc.)
37            include_imports: Whether to include imports/using statements in chunks
38            logger: Optional BasicLogger instance
39        """
40        super().__init__(logger)
41        
42        if max_chunk_size <= 0:
43            raise ValueError("max_chunk_size must be greater than 0")
44        
45        self.max_chunk_size = max_chunk_size
46        self.language = language.lower()
47        self.include_imports = include_imports
48        
49        # Define patterns for different languages
50        self._setup_patterns()
51        
52        if self.logger:
53            self.logger.info(
54                "Initialized CodeChunker",
55                max_chunk_size=max_chunk_size,
56                language=language
57            )

Initialize the code chunker.

Args: max_chunk_size: Maximum characters per chunk (default: 2000) language: Programming language ("python", "javascript", "java", etc.) include_imports: Whether to include imports/using statements in chunks logger: Optional BasicLogger instance

max_chunk_size
language
include_imports
def chunk( self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Chunk]:
 93    def chunk(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Chunk]:
 94        """
 95        Split code into chunks respecting function and class boundaries.
 96        
 97        Args:
 98            text: The code text to chunk
 99            metadata: Optional metadata to attach to each chunk
100            
101        Returns:
102            List of Chunk objects
103            
104        Raises:
105            ValueError: If text is invalid
106        """
107        self.validate_text(text)
108        
109        if metadata is None:
110            metadata = {}
111        
112        if self.logger:
113            self.logger.debug(f"Chunking code of length {len(text)} characters")
114        
115        # Extract imports if needed
116        imports_text = ""
117        if self.include_imports:
118            imports_text = self._extract_imports(text)
119        
120        # Find all code blocks (functions and classes)
121        code_blocks = self._parse_code_blocks(text)
122        
123        if self.logger:
124            self.logger.debug(f"Found {len(code_blocks)} code blocks")
125        
126        # Create chunks from code blocks
127        chunks = []
128        current_chunk_parts = []
129        current_size = len(imports_text)
130        chunk_index = 0
131        
132        for block_type, block_name, block_content, start, end in code_blocks:
133            block_size = len(block_content)
134            
135            # Check if we should start a new chunk
136            if current_size + block_size > self.max_chunk_size and current_chunk_parts:
137                # Create chunk from accumulated blocks
138                chunk_text = imports_text + "\n\n" + "\n\n".join(current_chunk_parts)
139                
140                chunk = Chunk(
141                    text=chunk_text.strip(),
142                    metadata=metadata.copy(),
143                    start_pos=start - current_size,
144                    end_pos=start,
145                    chunk_id=self._generate_chunk_id(chunk_index, metadata)
146                )
147                chunk.metadata["code_blocks"] = len(current_chunk_parts)
148                chunk.metadata["chunking_strategy"] = "code"
149                chunk.metadata["language"] = self.language
150                
151                chunks.append(chunk)
152                chunk_index += 1
153                
154                # Reset for new chunk
155                current_chunk_parts = []
156                current_size = len(imports_text)
157            
158            # Add block to current chunk
159            current_chunk_parts.append(block_content)
160            current_size += block_size + 2  # +2 for \n\n
161        
162        # Create final chunk
163        if current_chunk_parts:
164            chunk_text = imports_text + "\n\n" + "\n\n".join(current_chunk_parts)
165            
166            chunk = Chunk(
167                text=chunk_text.strip(),
168                metadata=metadata.copy(),
169                start_pos=len(text) - current_size,
170                end_pos=len(text),
171                chunk_id=self._generate_chunk_id(chunk_index, metadata)
172            )
173            chunk.metadata["code_blocks"] = len(current_chunk_parts)
174            chunk.metadata["chunking_strategy"] = "code"
175            chunk.metadata["language"] = self.language
176            
177            chunks.append(chunk)
178        
179        if self.logger:
180            self.logger.info(
181                f"Created {len(chunks)} code chunks",
182                total_blocks=len(code_blocks),
183                avg_blocks_per_chunk=len(code_blocks) / len(chunks) if chunks else 0
184            )
185        
186        return chunks

Split code into chunks respecting function and class boundaries.

Args: text: The code text to chunk metadata: Optional metadata to attach to each chunk

Returns: List of Chunk objects

Raises: ValueError: If text is invalid

class BaseVectorStore(abc.ABC):
181class BaseVectorStore(ABC):
182    """
183    Abstract base class for vector store implementations.
184    
185    Vector stores provide storage and retrieval of document embeddings,
186    supporting various search strategies (vector similarity, keyword, hybrid).
187    
188    Implementations must provide:
189    - Document storage with embeddings
190    - Vector similarity search
191    - Optional keyword/hybrid search
192    - Document lifecycle management (add, update, delete)
193    """
194    
195    @abstractmethod
196    def add_documents(
197        self,
198        documents: List[Document],
199        generate_embeddings: bool = True
200    ) -> List[str]:
201        """
202        Add documents to the vector store.
203        
204        Args:
205            documents: List of Document objects to add
206            generate_embeddings: If True and documents lack embeddings,
207                                generate them automatically
208        
209        Returns:
210            List of document IDs that were added
211            
212        Raises:
213            ValueError: If documents are invalid or embeddings cannot be generated
214        """
215        pass
216    
217    @abstractmethod
218    def search(
219        self,
220        query: Optional[str] = None,
221        query_embedding: Optional[List[float]] = None,
222        top_k: int = 5,
223        filters: Optional[Dict[str, Any]] = None,
224        search_type: str = "vector"
225    ) -> List[SearchResult]:
226        """
227        Search the vector store.
228        
229        Args:
230            query: Text query (required for keyword/hybrid search)
231            query_embedding: Vector embedding of the query (required for vector search)
232            top_k: Number of results to return
233            filters: Metadata filters to apply (e.g., {"source": "docs.pdf"})
234            search_type: Type of search - "vector", "keyword", or "hybrid"
235        
236        Returns:
237            List of SearchResult objects, ordered by relevance
238            
239        Raises:
240            ValueError: If required parameters are missing for the search type
241        """
242        pass
243    
244    @abstractmethod
245    def delete_documents(self, document_ids: List[str]) -> int:
246        """
247        Delete documents from the vector store.
248        
249        Args:
250            document_ids: List of document IDs to delete
251        
252        Returns:
253            Number of documents successfully deleted
254        """
255        pass
256    
257    @abstractmethod
258    def update_document(
259        self,
260        document_id: str,
261        content: Optional[str] = None,
262        embedding: Optional[List[float]] = None,
263        metadata: Optional[Dict[str, Any]] = None
264    ) -> bool:
265        """
266        Update an existing document.
267        
268        Args:
269            document_id: ID of the document to update
270            content: New content (if provided)
271            embedding: New embedding (if provided)
272            metadata: New metadata (merged with existing)
273        
274        Returns:
275            True if update was successful, False if document not found
276        """
277        pass
278    
279    @abstractmethod
280    def get_document(self, document_id: str) -> Optional[Document]:
281        """
282        Retrieve a document by ID.
283        
284        Args:
285            document_id: ID of the document to retrieve
286        
287        Returns:
288            Document if found, None otherwise
289        """
290        pass
291    
292    @abstractmethod
293    def count(self) -> int:
294        """
295        Get the total number of documents in the vector store.
296        
297        Returns:
298            Total document count
299        """
300        pass
301    
302    @abstractmethod
303    def clear(self) -> None:
304        """
305        Remove all documents from the vector store.
306        
307        Warning: This operation is irreversible.
308        """
309        pass

Abstract base class for vector store implementations.

Vector stores provide storage and retrieval of document embeddings, supporting various search strategies (vector similarity, keyword, hybrid).

Implementations must provide:

  • Document storage with embeddings
  • Vector similarity search
  • Optional keyword/hybrid search
  • Document lifecycle management (add, update, delete)
@abstractmethod
def add_documents( self, documents: List[Document], generate_embeddings: bool = True) -> List[str]:
195    @abstractmethod
196    def add_documents(
197        self,
198        documents: List[Document],
199        generate_embeddings: bool = True
200    ) -> List[str]:
201        """
202        Add documents to the vector store.
203        
204        Args:
205            documents: List of Document objects to add
206            generate_embeddings: If True and documents lack embeddings,
207                                generate them automatically
208        
209        Returns:
210            List of document IDs that were added
211            
212        Raises:
213            ValueError: If documents are invalid or embeddings cannot be generated
214        """
215        pass

Add documents to the vector store.

Args: documents: List of Document objects to add generate_embeddings: If True and documents lack embeddings, generate them automatically

Returns: List of document IDs that were added

Raises: ValueError: If documents are invalid or embeddings cannot be generated

@abstractmethod
def search( self, query: Optional[str] = None, query_embedding: Optional[List[float]] = None, top_k: int = 5, filters: Optional[Dict[str, Any]] = None, search_type: str = 'vector') -> List[SearchResult]:
217    @abstractmethod
218    def search(
219        self,
220        query: Optional[str] = None,
221        query_embedding: Optional[List[float]] = None,
222        top_k: int = 5,
223        filters: Optional[Dict[str, Any]] = None,
224        search_type: str = "vector"
225    ) -> List[SearchResult]:
226        """
227        Search the vector store.
228        
229        Args:
230            query: Text query (required for keyword/hybrid search)
231            query_embedding: Vector embedding of the query (required for vector search)
232            top_k: Number of results to return
233            filters: Metadata filters to apply (e.g., {"source": "docs.pdf"})
234            search_type: Type of search - "vector", "keyword", or "hybrid"
235        
236        Returns:
237            List of SearchResult objects, ordered by relevance
238            
239        Raises:
240            ValueError: If required parameters are missing for the search type
241        """
242        pass

Search the vector store.

Args: query: Text query (required for keyword/hybrid search) query_embedding: Vector embedding of the query (required for vector search) top_k: Number of results to return filters: Metadata filters to apply (e.g., {"source": "docs.pdf"}) search_type: Type of search - "vector", "keyword", or "hybrid"

Returns: List of SearchResult objects, ordered by relevance

Raises: ValueError: If required parameters are missing for the search type

@abstractmethod
def delete_documents(self, document_ids: List[str]) -> int:
244    @abstractmethod
245    def delete_documents(self, document_ids: List[str]) -> int:
246        """
247        Delete documents from the vector store.
248        
249        Args:
250            document_ids: List of document IDs to delete
251        
252        Returns:
253            Number of documents successfully deleted
254        """
255        pass

Delete documents from the vector store.

Args: document_ids: List of document IDs to delete

Returns: Number of documents successfully deleted

@abstractmethod
def update_document( self, document_id: str, content: Optional[str] = None, embedding: Optional[List[float]] = None, metadata: Optional[Dict[str, Any]] = None) -> bool:
257    @abstractmethod
258    def update_document(
259        self,
260        document_id: str,
261        content: Optional[str] = None,
262        embedding: Optional[List[float]] = None,
263        metadata: Optional[Dict[str, Any]] = None
264    ) -> bool:
265        """
266        Update an existing document.
267        
268        Args:
269            document_id: ID of the document to update
270            content: New content (if provided)
271            embedding: New embedding (if provided)
272            metadata: New metadata (merged with existing)
273        
274        Returns:
275            True if update was successful, False if document not found
276        """
277        pass

Update an existing document.

Args: document_id: ID of the document to update content: New content (if provided) embedding: New embedding (if provided) metadata: New metadata (merged with existing)

Returns: True if update was successful, False if document not found

@abstractmethod
def get_document( self, document_id: str) -> Optional[Document]:
279    @abstractmethod
280    def get_document(self, document_id: str) -> Optional[Document]:
281        """
282        Retrieve a document by ID.
283        
284        Args:
285            document_id: ID of the document to retrieve
286        
287        Returns:
288            Document if found, None otherwise
289        """
290        pass

Retrieve a document by ID.

Args: document_id: ID of the document to retrieve

Returns: Document if found, None otherwise

@abstractmethod
def count(self) -> int:
292    @abstractmethod
293    def count(self) -> int:
294        """
295        Get the total number of documents in the vector store.
296        
297        Returns:
298            Total document count
299        """
300        pass

Get the total number of documents in the vector store.

Returns: Total document count

@abstractmethod
def clear(self) -> None:
302    @abstractmethod
303    def clear(self) -> None:
304        """
305        Remove all documents from the vector store.
306        
307        Warning: This operation is irreversible.
308        """
309        pass

Remove all documents from the vector store.

Warning: This operation is irreversible.

@dataclass
class Document:
 19@dataclass
 20class Document:
 21    """
 22    Represents a document or chunk with its content and metadata.
 23    
 24    This class is designed to be extensible - you can subclass it to add
 25    custom fields specific to your application domain.
 26    
 27    Base Attributes:
 28        id: Unique identifier for the document
 29        content: The text content of the document
 30        embedding: Optional vector embedding (list of floats)
 31        timestamp: Optional timestamp for time-weighted retrieval
 32        metadata: Flexible dictionary for custom properties
 33    
 34    Design Philosophy:
 35        The base Document keeps only truly universal fields. Domain-specific
 36        fields (source, page_number, category, etc.) should be added via
 37        inheritance to match your application's data model.
 38    
 39    Example - Using base Document with metadata:
 40        ```python
 41        doc = Document(
 42            id="doc_1",
 43            content="AI is transforming software development...",
 44            timestamp=datetime.now(),
 45            metadata={
 46                "source": "blog.pdf",
 47                "page_number": 5,
 48                "author": "John Doe",
 49                "category": "AI"
 50            }
 51        )
 52        ```
 53    
 54    Example - Custom Document schema via inheritance:
 55        ```python
 56        from dataclasses import dataclass
 57        from gmf_forge_ai_data.vector_stores import Document
 58        
 59        @dataclass
 60        class LegalDocument(Document):
 61            case_number: str
 62            court: str
 63            decision_date: datetime
 64            jurisdiction: str
 65            source: str = ""  # Add source as typed field
 66            page_number: Optional[int] = None
 67            
 68        legal_doc = LegalDocument(
 69            id="case_123",
 70            content="In the matter of...",
 71            case_number="2024-CV-12345",
 72            court="Supreme Court",
 73            decision_date=datetime(2024, 1, 15),
 74            jurisdiction="Federal",
 75            source="court_records.pdf",
 76            page_number=42
 77        )
 78        ```
 79    
 80    Example - E-commerce Product Document:
 81        ```python
 82        @dataclass
 83        class ProductDocument(Document):
 84            sku: str
 85            category: str
 86            price: float
 87            in_stock: bool
 88            brand: str
 89            # No source/page_number - not applicable to products
 90            
 91        product = ProductDocument(
 92            id="prod_456",
 93            content="High-performance laptop with...",
 94            sku="LAP-001",
 95            category="Electronics",
 96            price=1299.99,
 97            in_stock=True,
 98            brand="TechCorp"
 99        )
100        ```
101    """
102    id: str
103    content: str
104    embedding: Optional[List[float]] = None
105    timestamp: Optional[datetime] = None
106    metadata: Dict[str, Any] = field(default_factory=dict)
107    
108    def to_dict(self) -> Dict[str, Any]:
109        """
110        Convert Document to dictionary for serialization.
111        
112        Returns:
113            Dictionary representation of the document
114        """
115        doc_dict = asdict(self)
116        # Convert datetime to ISO string for serialization
117        if self.timestamp:
118            doc_dict['timestamp'] = self.timestamp.isoformat()
119        return doc_dict
120    
121    @classmethod
122    def from_dict(cls: Type[T], data: Dict[str, Any]) -> T:
123        """
124        Create Document from dictionary.
125        
126        Args:
127            data: Dictionary containing document fields
128        
129        Returns:
130            Document instance (or subclass instance)
131        """
132        # Handle datetime deserialization
133        if 'timestamp' in data and isinstance(data['timestamp'], str):
134            data['timestamp'] = datetime.fromisoformat(data['timestamp'])
135        
136        # Filter to only fields that exist in the dataclass
137        import inspect
138        valid_fields = set(inspect.signature(cls).parameters.keys())
139        filtered_data = {k: v for k, v in data.items() if k in valid_fields}
140        
141        return cls(**filtered_data)
142    
143    def update_metadata(self, **kwargs) -> None:
144        """
145        Update metadata fields.
146        
147        Args:
148            **kwargs: Key-value pairs to add/update in metadata
149        """
150        self.metadata.update(kwargs)
151    
152    def get_metadata(self, key: str, default: Any = None) -> Any:
153        """
154        Get a metadata value.
155        
156        Args:
157            key: Metadata key to retrieve
158            default: Default value if key not found
159        
160        Returns:
161            Metadata value or default
162        """
163        return self.metadata.get(key, default)

Represents a document or chunk with its content and metadata.

This class is designed to be extensible - you can subclass it to add custom fields specific to your application domain.

Base Attributes: id: Unique identifier for the document content: The text content of the document embedding: Optional vector embedding (list of floats) timestamp: Optional timestamp for time-weighted retrieval metadata: Flexible dictionary for custom properties

Design Philosophy: The base Document keeps only truly universal fields. Domain-specific fields (source, page_number, category, etc.) should be added via inheritance to match your application's data model.

Example - Using base Document with metadata:

doc = Document(
    id="doc_1",
    content="AI is transforming software development...",
    timestamp=datetime.now(),
    metadata={
        "source": "blog.pdf",
        "page_number": 5,
        "author": "John Doe",
        "category": "AI"
    }
)

Example - Custom Document schema via inheritance:

from dataclasses import dataclass
from gmf_forge_ai_data.vector_stores import Document

@dataclass
class LegalDocument(Document):
    case_number: str
    court: str
    decision_date: datetime
    jurisdiction: str
    source: str = ""  # Add source as typed field
    page_number: Optional[int] = None

legal_doc = LegalDocument(
    id="case_123",
    content="In the matter of...",
    case_number="2024-CV-12345",
    court="Supreme Court",
    decision_date=datetime(2024, 1, 15),
    jurisdiction="Federal",
    source="court_records.pdf",
    page_number=42
)

Example - E-commerce Product Document:

@dataclass
class ProductDocument(Document):
    sku: str
    category: str
    price: float
    in_stock: bool
    brand: str
    # No source/page_number - not applicable to products

product = ProductDocument(
    id="prod_456",
    content="High-performance laptop with...",
    sku="LAP-001",
    category="Electronics",
    price=1299.99,
    in_stock=True,
    brand="TechCorp"
)
Document( id: str, content: str, embedding: Optional[List[float]] = None, timestamp: Optional[datetime.datetime] = None, metadata: Dict[str, Any] = <factory>)
id: str
content: str
embedding: Optional[List[float]] = None
timestamp: Optional[datetime.datetime] = None
metadata: Dict[str, Any]
def to_dict(self) -> Dict[str, Any]:
108    def to_dict(self) -> Dict[str, Any]:
109        """
110        Convert Document to dictionary for serialization.
111        
112        Returns:
113            Dictionary representation of the document
114        """
115        doc_dict = asdict(self)
116        # Convert datetime to ISO string for serialization
117        if self.timestamp:
118            doc_dict['timestamp'] = self.timestamp.isoformat()
119        return doc_dict

Convert Document to dictionary for serialization.

Returns: Dictionary representation of the document

@classmethod
def from_dict(cls: Type[~T], data: Dict[str, Any]) -> ~T:
121    @classmethod
122    def from_dict(cls: Type[T], data: Dict[str, Any]) -> T:
123        """
124        Create Document from dictionary.
125        
126        Args:
127            data: Dictionary containing document fields
128        
129        Returns:
130            Document instance (or subclass instance)
131        """
132        # Handle datetime deserialization
133        if 'timestamp' in data and isinstance(data['timestamp'], str):
134            data['timestamp'] = datetime.fromisoformat(data['timestamp'])
135        
136        # Filter to only fields that exist in the dataclass
137        import inspect
138        valid_fields = set(inspect.signature(cls).parameters.keys())
139        filtered_data = {k: v for k, v in data.items() if k in valid_fields}
140        
141        return cls(**filtered_data)

Create Document from dictionary.

Args: data: Dictionary containing document fields

Returns: Document instance (or subclass instance)

def update_metadata(self, **kwargs) -> None:
143    def update_metadata(self, **kwargs) -> None:
144        """
145        Update metadata fields.
146        
147        Args:
148            **kwargs: Key-value pairs to add/update in metadata
149        """
150        self.metadata.update(kwargs)

Update metadata fields.

Args: **kwargs: Key-value pairs to add/update in metadata

def get_metadata(self, key: str, default: Any = None) -> Any:
152    def get_metadata(self, key: str, default: Any = None) -> Any:
153        """
154        Get a metadata value.
155        
156        Args:
157            key: Metadata key to retrieve
158            default: Default value if key not found
159        
160        Returns:
161            Metadata value or default
162        """
163        return self.metadata.get(key, default)

Get a metadata value.

Args: key: Metadata key to retrieve default: Default value if key not found

Returns: Metadata value or default

@dataclass
class SearchResult:
166@dataclass
167class SearchResult:
168    """
169    Represents a search result from a vector store.
170    
171    Attributes:
172        document: The retrieved document
173        score: Similarity score (0.0 to 1.0, higher is better)
174        rank: Position in the result list (0-indexed)
175    """
176    document: Document
177    score: float
178    rank: int

Represents a search result from a vector store.

Attributes: document: The retrieved document score: Similarity score (0.0 to 1.0, higher is better) rank: Position in the result list (0-indexed)

SearchResult( document: Document, score: float, rank: int)
document: Document
score: float
rank: int
class InMemoryVectorStore(gmf_forge_ai_data.BaseVectorStore):
 17class InMemoryVectorStore(BaseVectorStore):
 18    """
 19    In-memory vector store using numpy for vector operations.
 20    
 21    Features:
 22    - Fast cosine similarity search using numpy
 23    - Keyword search using simple text matching
 24    - Hybrid search combining vector and keyword scores
 25    - Metadata filtering
 26    - No external dependencies (Azure, etc.)
 27    
 28    Ideal for:
 29    - Unit testing
 30    - Development and prototyping
 31    - Small datasets (< 10,000 documents)
 32    - CI/CD pipelines
 33    
 34    Note: All data is lost when the process terminates.
 35    """
 36    
 37    def __init__(self, embedding_dimension: int = 1536):
 38        """
 39        Initialize the in-memory vector store.
 40        
 41        Args:
 42            embedding_dimension: Expected dimension of embeddings (default 1536 for text-embedding-ada-002)
 43        """
 44        self.embedding_dimension = embedding_dimension
 45        self._documents: Dict[str, Document] = {}
 46        self._embeddings: Optional[np.ndarray] = None
 47        self._doc_ids: List[str] = []
 48    
 49    def add_documents(
 50        self,
 51        documents: List[Document],
 52        generate_embeddings: bool = True
 53    ) -> List[str]:
 54        """
 55        Add documents to the in-memory store.
 56        
 57        Args:
 58            documents: List of Document objects to add
 59            generate_embeddings: If True, validates embeddings are present
 60                                (actual generation should be done externally)
 61        
 62        Returns:
 63            List of document IDs that were added
 64        """
 65        if not documents:
 66            return []
 67        
 68        added_ids = []
 69        
 70        for doc in documents:
 71            # Validate document
 72            if not doc.id or not doc.content:
 73                raise ValueError(f"Document must have id and content: {doc}")
 74            
 75            if generate_embeddings and doc.embedding is None:
 76                raise ValueError(
 77                    f"Document {doc.id} lacks embedding. "
 78                    "Generate embeddings externally before adding to store."
 79                )
 80            
 81            if doc.embedding is not None and len(doc.embedding) != self.embedding_dimension:
 82                raise ValueError(
 83                    f"Document {doc.id} has embedding dimension {len(doc.embedding)}, "
 84                    f"expected {self.embedding_dimension}"
 85                )
 86            
 87            # Store document
 88            self._documents[doc.id] = deepcopy(doc)
 89            added_ids.append(doc.id)
 90        
 91        # Rebuild embedding matrix
 92        self._rebuild_embeddings()
 93        
 94        return added_ids
 95    
 96    def search(
 97        self,
 98        query: Optional[str] = None,
 99        query_embedding: Optional[List[float]] = None,
100        top_k: int = 5,
101        filters: Optional[Dict[str, Any]] = None,
102        search_type: str = "vector"
103    ) -> List[SearchResult]:
104        """
105        Search the in-memory vector store.
106        
107        Args:
108            query: Text query (required for keyword/hybrid search)
109            query_embedding: Vector embedding of the query (required for vector/hybrid search)
110            top_k: Number of results to return
111            filters: Metadata filters (e.g., {"source": "doc.pdf"})
112            search_type: "vector", "keyword", or "hybrid"
113        
114        Returns:
115            List of SearchResult objects, ordered by relevance
116        """
117        if search_type not in ["vector", "keyword", "hybrid"]:
118            raise ValueError(f"Invalid search_type: {search_type}. Must be 'vector', 'keyword', or 'hybrid'")
119        
120        if search_type in ["vector", "hybrid"] and query_embedding is None:
121            raise ValueError(f"{search_type} search requires query_embedding")
122        
123        if search_type in ["keyword", "hybrid"] and query is None:
124            raise ValueError(f"{search_type} search requires query text")
125        
126        if not self._documents:
127            return []
128        
129        # Apply metadata filters
130        filtered_docs = self._apply_filters(filters)
131        if not filtered_docs:
132            return []
133        
134        # Calculate scores based on search type
135        if search_type == "vector":
136            scores = self._vector_search(query_embedding, filtered_docs)
137        elif search_type == "keyword":
138            scores = self._keyword_search(query, filtered_docs)
139        else:  # hybrid
140            vector_scores = self._vector_search(query_embedding, filtered_docs)
141            keyword_scores = self._keyword_search(query, filtered_docs)
142            # Combine scores (50/50 weighting)
143            scores = {
144                doc_id: 0.5 * vector_scores.get(doc_id, 0.0) + 0.5 * keyword_scores.get(doc_id, 0.0)
145                for doc_id in filtered_docs
146            }
147        
148        # Sort by score and take top_k
149        sorted_results = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:top_k]
150        
151        # Build SearchResult objects
152        results = []
153        for rank, (doc_id, score) in enumerate(sorted_results):
154            results.append(SearchResult(
155                document=deepcopy(self._documents[doc_id]),
156                score=float(score),
157                rank=rank
158            ))
159        
160        return results
161    
162    def delete_documents(self, document_ids: List[str]) -> int:
163        """
164        Delete documents from the store.
165        
166        Args:
167            document_ids: List of document IDs to delete
168        
169        Returns:
170            Number of documents successfully deleted
171        """
172        deleted_count = 0
173        for doc_id in document_ids:
174            if doc_id in self._documents:
175                del self._documents[doc_id]
176                deleted_count += 1
177        
178        # Rebuild embedding matrix
179        self._rebuild_embeddings()
180        
181        return deleted_count
182    
183    def update_document(
184        self,
185        document_id: str,
186        content: Optional[str] = None,
187        embedding: Optional[List[float]] = None,
188        metadata: Optional[Dict[str, Any]] = None
189    ) -> bool:
190        """
191        Update an existing document.
192        
193        Args:
194            document_id: ID of the document to update
195            content: New content (if provided)
196            embedding: New embedding (if provided)
197            metadata: New metadata (merged with existing)
198        
199        Returns:
200            True if update was successful, False if document not found
201        """
202        if document_id not in self._documents:
203            return False
204        
205        doc = self._documents[document_id]
206        
207        if content is not None:
208            doc.content = content
209        
210        if embedding is not None:
211            if len(embedding) != self.embedding_dimension:
212                raise ValueError(
213                    f"Embedding dimension {len(embedding)} != expected {self.embedding_dimension}"
214                )
215            doc.embedding = embedding
216            self._rebuild_embeddings()
217        
218        if metadata is not None:
219            doc.metadata.update(metadata)
220        
221        return True
222    
223    def get_document(self, document_id: str) -> Optional[Document]:
224        """
225        Retrieve a document by ID.
226        
227        Args:
228            document_id: ID of the document to retrieve
229        
230        Returns:
231            Document if found, None otherwise
232        """
233        doc = self._documents.get(document_id)
234        return deepcopy(doc) if doc else None
235    
236    def count(self) -> int:
237        """Get the total number of documents in the store."""
238        return len(self._documents)
239    
240    def clear(self) -> None:
241        """Remove all documents from the store."""
242        self._documents.clear()
243        self._embeddings = None
244        self._doc_ids = []
245    
246    # Private helper methods
247    
248    def _rebuild_embeddings(self) -> None:
249        """Rebuild the embedding matrix from stored documents."""
250        docs_with_embeddings = [
251            (doc_id, doc) for doc_id, doc in self._documents.items()
252            if doc.embedding is not None
253        ]
254        
255        if not docs_with_embeddings:
256            self._embeddings = None
257            self._doc_ids = []
258            return
259        
260        self._doc_ids = [doc_id for doc_id, _ in docs_with_embeddings]
261        embeddings_list = [doc.embedding for _, doc in docs_with_embeddings]
262        self._embeddings = np.array(embeddings_list, dtype=np.float32)
263    
264    def _apply_filters(self, filters: Optional[Dict[str, Any]]) -> List[str]:
265        """
266        Apply metadata filters and return list of matching document IDs.
267        
268        Supports filtering on: source, page_number, timestamp, and metadata fields
269        Special operators: use tuples for range queries
270        - (">=", value) for greater than or equal
271        - ("<=", value) for less than or equal  
272        - ("range", min_val, max_val) for range queries
273        """
274        if filters is None:
275            return list(self._documents.keys())
276        
277        matching_ids = []
278        for doc_id, doc in self._documents.items():
279            matches = True
280            
281            for key, value in filters.items():
282                # Check document attributes (source, page_number, timestamp)
283                doc_value = getattr(doc, key, None)
284                
285                # If not a document attribute, check metadata
286                if doc_value is None and key in doc.metadata:
287                    doc_value = doc.metadata[key]
288                
289                # Skip if field doesn't exist
290                if doc_value is None:
291                    matches = False
292                    break
293                
294                # Handle tuple operators for range queries
295                if isinstance(value, tuple):
296                    operator = value[0]
297                    if operator == "range" and len(value) == 3:
298                        min_val, max_val = value[1], value[2]
299                        if not (min_val <= doc_value <= max_val):
300                            matches = False
301                            break
302                    elif operator in [">=", "gt"]:
303                        if not (doc_value >= value[1]):
304                            matches = False
305                            break
306                    elif operator in ["<=", "lt"]:
307                        if not (doc_value <= value[1]):
308                            matches = False
309                            break
310                # Simple equality check
311                elif doc_value != value:
312                    matches = False
313                    break
314            
315            if matches:
316                matching_ids.append(doc_id)
317        
318        return matching_ids
319    
320    def _vector_search(
321        self,
322        query_embedding: List[float],
323        candidate_ids: List[str]
324    ) -> Dict[str, float]:
325        """
326        Perform vector similarity search using cosine similarity.
327        
328        Returns:
329            Dict mapping document IDs to similarity scores (0.0 to 1.0)
330        """
331        if self._embeddings is None or len(self._doc_ids) == 0:
332            return {}
333        
334        query_vector = np.array(query_embedding, dtype=np.float32)
335        
336        # Filter embeddings to only candidate documents
337        candidate_indices = [
338            i for i, doc_id in enumerate(self._doc_ids)
339            if doc_id in candidate_ids
340        ]
341        
342        if not candidate_indices:
343            return {}
344        
345        candidate_embeddings = self._embeddings[candidate_indices]
346        candidate_doc_ids = [self._doc_ids[i] for i in candidate_indices]
347        
348        # Compute cosine similarity
349        query_norm = np.linalg.norm(query_vector)
350        if query_norm == 0:
351            return {doc_id: 0.0 for doc_id in candidate_doc_ids}
352        
353        doc_norms = np.linalg.norm(candidate_embeddings, axis=1)
354        dot_products = candidate_embeddings.dot(query_vector)
355        
356        # Avoid division by zero
357        similarities = np.zeros(len(candidate_doc_ids))
358        valid_mask = doc_norms > 0
359        similarities[valid_mask] = dot_products[valid_mask] / (doc_norms[valid_mask] * query_norm)
360        
361        # Convert from [-1, 1] to [0, 1]
362        similarities = (similarities + 1) / 2
363        
364        return {doc_id: float(sim) for doc_id, sim in zip(candidate_doc_ids, similarities)}
365    
366    def _keyword_search(
367        self,
368        query: str,
369        candidate_ids: List[str]
370    ) -> Dict[str, float]:
371        """
372        Perform simple keyword search using term matching.
373        
374        Returns:
375            Dict mapping document IDs to relevance scores (0.0 to 1.0)
376        """
377        query_lower = query.lower()
378        query_terms = set(query_lower.split())
379        
380        scores = {}
381        for doc_id in candidate_ids:
382            doc = self._documents[doc_id]
383            content_lower = doc.content.lower()
384            content_terms = set(content_lower.split())
385            
386            # Calculate Jaccard similarity as relevance score
387            if not query_terms:
388                scores[doc_id] = 0.0
389            else:
390                intersection = len(query_terms & content_terms)
391                union = len(query_terms | content_terms)
392                scores[doc_id] = intersection / union if union > 0 else 0.0
393        
394        return scores

In-memory vector store using numpy for vector operations.

Features:

  • Fast cosine similarity search using numpy
  • Keyword search using simple text matching
  • Hybrid search combining vector and keyword scores
  • Metadata filtering
  • No external dependencies (Azure, etc.)

Ideal for:

  • Unit testing
  • Development and prototyping
  • Small datasets (< 10,000 documents)
  • CI/CD pipelines

Note: All data is lost when the process terminates.

InMemoryVectorStore(embedding_dimension: int = 1536)
37    def __init__(self, embedding_dimension: int = 1536):
38        """
39        Initialize the in-memory vector store.
40        
41        Args:
42            embedding_dimension: Expected dimension of embeddings (default 1536 for text-embedding-ada-002)
43        """
44        self.embedding_dimension = embedding_dimension
45        self._documents: Dict[str, Document] = {}
46        self._embeddings: Optional[np.ndarray] = None
47        self._doc_ids: List[str] = []

Initialize the in-memory vector store.

Args: embedding_dimension: Expected dimension of embeddings (default 1536 for text-embedding-ada-002)

embedding_dimension
def add_documents( self, documents: List[Document], generate_embeddings: bool = True) -> List[str]:
49    def add_documents(
50        self,
51        documents: List[Document],
52        generate_embeddings: bool = True
53    ) -> List[str]:
54        """
55        Add documents to the in-memory store.
56        
57        Args:
58            documents: List of Document objects to add
59            generate_embeddings: If True, validates embeddings are present
60                                (actual generation should be done externally)
61        
62        Returns:
63            List of document IDs that were added
64        """
65        if not documents:
66            return []
67        
68        added_ids = []
69        
70        for doc in documents:
71            # Validate document
72            if not doc.id or not doc.content:
73                raise ValueError(f"Document must have id and content: {doc}")
74            
75            if generate_embeddings and doc.embedding is None:
76                raise ValueError(
77                    f"Document {doc.id} lacks embedding. "
78                    "Generate embeddings externally before adding to store."
79                )
80            
81            if doc.embedding is not None and len(doc.embedding) != self.embedding_dimension:
82                raise ValueError(
83                    f"Document {doc.id} has embedding dimension {len(doc.embedding)}, "
84                    f"expected {self.embedding_dimension}"
85                )
86            
87            # Store document
88            self._documents[doc.id] = deepcopy(doc)
89            added_ids.append(doc.id)
90        
91        # Rebuild embedding matrix
92        self._rebuild_embeddings()
93        
94        return added_ids

Add documents to the in-memory store.

Args: documents: List of Document objects to add generate_embeddings: If True, validates embeddings are present (actual generation should be done externally)

Returns: List of document IDs that were added

def search( self, query: Optional[str] = None, query_embedding: Optional[List[float]] = None, top_k: int = 5, filters: Optional[Dict[str, Any]] = None, search_type: str = 'vector') -> List[SearchResult]:
 96    def search(
 97        self,
 98        query: Optional[str] = None,
 99        query_embedding: Optional[List[float]] = None,
100        top_k: int = 5,
101        filters: Optional[Dict[str, Any]] = None,
102        search_type: str = "vector"
103    ) -> List[SearchResult]:
104        """
105        Search the in-memory vector store.
106        
107        Args:
108            query: Text query (required for keyword/hybrid search)
109            query_embedding: Vector embedding of the query (required for vector/hybrid search)
110            top_k: Number of results to return
111            filters: Metadata filters (e.g., {"source": "doc.pdf"})
112            search_type: "vector", "keyword", or "hybrid"
113        
114        Returns:
115            List of SearchResult objects, ordered by relevance
116        """
117        if search_type not in ["vector", "keyword", "hybrid"]:
118            raise ValueError(f"Invalid search_type: {search_type}. Must be 'vector', 'keyword', or 'hybrid'")
119        
120        if search_type in ["vector", "hybrid"] and query_embedding is None:
121            raise ValueError(f"{search_type} search requires query_embedding")
122        
123        if search_type in ["keyword", "hybrid"] and query is None:
124            raise ValueError(f"{search_type} search requires query text")
125        
126        if not self._documents:
127            return []
128        
129        # Apply metadata filters
130        filtered_docs = self._apply_filters(filters)
131        if not filtered_docs:
132            return []
133        
134        # Calculate scores based on search type
135        if search_type == "vector":
136            scores = self._vector_search(query_embedding, filtered_docs)
137        elif search_type == "keyword":
138            scores = self._keyword_search(query, filtered_docs)
139        else:  # hybrid
140            vector_scores = self._vector_search(query_embedding, filtered_docs)
141            keyword_scores = self._keyword_search(query, filtered_docs)
142            # Combine scores (50/50 weighting)
143            scores = {
144                doc_id: 0.5 * vector_scores.get(doc_id, 0.0) + 0.5 * keyword_scores.get(doc_id, 0.0)
145                for doc_id in filtered_docs
146            }
147        
148        # Sort by score and take top_k
149        sorted_results = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:top_k]
150        
151        # Build SearchResult objects
152        results = []
153        for rank, (doc_id, score) in enumerate(sorted_results):
154            results.append(SearchResult(
155                document=deepcopy(self._documents[doc_id]),
156                score=float(score),
157                rank=rank
158            ))
159        
160        return results

Search the in-memory vector store.

Args: query: Text query (required for keyword/hybrid search) query_embedding: Vector embedding of the query (required for vector/hybrid search) top_k: Number of results to return filters: Metadata filters (e.g., {"source": "doc.pdf"}) search_type: "vector", "keyword", or "hybrid"

Returns: List of SearchResult objects, ordered by relevance

def delete_documents(self, document_ids: List[str]) -> int:
162    def delete_documents(self, document_ids: List[str]) -> int:
163        """
164        Delete documents from the store.
165        
166        Args:
167            document_ids: List of document IDs to delete
168        
169        Returns:
170            Number of documents successfully deleted
171        """
172        deleted_count = 0
173        for doc_id in document_ids:
174            if doc_id in self._documents:
175                del self._documents[doc_id]
176                deleted_count += 1
177        
178        # Rebuild embedding matrix
179        self._rebuild_embeddings()
180        
181        return deleted_count

Delete documents from the store.

Args: document_ids: List of document IDs to delete

Returns: Number of documents successfully deleted

def update_document( self, document_id: str, content: Optional[str] = None, embedding: Optional[List[float]] = None, metadata: Optional[Dict[str, Any]] = None) -> bool:
183    def update_document(
184        self,
185        document_id: str,
186        content: Optional[str] = None,
187        embedding: Optional[List[float]] = None,
188        metadata: Optional[Dict[str, Any]] = None
189    ) -> bool:
190        """
191        Update an existing document.
192        
193        Args:
194            document_id: ID of the document to update
195            content: New content (if provided)
196            embedding: New embedding (if provided)
197            metadata: New metadata (merged with existing)
198        
199        Returns:
200            True if update was successful, False if document not found
201        """
202        if document_id not in self._documents:
203            return False
204        
205        doc = self._documents[document_id]
206        
207        if content is not None:
208            doc.content = content
209        
210        if embedding is not None:
211            if len(embedding) != self.embedding_dimension:
212                raise ValueError(
213                    f"Embedding dimension {len(embedding)} != expected {self.embedding_dimension}"
214                )
215            doc.embedding = embedding
216            self._rebuild_embeddings()
217        
218        if metadata is not None:
219            doc.metadata.update(metadata)
220        
221        return True

Update an existing document.

Args: document_id: ID of the document to update content: New content (if provided) embedding: New embedding (if provided) metadata: New metadata (merged with existing)

Returns: True if update was successful, False if document not found

def get_document( self, document_id: str) -> Optional[Document]:
223    def get_document(self, document_id: str) -> Optional[Document]:
224        """
225        Retrieve a document by ID.
226        
227        Args:
228            document_id: ID of the document to retrieve
229        
230        Returns:
231            Document if found, None otherwise
232        """
233        doc = self._documents.get(document_id)
234        return deepcopy(doc) if doc else None

Retrieve a document by ID.

Args: document_id: ID of the document to retrieve

Returns: Document if found, None otherwise

def count(self) -> int:
236    def count(self) -> int:
237        """Get the total number of documents in the store."""
238        return len(self._documents)

Get the total number of documents in the store.

def clear(self) -> None:
240    def clear(self) -> None:
241        """Remove all documents from the store."""
242        self._documents.clear()
243        self._embeddings = None
244        self._doc_ids = []

Remove all documents from the store.

class AzureAISearchVectorStore(gmf_forge_ai_data.BaseVectorStore):
 43class AzureAISearchVectorStore(BaseVectorStore):
 44    """
 45    Azure AI Search vector store for production RAG pipelines.
 46    
 47    Features:
 48    - Scalable vector search with HNSW (Hierarchical Navigable Small World)
 49    - Hybrid search (vector + keyword BM25)
 50    - Semantic ranking powered by Microsoft Bing
 51    - Automatic indexing of all document fields (universal + custom)
 52    - Efficient filtering on all indexed fields
 53    - High availability and disaster recovery
 54    - Managed service (no infrastructure management)
 55    
 56    Prerequisites:
 57    - Azure AI Search service (Basic tier or higher for vector search)
 58    - Service endpoint and API key **or** a managed-identity token provider
 59    - Index created with vector field configuration
 60    
 61    Usage:
 62    
 63    **Simple Mode** (base Document):
 64        ```python
 65        # API key auth
 66        store = AzureAISearchVectorStore(
 67            endpoint="https://my-search.search.windows.net",
 68            index_name="documents",
 69            api_key="...",
 70            embedding_dimension=1536
 71        )
 72
 73        # Managed identity / token provider auth
 74        from azure.identity import DefaultAzureCredential, get_bearer_token_provider
 75        token_provider = get_bearer_token_provider(
 76            DefaultAzureCredential(),
 77            "https://search.azure.com/.default"  # Azure AI Search scope
 78        )
 79        store = AzureAISearchVectorStore(
 80            endpoint="https://my-search.search.windows.net",
 81            index_name="documents",
 82            token_provider=token_provider,
 83            embedding_dimension=1536
 84        )
 85        # Uses base Document with universal fields only
 86        ```
 87    
 88    **Custom Schema Mode** (domain-specific Document):
 89        ```python
 90        from dataclasses import dataclass
 91        from gmf_forge_ai_data.vector_stores import Document
 92        
 93        @dataclass
 94        class LegalDocument(Document):
 95            case_number: str = ""
 96            court: str = ""
 97            jurisdiction: str = ""
 98            decision_date: Optional[datetime] = None
 99        
100        # Pass document_type - all fields automatically indexed.
101        # If the index was provisioned externally (portal, Bicep, etc.) the
102        # vector and content field names will likely differ from the defaults
103        # ("embedding" / "content"). Always set vector_field_name and
104        # content_field_name to match your actual index schema.
105        store = AzureAISearchVectorStore(
106            endpoint="https://my-search.search.windows.net",
107            index_name="legal_documents",
108            api_key="...",
109            embedding_dimension=1536,
110            document_type=LegalDocument,
111            vector_field_name="contentVector",   # match your index schema
112            content_field_name="chunkContent",   # match your index schema
113        )
114        
115        # All fields indexed - filter on any field
116        results = store.search(
117            query_embedding=vector,
118            filters={
119                "jurisdiction": "Federal",
120                "court": "Supreme Court",
121                "decision_date": (">=", datetime(2024, 1, 1))
122            }
123        )
124        ```
125    """
126    
127    def __init__(
128        self,
129        endpoint: str,
130        index_name: str,
131        api_key: Optional[str] = None,
132        token_provider: Optional[Callable[[], str]] = None,
133        embedding_dimension: int = 1536,
134        document_type: type = Document,
135        vector_field_name: str = "embedding",
136        content_field_name: str = "content",
137    ):
138        """
139        Initialize Azure AI Search vector store.
140
141        The index must be pre-provisioned using ``AzureAISearchIndexBuilder``
142        before first use.  The constructor only establishes client connections
143        — it never creates or modifies indexes.
144
145        Exactly one of ``api_key`` or ``token_provider`` must be supplied.
146
147        Args:
148            endpoint: Azure AI Search service endpoint
149            index_name: Name of the search index
150            api_key: Azure AI Search API key. Use for local development or
151                when managed identity is not available.
152            token_provider: Zero-argument callable that returns a bearer token
153                string. Use for managed identity / workload identity scenarios.
154                The callable must request the **Azure AI Search** scope::
155
156                    from azure.identity import DefaultAzureCredential, get_bearer_token_provider
157                    token_provider = get_bearer_token_provider(
158                        DefaultAzureCredential(),
159                        "https://search.azure.com/.default"
160                    )
161
162                Note: this scope is different from Azure OpenAI / Cognitive Services
163                (``https://cognitiveservices.azure.com/.default``) — each service
164                requires its own token_provider.
165            embedding_dimension: Dimension of vector embeddings (default 1536)
166            document_type: Document class (base or custom subclass). All fields will be indexed.
167            vector_field_name: Name of the vector field in the Azure Search index (default "embedding").
168                Override this when the index was built externally with a different field name
169                (e.g. "contentVector", "chunkVector").
170            content_field_name: Name of the text content field in the Azure Search index (default "content").
171                Override when the index uses a different name (e.g. "chunkContent", "text").
172        """
173        if not api_key and not token_provider:
174            raise ValueError(
175                "Either api_key or token_provider must be supplied to AzureAISearchVectorStore."
176            )
177        self.endpoint = endpoint
178        self.index_name = index_name
179        self.embedding_dimension = embedding_dimension
180        self.document_type = document_type
181        self.vector_field_name = vector_field_name
182        self.content_field_name = content_field_name
183
184        if token_provider:
185            credential = _TokenProviderCredential(token_provider)
186        else:
187            credential = AzureKeyCredential(api_key)
188
189        # Initialize clients
190        self.search_client = SearchClient(
191            endpoint=endpoint,
192            index_name=index_name,
193            credential=credential
194        )
195    
196    @staticmethod
197    def _format_datetime_for_azure(dt: datetime) -> str:
198        """
199        Convert datetime to Azure Search DateTimeOffset format (ISO 8601 with timezone).
200        
201        Args:
202            dt: datetime object (timezone-aware or naive)
203            
204        Returns:
205            ISO 8601 string with timezone (e.g., '2024-01-15T10:30:00Z')
206        """
207        if dt.tzinfo is None:
208            # Timezone-naive datetime: assume UTC and append 'Z'
209            return dt.isoformat() + 'Z'
210        else:
211            # Timezone-aware datetime: use standard ISO format
212            return dt.isoformat()
213    
214    def add_documents(
215        self,
216        documents: List[Document],
217        generate_embeddings: bool = True
218    ) -> List[str]:
219        """
220        Add documents to Azure AI Search.
221        
222        Args:
223            documents: List of Document objects to add (can be base or custom subclasses)
224            generate_embeddings: If True, validates embeddings are present
225        
226        Returns:
227            List of document IDs that were added
228            
229        Note:
230            Custom document fields are automatically serialized and stored.
231            All fields remain accessible after retrieval.
232        """
233        if not documents:
234            return []
235        
236        # Helper function to format datetime for Azure Search DateTimeOffset
237        format_datetime_for_azure = self._format_datetime_for_azure
238        
239        # Helper function to serialize datetime objects in dictionaries
240        def serialize_for_json(obj):
241            """Recursively convert datetime objects to ISO format strings for JSON serialization."""
242            if isinstance(obj, datetime):
243                return format_datetime_for_azure(obj)
244            elif isinstance(obj, dict):
245                return {k: serialize_for_json(v) for k, v in obj.items()}
246            elif isinstance(obj, (list, tuple)):
247                return [serialize_for_json(item) for item in obj]
248            else:
249                return obj
250        
251        # Prepare documents for upload
252        search_documents = []
253        added_ids = []
254        
255        for doc in documents:
256            if not doc.id or not doc.content:
257                raise ValueError(f"Document must have id and content: {doc}")
258            
259            if generate_embeddings and doc.embedding is None:
260                raise ValueError(
261                    f"Document {doc.id} lacks embedding. "
262                    "Generate embeddings externally before adding."
263                )
264            
265            # Serialize entire document (including custom fields) to JSON
266            import json
267            import dataclasses
268            doc_dict = doc.to_dict()
269            
270            # Convert datetime objects for JSON serialization
271            doc_dict_serializable = serialize_for_json(doc_dict)
272            
273            # Convert to Azure Search document format
274            # Store base fields directly
275            search_doc = {
276                "id": doc.id,
277                "content": doc.content,
278                "embedding": doc.embedding if doc.embedding else [],
279                "timestamp": format_datetime_for_azure(doc.timestamp) if doc.timestamp else None,
280                "document_data": json.dumps(doc_dict_serializable)  # Keep for metadata and backwards compatibility
281            }
282            
283            # Add all custom fields as indexed fields
284            if dataclasses.is_dataclass(doc):
285                for field in dataclasses.fields(doc):
286                    # Skip base Document fields (already added above)
287                    if field.name in {'id', 'content', 'embedding', 'timestamp', 'metadata'}:
288                        continue
289                    
290                    field_value = getattr(doc, field.name, None)
291                    
292                    if field_value is not None:
293                        # Serialize datetime fields with Azure DateTimeOffset format
294                        if isinstance(field_value, datetime):
295                            search_doc[field.name] = format_datetime_for_azure(field_value)
296                        # Handle lists/dicts by converting to JSON string
297                        elif isinstance(field_value, (list, dict)):
298                            search_doc[field.name] = json.dumps(field_value)
299                        else:
300                            search_doc[field.name] = field_value
301            
302            search_documents.append(search_doc)
303            added_ids.append(doc.id)
304        
305        # Upload to Azure Search
306        result = self.search_client.upload_documents(documents=search_documents)
307        
308        # Check for failures
309        failed_count = sum(1 for r in result if not r.succeeded)
310        if failed_count > 0:
311            logger.warning(f"{failed_count} documents failed to upload")
312        
313        return added_ids
314    
315    def search(
316        self,
317        query: Optional[str] = None,
318        query_embedding: Optional[List[float]] = None,
319        top_k: int = 5,
320        filters: Optional[Dict[str, Any]] = None,
321        search_type: str = "vector"
322    ) -> List[SearchResult]:
323        """
324        Search Azure AI Search index.
325        
326        Args:
327            query: Text query (required for keyword/hybrid search)
328            query_embedding: Vector embedding (required for vector/hybrid search)
329            top_k: Number of results to return
330            filters: Metadata filters (e.g., {"source": "doc.pdf"})
331            search_type: "vector", "keyword", or "hybrid"
332        
333        Returns:
334            List of SearchResult objects, ordered by relevance
335        """
336        if search_type not in ["vector", "keyword", "hybrid"]:
337            raise ValueError(f"Invalid search_type: {search_type}")
338        
339        if search_type in ["vector", "hybrid"] and query_embedding is None:
340            raise ValueError(f"{search_type} search requires query_embedding")
341        
342        if search_type in ["keyword", "hybrid"] and query is None:
343            raise ValueError(f"{search_type} search requires query text")
344        
345        # Build filter expression
346        filter_expr = self._build_filter_expression(filters)
347        
348        # Execute search based on type
349        if search_type == "vector":
350            # Pure vector search
351            vector_query = VectorizedQuery(
352                vector=query_embedding,
353                k_nearest_neighbors=top_k,
354                fields=self.vector_field_name
355            )
356            
357            results = self.search_client.search(
358                search_text=None,
359                vector_queries=[vector_query],
360                filter=filter_expr,
361                top=top_k
362            )
363        
364        elif search_type == "keyword":
365            # Pure keyword search (BM25)
366            results = self.search_client.search(
367                search_text=query,
368                filter=filter_expr,
369                top=top_k,
370                include_total_count=False
371            )
372        
373        else:  # hybrid
374            # Hybrid search (vector + keyword)
375            vector_query = VectorizedQuery(
376                vector=query_embedding,
377                k_nearest_neighbors=top_k,
378                fields=self.vector_field_name
379            )
380            
381            results = self.search_client.search(
382                search_text=query,
383                vector_queries=[vector_query],
384                filter=filter_expr,
385                top=top_k
386            )
387        
388        # Convert to SearchResult objects
389        search_results = []
390        for rank, result in enumerate(results):
391            import json
392            
393            # Deserialize full document data
394            doc_data = json.loads(result.get("document_data", "{}") or "{}")
395            
396            if not doc_data:
397                # Index was built externally — fields are stored directly on the result.
398                # Collect all non-Azure-internal fields, excluding the vector blob.
399                doc_data = {
400                    k: v for k, v in result.items()
401                    if not k.startswith("@") and k != self.vector_field_name
402                }
403                # Remap content field when the index uses a non-standard name.
404                if self.content_field_name != "content" and self.content_field_name in doc_data:
405                    doc_data["content"] = doc_data.pop(self.content_field_name)
406            
407            # Reconstruct document using the configured document_type
408            doc = self.document_type.from_dict(doc_data)
409            
410            search_results.append(SearchResult(
411                document=doc,
412                score=result.get("@search.score", 0.0),
413                rank=rank
414            ))
415        
416        return search_results
417    
418    def delete_documents(self, document_ids: List[str]) -> int:
419        """
420        Delete documents from Azure AI Search.
421        
422        Args:
423            document_ids: List of document IDs to delete
424        
425        Returns:
426            Number of documents successfully deleted
427        """
428        if not document_ids:
429            return 0
430        
431        # Prepare delete operations
432        delete_docs = [{"id": doc_id} for doc_id in document_ids]
433        
434        # Execute delete
435        result = self.search_client.delete_documents(documents=delete_docs)
436        
437        # Count successful deletions
438        deleted_count = sum(1 for r in result if r.succeeded)
439        return deleted_count
440    
441    def update_document(
442        self,
443        document_id: str,
444        content: Optional[str] = None,
445        embedding: Optional[List[float]] = None,
446        metadata: Optional[Dict[str, Any]] = None,
447        timestamp: Optional[datetime] = None,
448        **custom_fields
449    ) -> bool:
450        """
451        Update an existing document in Azure AI Search.
452        
453        Args:
454            document_id: ID of the document to update
455            content: New content (if provided)
456            embedding: New embedding (if provided)
457            metadata: New metadata (merged with existing)
458            timestamp: New timestamp (if provided)
459            **custom_fields: Any custom fields to update
460        
461        Returns:
462            True if update was successful, False if document not found
463        """
464        try:
465            import json
466            
467            # Retrieve existing document
468            existing_doc = self.search_client.get_document(key=document_id)
469            doc_data = json.loads(existing_doc.get("document_data", "{}"))
470            
471            # Update fields in document data
472            if content is not None:
473                doc_data["content"] = content
474                existing_doc["content"] = content
475            
476            if embedding is not None:
477                doc_data["embedding"] = embedding
478                existing_doc["embedding"] = embedding
479            
480            if timestamp is not None:
481                doc_data["timestamp"] = self._format_datetime_for_azure(timestamp)
482                existing_doc["timestamp"] = self._format_datetime_for_azure(timestamp)
483            
484            if metadata is not None:
485                if "metadata" not in doc_data:
486                    doc_data["metadata"] = {}
487                doc_data["metadata"].update(metadata)
488            
489            # Update custom fields
490            for key, value in custom_fields.items():
491                doc_data[key] = value
492            
493            # Update document_data
494            existing_doc["document_data"] = json.dumps(doc_data)
495            
496            # Merge/upload updated document
497            result = self.search_client.merge_or_upload_documents(documents=[existing_doc])
498            
499            return result[0].succeeded
500        
501        except Exception as e:
502            logger.error(f"Failed to update document {document_id}: {e}")
503            return False
504    
505    def get_document(self, document_id: str) -> Optional[Document]:
506        """
507        Retrieve a document by ID from Azure AI Search.
508        
509        Args:
510            document_id: ID of the document to retrieve
511        
512        Returns:
513            Document if found, None otherwise (returns configured document_type)
514        """
515        try:
516            import json
517            result = self.search_client.get_document(key=document_id)
518            
519            # Deserialize full document data using configured document_type
520            doc_data = json.loads(result.get("document_data", "{}"))
521            return self.document_type.from_dict(doc_data)
522        
523        except Exception as e:
524            logger.debug(f"Document {document_id} not found: {e}")
525            return None
526
527    def get_document_by_parent(self, document_id: str) -> List[Document]:
528        """
529        Retrieve all chunks belonging to a parent document.
530
531        Args:
532            document_id: The value of the ``documentId`` field shared across
533                        all chunks of the same parent document.
534
535        Returns:
536            List of Document objects sorted by ``pageNumber``, or an empty
537            list if no matching chunks are found.
538        """
539        results = self.search(
540            query="*",
541            search_type="keyword",
542            filters={"documentId": document_id},
543            top_k=1000,
544        )
545        chunks = [r.document for r in results]
546        chunks.sort(key=lambda c: getattr(c, "pageNumber", 0))
547        return chunks
548
549    def count(self) -> int:
550        """
551        Get the total number of documents in the index.
552        
553        Returns:
554            Total document count
555        """
556        # Execute a search to get total count
557        results = self.search_client.search(
558            search_text="*",
559            include_total_count=True,
560            top=0
561        )
562        
563        # Get count from results
564        count = getattr(results, 'get_count', lambda: 0)()
565        return count if count is not None else 0
566    
567    def clear(self) -> None:
568        """
569        Remove all documents from the index.
570        
571        Warning: This operation is irreversible.
572        """
573        # Search for all document IDs
574        results = self.search_client.search(
575            search_text="*",
576            select=["id"],
577            top=1000  # Adjust batch size as needed
578        )
579        
580        # Collect all IDs
581        doc_ids = [r["id"] for r in results]
582        
583        # Delete in batches
584        if doc_ids:
585            self.delete_documents(doc_ids)
586            logger.info(f"Cleared {len(doc_ids)} documents from index '{self.index_name}'")
587    
588    def _build_filter_expression(self, filters: Optional[Dict[str, Any]]) -> Optional[str]:
589        """
590        Build OData filter expression from filter dictionary.
591        
592        Supports filtering on all indexed fields (universal + custom fields from document_type).
593        
594        Args:
595            filters: Dictionary of field: value pairs
596                    Special operators: use tuples for range queries
597                    - (">=", value) for greater than or equal
598                    - ("<=", value) for less than or equal
599                    - ("range", min_val, max_val) for range queries
600        
601        Returns:
602            OData filter string or None
603            
604        Example:
605            ```python
606            # Filter on universal field
607            filters = {"timestamp": (">=", datetime(2024, 1, 1))}
608            
609            # Filter on custom fields (if document_type has them)
610            filters = {
611                "court": "Supreme Court",
612                "case_number": "2024-001",
613                "decision_date": ("range", start_date, end_date)
614            }
615            ```
616        """
617        if not filters:
618            return None
619        
620        # Get indexed field names from document_type
621        import dataclasses
622        indexed_fields = {'id', 'content', 'timestamp'}  # Universal fields
623        
624        if dataclasses.is_dataclass(self.document_type):
625            for field in dataclasses.fields(self.document_type):
626                if field.name not in {'id', 'content', 'embedding', 'timestamp', 'metadata'}:
627                    indexed_fields.add(field.name)
628        
629        filter_parts = []
630        
631        for key, value in filters.items():
632            # Check if field is indexed
633            if key not in indexed_fields:
634                logger.warning(
635                    f"Filtering on '{key}' - field not found in document schema. "
636                    f"Indexed fields: {indexed_fields}"
637                )
638                continue
639            
640            # Handle tuple operators for range queries
641            if isinstance(value, tuple):
642                operator = value[0]
643                if operator == "range" and len(value) == 3:
644                    # Range query: field >= min and field <= max
645                    min_val = value[1]
646                    max_val = value[2]
647                    if isinstance(min_val, datetime):
648                        min_val = self._format_datetime_for_azure(min_val)
649                    if isinstance(max_val, datetime):
650                        max_val = self._format_datetime_for_azure(max_val)
651                    filter_parts.append(f"{key} ge {min_val} and {key} le {max_val}")
652                elif operator in [">=", "gt", ">="]:
653                    op_str = "ge"
654                    val = value[1]
655                    if isinstance(val, datetime):
656                        val = self._format_datetime_for_azure(val)
657                    filter_parts.append(f"{key} {op_str} {val}")
658                elif operator in ["<=", "lt", "<="]:
659                    op_str = "le"
660                    val = value[1]
661                    if isinstance(val, datetime):
662                        val = self._format_datetime_for_azure(val)
663                    filter_parts.append(f"{key} {op_str} {val}")
664            # Handle datetime equality
665            elif isinstance(value, datetime):
666                filter_parts.append(f"{key} eq {self._format_datetime_for_azure(value)}")
667            # Handle string equality (needs quotes in OData)
668            elif isinstance(value, str):
669                # Escape single quotes in the value
670                escaped_value = value.replace("'", "''")
671                filter_parts.append(f"{key} eq '{escaped_value}'")
672            # Handle boolean
673            elif isinstance(value, bool):
674                filter_parts.append(f"{key} eq {str(value).lower()}")
675            # Handle numeric types
676            elif isinstance(value, (int, float)):
677                filter_parts.append(f"{key} eq {value}")
678        
679        return " and ".join(filter_parts) if filter_parts else None

Azure AI Search vector store for production RAG pipelines.

Features:

  • Scalable vector search with HNSW (Hierarchical Navigable Small World)
  • Hybrid search (vector + keyword BM25)
  • Semantic ranking powered by Microsoft Bing
  • Automatic indexing of all document fields (universal + custom)
  • Efficient filtering on all indexed fields
  • High availability and disaster recovery
  • Managed service (no infrastructure management)

Prerequisites:

  • Azure AI Search service (Basic tier or higher for vector search)
  • Service endpoint and API key or a managed-identity token provider
  • Index created with vector field configuration

Usage:

Simple Mode (base Document):

# API key auth
store = AzureAISearchVectorStore(
    endpoint="https://my-search.search.windows.net",
    index_name="documents",
    api_key="...",
    embedding_dimension=1536
)

# Managed identity / token provider auth
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
token_provider = get_bearer_token_provider(
    DefaultAzureCredential(),
    "https://search.azure.com/.default"  # Azure AI Search scope
)
store = AzureAISearchVectorStore(
    endpoint="https://my-search.search.windows.net",
    index_name="documents",
    token_provider=token_provider,
    embedding_dimension=1536
)
# Uses base Document with universal fields only

Custom Schema Mode (domain-specific Document):

from dataclasses import dataclass
from gmf_forge_ai_data.vector_stores import Document

@dataclass
class LegalDocument(Document):
    case_number: str = ""
    court: str = ""
    jurisdiction: str = ""
    decision_date: Optional[datetime] = None

# Pass document_type - all fields automatically indexed.
# If the index was provisioned externally (portal, Bicep, etc.) the
# vector and content field names will likely differ from the defaults
# ("embedding" / "content"). Always set vector_field_name and
# content_field_name to match your actual index schema.
store = AzureAISearchVectorStore(
    endpoint="https://my-search.search.windows.net",
    index_name="legal_documents",
    api_key="...",
    embedding_dimension=1536,
    document_type=LegalDocument,
    vector_field_name="contentVector",   # match your index schema
    content_field_name="chunkContent",   # match your index schema
)

# All fields indexed - filter on any field
results = store.search(
    query_embedding=vector,
    filters={
        "jurisdiction": "Federal",
        "court": "Supreme Court",
        "decision_date": (">=", datetime(2024, 1, 1))
    }
)
AzureAISearchVectorStore( endpoint: str, index_name: str, api_key: Optional[str] = None, token_provider: Optional[Callable[[], str]] = None, embedding_dimension: int = 1536, document_type: type = <class 'Document'>, vector_field_name: str = 'embedding', content_field_name: str = 'content')
127    def __init__(
128        self,
129        endpoint: str,
130        index_name: str,
131        api_key: Optional[str] = None,
132        token_provider: Optional[Callable[[], str]] = None,
133        embedding_dimension: int = 1536,
134        document_type: type = Document,
135        vector_field_name: str = "embedding",
136        content_field_name: str = "content",
137    ):
138        """
139        Initialize Azure AI Search vector store.
140
141        The index must be pre-provisioned using ``AzureAISearchIndexBuilder``
142        before first use.  The constructor only establishes client connections
143        — it never creates or modifies indexes.
144
145        Exactly one of ``api_key`` or ``token_provider`` must be supplied.
146
147        Args:
148            endpoint: Azure AI Search service endpoint
149            index_name: Name of the search index
150            api_key: Azure AI Search API key. Use for local development or
151                when managed identity is not available.
152            token_provider: Zero-argument callable that returns a bearer token
153                string. Use for managed identity / workload identity scenarios.
154                The callable must request the **Azure AI Search** scope::
155
156                    from azure.identity import DefaultAzureCredential, get_bearer_token_provider
157                    token_provider = get_bearer_token_provider(
158                        DefaultAzureCredential(),
159                        "https://search.azure.com/.default"
160                    )
161
162                Note: this scope is different from Azure OpenAI / Cognitive Services
163                (``https://cognitiveservices.azure.com/.default``) — each service
164                requires its own token_provider.
165            embedding_dimension: Dimension of vector embeddings (default 1536)
166            document_type: Document class (base or custom subclass). All fields will be indexed.
167            vector_field_name: Name of the vector field in the Azure Search index (default "embedding").
168                Override this when the index was built externally with a different field name
169                (e.g. "contentVector", "chunkVector").
170            content_field_name: Name of the text content field in the Azure Search index (default "content").
171                Override when the index uses a different name (e.g. "chunkContent", "text").
172        """
173        if not api_key and not token_provider:
174            raise ValueError(
175                "Either api_key or token_provider must be supplied to AzureAISearchVectorStore."
176            )
177        self.endpoint = endpoint
178        self.index_name = index_name
179        self.embedding_dimension = embedding_dimension
180        self.document_type = document_type
181        self.vector_field_name = vector_field_name
182        self.content_field_name = content_field_name
183
184        if token_provider:
185            credential = _TokenProviderCredential(token_provider)
186        else:
187            credential = AzureKeyCredential(api_key)
188
189        # Initialize clients
190        self.search_client = SearchClient(
191            endpoint=endpoint,
192            index_name=index_name,
193            credential=credential
194        )

Initialize Azure AI Search vector store.

The index must be pre-provisioned using AzureAISearchIndexBuilder before first use. The constructor only establishes client connections — it never creates or modifies indexes.

Exactly one of api_key or token_provider must be supplied.

Args: endpoint: Azure AI Search service endpoint index_name: Name of the search index api_key: Azure AI Search API key. Use for local development or when managed identity is not available. token_provider: Zero-argument callable that returns a bearer token string. Use for managed identity / workload identity scenarios. The callable must request the Azure AI Search scope::

        from azure.identity import DefaultAzureCredential, get_bearer_token_provider
        token_provider = get_bearer_token_provider(
            DefaultAzureCredential(),
            "https://search.azure.com/.default"
        )

    Note: this scope is different from Azure OpenAI / Cognitive Services
    (``https://cognitiveservices.azure.com/.default``) — each service
    requires its own token_provider.
embedding_dimension: Dimension of vector embeddings (default 1536)
document_type: Document class (base or custom subclass). All fields will be indexed.
vector_field_name: Name of the vector field in the Azure Search index (default "embedding").
    Override this when the index was built externally with a different field name
    (e.g. "contentVector", "chunkVector").
content_field_name: Name of the text content field in the Azure Search index (default "content").
    Override when the index uses a different name (e.g. "chunkContent", "text").
endpoint
index_name
embedding_dimension
document_type
vector_field_name
content_field_name
search_client
def add_documents( self, documents: List[Document], generate_embeddings: bool = True) -> List[str]:
214    def add_documents(
215        self,
216        documents: List[Document],
217        generate_embeddings: bool = True
218    ) -> List[str]:
219        """
220        Add documents to Azure AI Search.
221        
222        Args:
223            documents: List of Document objects to add (can be base or custom subclasses)
224            generate_embeddings: If True, validates embeddings are present
225        
226        Returns:
227            List of document IDs that were added
228            
229        Note:
230            Custom document fields are automatically serialized and stored.
231            All fields remain accessible after retrieval.
232        """
233        if not documents:
234            return []
235        
236        # Helper function to format datetime for Azure Search DateTimeOffset
237        format_datetime_for_azure = self._format_datetime_for_azure
238        
239        # Helper function to serialize datetime objects in dictionaries
240        def serialize_for_json(obj):
241            """Recursively convert datetime objects to ISO format strings for JSON serialization."""
242            if isinstance(obj, datetime):
243                return format_datetime_for_azure(obj)
244            elif isinstance(obj, dict):
245                return {k: serialize_for_json(v) for k, v in obj.items()}
246            elif isinstance(obj, (list, tuple)):
247                return [serialize_for_json(item) for item in obj]
248            else:
249                return obj
250        
251        # Prepare documents for upload
252        search_documents = []
253        added_ids = []
254        
255        for doc in documents:
256            if not doc.id or not doc.content:
257                raise ValueError(f"Document must have id and content: {doc}")
258            
259            if generate_embeddings and doc.embedding is None:
260                raise ValueError(
261                    f"Document {doc.id} lacks embedding. "
262                    "Generate embeddings externally before adding."
263                )
264            
265            # Serialize entire document (including custom fields) to JSON
266            import json
267            import dataclasses
268            doc_dict = doc.to_dict()
269            
270            # Convert datetime objects for JSON serialization
271            doc_dict_serializable = serialize_for_json(doc_dict)
272            
273            # Convert to Azure Search document format
274            # Store base fields directly
275            search_doc = {
276                "id": doc.id,
277                "content": doc.content,
278                "embedding": doc.embedding if doc.embedding else [],
279                "timestamp": format_datetime_for_azure(doc.timestamp) if doc.timestamp else None,
280                "document_data": json.dumps(doc_dict_serializable)  # Keep for metadata and backwards compatibility
281            }
282            
283            # Add all custom fields as indexed fields
284            if dataclasses.is_dataclass(doc):
285                for field in dataclasses.fields(doc):
286                    # Skip base Document fields (already added above)
287                    if field.name in {'id', 'content', 'embedding', 'timestamp', 'metadata'}:
288                        continue
289                    
290                    field_value = getattr(doc, field.name, None)
291                    
292                    if field_value is not None:
293                        # Serialize datetime fields with Azure DateTimeOffset format
294                        if isinstance(field_value, datetime):
295                            search_doc[field.name] = format_datetime_for_azure(field_value)
296                        # Handle lists/dicts by converting to JSON string
297                        elif isinstance(field_value, (list, dict)):
298                            search_doc[field.name] = json.dumps(field_value)
299                        else:
300                            search_doc[field.name] = field_value
301            
302            search_documents.append(search_doc)
303            added_ids.append(doc.id)
304        
305        # Upload to Azure Search
306        result = self.search_client.upload_documents(documents=search_documents)
307        
308        # Check for failures
309        failed_count = sum(1 for r in result if not r.succeeded)
310        if failed_count > 0:
311            logger.warning(f"{failed_count} documents failed to upload")
312        
313        return added_ids

Add documents to Azure AI Search.

Args: documents: List of Document objects to add (can be base or custom subclasses) generate_embeddings: If True, validates embeddings are present

Returns: List of document IDs that were added

Note: Custom document fields are automatically serialized and stored. All fields remain accessible after retrieval.

def search( self, query: Optional[str] = None, query_embedding: Optional[List[float]] = None, top_k: int = 5, filters: Optional[Dict[str, Any]] = None, search_type: str = 'vector') -> List[SearchResult]:
315    def search(
316        self,
317        query: Optional[str] = None,
318        query_embedding: Optional[List[float]] = None,
319        top_k: int = 5,
320        filters: Optional[Dict[str, Any]] = None,
321        search_type: str = "vector"
322    ) -> List[SearchResult]:
323        """
324        Search Azure AI Search index.
325        
326        Args:
327            query: Text query (required for keyword/hybrid search)
328            query_embedding: Vector embedding (required for vector/hybrid search)
329            top_k: Number of results to return
330            filters: Metadata filters (e.g., {"source": "doc.pdf"})
331            search_type: "vector", "keyword", or "hybrid"
332        
333        Returns:
334            List of SearchResult objects, ordered by relevance
335        """
336        if search_type not in ["vector", "keyword", "hybrid"]:
337            raise ValueError(f"Invalid search_type: {search_type}")
338        
339        if search_type in ["vector", "hybrid"] and query_embedding is None:
340            raise ValueError(f"{search_type} search requires query_embedding")
341        
342        if search_type in ["keyword", "hybrid"] and query is None:
343            raise ValueError(f"{search_type} search requires query text")
344        
345        # Build filter expression
346        filter_expr = self._build_filter_expression(filters)
347        
348        # Execute search based on type
349        if search_type == "vector":
350            # Pure vector search
351            vector_query = VectorizedQuery(
352                vector=query_embedding,
353                k_nearest_neighbors=top_k,
354                fields=self.vector_field_name
355            )
356            
357            results = self.search_client.search(
358                search_text=None,
359                vector_queries=[vector_query],
360                filter=filter_expr,
361                top=top_k
362            )
363        
364        elif search_type == "keyword":
365            # Pure keyword search (BM25)
366            results = self.search_client.search(
367                search_text=query,
368                filter=filter_expr,
369                top=top_k,
370                include_total_count=False
371            )
372        
373        else:  # hybrid
374            # Hybrid search (vector + keyword)
375            vector_query = VectorizedQuery(
376                vector=query_embedding,
377                k_nearest_neighbors=top_k,
378                fields=self.vector_field_name
379            )
380            
381            results = self.search_client.search(
382                search_text=query,
383                vector_queries=[vector_query],
384                filter=filter_expr,
385                top=top_k
386            )
387        
388        # Convert to SearchResult objects
389        search_results = []
390        for rank, result in enumerate(results):
391            import json
392            
393            # Deserialize full document data
394            doc_data = json.loads(result.get("document_data", "{}") or "{}")
395            
396            if not doc_data:
397                # Index was built externally — fields are stored directly on the result.
398                # Collect all non-Azure-internal fields, excluding the vector blob.
399                doc_data = {
400                    k: v for k, v in result.items()
401                    if not k.startswith("@") and k != self.vector_field_name
402                }
403                # Remap content field when the index uses a non-standard name.
404                if self.content_field_name != "content" and self.content_field_name in doc_data:
405                    doc_data["content"] = doc_data.pop(self.content_field_name)
406            
407            # Reconstruct document using the configured document_type
408            doc = self.document_type.from_dict(doc_data)
409            
410            search_results.append(SearchResult(
411                document=doc,
412                score=result.get("@search.score", 0.0),
413                rank=rank
414            ))
415        
416        return search_results

Search Azure AI Search index.

Args: query: Text query (required for keyword/hybrid search) query_embedding: Vector embedding (required for vector/hybrid search) top_k: Number of results to return filters: Metadata filters (e.g., {"source": "doc.pdf"}) search_type: "vector", "keyword", or "hybrid"

Returns: List of SearchResult objects, ordered by relevance

def delete_documents(self, document_ids: List[str]) -> int:
418    def delete_documents(self, document_ids: List[str]) -> int:
419        """
420        Delete documents from Azure AI Search.
421        
422        Args:
423            document_ids: List of document IDs to delete
424        
425        Returns:
426            Number of documents successfully deleted
427        """
428        if not document_ids:
429            return 0
430        
431        # Prepare delete operations
432        delete_docs = [{"id": doc_id} for doc_id in document_ids]
433        
434        # Execute delete
435        result = self.search_client.delete_documents(documents=delete_docs)
436        
437        # Count successful deletions
438        deleted_count = sum(1 for r in result if r.succeeded)
439        return deleted_count

Delete documents from Azure AI Search.

Args: document_ids: List of document IDs to delete

Returns: Number of documents successfully deleted

def update_document( self, document_id: str, content: Optional[str] = None, embedding: Optional[List[float]] = None, metadata: Optional[Dict[str, Any]] = None, timestamp: Optional[datetime.datetime] = None, **custom_fields) -> bool:
441    def update_document(
442        self,
443        document_id: str,
444        content: Optional[str] = None,
445        embedding: Optional[List[float]] = None,
446        metadata: Optional[Dict[str, Any]] = None,
447        timestamp: Optional[datetime] = None,
448        **custom_fields
449    ) -> bool:
450        """
451        Update an existing document in Azure AI Search.
452        
453        Args:
454            document_id: ID of the document to update
455            content: New content (if provided)
456            embedding: New embedding (if provided)
457            metadata: New metadata (merged with existing)
458            timestamp: New timestamp (if provided)
459            **custom_fields: Any custom fields to update
460        
461        Returns:
462            True if update was successful, False if document not found
463        """
464        try:
465            import json
466            
467            # Retrieve existing document
468            existing_doc = self.search_client.get_document(key=document_id)
469            doc_data = json.loads(existing_doc.get("document_data", "{}"))
470            
471            # Update fields in document data
472            if content is not None:
473                doc_data["content"] = content
474                existing_doc["content"] = content
475            
476            if embedding is not None:
477                doc_data["embedding"] = embedding
478                existing_doc["embedding"] = embedding
479            
480            if timestamp is not None:
481                doc_data["timestamp"] = self._format_datetime_for_azure(timestamp)
482                existing_doc["timestamp"] = self._format_datetime_for_azure(timestamp)
483            
484            if metadata is not None:
485                if "metadata" not in doc_data:
486                    doc_data["metadata"] = {}
487                doc_data["metadata"].update(metadata)
488            
489            # Update custom fields
490            for key, value in custom_fields.items():
491                doc_data[key] = value
492            
493            # Update document_data
494            existing_doc["document_data"] = json.dumps(doc_data)
495            
496            # Merge/upload updated document
497            result = self.search_client.merge_or_upload_documents(documents=[existing_doc])
498            
499            return result[0].succeeded
500        
501        except Exception as e:
502            logger.error(f"Failed to update document {document_id}: {e}")
503            return False

Update an existing document in Azure AI Search.

Args: document_id: ID of the document to update content: New content (if provided) embedding: New embedding (if provided) metadata: New metadata (merged with existing) timestamp: New timestamp (if provided) **custom_fields: Any custom fields to update

Returns: True if update was successful, False if document not found

def get_document( self, document_id: str) -> Optional[Document]:
505    def get_document(self, document_id: str) -> Optional[Document]:
506        """
507        Retrieve a document by ID from Azure AI Search.
508        
509        Args:
510            document_id: ID of the document to retrieve
511        
512        Returns:
513            Document if found, None otherwise (returns configured document_type)
514        """
515        try:
516            import json
517            result = self.search_client.get_document(key=document_id)
518            
519            # Deserialize full document data using configured document_type
520            doc_data = json.loads(result.get("document_data", "{}"))
521            return self.document_type.from_dict(doc_data)
522        
523        except Exception as e:
524            logger.debug(f"Document {document_id} not found: {e}")
525            return None

Retrieve a document by ID from Azure AI Search.

Args: document_id: ID of the document to retrieve

Returns: Document if found, None otherwise (returns configured document_type)

def get_document_by_parent( self, document_id: str) -> List[Document]:
527    def get_document_by_parent(self, document_id: str) -> List[Document]:
528        """
529        Retrieve all chunks belonging to a parent document.
530
531        Args:
532            document_id: The value of the ``documentId`` field shared across
533                        all chunks of the same parent document.
534
535        Returns:
536            List of Document objects sorted by ``pageNumber``, or an empty
537            list if no matching chunks are found.
538        """
539        results = self.search(
540            query="*",
541            search_type="keyword",
542            filters={"documentId": document_id},
543            top_k=1000,
544        )
545        chunks = [r.document for r in results]
546        chunks.sort(key=lambda c: getattr(c, "pageNumber", 0))
547        return chunks

Retrieve all chunks belonging to a parent document.

Args: document_id: The value of the documentId field shared across all chunks of the same parent document.

Returns: List of Document objects sorted by pageNumber, or an empty list if no matching chunks are found.

def count(self) -> int:
549    def count(self) -> int:
550        """
551        Get the total number of documents in the index.
552        
553        Returns:
554            Total document count
555        """
556        # Execute a search to get total count
557        results = self.search_client.search(
558            search_text="*",
559            include_total_count=True,
560            top=0
561        )
562        
563        # Get count from results
564        count = getattr(results, 'get_count', lambda: 0)()
565        return count if count is not None else 0

Get the total number of documents in the index.

Returns: Total document count

def clear(self) -> None:
567    def clear(self) -> None:
568        """
569        Remove all documents from the index.
570        
571        Warning: This operation is irreversible.
572        """
573        # Search for all document IDs
574        results = self.search_client.search(
575            search_text="*",
576            select=["id"],
577            top=1000  # Adjust batch size as needed
578        )
579        
580        # Collect all IDs
581        doc_ids = [r["id"] for r in results]
582        
583        # Delete in batches
584        if doc_ids:
585            self.delete_documents(doc_ids)
586            logger.info(f"Cleared {len(doc_ids)} documents from index '{self.index_name}'")

Remove all documents from the index.

Warning: This operation is irreversible.

class AzureCosmosDBVectorStore(gmf_forge_ai_data.BaseVectorStore):
 32class AzureCosmosDBVectorStore(BaseVectorStore):
 33    """
 34    Azure Cosmos DB NoSQL API vector store.
 35
 36    Stores documents in a Cosmos DB NoSQL container and queries them using
 37    Cosmos DB's vector distance functions and SQL-like query language.
 38
 39    Features:
 40    - Vector search using Cosmos DB vector indexing (cosine distance)
 41    - Full-text keyword search via SQL CONTAINS / LIKE queries
 42    - Hybrid search (vector + keyword, equal weighting)
 43    - Metadata filtering via SQL WHERE clauses
 44    - Custom ``document_type`` support (any Document subclass)
 45    - Automatic container creation with vector embedding policy
 46
 47    Usage::
 48
 49        from gmf_forge_ai_data.vector_stores import AzureCosmosDBVectorStore
 50
 51        store = AzureCosmosDBVectorStore(
 52            endpoint="https://your-account.documents.azure.com:443/",
 53            key="your-key",
 54            database_name="rag_db",
 55            container_name="documents",
 56            embedding_dimension=1536,
 57        )
 58        store.add_documents(docs_with_embeddings)
 59        results = store.search(query_embedding=embedding, top_k=5)
 60    """
 61
 62    def __init__(
 63        self,
 64        endpoint: str,
 65        key: str,
 66        database_name: str,
 67        container_name: str,
 68        embedding_dimension: int = 1536,
 69        document_type: type = Document,
 70        ssl_cert_path: Optional[str] = None,
 71    ) -> None:
 72        """
 73        Initialise the Cosmos DB NoSQL vector store.
 74
 75        The database and container must be pre-provisioned using
 76        ``CosmosDBIndexBuilder`` before first use.  The constructor only
 77        establishes client connections — it never creates or modifies
 78        databases or containers.
 79
 80        Args:
 81            endpoint: Cosmos DB account endpoint
 82                (e.g. ``https://your-account.documents.azure.com:443/``).
 83            key: Cosmos DB account key or resource token.
 84            database_name: Name of the Cosmos DB database.
 85            container_name: Name of the container.
 86            embedding_dimension: Dimension of vector embeddings (default 1536).
 87            document_type: Document class to use when deserialising results.
 88                Pass a custom subclass to support additional typed fields.
 89            ssl_cert_path: Path to a CA certificate bundle (PEM) for TLS
 90                verification.  Useful in corporate environments with custom
 91                certificate authorities.
 92        """
 93        try:
 94            from azure.cosmos import CosmosClient  # noqa: F401
 95        except ImportError as exc:
 96            raise ImportError(
 97                "azure-cosmos is required for AzureCosmosDBVectorStore. "
 98                "Install it with:  pip install azure-cosmos"
 99            ) from exc
100
101        self.embedding_dimension = embedding_dimension
102        self.document_type = document_type
103        self._container_name = container_name
104
105        # Build client with optional custom SSL cert
106        import ssl as _ssl
107        connection_kwargs: Dict[str, Any] = {}
108        if ssl_cert_path:
109            ssl_context = _ssl.create_default_context(cafile=ssl_cert_path)
110            connection_kwargs["connection_verify"] = ssl_cert_path
111
112        from azure.cosmos import CosmosClient
113        self._client = CosmosClient(endpoint, credential=key, **connection_kwargs)
114
115        # Database and container — pre-provisioned by developer
116        self._database = self._client.get_database_client(database_name)
117        self._container = self._database.get_container_client(container_name)
118
119    # ------------------------------------------------------------------
120    # Serialisation helpers
121    # ------------------------------------------------------------------
122
123    @staticmethod
124    def _serialize_value(obj: Any) -> Any:
125        """Recursively convert datetime objects to ISO strings."""
126        if isinstance(obj, datetime):
127            return obj.isoformat()
128        if isinstance(obj, dict):
129            return {k: AzureCosmosDBVectorStore._serialize_value(v) for k, v in obj.items()}
130        if isinstance(obj, (list, tuple)):
131            return [AzureCosmosDBVectorStore._serialize_value(i) for i in obj]
132        return obj
133
134    def _to_cosmos(self, doc: Document) -> Dict[str, Any]:
135        """Convert a Document to a Cosmos DB item dict."""
136        doc_dict = doc.to_dict()
137        doc_dict_serialisable = self._serialize_value(doc_dict)
138        item = {
139            "id": doc.id,
140            "content": doc.content,
141            "embedding": doc.embedding or [],
142            "timestamp": doc.timestamp.isoformat() if doc.timestamp else None,
143            "metadata": doc.metadata,
144            "document_data": json.dumps(doc_dict_serialisable),
145        }
146        # Promote custom fields (e.g. category, author, year) to top-level
147        # so they are filterable via SQL WHERE clauses
148        base_keys = {"id", "content", "embedding", "timestamp", "metadata"}
149        for key, value in doc_dict_serialisable.items():
150            if key not in base_keys:
151                item[key] = value
152        return item
153
154    def _from_cosmos(self, cosmos_item: Dict[str, Any]) -> Document:
155        """Reconstruct a Document (or subclass) from a Cosmos DB item."""
156        doc_data = json.loads(cosmos_item.get("document_data", "{}"))
157        return self.document_type.from_dict(doc_data)
158
159    # ------------------------------------------------------------------
160    # Filter helpers
161    # ------------------------------------------------------------------
162
163    @staticmethod
164    def _build_where_clause(filters: Optional[Dict[str, Any]]) -> str:
165        """Translate a user-facing filter dict to a SQL WHERE clause."""
166        if not filters:
167            return ""
168        op_map = {">=": ">=", "<=": "<=", ">": ">", "<": "<", "!=": "!="}
169        conditions: List[str] = []
170        for key, value in filters.items():
171            if isinstance(value, tuple) and len(value) == 2:
172                op, val = value
173                sql_op = op_map.get(op, "=")
174                if isinstance(val, str):
175                    conditions.append(f"c.{key} {sql_op} '{val}'")
176                elif isinstance(val, datetime):
177                    conditions.append(f"c.{key} {sql_op} '{val.isoformat()}'")
178                else:
179                    conditions.append(f"c.{key} {sql_op} {val}")
180            else:
181                if isinstance(value, str):
182                    conditions.append(f"c.{key} = '{value}'")
183                elif isinstance(value, bool):
184                    conditions.append(f"c.{key} = {'true' if value else 'false'}")
185                elif isinstance(value, datetime):
186                    conditions.append(f"c.{key} = '{value.isoformat()}'")
187                else:
188                    conditions.append(f"c.{key} = {value}")
189        return " AND ".join(conditions)
190
191    # ------------------------------------------------------------------
192    # BaseVectorStore interface
193    # ------------------------------------------------------------------
194
195    def add_documents(
196        self,
197        documents: List[Document],
198        generate_embeddings: bool = True,
199    ) -> List[str]:
200        """
201        Upsert documents into the Cosmos DB container.
202
203        Args:
204            documents: Documents to add.  Each must have ``id`` and ``content``
205                set.  If *generate_embeddings* is ``True`` (default) every
206                document must also carry a pre-computed ``embedding``.
207            generate_embeddings: When ``True`` the method validates that all
208                documents already have embeddings rather than generating them
209                itself (generation is the caller's responsibility).
210
211        Returns:
212            List of document IDs that were successfully upserted.
213        """
214        if not documents:
215            return []
216
217        added_ids: List[str] = []
218
219        for doc in documents:
220            if not doc.id or not doc.content:
221                raise ValueError(f"Document must have id and content: {doc}")
222
223            if generate_embeddings and doc.embedding is None:
224                raise ValueError(
225                    f"Document '{doc.id}' has no embedding.  "
226                    "Generate embeddings before calling add_documents()."
227                )
228
229            cosmos_item = self._to_cosmos(doc)
230            self._container.upsert_item(cosmos_item)
231            added_ids.append(doc.id)
232
233        logger.info(
234            "Upserted %d document(s) into Cosmos DB container '%s'",
235            len(added_ids),
236            self._container_name,
237        )
238        return added_ids
239
240    def search(
241        self,
242        query: Optional[str] = None,
243        query_embedding: Optional[List[float]] = None,
244        top_k: int = 5,
245        filters: Optional[Dict[str, Any]] = None,
246        search_type: str = "vector",
247    ) -> List[SearchResult]:
248        """
249        Search the Cosmos DB container.
250
251        Args:
252            query: Plain-text query string.  Required for ``keyword`` and
253                ``hybrid`` search types.
254            query_embedding: Pre-computed query vector.  Required for
255                ``vector`` and ``hybrid`` search types.
256            top_k: Maximum number of results to return.
257            filters: Optional filter dict.  Supports equality
258                (``{"key": value}``) and range (``{"key": (">=", value)}``)
259                operators.
260            search_type: One of ``"vector"``, ``"keyword"``, or ``"hybrid"``.
261
262        Returns:
263            Ranked list of :class:`SearchResult` objects.
264        """
265        if search_type not in ("vector", "keyword", "hybrid"):
266            raise ValueError(
267                f"Invalid search_type '{search_type}'. "
268                "Must be 'vector', 'keyword', or 'hybrid'."
269            )
270        if search_type in ("vector", "hybrid") and query_embedding is None:
271            raise ValueError(f"search_type='{search_type}' requires query_embedding")
272        if search_type in ("keyword", "hybrid") and query is None:
273            raise ValueError(f"search_type='{search_type}' requires query text")
274
275        if search_type == "vector":
276            return self._vector_search(query_embedding, top_k, filters)
277        if search_type == "keyword":
278            return self._keyword_search(query, top_k, filters)
279        return self._hybrid_search(query, query_embedding, top_k, filters)
280
281    # --- search helpers -----------------------------------------------
282
283    def _vector_search(
284        self,
285        query_embedding: List[float],
286        top_k: int,
287        filters: Optional[Dict[str, Any]],
288    ) -> List[SearchResult]:
289        where_clause = self._build_where_clause(filters)
290        where_sql = f"WHERE {where_clause}" if where_clause else ""
291
292        query_text = (
293            f"SELECT TOP @top_k c.id, c.document_data, "
294            f"VectorDistance(c.embedding, @embedding) AS score "
295            f"FROM c {where_sql} "
296            f"ORDER BY VectorDistance(c.embedding, @embedding)"
297        )
298        parameters = [
299            {"name": "@top_k", "value": top_k},
300            {"name": "@embedding", "value": query_embedding},
301        ]
302
303        items = list(self._container.query_items(
304            query=query_text,
305            parameters=parameters,
306            enable_cross_partition_query=True,
307        ))
308
309        results = []
310        for rank, item in enumerate(items):
311            doc = self._from_cosmos(item)
312            # VectorDistance with cosine returns distance (0 = identical);
313            # convert to similarity score (1 = identical)
314            distance = float(item.get("score", 1.0))
315            similarity = 1.0 - distance
316            results.append(SearchResult(document=doc, score=similarity, rank=rank))
317        return results
318
319    def _keyword_search(
320        self,
321        query: str,
322        top_k: int,
323        filters: Optional[Dict[str, Any]],
324    ) -> List[SearchResult]:
325        # Build keyword conditions using CONTAINS for each query term
326        terms = query.split()
327        keyword_conditions = " OR ".join(
328            f"CONTAINS(LOWER(c.content), '{term.lower()}')" for term in terms
329        )
330        where_clause = self._build_where_clause(filters)
331        filter_sql = f" AND {where_clause}" if where_clause else ""
332
333        # Fetch ALL matching documents — scoring and top_k slicing happen in Python.
334        # Applying SELECT TOP before Python scoring would miss high-scoring documents
335        # if Cosmos DB returns lower-scoring ones first (order without ORDER BY is undefined).
336        query_text = (
337            f"SELECT c.id, c.document_data, c.content "
338            f"FROM c WHERE ({keyword_conditions}){filter_sql}"
339        )
340
341        items = list(self._container.query_items(
342            query=query_text,
343            parameters=[],
344            enable_cross_partition_query=True,
345        ))
346
347        # Score by number of matching terms (simple BM25-like relevance)
348        results = []
349        for item in items:
350            content_lower = item.get("content", "").lower()
351            match_count = sum(1 for t in terms if t.lower() in content_lower)
352            score = match_count / len(terms) if terms else 0.0
353            results.append(SearchResult(
354                document=self._from_cosmos(item),
355                score=score,
356                rank=0,
357            ))
358
359        results.sort(key=lambda r: r.score, reverse=True)
360        for rank, r in enumerate(results):
361            r.rank = rank
362        return results[:top_k]
363
364    def _hybrid_search(
365        self,
366        query: str,
367        query_embedding: List[float],
368        top_k: int,
369        filters: Optional[Dict[str, Any]],
370    ) -> List[SearchResult]:
371        """Combine vector and keyword results with equal (50/50) weighting."""
372        vector_results = self._vector_search(query_embedding, top_k, filters)
373        keyword_results = self._keyword_search(query, top_k, filters)
374
375        def _max(results: List[SearchResult]) -> float:
376            return max((r.score for r in results), default=1.0) or 1.0
377
378        v_max = _max(vector_results)
379        k_max = _max(keyword_results)
380
381        combined: Dict[str, float] = {}
382        doc_map: Dict[str, Document] = {}
383
384        for r in vector_results:
385            combined[r.document.id] = 0.5 * (r.score / v_max)
386            doc_map[r.document.id] = r.document
387
388        for r in keyword_results:
389            combined[r.document.id] = (
390                combined.get(r.document.id, 0.0) + 0.5 * (r.score / k_max)
391            )
392            doc_map.setdefault(r.document.id, r.document)
393
394        sorted_ids = sorted(combined, key=lambda x: combined[x], reverse=True)[:top_k]
395        return [
396            SearchResult(
397                document=doc_map[doc_id],
398                score=combined[doc_id],
399                rank=rank,
400            )
401            for rank, doc_id in enumerate(sorted_ids)
402        ]
403
404    # ------------------------------------------------------------------
405    # Document lifecycle
406    # ------------------------------------------------------------------
407
408    def delete_documents(self, document_ids: List[str]) -> int:
409        """
410        Delete documents by ID.
411
412        Args:
413            document_ids: IDs of the documents to remove.
414
415        Returns:
416            Number of documents actually deleted.
417        """
418        deleted = 0
419        for doc_id in document_ids:
420            try:
421                self._container.delete_item(item=doc_id, partition_key=doc_id)
422                deleted += 1
423            except Exception:
424                logger.debug("Document '%s' not found for deletion", doc_id)
425        return deleted
426
427    def update_document(
428        self,
429        document_id: str,
430        content: Optional[str] = None,
431        embedding: Optional[List[float]] = None,
432        metadata: Optional[Dict[str, Any]] = None,
433    ) -> bool:
434        """
435        Update an existing document.
436
437        Args:
438            document_id: ID of the document to update.
439            content: New text content (optional).
440            embedding: New embedding vector (optional).
441            metadata: Metadata key-value pairs to *merge* into existing metadata
442                (optional).
443
444        Returns:
445            ``True`` if the document was found and updated, ``False`` otherwise.
446        """
447        try:
448            item = self._container.read_item(item=document_id, partition_key=document_id)
449        except Exception:
450            return False
451
452        doc = self._from_cosmos(item)
453
454        if content is not None:
455            doc.content = content
456        if embedding is not None:
457            doc.embedding = embedding
458        if metadata is not None:
459            doc.metadata.update(metadata)
460
461        self._container.upsert_item(self._to_cosmos(doc))
462        return True
463
464    def get_document(self, document_id: str) -> Optional[Document]:
465        """
466        Retrieve a single document by ID.
467
468        Args:
469            document_id: The document's unique identifier.
470
471        Returns:
472            The :class:`Document` (or subclass) if found, otherwise ``None``.
473        """
474        try:
475            item = self._container.read_item(item=document_id, partition_key=document_id)
476            return self._from_cosmos(item)
477        except Exception:
478            return None
479
480    def count(self) -> int:
481        """Return the total number of documents in the container."""
482        query_text = "SELECT VALUE COUNT(1) FROM c"
483        items = list(self._container.query_items(
484            query=query_text,
485            enable_cross_partition_query=True,
486        ))
487        return items[0] if items else 0
488
489    def clear(self) -> None:
490        """
491        Remove **all** documents from the container.
492
493        Warning:
494            This operation is irreversible.
495        """
496        query_text = "SELECT c.id FROM c"
497        items = list(self._container.query_items(
498            query=query_text,
499            enable_cross_partition_query=True,
500        ))
501        for item in items:
502            try:
503                self._container.delete_item(item=item["id"], partition_key=item["id"])
504            except Exception:
505                pass
506        logger.warning(
507            "Cleared all documents from Cosmos DB container '%s'",
508            self._container_name,
509        )
510
511    def close(self) -> None:
512        """Close the underlying Cosmos DB client."""
513        # azure-cosmos CosmosClient doesn't require explicit close,
514        # but we provide the method for interface consistency
515        pass

Azure Cosmos DB NoSQL API vector store.

Stores documents in a Cosmos DB NoSQL container and queries them using Cosmos DB's vector distance functions and SQL-like query language.

Features:

  • Vector search using Cosmos DB vector indexing (cosine distance)
  • Full-text keyword search via SQL CONTAINS / LIKE queries
  • Hybrid search (vector + keyword, equal weighting)
  • Metadata filtering via SQL WHERE clauses
  • Custom document_type support (any Document subclass)
  • Automatic container creation with vector embedding policy

Usage::

from gmf_forge_ai_data.vector_stores import AzureCosmosDBVectorStore

store = AzureCosmosDBVectorStore(
    endpoint="https://your-account.documents.azure.com:443/",
    key="your-key",
    database_name="rag_db",
    container_name="documents",
    embedding_dimension=1536,
)
store.add_documents(docs_with_embeddings)
results = store.search(query_embedding=embedding, top_k=5)
AzureCosmosDBVectorStore( endpoint: str, key: str, database_name: str, container_name: str, embedding_dimension: int = 1536, document_type: type = <class 'Document'>, ssl_cert_path: Optional[str] = None)
 62    def __init__(
 63        self,
 64        endpoint: str,
 65        key: str,
 66        database_name: str,
 67        container_name: str,
 68        embedding_dimension: int = 1536,
 69        document_type: type = Document,
 70        ssl_cert_path: Optional[str] = None,
 71    ) -> None:
 72        """
 73        Initialise the Cosmos DB NoSQL vector store.
 74
 75        The database and container must be pre-provisioned using
 76        ``CosmosDBIndexBuilder`` before first use.  The constructor only
 77        establishes client connections — it never creates or modifies
 78        databases or containers.
 79
 80        Args:
 81            endpoint: Cosmos DB account endpoint
 82                (e.g. ``https://your-account.documents.azure.com:443/``).
 83            key: Cosmos DB account key or resource token.
 84            database_name: Name of the Cosmos DB database.
 85            container_name: Name of the container.
 86            embedding_dimension: Dimension of vector embeddings (default 1536).
 87            document_type: Document class to use when deserialising results.
 88                Pass a custom subclass to support additional typed fields.
 89            ssl_cert_path: Path to a CA certificate bundle (PEM) for TLS
 90                verification.  Useful in corporate environments with custom
 91                certificate authorities.
 92        """
 93        try:
 94            from azure.cosmos import CosmosClient  # noqa: F401
 95        except ImportError as exc:
 96            raise ImportError(
 97                "azure-cosmos is required for AzureCosmosDBVectorStore. "
 98                "Install it with:  pip install azure-cosmos"
 99            ) from exc
100
101        self.embedding_dimension = embedding_dimension
102        self.document_type = document_type
103        self._container_name = container_name
104
105        # Build client with optional custom SSL cert
106        import ssl as _ssl
107        connection_kwargs: Dict[str, Any] = {}
108        if ssl_cert_path:
109            ssl_context = _ssl.create_default_context(cafile=ssl_cert_path)
110            connection_kwargs["connection_verify"] = ssl_cert_path
111
112        from azure.cosmos import CosmosClient
113        self._client = CosmosClient(endpoint, credential=key, **connection_kwargs)
114
115        # Database and container — pre-provisioned by developer
116        self._database = self._client.get_database_client(database_name)
117        self._container = self._database.get_container_client(container_name)

Initialise the Cosmos DB NoSQL vector store.

The database and container must be pre-provisioned using CosmosDBIndexBuilder before first use. The constructor only establishes client connections — it never creates or modifies databases or containers.

Args: endpoint: Cosmos DB account endpoint (e.g. https://your-account.documents.azure.com:443/). key: Cosmos DB account key or resource token. database_name: Name of the Cosmos DB database. container_name: Name of the container. embedding_dimension: Dimension of vector embeddings (default 1536). document_type: Document class to use when deserialising results. Pass a custom subclass to support additional typed fields. ssl_cert_path: Path to a CA certificate bundle (PEM) for TLS verification. Useful in corporate environments with custom certificate authorities.

embedding_dimension
document_type
def add_documents( self, documents: List[Document], generate_embeddings: bool = True) -> List[str]:
195    def add_documents(
196        self,
197        documents: List[Document],
198        generate_embeddings: bool = True,
199    ) -> List[str]:
200        """
201        Upsert documents into the Cosmos DB container.
202
203        Args:
204            documents: Documents to add.  Each must have ``id`` and ``content``
205                set.  If *generate_embeddings* is ``True`` (default) every
206                document must also carry a pre-computed ``embedding``.
207            generate_embeddings: When ``True`` the method validates that all
208                documents already have embeddings rather than generating them
209                itself (generation is the caller's responsibility).
210
211        Returns:
212            List of document IDs that were successfully upserted.
213        """
214        if not documents:
215            return []
216
217        added_ids: List[str] = []
218
219        for doc in documents:
220            if not doc.id or not doc.content:
221                raise ValueError(f"Document must have id and content: {doc}")
222
223            if generate_embeddings and doc.embedding is None:
224                raise ValueError(
225                    f"Document '{doc.id}' has no embedding.  "
226                    "Generate embeddings before calling add_documents()."
227                )
228
229            cosmos_item = self._to_cosmos(doc)
230            self._container.upsert_item(cosmos_item)
231            added_ids.append(doc.id)
232
233        logger.info(
234            "Upserted %d document(s) into Cosmos DB container '%s'",
235            len(added_ids),
236            self._container_name,
237        )
238        return added_ids

Upsert documents into the Cosmos DB container.

Args: documents: Documents to add. Each must have id and content set. If generate_embeddings is True (default) every document must also carry a pre-computed embedding. generate_embeddings: When True the method validates that all documents already have embeddings rather than generating them itself (generation is the caller's responsibility).

Returns: List of document IDs that were successfully upserted.

def search( self, query: Optional[str] = None, query_embedding: Optional[List[float]] = None, top_k: int = 5, filters: Optional[Dict[str, Any]] = None, search_type: str = 'vector') -> List[SearchResult]:
240    def search(
241        self,
242        query: Optional[str] = None,
243        query_embedding: Optional[List[float]] = None,
244        top_k: int = 5,
245        filters: Optional[Dict[str, Any]] = None,
246        search_type: str = "vector",
247    ) -> List[SearchResult]:
248        """
249        Search the Cosmos DB container.
250
251        Args:
252            query: Plain-text query string.  Required for ``keyword`` and
253                ``hybrid`` search types.
254            query_embedding: Pre-computed query vector.  Required for
255                ``vector`` and ``hybrid`` search types.
256            top_k: Maximum number of results to return.
257            filters: Optional filter dict.  Supports equality
258                (``{"key": value}``) and range (``{"key": (">=", value)}``)
259                operators.
260            search_type: One of ``"vector"``, ``"keyword"``, or ``"hybrid"``.
261
262        Returns:
263            Ranked list of :class:`SearchResult` objects.
264        """
265        if search_type not in ("vector", "keyword", "hybrid"):
266            raise ValueError(
267                f"Invalid search_type '{search_type}'. "
268                "Must be 'vector', 'keyword', or 'hybrid'."
269            )
270        if search_type in ("vector", "hybrid") and query_embedding is None:
271            raise ValueError(f"search_type='{search_type}' requires query_embedding")
272        if search_type in ("keyword", "hybrid") and query is None:
273            raise ValueError(f"search_type='{search_type}' requires query text")
274
275        if search_type == "vector":
276            return self._vector_search(query_embedding, top_k, filters)
277        if search_type == "keyword":
278            return self._keyword_search(query, top_k, filters)
279        return self._hybrid_search(query, query_embedding, top_k, filters)

Search the Cosmos DB container.

Args: query: Plain-text query string. Required for keyword and hybrid search types. query_embedding: Pre-computed query vector. Required for vector and hybrid search types. top_k: Maximum number of results to return. filters: Optional filter dict. Supports equality ({"key": value}) and range ({"key": (">=", value)}) operators. search_type: One of "vector", "keyword", or "hybrid".

Returns: Ranked list of SearchResult objects.

def delete_documents(self, document_ids: List[str]) -> int:
408    def delete_documents(self, document_ids: List[str]) -> int:
409        """
410        Delete documents by ID.
411
412        Args:
413            document_ids: IDs of the documents to remove.
414
415        Returns:
416            Number of documents actually deleted.
417        """
418        deleted = 0
419        for doc_id in document_ids:
420            try:
421                self._container.delete_item(item=doc_id, partition_key=doc_id)
422                deleted += 1
423            except Exception:
424                logger.debug("Document '%s' not found for deletion", doc_id)
425        return deleted

Delete documents by ID.

Args: document_ids: IDs of the documents to remove.

Returns: Number of documents actually deleted.

def update_document( self, document_id: str, content: Optional[str] = None, embedding: Optional[List[float]] = None, metadata: Optional[Dict[str, Any]] = None) -> bool:
427    def update_document(
428        self,
429        document_id: str,
430        content: Optional[str] = None,
431        embedding: Optional[List[float]] = None,
432        metadata: Optional[Dict[str, Any]] = None,
433    ) -> bool:
434        """
435        Update an existing document.
436
437        Args:
438            document_id: ID of the document to update.
439            content: New text content (optional).
440            embedding: New embedding vector (optional).
441            metadata: Metadata key-value pairs to *merge* into existing metadata
442                (optional).
443
444        Returns:
445            ``True`` if the document was found and updated, ``False`` otherwise.
446        """
447        try:
448            item = self._container.read_item(item=document_id, partition_key=document_id)
449        except Exception:
450            return False
451
452        doc = self._from_cosmos(item)
453
454        if content is not None:
455            doc.content = content
456        if embedding is not None:
457            doc.embedding = embedding
458        if metadata is not None:
459            doc.metadata.update(metadata)
460
461        self._container.upsert_item(self._to_cosmos(doc))
462        return True

Update an existing document.

Args: document_id: ID of the document to update. content: New text content (optional). embedding: New embedding vector (optional). metadata: Metadata key-value pairs to merge into existing metadata (optional).

Returns: True if the document was found and updated, False otherwise.

def get_document( self, document_id: str) -> Optional[Document]:
464    def get_document(self, document_id: str) -> Optional[Document]:
465        """
466        Retrieve a single document by ID.
467
468        Args:
469            document_id: The document's unique identifier.
470
471        Returns:
472            The :class:`Document` (or subclass) if found, otherwise ``None``.
473        """
474        try:
475            item = self._container.read_item(item=document_id, partition_key=document_id)
476            return self._from_cosmos(item)
477        except Exception:
478            return None

Retrieve a single document by ID.

Args: document_id: The document's unique identifier.

Returns: The Document (or subclass) if found, otherwise None.

def count(self) -> int:
480    def count(self) -> int:
481        """Return the total number of documents in the container."""
482        query_text = "SELECT VALUE COUNT(1) FROM c"
483        items = list(self._container.query_items(
484            query=query_text,
485            enable_cross_partition_query=True,
486        ))
487        return items[0] if items else 0

Return the total number of documents in the container.

def clear(self) -> None:
489    def clear(self) -> None:
490        """
491        Remove **all** documents from the container.
492
493        Warning:
494            This operation is irreversible.
495        """
496        query_text = "SELECT c.id FROM c"
497        items = list(self._container.query_items(
498            query=query_text,
499            enable_cross_partition_query=True,
500        ))
501        for item in items:
502            try:
503                self._container.delete_item(item=item["id"], partition_key=item["id"])
504            except Exception:
505                pass
506        logger.warning(
507            "Cleared all documents from Cosmos DB container '%s'",
508            self._container_name,
509        )

Remove all documents from the container.

Warning: This operation is irreversible.

def close(self) -> None:
511    def close(self) -> None:
512        """Close the underlying Cosmos DB client."""
513        # azure-cosmos CosmosClient doesn't require explicit close,
514        # but we provide the method for interface consistency
515        pass

Close the underlying Cosmos DB client.

class MongoDBVectorStore(gmf_forge_ai_data.BaseVectorStore):
 35class MongoDBVectorStore(BaseVectorStore):
 36    """
 37    MongoDB Atlas Vector Search vector store.
 38
 39    Stores documents in a MongoDB collection and queries them using the
 40    ``$vectorSearch`` aggregation stage.  Hybrid search is implemented by
 41    running vector and text searches separately and combining their normalised
 42    scores with a 50 / 50 weight.
 43
 44    Features:
 45    - Atlas Vector Search (approximate KNN, configurable candidates)
 46    - Full-text keyword search via ``$text`` index
 47    - Hybrid search (vector + keyword, equal weighting)
 48    - Metadata pre-filtering on vector search
 49    - Custom ``document_type`` support (any Document subclass)
 50
 51    Note:
 52        The Atlas Vector Search index (``$vectorSearch``) must be created
 53        outside of this class — through the MongoDB Atlas UI, Atlas CLI, or
 54        Atlas Admin API.  The ``vector_index_name`` parameter must match the
 55        name you gave that index.
 56
 57    Usage::
 58
 59        from gmf_forge_ai_data.vector_stores import MongoDBVectorStore
 60
 61        store = MongoDBVectorStore(
 62            connection_string="mongodb+srv://...",
 63            database_name="rag_db",
 64            collection_name="documents",
 65            vector_index_name="vector_index",
 66            embedding_dimension=1536,
 67        )
 68        store.add_documents(docs_with_embeddings)
 69        results = store.search(query_embedding=embedding, top_k=5)
 70    """
 71
 72    def __init__(
 73        self,
 74        connection_string: str,
 75        database_name: str,
 76        collection_name: str,
 77        vector_index_name: str = "vector_index",
 78        embedding_dimension: int = 1536,
 79        document_type: type = Document,
 80        ssl_cert_path: Optional[str] = None,
 81    ) -> None:
 82        """
 83        Initialise the MongoDB Atlas vector store.
 84
 85        Args:
 86            connection_string: MongoDB Atlas connection string
 87                (e.g. ``mongodb+srv://user:pass@cluster.mongodb.net/``).
 88            database_name: Name of the MongoDB database.
 89            collection_name: Name of the collection.
 90            vector_index_name: Name of the Atlas Vector Search index defined
 91                on this collection.  Defaults to ``"vector_index"``.
 92            embedding_dimension: Dimension of vector embeddings.  Must match
 93                the dimension configured in the Atlas Vector Search index
 94                (default 1536).
 95            document_type: Document class to use when deserialising results.
 96                Pass a custom subclass to support additional typed fields.
 97            ssl_cert_path: Path to a CA certificate bundle (PEM) for TLS
 98                verification.  Useful in corporate environments with custom
 99                certificate authorities.
100        """
101        try:
102            import pymongo  # noqa: F401 — checked at construction time
103        except ImportError as exc:
104            raise ImportError(
105                "pymongo is required for MongoDBVectorStore. "
106                "Install it with:  pip install pymongo"
107            ) from exc
108
109        self.embedding_dimension = embedding_dimension
110        self.document_type = document_type
111        self.vector_index_name = vector_index_name
112        self._collection_name = collection_name
113
114        client_kwargs: Dict[str, Any] = {}
115        if ssl_cert_path:
116            client_kwargs["tlsCAFile"] = ssl_cert_path
117
118        import pymongo
119
120        self._client: pymongo.MongoClient = pymongo.MongoClient(
121            connection_string, **client_kwargs
122        )
123        self._db = self._client[database_name]
124        self._collection = self._db[collection_name]
125
126    # ------------------------------------------------------------------
127    # Serialisation helpers
128    # ------------------------------------------------------------------
129
130    @staticmethod
131    def _serialize_value(obj: Any) -> Any:
132        """Recursively convert datetime objects to ISO strings."""
133        if isinstance(obj, datetime):
134            return obj.isoformat()
135        if isinstance(obj, dict):
136            return {k: MongoDBVectorStore._serialize_value(v) for k, v in obj.items()}
137        if isinstance(obj, (list, tuple)):
138            return [MongoDBVectorStore._serialize_value(i) for i in obj]
139        return obj
140
141    def _to_mongo(self, doc: Document) -> Dict[str, Any]:
142        """Convert a Document to a MongoDB document dict."""
143        doc_dict = doc.to_dict()
144        doc_dict_serialisable = self._serialize_value(doc_dict)
145        mongo_doc = {
146            "_id": doc.id,
147            "id": doc.id,
148            "content": doc.content,
149            "embedding": doc.embedding or [],
150            "timestamp": doc.timestamp.isoformat() if doc.timestamp else None,
151            "metadata": doc.metadata,
152            "document_data": json.dumps(doc_dict_serialisable),
153        }
154        # Promote custom fields (e.g. category, published_year, field, institution)
155        # to top level so they are filterable via $vectorSearch filter and $match.
156        base_keys = {"id", "content", "embedding", "timestamp", "metadata"}
157        for key, value in doc_dict_serialisable.items():
158            if key not in base_keys:
159                mongo_doc[key] = value
160        return mongo_doc
161
162    def _from_mongo(self, mongo_doc: Dict[str, Any]) -> Document:
163        """Reconstruct a Document (or subclass) from a MongoDB document."""
164        doc_data = json.loads(mongo_doc.get("document_data", "{}"))
165        return self.document_type.from_dict(doc_data)
166
167    # ------------------------------------------------------------------
168    # Filter helpers
169    # ------------------------------------------------------------------
170
171    @staticmethod
172    def _build_filter(filters: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
173        """Translate a user-facing filter dict to a MongoDB query expression."""
174        if not filters:
175            return None
176        op_map = {">=": "$gte", "<=": "$lte", ">": "$gt", "<": "$lt", "!=": "$ne"}
177        match: Dict[str, Any] = {}
178        for key, value in filters.items():
179            if isinstance(value, tuple) and len(value) == 2:
180                op, val = value
181                match[key] = {op_map[op]: val} if op in op_map else val
182            else:
183                match[key] = value
184        return match
185
186    # ------------------------------------------------------------------
187    # BaseVectorStore interface
188    # ------------------------------------------------------------------
189
190    def add_documents(
191        self,
192        documents: List[Document],
193        generate_embeddings: bool = True,
194    ) -> List[str]:
195        """
196        Upsert documents into the MongoDB collection.
197
198        Args:
199            documents: Documents to add.  Each must have ``id`` and ``content``
200                set.  If *generate_embeddings* is ``True`` (default) every
201                document must also carry a pre-computed ``embedding``.
202            generate_embeddings: When ``True`` the method validates that all
203                documents already have embeddings rather than generating them
204                itself (generation is the caller's responsibility).
205
206        Returns:
207            List of document IDs that were successfully upserted.
208        """
209        if not documents:
210            return []
211
212        added_ids: List[str] = []
213
214        for doc in documents:
215            if not doc.id or not doc.content:
216                raise ValueError(f"Document must have id and content: {doc}")
217
218            if generate_embeddings and doc.embedding is None:
219                raise ValueError(
220                    f"Document '{doc.id}' has no embedding.  "
221                    "Generate embeddings before calling add_documents()."
222                )
223
224            mongo_doc = self._to_mongo(doc)
225            self._collection.replace_one(
226                {"_id": mongo_doc["_id"]}, mongo_doc, upsert=True
227            )
228            added_ids.append(doc.id)
229
230        logger.info(
231            "Upserted %d document(s) into MongoDB collection '%s'",
232            len(added_ids),
233            self._collection_name,
234        )
235        return added_ids
236
237    def search(
238        self,
239        query: Optional[str] = None,
240        query_embedding: Optional[List[float]] = None,
241        top_k: int = 5,
242        filters: Optional[Dict[str, Any]] = None,
243        search_type: str = "vector",
244    ) -> List[SearchResult]:
245        """
246        Search the MongoDB collection.
247
248        Args:
249            query: Plain-text query string.  Required for ``keyword`` and
250                ``hybrid`` search types.
251            query_embedding: Pre-computed query vector.  Required for
252                ``vector`` and ``hybrid`` search types.
253            top_k: Maximum number of results to return.
254            filters: Optional metadata filter dict.  Supports equality
255                (``{"key": value}``) and range (``{"key": (">=", value)}``)
256                operators.
257            search_type: One of ``"vector"``, ``"keyword"``, or ``"hybrid"``.
258
259        Returns:
260            Ranked list of :class:`SearchResult` objects.
261        """
262        if search_type not in ("vector", "keyword", "hybrid"):
263            raise ValueError(
264                f"Invalid search_type '{search_type}'. "
265                "Must be 'vector', 'keyword', or 'hybrid'."
266            )
267        if search_type in ("vector", "hybrid") and query_embedding is None:
268            raise ValueError(f"search_type='{search_type}' requires query_embedding")
269        if search_type in ("keyword", "hybrid") and query is None:
270            raise ValueError(f"search_type='{search_type}' requires query text")
271
272        if search_type == "vector":
273            return self._vector_search(query_embedding, top_k, filters)
274        if search_type == "keyword":
275            return self._keyword_search(query, top_k, filters)
276        return self._hybrid_search(query, query_embedding, top_k, filters)
277
278    # --- search helpers -----------------------------------------------
279
280    def _vector_search(
281        self,
282        query_embedding: List[float],
283        top_k: int,
284        filters: Optional[Dict[str, Any]],
285    ) -> List[SearchResult]:
286        """Execute an Atlas ``$vectorSearch`` aggregation pipeline."""
287        vector_stage: Dict[str, Any] = {
288            "index": self.vector_index_name,
289            "path": "embedding",
290            "queryVector": query_embedding,
291            # numCandidates controls recall vs. latency trade-off.
292            # A value of top_k * 10 (min 100) is a sensible default.
293            "numCandidates": max(top_k * 10, 100),
294            "limit": top_k,
295        }
296
297        pre_filter = self._build_filter(filters)
298        if pre_filter:
299            vector_stage["filter"] = pre_filter
300
301        pipeline = [
302            {"$vectorSearch": vector_stage},
303            {
304                "$project": {
305                    "_id": 1,
306                    "document_data": 1,
307                    "score": {"$meta": "vectorSearchScore"},
308                }
309            },
310        ]
311
312        try:
313            return [
314                SearchResult(
315                    document=self._from_mongo(raw),
316                    score=float(raw.get("score", 0.0)),
317                    rank=rank,
318                )
319                for rank, raw in enumerate(self._collection.aggregate(pipeline))
320            ]
321        except Exception as exc:
322            msg = str(exc)
323            if "localhost:28000" in msg or "HostUnreachable" in msg or "PlanExecutor" in msg:
324                raise RuntimeError(
325                    f"Atlas Vector Search index '{self.vector_index_name}' is not available. "
326                    "Create it in the Atlas UI (collection → Search Indexes → Create Search Index → "
327                    "Atlas Vector Search) with field='embedding', "
328                    f"numDimensions={self.embedding_dimension}, similarity=cosine, "
329                    f"indexName='{self.vector_index_name}'."
330                ) from exc
331            raise
332
333    def _keyword_search(
334        self,
335        query: str,
336        top_k: int,
337        filters: Optional[Dict[str, Any]],
338    ) -> List[SearchResult]:
339        """Execute a full-text search via the MongoDB ``$text`` operator.
340
341        Requires a text index on the ``content`` field to be created by the
342        developer before use::
343
344            db.collection.createIndex({ "content": "text" })
345
346        or via the Atlas UI / CLI / Admin API.
347        """
348        mongo_filter: Dict[str, Any] = {"$text": {"$search": query}}
349        pre_filter = self._build_filter(filters)
350        if pre_filter:
351            mongo_filter.update(pre_filter)
352
353        cursor = (
354            self._collection.find(
355                mongo_filter,
356                {"document_data": 1, "textScore": {"$meta": "textScore"}},
357            )
358            .sort([("textScore", {"$meta": "textScore"})])
359            .limit(top_k)
360        )
361
362        try:
363            return [
364                SearchResult(
365                    document=self._from_mongo(raw),
366                    score=float(raw.get("textScore", 0.0)),
367                    rank=rank,
368                )
369                for rank, raw in enumerate(cursor)
370            ]
371        except Exception as exc:
372            msg = str(exc)
373            if "text index required" in msg.lower() or "$text" in msg:
374                raise RuntimeError(
375                    "No text index found on the collection. "
376                    "Create one before using keyword or hybrid search: "
377                    f"db.{self._collection.name}.createIndex({{ 'content': 'text' }})"
378                ) from exc
379            raise
380
381    def _hybrid_search(
382        self,
383        query: str,
384        query_embedding: List[float],
385        top_k: int,
386        filters: Optional[Dict[str, Any]],
387    ) -> List[SearchResult]:
388        """Combine vector and keyword results with equal (50/50) weighting."""
389        vector_results = self._vector_search(query_embedding, top_k, filters)
390        keyword_results = self._keyword_search(query, top_k, filters)
391
392        def _max(results: List[SearchResult]) -> float:
393            return max((r.score for r in results), default=1.0) or 1.0
394
395        v_max = _max(vector_results)
396        k_max = _max(keyword_results)
397
398        combined: Dict[str, float] = {}
399        doc_map: Dict[str, Document] = {}
400
401        for r in vector_results:
402            combined[r.document.id] = 0.5 * (r.score / v_max)
403            doc_map[r.document.id] = r.document
404
405        for r in keyword_results:
406            combined[r.document.id] = (
407                combined.get(r.document.id, 0.0) + 0.5 * (r.score / k_max)
408            )
409            doc_map.setdefault(r.document.id, r.document)
410
411        sorted_ids = sorted(combined, key=lambda x: combined[x], reverse=True)[:top_k]
412        return [
413            SearchResult(
414                document=doc_map[doc_id],
415                score=combined[doc_id],
416                rank=rank,
417            )
418            for rank, doc_id in enumerate(sorted_ids)
419        ]
420
421    # ------------------------------------------------------------------
422    # Document lifecycle
423    # ------------------------------------------------------------------
424
425    def delete_documents(self, document_ids: List[str]) -> int:
426        """
427        Delete documents by ID.
428
429        Args:
430            document_ids: IDs of the documents to remove.
431
432        Returns:
433            Number of documents actually deleted.
434        """
435        result = self._collection.delete_many({"_id": {"$in": document_ids}})
436        return result.deleted_count
437
438    def update_document(
439        self,
440        document_id: str,
441        content: Optional[str] = None,
442        embedding: Optional[List[float]] = None,
443        metadata: Optional[Dict[str, Any]] = None,
444    ) -> bool:
445        """
446        Update an existing document.
447
448        Args:
449            document_id: ID of the document to update.
450            content: New text content (optional).
451            embedding: New embedding vector (optional).
452            metadata: Metadata key-value pairs to *merge* into existing metadata
453                (optional).
454
455        Returns:
456            ``True`` if the document was found and updated, ``False`` otherwise.
457        """
458        raw = self._collection.find_one({"_id": document_id})
459        if raw is None:
460            return False
461
462        doc = self._from_mongo(raw)
463
464        if content is not None:
465            doc.content = content
466        if embedding is not None:
467            doc.embedding = embedding
468        if metadata is not None:
469            doc.metadata.update(metadata)
470
471        self._collection.replace_one({"_id": document_id}, self._to_mongo(doc))
472        return True
473
474    def get_document(self, document_id: str) -> Optional[Document]:
475        """
476        Retrieve a single document by ID.
477
478        Args:
479            document_id: The document's unique identifier.
480
481        Returns:
482            The :class:`Document` (or subclass) if found, otherwise ``None``.
483        """
484        raw = self._collection.find_one({"_id": document_id})
485        return None if raw is None else self._from_mongo(raw)
486
487    def count(self) -> int:
488        """Return the total number of documents in the collection."""
489        return self._collection.count_documents({})
490
491    def clear(self) -> None:
492        """
493        Remove **all** documents from the collection.
494
495        Warning:
496            This operation is irreversible.
497        """
498        self._collection.delete_many({})
499        logger.warning(
500            "Cleared all documents from MongoDB collection '%s'",
501            self._collection_name,
502        )
503
504    def close(self) -> None:
505        """Close the underlying MongoDB client connection."""
506        self._client.close()

MongoDB Atlas Vector Search vector store.

Stores documents in a MongoDB collection and queries them using the $vectorSearch aggregation stage. Hybrid search is implemented by running vector and text searches separately and combining their normalised scores with a 50 / 50 weight.

Features:

  • Atlas Vector Search (approximate KNN, configurable candidates)
  • Full-text keyword search via $text index
  • Hybrid search (vector + keyword, equal weighting)
  • Metadata pre-filtering on vector search
  • Custom document_type support (any Document subclass)

Note: The Atlas Vector Search index ($vectorSearch) must be created outside of this class — through the MongoDB Atlas UI, Atlas CLI, or Atlas Admin API. The vector_index_name parameter must match the name you gave that index.

Usage::

from gmf_forge_ai_data.vector_stores import MongoDBVectorStore

store = MongoDBVectorStore(
    connection_string="mongodb+srv://...",
    database_name="rag_db",
    collection_name="documents",
    vector_index_name="vector_index",
    embedding_dimension=1536,
)
store.add_documents(docs_with_embeddings)
results = store.search(query_embedding=embedding, top_k=5)
MongoDBVectorStore( connection_string: str, database_name: str, collection_name: str, vector_index_name: str = 'vector_index', embedding_dimension: int = 1536, document_type: type = <class 'Document'>, ssl_cert_path: Optional[str] = None)
 72    def __init__(
 73        self,
 74        connection_string: str,
 75        database_name: str,
 76        collection_name: str,
 77        vector_index_name: str = "vector_index",
 78        embedding_dimension: int = 1536,
 79        document_type: type = Document,
 80        ssl_cert_path: Optional[str] = None,
 81    ) -> None:
 82        """
 83        Initialise the MongoDB Atlas vector store.
 84
 85        Args:
 86            connection_string: MongoDB Atlas connection string
 87                (e.g. ``mongodb+srv://user:pass@cluster.mongodb.net/``).
 88            database_name: Name of the MongoDB database.
 89            collection_name: Name of the collection.
 90            vector_index_name: Name of the Atlas Vector Search index defined
 91                on this collection.  Defaults to ``"vector_index"``.
 92            embedding_dimension: Dimension of vector embeddings.  Must match
 93                the dimension configured in the Atlas Vector Search index
 94                (default 1536).
 95            document_type: Document class to use when deserialising results.
 96                Pass a custom subclass to support additional typed fields.
 97            ssl_cert_path: Path to a CA certificate bundle (PEM) for TLS
 98                verification.  Useful in corporate environments with custom
 99                certificate authorities.
100        """
101        try:
102            import pymongo  # noqa: F401 — checked at construction time
103        except ImportError as exc:
104            raise ImportError(
105                "pymongo is required for MongoDBVectorStore. "
106                "Install it with:  pip install pymongo"
107            ) from exc
108
109        self.embedding_dimension = embedding_dimension
110        self.document_type = document_type
111        self.vector_index_name = vector_index_name
112        self._collection_name = collection_name
113
114        client_kwargs: Dict[str, Any] = {}
115        if ssl_cert_path:
116            client_kwargs["tlsCAFile"] = ssl_cert_path
117
118        import pymongo
119
120        self._client: pymongo.MongoClient = pymongo.MongoClient(
121            connection_string, **client_kwargs
122        )
123        self._db = self._client[database_name]
124        self._collection = self._db[collection_name]

Initialise the MongoDB Atlas vector store.

Args: connection_string: MongoDB Atlas connection string (e.g. mongodb+srv://user:pass@cluster.mongodb.net/). database_name: Name of the MongoDB database. collection_name: Name of the collection. vector_index_name: Name of the Atlas Vector Search index defined on this collection. Defaults to "vector_index". embedding_dimension: Dimension of vector embeddings. Must match the dimension configured in the Atlas Vector Search index (default 1536). document_type: Document class to use when deserialising results. Pass a custom subclass to support additional typed fields. ssl_cert_path: Path to a CA certificate bundle (PEM) for TLS verification. Useful in corporate environments with custom certificate authorities.

embedding_dimension
document_type
vector_index_name
def add_documents( self, documents: List[Document], generate_embeddings: bool = True) -> List[str]:
190    def add_documents(
191        self,
192        documents: List[Document],
193        generate_embeddings: bool = True,
194    ) -> List[str]:
195        """
196        Upsert documents into the MongoDB collection.
197
198        Args:
199            documents: Documents to add.  Each must have ``id`` and ``content``
200                set.  If *generate_embeddings* is ``True`` (default) every
201                document must also carry a pre-computed ``embedding``.
202            generate_embeddings: When ``True`` the method validates that all
203                documents already have embeddings rather than generating them
204                itself (generation is the caller's responsibility).
205
206        Returns:
207            List of document IDs that were successfully upserted.
208        """
209        if not documents:
210            return []
211
212        added_ids: List[str] = []
213
214        for doc in documents:
215            if not doc.id or not doc.content:
216                raise ValueError(f"Document must have id and content: {doc}")
217
218            if generate_embeddings and doc.embedding is None:
219                raise ValueError(
220                    f"Document '{doc.id}' has no embedding.  "
221                    "Generate embeddings before calling add_documents()."
222                )
223
224            mongo_doc = self._to_mongo(doc)
225            self._collection.replace_one(
226                {"_id": mongo_doc["_id"]}, mongo_doc, upsert=True
227            )
228            added_ids.append(doc.id)
229
230        logger.info(
231            "Upserted %d document(s) into MongoDB collection '%s'",
232            len(added_ids),
233            self._collection_name,
234        )
235        return added_ids

Upsert documents into the MongoDB collection.

Args: documents: Documents to add. Each must have id and content set. If generate_embeddings is True (default) every document must also carry a pre-computed embedding. generate_embeddings: When True the method validates that all documents already have embeddings rather than generating them itself (generation is the caller's responsibility).

Returns: List of document IDs that were successfully upserted.

def search( self, query: Optional[str] = None, query_embedding: Optional[List[float]] = None, top_k: int = 5, filters: Optional[Dict[str, Any]] = None, search_type: str = 'vector') -> List[SearchResult]:
237    def search(
238        self,
239        query: Optional[str] = None,
240        query_embedding: Optional[List[float]] = None,
241        top_k: int = 5,
242        filters: Optional[Dict[str, Any]] = None,
243        search_type: str = "vector",
244    ) -> List[SearchResult]:
245        """
246        Search the MongoDB collection.
247
248        Args:
249            query: Plain-text query string.  Required for ``keyword`` and
250                ``hybrid`` search types.
251            query_embedding: Pre-computed query vector.  Required for
252                ``vector`` and ``hybrid`` search types.
253            top_k: Maximum number of results to return.
254            filters: Optional metadata filter dict.  Supports equality
255                (``{"key": value}``) and range (``{"key": (">=", value)}``)
256                operators.
257            search_type: One of ``"vector"``, ``"keyword"``, or ``"hybrid"``.
258
259        Returns:
260            Ranked list of :class:`SearchResult` objects.
261        """
262        if search_type not in ("vector", "keyword", "hybrid"):
263            raise ValueError(
264                f"Invalid search_type '{search_type}'. "
265                "Must be 'vector', 'keyword', or 'hybrid'."
266            )
267        if search_type in ("vector", "hybrid") and query_embedding is None:
268            raise ValueError(f"search_type='{search_type}' requires query_embedding")
269        if search_type in ("keyword", "hybrid") and query is None:
270            raise ValueError(f"search_type='{search_type}' requires query text")
271
272        if search_type == "vector":
273            return self._vector_search(query_embedding, top_k, filters)
274        if search_type == "keyword":
275            return self._keyword_search(query, top_k, filters)
276        return self._hybrid_search(query, query_embedding, top_k, filters)

Search the MongoDB collection.

Args: query: Plain-text query string. Required for keyword and hybrid search types. query_embedding: Pre-computed query vector. Required for vector and hybrid search types. top_k: Maximum number of results to return. filters: Optional metadata filter dict. Supports equality ({"key": value}) and range ({"key": (">=", value)}) operators. search_type: One of "vector", "keyword", or "hybrid".

Returns: Ranked list of SearchResult objects.

def delete_documents(self, document_ids: List[str]) -> int:
425    def delete_documents(self, document_ids: List[str]) -> int:
426        """
427        Delete documents by ID.
428
429        Args:
430            document_ids: IDs of the documents to remove.
431
432        Returns:
433            Number of documents actually deleted.
434        """
435        result = self._collection.delete_many({"_id": {"$in": document_ids}})
436        return result.deleted_count

Delete documents by ID.

Args: document_ids: IDs of the documents to remove.

Returns: Number of documents actually deleted.

def update_document( self, document_id: str, content: Optional[str] = None, embedding: Optional[List[float]] = None, metadata: Optional[Dict[str, Any]] = None) -> bool:
438    def update_document(
439        self,
440        document_id: str,
441        content: Optional[str] = None,
442        embedding: Optional[List[float]] = None,
443        metadata: Optional[Dict[str, Any]] = None,
444    ) -> bool:
445        """
446        Update an existing document.
447
448        Args:
449            document_id: ID of the document to update.
450            content: New text content (optional).
451            embedding: New embedding vector (optional).
452            metadata: Metadata key-value pairs to *merge* into existing metadata
453                (optional).
454
455        Returns:
456            ``True`` if the document was found and updated, ``False`` otherwise.
457        """
458        raw = self._collection.find_one({"_id": document_id})
459        if raw is None:
460            return False
461
462        doc = self._from_mongo(raw)
463
464        if content is not None:
465            doc.content = content
466        if embedding is not None:
467            doc.embedding = embedding
468        if metadata is not None:
469            doc.metadata.update(metadata)
470
471        self._collection.replace_one({"_id": document_id}, self._to_mongo(doc))
472        return True

Update an existing document.

Args: document_id: ID of the document to update. content: New text content (optional). embedding: New embedding vector (optional). metadata: Metadata key-value pairs to merge into existing metadata (optional).

Returns: True if the document was found and updated, False otherwise.

def get_document( self, document_id: str) -> Optional[Document]:
474    def get_document(self, document_id: str) -> Optional[Document]:
475        """
476        Retrieve a single document by ID.
477
478        Args:
479            document_id: The document's unique identifier.
480
481        Returns:
482            The :class:`Document` (or subclass) if found, otherwise ``None``.
483        """
484        raw = self._collection.find_one({"_id": document_id})
485        return None if raw is None else self._from_mongo(raw)

Retrieve a single document by ID.

Args: document_id: The document's unique identifier.

Returns: The Document (or subclass) if found, otherwise None.

def count(self) -> int:
487    def count(self) -> int:
488        """Return the total number of documents in the collection."""
489        return self._collection.count_documents({})

Return the total number of documents in the collection.

def clear(self) -> None:
491    def clear(self) -> None:
492        """
493        Remove **all** documents from the collection.
494
495        Warning:
496            This operation is irreversible.
497        """
498        self._collection.delete_many({})
499        logger.warning(
500            "Cleared all documents from MongoDB collection '%s'",
501            self._collection_name,
502        )

Remove all documents from the collection.

Warning: This operation is irreversible.

def close(self) -> None:
504    def close(self) -> None:
505        """Close the underlying MongoDB client connection."""
506        self._client.close()

Close the underlying MongoDB client connection.

class BaseRetriever(abc.ABC):
 35class BaseRetriever(ABC):
 36    """
 37    Abstract base class for all retriever implementations.
 38    
 39    Retrievers provide various strategies for finding relevant documents
 40    from a vector store or other data source.
 41    
 42    All implementations must provide:
 43    - retrieve(): Main retrieval method
 44    """
 45    
 46    @abstractmethod
 47    def retrieve(
 48        self,
 49        query: RetrievalQuery
 50    ) -> List[SearchResult]:
 51        """
 52        Retrieve relevant documents based on the query.
 53        
 54        Args:
 55            query: RetrievalQuery containing query parameters
 56        
 57        Returns:
 58            List of SearchResult objects, ordered by relevance
 59        
 60        Raises:
 61            ValueError: If required query parameters are missing
 62        """
 63        pass
 64    
 65    def retrieve_text(
 66        self,
 67        text: str,
 68        top_k: int = 5,
 69        filters: Optional[Dict[str, Any]] = None
 70    ) -> List[SearchResult]:
 71        """
 72        Convenience method for text-based retrieval.
 73        
 74        Args:
 75            text: Query text
 76            top_k: Number of results to retrieve
 77            filters: Optional metadata filters
 78        
 79        Returns:
 80            List of SearchResult objects
 81        """
 82        query = RetrievalQuery(
 83            text=text,
 84            top_k=top_k,
 85            filters=filters
 86        )
 87        return self.retrieve(query)
 88    
 89    def retrieve_embedding(
 90        self,
 91        embedding: List[float],
 92        top_k: int = 5,
 93        filters: Optional[Dict[str, Any]] = None
 94    ) -> List[SearchResult]:
 95        """
 96        Convenience method for embedding-based retrieval.
 97        
 98        Args:
 99            embedding: Query embedding vector
100            top_k: Number of results to retrieve
101            filters: Optional metadata filters
102        
103        Returns:
104            List of SearchResult objects
105        """
106        query = RetrievalQuery(
107            embedding=embedding,
108            top_k=top_k,
109            filters=filters
110        )
111        return self.retrieve(query)

Abstract base class for all retriever implementations.

Retrievers provide various strategies for finding relevant documents from a vector store or other data source.

All implementations must provide:

  • retrieve(): Main retrieval method
@abstractmethod
def retrieve( self, query: RetrievalQuery) -> List[SearchResult]:
46    @abstractmethod
47    def retrieve(
48        self,
49        query: RetrievalQuery
50    ) -> List[SearchResult]:
51        """
52        Retrieve relevant documents based on the query.
53        
54        Args:
55            query: RetrievalQuery containing query parameters
56        
57        Returns:
58            List of SearchResult objects, ordered by relevance
59        
60        Raises:
61            ValueError: If required query parameters are missing
62        """
63        pass

Retrieve relevant documents based on the query.

Args: query: RetrievalQuery containing query parameters

Returns: List of SearchResult objects, ordered by relevance

Raises: ValueError: If required query parameters are missing

def retrieve_text( self, text: str, top_k: int = 5, filters: Optional[Dict[str, Any]] = None) -> List[SearchResult]:
65    def retrieve_text(
66        self,
67        text: str,
68        top_k: int = 5,
69        filters: Optional[Dict[str, Any]] = None
70    ) -> List[SearchResult]:
71        """
72        Convenience method for text-based retrieval.
73        
74        Args:
75            text: Query text
76            top_k: Number of results to retrieve
77            filters: Optional metadata filters
78        
79        Returns:
80            List of SearchResult objects
81        """
82        query = RetrievalQuery(
83            text=text,
84            top_k=top_k,
85            filters=filters
86        )
87        return self.retrieve(query)

Convenience method for text-based retrieval.

Args: text: Query text top_k: Number of results to retrieve filters: Optional metadata filters

Returns: List of SearchResult objects

def retrieve_embedding( self, embedding: List[float], top_k: int = 5, filters: Optional[Dict[str, Any]] = None) -> List[SearchResult]:
 89    def retrieve_embedding(
 90        self,
 91        embedding: List[float],
 92        top_k: int = 5,
 93        filters: Optional[Dict[str, Any]] = None
 94    ) -> List[SearchResult]:
 95        """
 96        Convenience method for embedding-based retrieval.
 97        
 98        Args:
 99            embedding: Query embedding vector
100            top_k: Number of results to retrieve
101            filters: Optional metadata filters
102        
103        Returns:
104            List of SearchResult objects
105        """
106        query = RetrievalQuery(
107            embedding=embedding,
108            top_k=top_k,
109            filters=filters
110        )
111        return self.retrieve(query)

Convenience method for embedding-based retrieval.

Args: embedding: Query embedding vector top_k: Number of results to retrieve filters: Optional metadata filters

Returns: List of SearchResult objects

@dataclass
class RetrievalQuery:
16@dataclass
17class RetrievalQuery:
18    """
19    Represents a retrieval query with optional query text and embedding.
20    
21    Attributes:
22        text: Query text (required for keyword/hybrid search)
23        embedding: Query embedding vector (required for vector search)
24        top_k: Number of results to retrieve
25        filters: Metadata filters to apply
26        metadata: Additional query metadata
27    """
28    text: Optional[str] = None
29    embedding: Optional[List[float]] = None
30    top_k: int = 5
31    filters: Optional[Dict[str, Any]] = None
32    metadata: Optional[Dict[str, Any]] = None

Represents a retrieval query with optional query text and embedding.

Attributes: text: Query text (required for keyword/hybrid search) embedding: Query embedding vector (required for vector search) top_k: Number of results to retrieve filters: Metadata filters to apply metadata: Additional query metadata

RetrievalQuery( text: Optional[str] = None, embedding: Optional[List[float]] = None, top_k: int = 5, filters: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None)
text: Optional[str] = None
embedding: Optional[List[float]] = None
top_k: int = 5
filters: Optional[Dict[str, Any]] = None
metadata: Optional[Dict[str, Any]] = None
class VectorRetriever(gmf_forge_ai_data.BaseRetriever):
17class VectorRetriever(BaseRetriever):
18    """
19    Vector similarity retriever using cosine similarity.
20    
21    Performs pure vector search using embeddings. Requires query embeddings to be
22    provided externally (e.g., using an embeddings provider).
23    
24    Example:
25        ```python
26        from gmf_forge_ai_data.retrieval import VectorRetriever, RetrievalQuery
27        from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings
28        
29        # Setup
30        embedder = AzureOpenAIEmbeddings(...)
31        retriever = VectorRetriever(vector_store)
32        
33        # Retrieve
34        query_text = "What is machine learning?"
35        query_embedding = embedder.embed_text(query_text)
36        
37        results = retriever.retrieve_embedding(
38            embedding=query_embedding,
39            top_k=5
40        )
41        ```
42    """
43    
44    def __init__(self, vector_store: BaseVectorStore):
45        """
46        Initialize vector retriever.
47        
48        Args:
49            vector_store: Vector store to search
50        """
51        self.vector_store = vector_store
52    
53    def retrieve(self, query: RetrievalQuery) -> List[SearchResult]:
54        """
55        Retrieve documents using vector similarity search.
56        
57        Args:
58            query: RetrievalQuery with embedding and parameters
59        
60        Returns:
61            List of SearchResult objects ordered by similarity
62        
63        Raises:
64            ValueError: If query.embedding is None
65        """
66        if query.embedding is None:
67            raise ValueError("VectorRetriever requires query.embedding")
68        
69        return self.vector_store.search(
70            query_embedding=query.embedding,
71            top_k=query.top_k,
72            filters=query.filters,
73            search_type="vector"
74        )

Vector similarity retriever using cosine similarity.

Performs pure vector search using embeddings. Requires query embeddings to be provided externally (e.g., using an embeddings provider).

Example:

from gmf_forge_ai_data.retrieval import VectorRetriever, RetrievalQuery
from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings

# Setup
embedder = AzureOpenAIEmbeddings(...)
retriever = VectorRetriever(vector_store)

# Retrieve
query_text = "What is machine learning?"
query_embedding = embedder.embed_text(query_text)

results = retriever.retrieve_embedding(
    embedding=query_embedding,
    top_k=5
)
VectorRetriever( vector_store: BaseVectorStore)
44    def __init__(self, vector_store: BaseVectorStore):
45        """
46        Initialize vector retriever.
47        
48        Args:
49            vector_store: Vector store to search
50        """
51        self.vector_store = vector_store

Initialize vector retriever.

Args: vector_store: Vector store to search

vector_store
def retrieve( self, query: RetrievalQuery) -> List[SearchResult]:
53    def retrieve(self, query: RetrievalQuery) -> List[SearchResult]:
54        """
55        Retrieve documents using vector similarity search.
56        
57        Args:
58            query: RetrievalQuery with embedding and parameters
59        
60        Returns:
61            List of SearchResult objects ordered by similarity
62        
63        Raises:
64            ValueError: If query.embedding is None
65        """
66        if query.embedding is None:
67            raise ValueError("VectorRetriever requires query.embedding")
68        
69        return self.vector_store.search(
70            query_embedding=query.embedding,
71            top_k=query.top_k,
72            filters=query.filters,
73            search_type="vector"
74        )

Retrieve documents using vector similarity search.

Args: query: RetrievalQuery with embedding and parameters

Returns: List of SearchResult objects ordered by similarity

Raises: ValueError: If query.embedding is None

class KeywordRetriever(gmf_forge_ai_data.BaseRetriever):
 77class KeywordRetriever(BaseRetriever):
 78    """
 79    Keyword/BM25 retriever using text matching.
 80    
 81    Performs traditional keyword-based search without using embeddings.
 82    Uses BM25 for Azure AI Search, Jaccard similarity for in-memory.
 83    
 84    Example:
 85        ```python
 86        from gmf_forge_ai_data.retrieval import KeywordRetriever, RetrievalQuery
 87        
 88        retriever = KeywordRetriever(vector_store)
 89        
 90        results = retriever.retrieve_text(
 91            text="machine learning algorithms",
 92            top_k=5
 93        )
 94        ```
 95    """
 96    
 97    def __init__(self, vector_store: BaseVectorStore):
 98        """
 99        Initialize keyword retriever.
100        
101        Args:
102            vector_store: Vector store to search
103        """
104        self.vector_store = vector_store
105    
106    def retrieve(self, query: RetrievalQuery) -> List[SearchResult]:
107        """
108        Retrieve documents using keyword search.
109        
110        Args:
111            query: RetrievalQuery with text and parameters
112        
113        Returns:
114            List of SearchResult objects ordered by keyword relevance
115        
116        Raises:
117            ValueError: If query.text is None
118        """
119        if query.text is None:
120            raise ValueError("KeywordRetriever requires query.text")
121        
122        return self.vector_store.search(
123            query=query.text,
124            top_k=query.top_k,
125            filters=query.filters,
126            search_type="keyword"
127        )

Keyword/BM25 retriever using text matching.

Performs traditional keyword-based search without using embeddings. Uses BM25 for Azure AI Search, Jaccard similarity for in-memory.

Example:

from gmf_forge_ai_data.retrieval import KeywordRetriever, RetrievalQuery

retriever = KeywordRetriever(vector_store)

results = retriever.retrieve_text(
    text="machine learning algorithms",
    top_k=5
)
KeywordRetriever( vector_store: BaseVectorStore)
 97    def __init__(self, vector_store: BaseVectorStore):
 98        """
 99        Initialize keyword retriever.
100        
101        Args:
102            vector_store: Vector store to search
103        """
104        self.vector_store = vector_store

Initialize keyword retriever.

Args: vector_store: Vector store to search

vector_store
def retrieve( self, query: RetrievalQuery) -> List[SearchResult]:
106    def retrieve(self, query: RetrievalQuery) -> List[SearchResult]:
107        """
108        Retrieve documents using keyword search.
109        
110        Args:
111            query: RetrievalQuery with text and parameters
112        
113        Returns:
114            List of SearchResult objects ordered by keyword relevance
115        
116        Raises:
117            ValueError: If query.text is None
118        """
119        if query.text is None:
120            raise ValueError("KeywordRetriever requires query.text")
121        
122        return self.vector_store.search(
123            query=query.text,
124            top_k=query.top_k,
125            filters=query.filters,
126            search_type="keyword"
127        )

Retrieve documents using keyword search.

Args: query: RetrievalQuery with text and parameters

Returns: List of SearchResult objects ordered by keyword relevance

Raises: ValueError: If query.text is None

class HybridRetriever(gmf_forge_ai_data.BaseRetriever):
130class HybridRetriever(BaseRetriever):
131    """
132    Hybrid retriever combining vector and keyword search.
133    
134    Combines vector similarity with keyword matching for comprehensive retrieval.
135    Scores are normalized and combined (typically 50/50 weighting).
136    
137    Example:
138        ```python
139        from gmf_forge_ai_data.retrieval import HybridRetriever, RetrievalQuery
140        from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings
141        
142        embedder = AzureOpenAIEmbeddings(...)
143        retriever = HybridRetriever(vector_store)
144        
145        query_text = "machine learning"
146        query_embedding = embedder.embed_text(query_text)
147        
148        # Hybrid search using both text and embedding
149        query = RetrievalQuery(
150            text=query_text,
151            embedding=query_embedding,
152            top_k=5
153        )
154        results = retriever.retrieve(query)
155        ```
156    """
157    
158    def __init__(self, vector_store: BaseVectorStore):
159        """
160        Initialize hybrid retriever.
161        
162        Args:
163            vector_store: Vector store to search
164        """
165        self.vector_store = vector_store
166    
167    def retrieve(self, query: RetrievalQuery) -> List[SearchResult]:
168        """
169        Retrieve documents using hybrid search (vector + keyword).
170        
171        Args:
172            query: RetrievalQuery with both text and embedding
173        
174        Returns:
175            List of SearchResult objects ordered by combined relevance
176        
177        Raises:
178            ValueError: If query.text or query.embedding is None
179        """
180        if query.text is None:
181            raise ValueError("HybridRetriever requires query.text")
182        if query.embedding is None:
183            raise ValueError("HybridRetriever requires query.embedding")
184        
185        return self.vector_store.search(
186            query=query.text,
187            query_embedding=query.embedding,
188            top_k=query.top_k,
189            filters=query.filters,
190            search_type="hybrid"
191        )

Hybrid retriever combining vector and keyword search.

Combines vector similarity with keyword matching for comprehensive retrieval. Scores are normalized and combined (typically 50/50 weighting).

Example:

from gmf_forge_ai_data.retrieval import HybridRetriever, RetrievalQuery
from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings

embedder = AzureOpenAIEmbeddings(...)
retriever = HybridRetriever(vector_store)

query_text = "machine learning"
query_embedding = embedder.embed_text(query_text)

# Hybrid search using both text and embedding
query = RetrievalQuery(
    text=query_text,
    embedding=query_embedding,
    top_k=5
)
results = retriever.retrieve(query)
HybridRetriever( vector_store: BaseVectorStore)
158    def __init__(self, vector_store: BaseVectorStore):
159        """
160        Initialize hybrid retriever.
161        
162        Args:
163            vector_store: Vector store to search
164        """
165        self.vector_store = vector_store

Initialize hybrid retriever.

Args: vector_store: Vector store to search

vector_store
def retrieve( self, query: RetrievalQuery) -> List[SearchResult]:
167    def retrieve(self, query: RetrievalQuery) -> List[SearchResult]:
168        """
169        Retrieve documents using hybrid search (vector + keyword).
170        
171        Args:
172            query: RetrievalQuery with both text and embedding
173        
174        Returns:
175            List of SearchResult objects ordered by combined relevance
176        
177        Raises:
178            ValueError: If query.text or query.embedding is None
179        """
180        if query.text is None:
181            raise ValueError("HybridRetriever requires query.text")
182        if query.embedding is None:
183            raise ValueError("HybridRetriever requires query.embedding")
184        
185        return self.vector_store.search(
186            query=query.text,
187            query_embedding=query.embedding,
188            top_k=query.top_k,
189            filters=query.filters,
190            search_type="hybrid"
191        )

Retrieve documents using hybrid search (vector + keyword).

Args: query: RetrievalQuery with both text and embedding

Returns: List of SearchResult objects ordered by combined relevance

Raises: ValueError: If query.text or query.embedding is None

class MMRRetriever(gmf_forge_ai_data.BaseRetriever):
 16class MMRRetriever(BaseRetriever):
 17    """
 18    Maximal Marginal Relevance (MMR) retriever for diverse results.
 19    
 20    MMR reranks initial retrieval results to balance relevance with diversity:
 21    - High lambda (λ → 1.0): Prioritize relevance
 22    - Low lambda (λ → 0.0): Prioritize diversity
 23    - λ = 0.5: Balanced (default)
 24    
 25    Algorithm:
 26    1. Retrieve initial candidate documents (fetch_k results)
 27    2. Select most relevant document first
 28    3. For each remaining selection:
 29       - Score = λ * relevance - (1-λ) * max_similarity_to_selected
 30       - Choose document with highest score
 31    4. Return top_k diverse results
 32    
 33    Benefits:
 34    - Reduces redundancy in results
 35    - Improves coverage of different aspects
 36    - Useful for exploration and summarization
 37    
 38    Example:
 39        ```python
 40        from gmf_forge_ai_data.retrieval import MMRRetriever, RetrievalQuery
 41        from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings
 42        
 43        embedder = AzureOpenAIEmbeddings(...)
 44        retriever = MMRRetriever(
 45            vector_store=vector_store,
 46            lambda_param=0.5,  # Balanced relevance/diversity
 47            fetch_k=20         # Retrieve 20 candidates, return top_k diverse
 48        )
 49        
 50        query_text = "machine learning algorithms"
 51        query_embedding = embedder.embed_text(query_text)
 52        
 53        # Returns 5 diverse results from 20 candidates
 54        results = retriever.retrieve_embedding(
 55            embedding=query_embedding,
 56            top_k=5
 57        )
 58        ```
 59    
 60    References:
 61        Carbonell, J., & Goldstein, J. (1998). The use of MMR, diversity-based
 62        reranking for reordering documents and producing summaries.
 63    """
 64    
 65    def __init__(
 66        self,
 67        vector_store: BaseVectorStore,
 68        lambda_param: float = 0.5,
 69        fetch_k: int = 20
 70    ):
 71        """
 72        Initialize MMR retriever.
 73        
 74        Args:
 75            vector_store: Vector store to search
 76            lambda_param: Balance between relevance (1.0) and diversity (0.0)
 77            fetch_k: Number of initial candidates to fetch (should be >= top_k)
 78        
 79        Raises:
 80            ValueError: If lambda_param not in [0, 1] or fetch_k < 1
 81        """
 82        if not 0 <= lambda_param <= 1:
 83            raise ValueError(f"lambda_param must be in [0, 1], got {lambda_param}")
 84        if fetch_k < 1:
 85            raise ValueError(f"fetch_k must be >= 1, got {fetch_k}")
 86        
 87        self.vector_store = vector_store
 88        self.lambda_param = lambda_param
 89        self.fetch_k = fetch_k
 90    
 91    def retrieve(self, query: RetrievalQuery) -> List[SearchResult]:
 92        """
 93        Retrieve diverse documents using MMR reranking.
 94        
 95        Args:
 96            query: RetrievalQuery with embedding (required)
 97        
 98        Returns:
 99            List of SearchResult objects with diverse, relevant results
100        
101        Raises:
102            ValueError: If query.embedding is None
103        """
104        if query.embedding is None:
105            raise ValueError("MMRRetriever requires query.embedding for diversity calculation")
106        
107        # Fetch initial candidates (more than top_k)
108        fetch_k = max(self.fetch_k, query.top_k)
109        candidates = self.vector_store.search(
110            query_embedding=query.embedding,
111            top_k=fetch_k,
112            filters=query.filters,
113            search_type="vector"
114        )
115        
116        if len(candidates) == 0:
117            return []
118        
119        if len(candidates) <= query.top_k:
120            # Not enough candidates for MMR, return as-is
121            return candidates[:query.top_k]
122        
123        # Extract embeddings and relevance scores
124        query_embedding = np.array(query.embedding, dtype=np.float32)
125        candidate_embeddings = []
126        candidate_scores = []
127        
128        for result in candidates:
129            if result.document.embedding is None:
130                raise ValueError(
131                    f"Document {result.document.id} has no embedding. "
132                    "MMR requires all documents to have embeddings."
133                )
134            candidate_embeddings.append(result.document.embedding)
135            candidate_scores.append(result.score)
136        
137        candidate_embeddings = np.array(candidate_embeddings, dtype=np.float32)
138        candidate_scores = np.array(candidate_scores, dtype=np.float32)
139        
140        # Normalize embeddings for cosine similarity
141        query_norm = query_embedding / (np.linalg.norm(query_embedding) + 1e-10)
142        candidate_norms = candidate_embeddings / (
143            np.linalg.norm(candidate_embeddings, axis=1, keepdims=True) + 1e-10
144        )
145        
146        # MMR selection
147        selected_indices = []
148        selected_embeddings = []
149        
150        for _ in range(min(query.top_k, len(candidates))):
151            if len(selected_indices) == 0:
152                # First selection: most relevant
153                best_idx = int(np.argmax(candidate_scores))
154            else:
155                # Subsequent selections: balance relevance and diversity
156                mmr_scores = []
157                
158                for idx in range(len(candidates)):
159                    if idx in selected_indices:
160                        mmr_scores.append(-np.inf)
161                        continue
162                    
163                    # Relevance score (already normalized 0-1 from vector store)
164                    relevance = candidate_scores[idx]
165                    
166                    # Diversity score: max similarity to selected documents
167                    doc_embedding = candidate_norms[idx]
168                    similarities = [
169                        np.dot(doc_embedding, selected_embeddings[i])
170                        for i in range(len(selected_embeddings))
171                    ]
172                    max_similarity = max(similarities)
173                    
174                    # MMR score: λ * relevance - (1-λ) * max_similarity
175                    mmr_score = (
176                        self.lambda_param * relevance -
177                        (1 - self.lambda_param) * max_similarity
178                    )
179                    mmr_scores.append(mmr_score)
180                
181                best_idx = int(np.argmax(mmr_scores))
182            
183            selected_indices.append(best_idx)
184            selected_embeddings.append(candidate_norms[best_idx])
185        
186        # Build results with updated ranks
187        mmr_results = []
188        for rank, idx in enumerate(selected_indices):
189            result = candidates[idx]
190            mmr_results.append(SearchResult(
191                document=result.document,
192                score=result.score,  # Keep original relevance score
193                rank=rank
194            ))
195        
196        return mmr_results

Maximal Marginal Relevance (MMR) retriever for diverse results.

MMR reranks initial retrieval results to balance relevance with diversity:

  • High lambda (λ → 1.0): Prioritize relevance
  • Low lambda (λ → 0.0): Prioritize diversity
  • λ = 0.5: Balanced (default)

Algorithm:

  1. Retrieve initial candidate documents (fetch_k results)
  2. Select most relevant document first
  3. For each remaining selection:
    • Score = λ * relevance - (1-λ) * max_similarity_to_selected
    • Choose document with highest score
  4. Return top_k diverse results

Benefits:

  • Reduces redundancy in results
  • Improves coverage of different aspects
  • Useful for exploration and summarization

Example:

from gmf_forge_ai_data.retrieval import MMRRetriever, RetrievalQuery
from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings

embedder = AzureOpenAIEmbeddings(...)
retriever = MMRRetriever(
    vector_store=vector_store,
    lambda_param=0.5,  # Balanced relevance/diversity
    fetch_k=20         # Retrieve 20 candidates, return top_k diverse
)

query_text = "machine learning algorithms"
query_embedding = embedder.embed_text(query_text)

# Returns 5 diverse results from 20 candidates
results = retriever.retrieve_embedding(
    embedding=query_embedding,
    top_k=5
)

References: Carbonell, J., & Goldstein, J. (1998). The use of MMR, diversity-based reranking for reordering documents and producing summaries.

MMRRetriever( vector_store: BaseVectorStore, lambda_param: float = 0.5, fetch_k: int = 20)
65    def __init__(
66        self,
67        vector_store: BaseVectorStore,
68        lambda_param: float = 0.5,
69        fetch_k: int = 20
70    ):
71        """
72        Initialize MMR retriever.
73        
74        Args:
75            vector_store: Vector store to search
76            lambda_param: Balance between relevance (1.0) and diversity (0.0)
77            fetch_k: Number of initial candidates to fetch (should be >= top_k)
78        
79        Raises:
80            ValueError: If lambda_param not in [0, 1] or fetch_k < 1
81        """
82        if not 0 <= lambda_param <= 1:
83            raise ValueError(f"lambda_param must be in [0, 1], got {lambda_param}")
84        if fetch_k < 1:
85            raise ValueError(f"fetch_k must be >= 1, got {fetch_k}")
86        
87        self.vector_store = vector_store
88        self.lambda_param = lambda_param
89        self.fetch_k = fetch_k

Initialize MMR retriever.

Args: vector_store: Vector store to search lambda_param: Balance between relevance (1.0) and diversity (0.0) fetch_k: Number of initial candidates to fetch (should be >= top_k)

Raises: ValueError: If lambda_param not in [0, 1] or fetch_k < 1

vector_store
lambda_param
fetch_k
def retrieve( self, query: RetrievalQuery) -> List[SearchResult]:
 91    def retrieve(self, query: RetrievalQuery) -> List[SearchResult]:
 92        """
 93        Retrieve diverse documents using MMR reranking.
 94        
 95        Args:
 96            query: RetrievalQuery with embedding (required)
 97        
 98        Returns:
 99            List of SearchResult objects with diverse, relevant results
100        
101        Raises:
102            ValueError: If query.embedding is None
103        """
104        if query.embedding is None:
105            raise ValueError("MMRRetriever requires query.embedding for diversity calculation")
106        
107        # Fetch initial candidates (more than top_k)
108        fetch_k = max(self.fetch_k, query.top_k)
109        candidates = self.vector_store.search(
110            query_embedding=query.embedding,
111            top_k=fetch_k,
112            filters=query.filters,
113            search_type="vector"
114        )
115        
116        if len(candidates) == 0:
117            return []
118        
119        if len(candidates) <= query.top_k:
120            # Not enough candidates for MMR, return as-is
121            return candidates[:query.top_k]
122        
123        # Extract embeddings and relevance scores
124        query_embedding = np.array(query.embedding, dtype=np.float32)
125        candidate_embeddings = []
126        candidate_scores = []
127        
128        for result in candidates:
129            if result.document.embedding is None:
130                raise ValueError(
131                    f"Document {result.document.id} has no embedding. "
132                    "MMR requires all documents to have embeddings."
133                )
134            candidate_embeddings.append(result.document.embedding)
135            candidate_scores.append(result.score)
136        
137        candidate_embeddings = np.array(candidate_embeddings, dtype=np.float32)
138        candidate_scores = np.array(candidate_scores, dtype=np.float32)
139        
140        # Normalize embeddings for cosine similarity
141        query_norm = query_embedding / (np.linalg.norm(query_embedding) + 1e-10)
142        candidate_norms = candidate_embeddings / (
143            np.linalg.norm(candidate_embeddings, axis=1, keepdims=True) + 1e-10
144        )
145        
146        # MMR selection
147        selected_indices = []
148        selected_embeddings = []
149        
150        for _ in range(min(query.top_k, len(candidates))):
151            if len(selected_indices) == 0:
152                # First selection: most relevant
153                best_idx = int(np.argmax(candidate_scores))
154            else:
155                # Subsequent selections: balance relevance and diversity
156                mmr_scores = []
157                
158                for idx in range(len(candidates)):
159                    if idx in selected_indices:
160                        mmr_scores.append(-np.inf)
161                        continue
162                    
163                    # Relevance score (already normalized 0-1 from vector store)
164                    relevance = candidate_scores[idx]
165                    
166                    # Diversity score: max similarity to selected documents
167                    doc_embedding = candidate_norms[idx]
168                    similarities = [
169                        np.dot(doc_embedding, selected_embeddings[i])
170                        for i in range(len(selected_embeddings))
171                    ]
172                    max_similarity = max(similarities)
173                    
174                    # MMR score: λ * relevance - (1-λ) * max_similarity
175                    mmr_score = (
176                        self.lambda_param * relevance -
177                        (1 - self.lambda_param) * max_similarity
178                    )
179                    mmr_scores.append(mmr_score)
180                
181                best_idx = int(np.argmax(mmr_scores))
182            
183            selected_indices.append(best_idx)
184            selected_embeddings.append(candidate_norms[best_idx])
185        
186        # Build results with updated ranks
187        mmr_results = []
188        for rank, idx in enumerate(selected_indices):
189            result = candidates[idx]
190            mmr_results.append(SearchResult(
191                document=result.document,
192                score=result.score,  # Keep original relevance score
193                rank=rank
194            ))
195        
196        return mmr_results

Retrieve diverse documents using MMR reranking.

Args: query: RetrievalQuery with embedding (required)

Returns: List of SearchResult objects with diverse, relevant results

Raises: ValueError: If query.embedding is None

class ParentDocumentRetriever(gmf_forge_ai_data.BaseRetriever):
 16class ParentDocumentRetriever(BaseRetriever):
 17    """
 18    Retriever that searches child chunks but returns parent documents.
 19    
 20    This pattern is useful when:
 21    - Documents are chunked into small pieces for precise embedding
 22    - But you want to return full parent documents for better context
 23    - Multiple chunks from the same parent might match the query
 24    
 25    Architecture:
 26    - Child store: Contains small chunks with embeddings (searchable)
 27    - Parent store: Contains full parent documents (for retrieval)
 28    - Mapping: Children have metadata["parent_id"] pointing to parent
 29    
 30    Workflow:
 31    1. Search child store for relevant chunks
 32    2. Extract parent_ids from child metadata
 33    3. Retrieve parent documents from parent store
 34    4. Deduplicate and rank by best child score
 35    
 36    Example:
 37        ```python
 38        from gmf_forge_ai_data.retrieval import ParentDocumentRetriever
 39        from gmf_forge_ai_data.vector_stores import InMemoryVectorStore, Document
 40        from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings
 41        
 42        embedder = AzureOpenAIEmbeddings(...)
 43        
 44        # Setup stores
 45        child_store = InMemoryVectorStore()
 46        parent_store = InMemoryVectorStore()
 47        
 48        # Create parent document
 49        parent = Document(
 50            id="doc_1",
 51            content="Full document with multiple sections...",
 52            embedding=embedder.embed_text("Full document...")
 53        )
 54        parent_store.add_documents([parent])
 55        
 56        # Create child chunks
 57        child1 = Document(
 58            id="doc_1_chunk_0",
 59            content="Introduction section...",
 60            embedding=embedder.embed_text("Introduction..."),
 61            metadata={"parent_id": "doc_1"}  # Link to parent
 62        )
 63        child2 = Document(
 64            id="doc_1_chunk_1",
 65            content="Methods section...",
 66            embedding=embedder.embed_text("Methods..."),
 67            metadata={"parent_id": "doc_1"}
 68        )
 69        child_store.add_documents([child1, child2])
 70        
 71        # Setup retriever
 72        retriever = ParentDocumentRetriever(
 73            child_store=child_store,
 74            parent_store=parent_store,
 75            parent_id_key="parent_id"
 76        )
 77        
 78        # Search chunks, return parent
 79        query_embedding = embedder.embed_text("introduction")
 80        results = retriever.retrieve_embedding(
 81            embedding=query_embedding,
 82            top_k=5
 83        )
 84        # Returns full parent documents, not chunks
 85        ```
 86    
 87    Benefits:
 88    - Precise search (small chunks)
 89    - Rich context (full parents)
 90    - Automatic deduplication
 91    """
 92    
 93    def __init__(
 94        self,
 95        child_store: BaseVectorStore,
 96        parent_store: BaseVectorStore,
 97        parent_id_key: str = "parent_id",
 98        search_type: str = "vector"
 99    ):
100        """
101        Initialize parent document retriever.
102        
103        Args:
104            child_store: Vector store containing child chunks with embeddings
105            parent_store: Vector store containing parent documents
106            parent_id_key: Metadata key in children pointing to parent ID
107            search_type: Search type for child store ("vector", "keyword", "hybrid")
108        
109        Raises:
110            ValueError: If search_type is invalid
111        """
112        if search_type not in ["vector", "keyword", "hybrid"]:
113            raise ValueError(f"Invalid search_type: {search_type}")
114        
115        self.child_store = child_store
116        self.parent_store = parent_store
117        self.parent_id_key = parent_id_key
118        self.search_type = search_type
119    
120    def retrieve(self, query: RetrievalQuery) -> List[SearchResult]:
121        """
122        Retrieve parent documents by searching child chunks.
123        
124        Args:
125            query: RetrievalQuery with appropriate parameters for search_type
126        
127        Returns:
128            List of SearchResult with parent documents, deduplicated and ranked
129        
130        Raises:
131            ValueError: If required query parameters are missing
132        """
133        # Validate query based on search type
134        if self.search_type in ["vector", "hybrid"] and query.embedding is None:
135            raise ValueError(f"{self.search_type} search requires query.embedding")
136        if self.search_type in ["keyword", "hybrid"] and query.text is None:
137            raise ValueError(f"{self.search_type} search requires query.text")
138        
139        # Search child store (fetch more candidates to account for deduplication)
140        fetch_k = query.top_k * 3  # Heuristic: fetch 3x to handle duplicates
141        child_results = self.child_store.search(
142            query=query.text,
143            query_embedding=query.embedding,
144            top_k=fetch_k,
145            filters=query.filters,
146            search_type=self.search_type
147        )
148        
149        if len(child_results) == 0:
150            return []
151        
152        # Extract parent IDs and track best scores
153        parent_scores: Dict[str, float] = {}  # parent_id -> best_score
154        parent_order: OrderedDict[str, None] = OrderedDict()  # Track first occurrence
155        
156        for child_result in child_results:
157            parent_id = child_result.document.metadata.get(self.parent_id_key)
158            
159            if parent_id is None:
160                # Skip children without parent link
161                continue
162            
163            # Track best score for each parent (in case multiple children match)
164            if parent_id not in parent_scores:
165                parent_scores[parent_id] = child_result.score
166                parent_order[parent_id] = None  # Track insertion order
167            else:
168                # Update if this child has better score
169                parent_scores[parent_id] = max(
170                    parent_scores[parent_id],
171                    child_result.score
172                )
173        
174        if len(parent_scores) == 0:
175            return []
176        
177        # Retrieve parent documents
178        parent_results = []
179        for rank, parent_id in enumerate(parent_order.keys()):
180            if rank >= query.top_k:
181                break
182            
183            parent_doc = self.parent_store.get_document(parent_id)
184            if parent_doc is None:
185                # Parent not found in store, skip
186                continue
187            
188            parent_results.append(SearchResult(
189                document=parent_doc,
190                score=parent_scores[parent_id],  # Best child score
191                rank=rank
192            ))
193        
194        return parent_results

Retriever that searches child chunks but returns parent documents.

This pattern is useful when:

  • Documents are chunked into small pieces for precise embedding
  • But you want to return full parent documents for better context
  • Multiple chunks from the same parent might match the query

Architecture:

  • Child store: Contains small chunks with embeddings (searchable)
  • Parent store: Contains full parent documents (for retrieval)
  • Mapping: Children have metadata["parent_id"] pointing to parent

Workflow:

  1. Search child store for relevant chunks
  2. Extract parent_ids from child metadata
  3. Retrieve parent documents from parent store
  4. Deduplicate and rank by best child score

Example:

from gmf_forge_ai_data.retrieval import ParentDocumentRetriever
from gmf_forge_ai_data.vector_stores import InMemoryVectorStore, Document
from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings

embedder = AzureOpenAIEmbeddings(...)

# Setup stores
child_store = InMemoryVectorStore()
parent_store = InMemoryVectorStore()

# Create parent document
parent = Document(
    id="doc_1",
    content="Full document with multiple sections...",
    embedding=embedder.embed_text("Full document...")
)
parent_store.add_documents([parent])

# Create child chunks
child1 = Document(
    id="doc_1_chunk_0",
    content="Introduction section...",
    embedding=embedder.embed_text("Introduction..."),
    metadata={"parent_id": "doc_1"}  # Link to parent
)
child2 = Document(
    id="doc_1_chunk_1",
    content="Methods section...",
    embedding=embedder.embed_text("Methods..."),
    metadata={"parent_id": "doc_1"}
)
child_store.add_documents([child1, child2])

# Setup retriever
retriever = ParentDocumentRetriever(
    child_store=child_store,
    parent_store=parent_store,
    parent_id_key="parent_id"
)

# Search chunks, return parent
query_embedding = embedder.embed_text("introduction")
results = retriever.retrieve_embedding(
    embedding=query_embedding,
    top_k=5
)
# Returns full parent documents, not chunks

Benefits:

  • Precise search (small chunks)
  • Rich context (full parents)
  • Automatic deduplication
ParentDocumentRetriever( child_store: BaseVectorStore, parent_store: BaseVectorStore, parent_id_key: str = 'parent_id', search_type: str = 'vector')
 93    def __init__(
 94        self,
 95        child_store: BaseVectorStore,
 96        parent_store: BaseVectorStore,
 97        parent_id_key: str = "parent_id",
 98        search_type: str = "vector"
 99    ):
100        """
101        Initialize parent document retriever.
102        
103        Args:
104            child_store: Vector store containing child chunks with embeddings
105            parent_store: Vector store containing parent documents
106            parent_id_key: Metadata key in children pointing to parent ID
107            search_type: Search type for child store ("vector", "keyword", "hybrid")
108        
109        Raises:
110            ValueError: If search_type is invalid
111        """
112        if search_type not in ["vector", "keyword", "hybrid"]:
113            raise ValueError(f"Invalid search_type: {search_type}")
114        
115        self.child_store = child_store
116        self.parent_store = parent_store
117        self.parent_id_key = parent_id_key
118        self.search_type = search_type

Initialize parent document retriever.

Args: child_store: Vector store containing child chunks with embeddings parent_store: Vector store containing parent documents parent_id_key: Metadata key in children pointing to parent ID search_type: Search type for child store ("vector", "keyword", "hybrid")

Raises: ValueError: If search_type is invalid

child_store
parent_store
parent_id_key
search_type
def retrieve( self, query: RetrievalQuery) -> List[SearchResult]:
120    def retrieve(self, query: RetrievalQuery) -> List[SearchResult]:
121        """
122        Retrieve parent documents by searching child chunks.
123        
124        Args:
125            query: RetrievalQuery with appropriate parameters for search_type
126        
127        Returns:
128            List of SearchResult with parent documents, deduplicated and ranked
129        
130        Raises:
131            ValueError: If required query parameters are missing
132        """
133        # Validate query based on search type
134        if self.search_type in ["vector", "hybrid"] and query.embedding is None:
135            raise ValueError(f"{self.search_type} search requires query.embedding")
136        if self.search_type in ["keyword", "hybrid"] and query.text is None:
137            raise ValueError(f"{self.search_type} search requires query.text")
138        
139        # Search child store (fetch more candidates to account for deduplication)
140        fetch_k = query.top_k * 3  # Heuristic: fetch 3x to handle duplicates
141        child_results = self.child_store.search(
142            query=query.text,
143            query_embedding=query.embedding,
144            top_k=fetch_k,
145            filters=query.filters,
146            search_type=self.search_type
147        )
148        
149        if len(child_results) == 0:
150            return []
151        
152        # Extract parent IDs and track best scores
153        parent_scores: Dict[str, float] = {}  # parent_id -> best_score
154        parent_order: OrderedDict[str, None] = OrderedDict()  # Track first occurrence
155        
156        for child_result in child_results:
157            parent_id = child_result.document.metadata.get(self.parent_id_key)
158            
159            if parent_id is None:
160                # Skip children without parent link
161                continue
162            
163            # Track best score for each parent (in case multiple children match)
164            if parent_id not in parent_scores:
165                parent_scores[parent_id] = child_result.score
166                parent_order[parent_id] = None  # Track insertion order
167            else:
168                # Update if this child has better score
169                parent_scores[parent_id] = max(
170                    parent_scores[parent_id],
171                    child_result.score
172                )
173        
174        if len(parent_scores) == 0:
175            return []
176        
177        # Retrieve parent documents
178        parent_results = []
179        for rank, parent_id in enumerate(parent_order.keys()):
180            if rank >= query.top_k:
181                break
182            
183            parent_doc = self.parent_store.get_document(parent_id)
184            if parent_doc is None:
185                # Parent not found in store, skip
186                continue
187            
188            parent_results.append(SearchResult(
189                document=parent_doc,
190                score=parent_scores[parent_id],  # Best child score
191                rank=rank
192            ))
193        
194        return parent_results

Retrieve parent documents by searching child chunks.

Args: query: RetrievalQuery with appropriate parameters for search_type

Returns: List of SearchResult with parent documents, deduplicated and ranked

Raises: ValueError: If required query parameters are missing

class EnsembleRetriever(gmf_forge_ai_data.BaseRetriever):
 17class EnsembleRetriever(BaseRetriever):
 18    """
 19    Ensemble retriever combining multiple retrieval strategies.
 20    
 21    Combines results from multiple retrievers using score fusion techniques:
 22    - Reciprocal Rank Fusion (RRF): Robust, rank-based fusion
 23    - Weighted Average: Score-based fusion with configurable weights
 24    - Max Score: Conservative, consensus-based fusion
 25    
 26    Benefits:
 27    - Improved recall (combine different strategies)
 28    - Robustness (less sensitive to individual retriever failures)
 29    - Flexibility (combine vector, keyword, hybrid, MMR, etc.)
 30    
 31    Example:
 32        ```python
 33        from gmf_forge_ai_data.retrieval import (
 34            EnsembleRetriever,
 35            VectorRetriever,
 36            KeywordRetriever,
 37            MMRRetriever
 38        )
 39        from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings
 40        
 41        embedder = AzureOpenAIEmbeddings(...)
 42        
 43        # Create multiple retrievers
 44        vector_retriever = VectorRetriever(vector_store)
 45        keyword_retriever = KeywordRetriever(vector_store)
 46        mmr_retriever = MMRRetriever(vector_store, lambda_param=0.7)
 47        
 48        # Combine with RRF fusion
 49        ensemble = EnsembleRetriever(
 50            retrievers=[vector_retriever, keyword_retriever, mmr_retriever],
 51            weights=[0.5, 0.3, 0.2],  # Optional weights
 52            fusion_strategy="rrf"  # Reciprocal Rank Fusion
 53        )
 54        
 55        # Query requires parameters for all retrievers
 56        query_text = "machine learning"
 57        query_embedding = embedder.embed_text(query_text)
 58        
 59        query = RetrievalQuery(
 60            text=query_text,  # For keyword retriever
 61            embedding=query_embedding,  # For vector/MMR retrievers
 62            top_k=5
 63        )
 64        
 65        results = ensemble.retrieve(query)
 66        # Returns fused results from all retrievers
 67        ```
 68    
 69    Fusion Strategies:
 70    
 71    1. **rrf** (Reciprocal Rank Fusion):
 72       - score(doc) = Σ weights[i] / (k + rank_i)
 73       - k = 60 (default)
 74       - Robust to different score scales
 75       - Rank-based, not score-based
 76    
 77    2. **weighted_avg** (Weighted Average):
 78       - Normalize scores to [0, 1]
 79       - score(doc) = Σ weights[i] * normalized_score_i
 80       - Score-based fusion
 81    
 82    3. **max_score** (Maximum Score):
 83       - score(doc) = max(normalized_scores)
 84       - Conservative, requires consensus
 85    
 86    References:
 87        Cormack, G. V., Clarke, C. L., & Buettcher, S. (2009).
 88        Reciprocal rank fusion outperforms condorcet and individual
 89        ranklists. ACM SIGIR.
 90    """
 91    
 92    def __init__(
 93        self,
 94        retrievers: List[BaseRetriever],
 95        weights: List[float] = None,
 96        fusion_strategy: str = "rrf",
 97        rrf_k: int = 60
 98    ):
 99        """
100        Initialize ensemble retriever.
101        
102        Args:
103            retrievers: List of retrievers to combine
104            weights: Optional weights for each retriever (default: equal weights)
105            fusion_strategy: Fusion method - "rrf", "weighted_avg", or "max_score"
106            rrf_k: k parameter for RRF (default: 60)
107        
108        Raises:
109            ValueError: If retrievers list is empty or params are invalid
110        """
111        if not retrievers:
112            raise ValueError("Must provide at least one retriever")
113        
114        if fusion_strategy not in ["rrf", "weighted_avg", "max_score"]:
115            raise ValueError(
116                f"Invalid fusion_strategy: {fusion_strategy}. "
117                "Must be 'rrf', 'weighted_avg', or 'max_score'"
118            )
119        
120        if weights is None:
121            # Equal weights
122            weights = [1.0 / len(retrievers)] * len(retrievers)
123        else:
124            if len(weights) != len(retrievers):
125                raise ValueError(
126                    f"Number of weights ({len(weights)}) must match "
127                    f"number of retrievers ({len(retrievers)})"
128                )
129            # Normalize weights to sum to 1.0
130            total = sum(weights)
131            weights = [w / total for w in weights]
132        
133        self.retrievers = retrievers
134        self.weights = weights
135        self.fusion_strategy = fusion_strategy
136        self.rrf_k = rrf_k
137    
138    def retrieve(self, query: RetrievalQuery) -> List[SearchResult]:
139        """
140        Retrieve documents using ensemble fusion.
141        
142        Args:
143            query: RetrievalQuery with parameters for all retrievers
144        
145        Returns:
146            List of SearchResult with fused scores and ranks
147        
148        Note:
149            All retrievers are called with the same query. Ensure the query
150            contains all required parameters (text, embedding) for your retrievers.
151        """
152        # Retrieve from all retrievers
153        all_results: List[List[SearchResult]] = []
154        
155        for retriever in self.retrievers:
156            try:
157                results = retriever.retrieve(query)
158                all_results.append(results)
159            except ValueError as e:
160                # Retriever might not support this query type, skip it
161                all_results.append([])
162        
163        if not any(all_results):
164            return []
165        
166        # Apply fusion strategy
167        if self.fusion_strategy == "rrf":
168            fused_results = self._reciprocal_rank_fusion(all_results)
169        elif self.fusion_strategy == "weighted_avg":
170            fused_results = self._weighted_average_fusion(all_results)
171        else:  # max_score
172            fused_results = self._max_score_fusion(all_results)
173        
174        # Sort by fused score and take top_k
175        fused_results.sort(key=lambda x: x.score, reverse=True)
176        top_results = fused_results[:query.top_k]
177        
178        # Update ranks
179        for rank, result in enumerate(top_results):
180            result.rank = rank
181        
182        return top_results
183    
184    def _reciprocal_rank_fusion(
185        self,
186        all_results: List[List[SearchResult]]
187    ) -> List[SearchResult]:
188        """
189        Fuse results using Reciprocal Rank Fusion (RRF).
190        
191        RRF score for document d:
192            score(d) = Σ weight_i / (k + rank_i(d))
193        
194        Where rank_i(d) is the rank of d in retriever i's results.
195        """
196        doc_scores: Dict[str, float] = defaultdict(float)
197        doc_objects: Dict[str, Document] = {}
198        
199        for retriever_idx, results in enumerate(all_results):
200            weight = self.weights[retriever_idx]
201            
202            for result in results:
203                doc_id = result.document.id
204                
205                # RRF score contribution
206                rrf_score = weight / (self.rrf_k + result.rank)
207                doc_scores[doc_id] += rrf_score
208                
209                # Keep document object (use first occurrence)
210                if doc_id not in doc_objects:
211                    doc_objects[doc_id] = result.document
212        
213        # Build SearchResult objects
214        fused_results = [
215            SearchResult(
216                document=doc_objects[doc_id],
217                score=score,
218                rank=0  # Will be set later
219            )
220            for doc_id, score in doc_scores.items()
221        ]
222        
223        return fused_results
224    
225    def _weighted_average_fusion(
226        self,
227        all_results: List[List[SearchResult]]
228    ) -> List[SearchResult]:
229        """
230        Fuse results using weighted average of normalized scores.
231        
232        Scores are normalized to [0, 1] within each retriever, then averaged.
233        """
234        doc_scores: Dict[str, float] = defaultdict(float)
235        doc_objects: Dict[str, Document] = {}
236        
237        for retriever_idx, results in enumerate(all_results):
238            if not results:
239                continue
240            
241            weight = self.weights[retriever_idx]
242            
243            # Normalize scores to [0, 1]
244            scores = np.array([r.score for r in results])
245            min_score = scores.min()
246            max_score = scores.max()
247            
248            if max_score > min_score:
249                normalized_scores = (scores - min_score) / (max_score - min_score)
250            else:
251                normalized_scores = np.ones_like(scores)
252            
253            for result, norm_score in zip(results, normalized_scores):
254                doc_id = result.document.id
255                doc_scores[doc_id] += weight * norm_score
256                
257                if doc_id not in doc_objects:
258                    doc_objects[doc_id] = result.document
259        
260        # Build SearchResult objects
261        fused_results = [
262            SearchResult(
263                document=doc_objects[doc_id],
264                score=score,
265                rank=0
266            )
267            for doc_id, score in doc_scores.items()
268        ]
269        
270        return fused_results
271    
272    def _max_score_fusion(
273        self,
274        all_results: List[List[SearchResult]]
275    ) -> List[SearchResult]:
276        """
277        Fuse results using maximum normalized score across retrievers.
278        
279        Conservative strategy: document needs high score from at least one retriever.
280        """
281        doc_scores: Dict[str, float] = defaultdict(float)
282        doc_objects: Dict[str, Document] = {}
283        
284        for retriever_idx, results in enumerate(all_results):
285            if not results:
286                continue
287            
288            # Normalize scores to [0, 1]
289            scores = np.array([r.score for r in results])
290            min_score = scores.min()
291            max_score = scores.max()
292            
293            if max_score > min_score:
294                normalized_scores = (scores - min_score) / (max_score - min_score)
295            else:
296                normalized_scores = np.ones_like(scores)
297            
298            for result, norm_score in zip(results, normalized_scores):
299                doc_id = result.document.id
300                
301                # Take maximum normalized score
302                doc_scores[doc_id] = max(doc_scores[doc_id], norm_score)
303                
304                if doc_id not in doc_objects:
305                    doc_objects[doc_id] = result.document
306        
307        # Build SearchResult objects
308        fused_results = [
309            SearchResult(
310                document=doc_objects[doc_id],
311                score=score,
312                rank=0
313            )
314            for doc_id, score in doc_scores.items()
315        ]
316        
317        return fused_results

Ensemble retriever combining multiple retrieval strategies.

Combines results from multiple retrievers using score fusion techniques:

  • Reciprocal Rank Fusion (RRF): Robust, rank-based fusion
  • Weighted Average: Score-based fusion with configurable weights
  • Max Score: Conservative, consensus-based fusion

Benefits:

  • Improved recall (combine different strategies)
  • Robustness (less sensitive to individual retriever failures)
  • Flexibility (combine vector, keyword, hybrid, MMR, etc.)

Example:

from gmf_forge_ai_data.retrieval import (
    EnsembleRetriever,
    VectorRetriever,
    KeywordRetriever,
    MMRRetriever
)
from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings

embedder = AzureOpenAIEmbeddings(...)

# Create multiple retrievers
vector_retriever = VectorRetriever(vector_store)
keyword_retriever = KeywordRetriever(vector_store)
mmr_retriever = MMRRetriever(vector_store, lambda_param=0.7)

# Combine with RRF fusion
ensemble = EnsembleRetriever(
    retrievers=[vector_retriever, keyword_retriever, mmr_retriever],
    weights=[0.5, 0.3, 0.2],  # Optional weights
    fusion_strategy="rrf"  # Reciprocal Rank Fusion
)

# Query requires parameters for all retrievers
query_text = "machine learning"
query_embedding = embedder.embed_text(query_text)

query = RetrievalQuery(
    text=query_text,  # For keyword retriever
    embedding=query_embedding,  # For vector/MMR retrievers
    top_k=5
)

results = ensemble.retrieve(query)
# Returns fused results from all retrievers

Fusion Strategies:

  1. rrf (Reciprocal Rank Fusion):

    • score(doc) = Σ weights[i] / (k + rank_i)
    • k = 60 (default)
    • Robust to different score scales
    • Rank-based, not score-based
  2. weighted_avg (Weighted Average):

    • Normalize scores to [0, 1]
    • score(doc) = Σ weights[i] * normalized_score_i
    • Score-based fusion
  3. max_score (Maximum Score):

    • score(doc) = max(normalized_scores)
    • Conservative, requires consensus

References: Cormack, G. V., Clarke, C. L., & Buettcher, S. (2009). Reciprocal rank fusion outperforms condorcet and individual ranklists. ACM SIGIR.

EnsembleRetriever( retrievers: List[BaseRetriever], weights: List[float] = None, fusion_strategy: str = 'rrf', rrf_k: int = 60)
 92    def __init__(
 93        self,
 94        retrievers: List[BaseRetriever],
 95        weights: List[float] = None,
 96        fusion_strategy: str = "rrf",
 97        rrf_k: int = 60
 98    ):
 99        """
100        Initialize ensemble retriever.
101        
102        Args:
103            retrievers: List of retrievers to combine
104            weights: Optional weights for each retriever (default: equal weights)
105            fusion_strategy: Fusion method - "rrf", "weighted_avg", or "max_score"
106            rrf_k: k parameter for RRF (default: 60)
107        
108        Raises:
109            ValueError: If retrievers list is empty or params are invalid
110        """
111        if not retrievers:
112            raise ValueError("Must provide at least one retriever")
113        
114        if fusion_strategy not in ["rrf", "weighted_avg", "max_score"]:
115            raise ValueError(
116                f"Invalid fusion_strategy: {fusion_strategy}. "
117                "Must be 'rrf', 'weighted_avg', or 'max_score'"
118            )
119        
120        if weights is None:
121            # Equal weights
122            weights = [1.0 / len(retrievers)] * len(retrievers)
123        else:
124            if len(weights) != len(retrievers):
125                raise ValueError(
126                    f"Number of weights ({len(weights)}) must match "
127                    f"number of retrievers ({len(retrievers)})"
128                )
129            # Normalize weights to sum to 1.0
130            total = sum(weights)
131            weights = [w / total for w in weights]
132        
133        self.retrievers = retrievers
134        self.weights = weights
135        self.fusion_strategy = fusion_strategy
136        self.rrf_k = rrf_k

Initialize ensemble retriever.

Args: retrievers: List of retrievers to combine weights: Optional weights for each retriever (default: equal weights) fusion_strategy: Fusion method - "rrf", "weighted_avg", or "max_score" rrf_k: k parameter for RRF (default: 60)

Raises: ValueError: If retrievers list is empty or params are invalid

retrievers
weights
fusion_strategy
rrf_k
def retrieve( self, query: RetrievalQuery) -> List[SearchResult]:
138    def retrieve(self, query: RetrievalQuery) -> List[SearchResult]:
139        """
140        Retrieve documents using ensemble fusion.
141        
142        Args:
143            query: RetrievalQuery with parameters for all retrievers
144        
145        Returns:
146            List of SearchResult with fused scores and ranks
147        
148        Note:
149            All retrievers are called with the same query. Ensure the query
150            contains all required parameters (text, embedding) for your retrievers.
151        """
152        # Retrieve from all retrievers
153        all_results: List[List[SearchResult]] = []
154        
155        for retriever in self.retrievers:
156            try:
157                results = retriever.retrieve(query)
158                all_results.append(results)
159            except ValueError as e:
160                # Retriever might not support this query type, skip it
161                all_results.append([])
162        
163        if not any(all_results):
164            return []
165        
166        # Apply fusion strategy
167        if self.fusion_strategy == "rrf":
168            fused_results = self._reciprocal_rank_fusion(all_results)
169        elif self.fusion_strategy == "weighted_avg":
170            fused_results = self._weighted_average_fusion(all_results)
171        else:  # max_score
172            fused_results = self._max_score_fusion(all_results)
173        
174        # Sort by fused score and take top_k
175        fused_results.sort(key=lambda x: x.score, reverse=True)
176        top_results = fused_results[:query.top_k]
177        
178        # Update ranks
179        for rank, result in enumerate(top_results):
180            result.rank = rank
181        
182        return top_results

Retrieve documents using ensemble fusion.

Args: query: RetrievalQuery with parameters for all retrievers

Returns: List of SearchResult with fused scores and ranks

Note: All retrievers are called with the same query. Ensure the query contains all required parameters (text, embedding) for your retrievers.

class HierarchicalRetriever(gmf_forge_ai_data.BaseRetriever):
 24class HierarchicalRetriever(BaseRetriever):
 25    """
 26    Two-stage hierarchical retrieval for efficient search in large collections.
 27    
 28    Stage 1: Retrieve document summaries to identify relevant documents
 29    Stage 2: Retrieve detailed chunks from top documents only
 30    
 31    This reduces computational cost by focusing detailed retrieval on relevant documents.
 32    
 33    Example:
 34        ```python
 35        # Create summary and chunk retrievers
 36        summary_retriever = VectorRetriever(summary_store, embedder)
 37        chunk_retriever = VectorRetriever(chunk_store, embedder)
 38        
 39        # Create hierarchical retriever
 40        retriever = HierarchicalRetriever(
 41            summary_retriever=summary_retriever,
 42            chunk_retriever=chunk_retriever,
 43            stage1_top_k=10,  # Get top 10 documents
 44            stage2_top_k=5,   # Get 5 chunks per document
 45            document_id_field="document_id"
 46        )
 47        
 48        # Retrieve
 49        query = RetrievalQuery(text="machine learning", top_k=20)
 50        results = retriever.retrieve(query)
 51        ```
 52    """
 53    
 54    def __init__(
 55        self,
 56        summary_retriever: BaseRetriever,
 57        chunk_retriever: BaseRetriever,
 58        stage1_top_k: int = 10,
 59        stage2_top_k: int = 5,
 60        document_id_field: str = "document_id",
 61        combine_scores: bool = True,
 62        stage1_weight: float = 0.3,
 63        stage2_weight: float = 0.7
 64    ):
 65        """
 66        Initialize hierarchical retriever.
 67        
 68        Args:
 69            summary_retriever: Retriever for document summaries (stage 1)
 70            chunk_retriever: Retriever for detailed chunks (stage 2)
 71            stage1_top_k: Number of documents to retrieve in stage 1
 72            stage2_top_k: Number of chunks to retrieve per document in stage 2
 73            document_id_field: Metadata field linking chunks to documents
 74            combine_scores: If True, combine stage1 and stage2 scores
 75            stage1_weight: Weight for stage 1 score (document relevance)
 76            stage2_weight: Weight for stage 2 score (chunk relevance)
 77        """
 78        self.summary_retriever = summary_retriever
 79        self.chunk_retriever = chunk_retriever
 80        self.stage1_top_k = stage1_top_k
 81        self.stage2_top_k = stage2_top_k
 82        self.document_id_field = document_id_field
 83        self.combine_scores = combine_scores
 84        self.stage1_weight = stage1_weight
 85        self.stage2_weight = stage2_weight
 86        
 87        # Normalize weights
 88        total_weight = stage1_weight + stage2_weight
 89        if total_weight > 0:
 90            self.stage1_weight = stage1_weight / total_weight
 91            self.stage2_weight = stage2_weight / total_weight
 92    
 93    def retrieve(
 94        self,
 95        query: RetrievalQuery,
 96        **kwargs
 97    ) -> List[SearchResult]:
 98        """
 99        Perform two-stage hierarchical retrieval.
100        
101        Args:
102            query: The retrieval query
103            **kwargs: Additional arguments
104            
105        Returns:
106            List of SearchResult objects from stage 2, optionally with combined scores
107        """
108        # Stage 1: Retrieve document summaries
109        logger.info(f"Stage 1: Retrieving top {self.stage1_top_k} document summaries")
110        
111        stage1_query = RetrievalQuery(
112            text=query.text,
113            embedding=query.embedding,
114            top_k=self.stage1_top_k,
115            filters=query.filters,
116            metadata=query.metadata
117        )
118        
119        summary_results = self.summary_retriever.retrieve(stage1_query, **kwargs)
120        
121        if not summary_results:
122            logger.warning("Stage 1 returned no results")
123            return []
124        
125        # Extract document IDs from summaries
126        document_ids = []
127        document_scores = {}
128        
129        for result in summary_results:
130            doc_id = result.document.metadata.get(self.document_id_field)
131            if doc_id:
132                document_ids.append(doc_id)
133                document_scores[doc_id] = result.score
134            else:
135                # If no document_id field, use the document id itself
136                document_ids.append(result.document.id)
137                document_scores[result.document.id] = result.score
138        
139        logger.info(f"Stage 1 identified {len(document_ids)} relevant documents")
140        
141        # Stage 2: Retrieve detailed chunks from top documents
142        logger.info(f"Stage 2: Retrieving up to {self.stage2_top_k} chunks per document")
143        
144        all_chunks: List[SearchResult] = []
145        
146        for doc_id in document_ids:
147            # Create filter for this document
148            doc_filter = {self.document_id_field: doc_id}
149            
150            # Merge with existing filters if any
151            if query.filters:
152                doc_filter.update(query.filters)
153            
154            # Retrieve chunks for this document
155            stage2_query = RetrievalQuery(
156                text=query.text,
157                embedding=query.embedding,
158                top_k=self.stage2_top_k,
159                filters=doc_filter,
160                metadata=query.metadata
161            )
162            
163            try:
164                chunk_results = self.chunk_retriever.retrieve(stage2_query, **kwargs)
165                
166                if self.combine_scores and doc_id in document_scores:
167                    # Combine stage 1 and stage 2 scores
168                    for result in chunk_results:
169                        combined_score = (
170                            self.stage1_weight * document_scores[doc_id] +
171                            self.stage2_weight * result.score
172                        )
173                        # Create new result with combined score
174                        result.score = combined_score
175                
176                all_chunks.extend(chunk_results)
177                
178            except ValueError as e:
179                logger.warning(f"Failed to retrieve chunks for document {doc_id}: {e}")
180                continue
181        
182        if not all_chunks:
183            logger.warning("Stage 2 returned no chunks")
184            return []
185        
186        # Sort by score and limit to top_k
187        all_chunks.sort(key=lambda x: x.score, reverse=True)
188        final_results = all_chunks[:query.top_k]
189        
190        # Update ranks
191        for i, result in enumerate(final_results):
192            result.rank = i + 1
193        
194        logger.info(
195            f"Hierarchical retrieval complete: {len(final_results)} chunks from "
196            f"{len(document_ids)} documents"
197        )
198        
199        return final_results

Two-stage hierarchical retrieval for efficient search in large collections.

Stage 1: Retrieve document summaries to identify relevant documents Stage 2: Retrieve detailed chunks from top documents only

This reduces computational cost by focusing detailed retrieval on relevant documents.

Example:

# Create summary and chunk retrievers
summary_retriever = VectorRetriever(summary_store, embedder)
chunk_retriever = VectorRetriever(chunk_store, embedder)

# Create hierarchical retriever
retriever = HierarchicalRetriever(
    summary_retriever=summary_retriever,
    chunk_retriever=chunk_retriever,
    stage1_top_k=10,  # Get top 10 documents
    stage2_top_k=5,   # Get 5 chunks per document
    document_id_field="document_id"
)

# Retrieve
query = RetrievalQuery(text="machine learning", top_k=20)
results = retriever.retrieve(query)
HierarchicalRetriever( summary_retriever: BaseRetriever, chunk_retriever: BaseRetriever, stage1_top_k: int = 10, stage2_top_k: int = 5, document_id_field: str = 'document_id', combine_scores: bool = True, stage1_weight: float = 0.3, stage2_weight: float = 0.7)
54    def __init__(
55        self,
56        summary_retriever: BaseRetriever,
57        chunk_retriever: BaseRetriever,
58        stage1_top_k: int = 10,
59        stage2_top_k: int = 5,
60        document_id_field: str = "document_id",
61        combine_scores: bool = True,
62        stage1_weight: float = 0.3,
63        stage2_weight: float = 0.7
64    ):
65        """
66        Initialize hierarchical retriever.
67        
68        Args:
69            summary_retriever: Retriever for document summaries (stage 1)
70            chunk_retriever: Retriever for detailed chunks (stage 2)
71            stage1_top_k: Number of documents to retrieve in stage 1
72            stage2_top_k: Number of chunks to retrieve per document in stage 2
73            document_id_field: Metadata field linking chunks to documents
74            combine_scores: If True, combine stage1 and stage2 scores
75            stage1_weight: Weight for stage 1 score (document relevance)
76            stage2_weight: Weight for stage 2 score (chunk relevance)
77        """
78        self.summary_retriever = summary_retriever
79        self.chunk_retriever = chunk_retriever
80        self.stage1_top_k = stage1_top_k
81        self.stage2_top_k = stage2_top_k
82        self.document_id_field = document_id_field
83        self.combine_scores = combine_scores
84        self.stage1_weight = stage1_weight
85        self.stage2_weight = stage2_weight
86        
87        # Normalize weights
88        total_weight = stage1_weight + stage2_weight
89        if total_weight > 0:
90            self.stage1_weight = stage1_weight / total_weight
91            self.stage2_weight = stage2_weight / total_weight

Initialize hierarchical retriever.

Args: summary_retriever: Retriever for document summaries (stage 1) chunk_retriever: Retriever for detailed chunks (stage 2) stage1_top_k: Number of documents to retrieve in stage 1 stage2_top_k: Number of chunks to retrieve per document in stage 2 document_id_field: Metadata field linking chunks to documents combine_scores: If True, combine stage1 and stage2 scores stage1_weight: Weight for stage 1 score (document relevance) stage2_weight: Weight for stage 2 score (chunk relevance)

summary_retriever
chunk_retriever
stage1_top_k
stage2_top_k
document_id_field
combine_scores
stage1_weight
stage2_weight
def retrieve( self, query: RetrievalQuery, **kwargs) -> List[SearchResult]:
 93    def retrieve(
 94        self,
 95        query: RetrievalQuery,
 96        **kwargs
 97    ) -> List[SearchResult]:
 98        """
 99        Perform two-stage hierarchical retrieval.
100        
101        Args:
102            query: The retrieval query
103            **kwargs: Additional arguments
104            
105        Returns:
106            List of SearchResult objects from stage 2, optionally with combined scores
107        """
108        # Stage 1: Retrieve document summaries
109        logger.info(f"Stage 1: Retrieving top {self.stage1_top_k} document summaries")
110        
111        stage1_query = RetrievalQuery(
112            text=query.text,
113            embedding=query.embedding,
114            top_k=self.stage1_top_k,
115            filters=query.filters,
116            metadata=query.metadata
117        )
118        
119        summary_results = self.summary_retriever.retrieve(stage1_query, **kwargs)
120        
121        if not summary_results:
122            logger.warning("Stage 1 returned no results")
123            return []
124        
125        # Extract document IDs from summaries
126        document_ids = []
127        document_scores = {}
128        
129        for result in summary_results:
130            doc_id = result.document.metadata.get(self.document_id_field)
131            if doc_id:
132                document_ids.append(doc_id)
133                document_scores[doc_id] = result.score
134            else:
135                # If no document_id field, use the document id itself
136                document_ids.append(result.document.id)
137                document_scores[result.document.id] = result.score
138        
139        logger.info(f"Stage 1 identified {len(document_ids)} relevant documents")
140        
141        # Stage 2: Retrieve detailed chunks from top documents
142        logger.info(f"Stage 2: Retrieving up to {self.stage2_top_k} chunks per document")
143        
144        all_chunks: List[SearchResult] = []
145        
146        for doc_id in document_ids:
147            # Create filter for this document
148            doc_filter = {self.document_id_field: doc_id}
149            
150            # Merge with existing filters if any
151            if query.filters:
152                doc_filter.update(query.filters)
153            
154            # Retrieve chunks for this document
155            stage2_query = RetrievalQuery(
156                text=query.text,
157                embedding=query.embedding,
158                top_k=self.stage2_top_k,
159                filters=doc_filter,
160                metadata=query.metadata
161            )
162            
163            try:
164                chunk_results = self.chunk_retriever.retrieve(stage2_query, **kwargs)
165                
166                if self.combine_scores and doc_id in document_scores:
167                    # Combine stage 1 and stage 2 scores
168                    for result in chunk_results:
169                        combined_score = (
170                            self.stage1_weight * document_scores[doc_id] +
171                            self.stage2_weight * result.score
172                        )
173                        # Create new result with combined score
174                        result.score = combined_score
175                
176                all_chunks.extend(chunk_results)
177                
178            except ValueError as e:
179                logger.warning(f"Failed to retrieve chunks for document {doc_id}: {e}")
180                continue
181        
182        if not all_chunks:
183            logger.warning("Stage 2 returned no chunks")
184            return []
185        
186        # Sort by score and limit to top_k
187        all_chunks.sort(key=lambda x: x.score, reverse=True)
188        final_results = all_chunks[:query.top_k]
189        
190        # Update ranks
191        for i, result in enumerate(final_results):
192            result.rank = i + 1
193        
194        logger.info(
195            f"Hierarchical retrieval complete: {len(final_results)} chunks from "
196            f"{len(document_ids)} documents"
197        )
198        
199        return final_results

Perform two-stage hierarchical retrieval.

Args: query: The retrieval query **kwargs: Additional arguments

Returns: List of SearchResult objects from stage 2, optionally with combined scores

class GraphRetriever(gmf_forge_ai_data.BaseRetriever):
 45class GraphRetriever(BaseRetriever):
 46    """
 47    Graph-based retrieval using entity relationships.
 48    
 49    This retriever:
 50    1. Extracts entities from the query
 51    2. Finds matching entities in the knowledge graph
 52    3. Traverses the graph to find related entities
 53    4. Retrieves documents associated with relevant entities
 54    
 55    Example:
 56        ```python
 57        # Create knowledge graph
 58        graph = nx.DiGraph()
 59        graph.add_edge("Python", "Machine Learning", relation="used_in", weight=0.9)
 60        graph.add_edge("Machine Learning", "Neural Networks", relation="includes", weight=0.8)
 61        
 62        # Create entity-document mapping
 63        entity_docs = {
 64            "Python": ["doc1", "doc2"],
 65            "Machine Learning": ["doc3", "doc4"],
 66            "Neural Networks": ["doc5"]
 67        }
 68        
 69        # Create retriever
 70        retriever = GraphRetriever(
 71            vector_store=store,
 72            knowledge_graph=graph,
 73            entity_document_mapping=entity_docs,
 74            embedder=embedder,
 75            max_hops=2
 76        )
 77        
 78        # Retrieve
 79        query = RetrievalQuery(text="Python for deep learning", top_k=5)
 80        results = retriever.retrieve(query)
 81        ```
 82    """
 83    
 84    def __init__(
 85        self,
 86        vector_store: BaseVectorStore,
 87        knowledge_graph: Optional['nx.DiGraph'] = None,
 88        entity_document_mapping: Optional[Dict[str, List[str]]] = None,
 89        embedder: Optional[Any] = None,
 90        max_hops: int = 2,
 91        min_relation_weight: float = 0.5,
 92        combine_vector_scores: bool = True,
 93        graph_weight: float = 0.4,
 94        vector_weight: float = 0.6
 95    ):
 96        """
 97        Initialize graph retriever.
 98        
 99        Args:
100            vector_store: Vector store for document retrieval
101            knowledge_graph: NetworkX DiGraph with entity relationships
102            entity_document_mapping: Dict mapping entity IDs to document IDs
103            embedder: Embedding provider for query encoding
104            max_hops: Maximum number of hops in graph traversal
105            min_relation_weight: Minimum weight for relationship edges
106            combine_vector_scores: If True, combine graph and vector scores
107            graph_weight: Weight for graph-based relevance
108            vector_weight: Weight for vector similarity
109        """
110        if not NETWORKX_AVAILABLE:
111            raise ImportError(
112                "NetworkX is required for GraphRetriever. "
113                "Install with: pip install networkx"
114            )
115        
116        self.vector_store = vector_store
117        self.knowledge_graph = knowledge_graph or nx.DiGraph()
118        self.entity_document_mapping = entity_document_mapping or {}
119        self.embedder = embedder
120        self.max_hops = max_hops
121        self.min_relation_weight = min_relation_weight
122        self.combine_vector_scores = combine_vector_scores
123        self.graph_weight = graph_weight
124        self.vector_weight = vector_weight
125        
126        # Normalize weights
127        total_weight = graph_weight + vector_weight
128        if total_weight > 0:
129            self.graph_weight = graph_weight / total_weight
130            self.vector_weight = vector_weight / total_weight
131    
132    def extract_entities(self, query_text: str) -> List[str]:
133        """
134        Extract entities from query text.
135        
136        Simple implementation: looks for entity names in the knowledge graph.
137        For production, use NER models (spaCy, Hugging Face, etc.).
138        
139        Args:
140            query_text: Query text
141            
142        Returns:
143            List of entity IDs found in the query
144        """
145        query_lower = query_text.lower()
146        entities = []
147        
148        # Simple matching: check if entity names appear in query
149        for entity_id in self.knowledge_graph.nodes():
150            entity_name = str(entity_id).lower()
151            if entity_name in query_lower:
152                entities.append(entity_id)
153        
154        logger.info(f"Extracted {len(entities)} entities from query: {entities}")
155        return entities
156    
157    def traverse_graph(
158        self,
159        seed_entities: List[str],
160        max_hops: int
161    ) -> Dict[str, float]:
162        """
163        Traverse knowledge graph from seed entities.
164        
165        Args:
166            seed_entities: Starting entities
167            max_hops: Maximum number of hops
168            
169        Returns:
170            Dict mapping entity IDs to relevance scores (0-1)
171        """
172        entity_scores: Dict[str, float] = {}
173        
174        # Initialize seed entities with score 1.0
175        for entity in seed_entities:
176            if entity in self.knowledge_graph:
177                entity_scores[entity] = 1.0
178        
179        if not entity_scores:
180            return entity_scores
181        
182        # BFS traversal
183        visited: Set[str] = set()
184        current_level = [(e, 1.0, 0) for e in seed_entities]  # (entity, score, hop)
185        
186        while current_level:
187            next_level = []
188            
189            for entity_id, score, hop in current_level:
190                if entity_id in visited or hop >= max_hops:
191                    continue
192                
193                visited.add(entity_id)
194                
195                # Get neighbors
196                if entity_id not in self.knowledge_graph:
197                    continue
198                
199                for neighbor in self.knowledge_graph.neighbors(entity_id):
200                    # Get edge weight
201                    edge_data = self.knowledge_graph.get_edge_data(entity_id, neighbor)
202                    edge_weight = edge_data.get('weight', 1.0) if edge_data else 1.0
203                    
204                    # Skip weak relationships
205                    if edge_weight < self.min_relation_weight:
206                        continue
207                    
208                    # Calculate neighbor score (decay with hops)
209                    decay_factor = 0.7 ** (hop + 1)
210                    neighbor_score = score * edge_weight * decay_factor
211                    
212                    # Update score if better
213                    if neighbor not in entity_scores or neighbor_score > entity_scores[neighbor]:
214                        entity_scores[neighbor] = neighbor_score
215                        next_level.append((neighbor, neighbor_score, hop + 1))
216            
217            current_level = next_level
218        
219        logger.info(f"Graph traversal found {len(entity_scores)} relevant entities")
220        return entity_scores
221    
222    def retrieve(
223        self,
224        query: RetrievalQuery,
225        **kwargs
226    ) -> List[SearchResult]:
227        """
228        Perform graph-based retrieval.
229        
230        Args:
231            query: The retrieval query
232            **kwargs: Additional arguments
233            
234        Returns:
235            List of SearchResult objects ranked by combined graph + vector scores
236        """
237        if not query.text:
238            raise ValueError("GraphRetriever requires query.text")
239        
240        # Step 1: Extract entities from query
241        seed_entities = self.extract_entities(query.text)
242        
243        if not seed_entities:
244            logger.warning("No entities found in query, falling back to vector search")
245            # Fallback to pure vector search
246            if self.embedder and not query.embedding:
247                query.embedding = self.embedder.embed_text(query.text)
248            
249            return self.vector_store.search(
250                query_embedding=query.embedding,
251                top_k=query.top_k,
252                filters=query.filters
253            )
254        
255        # Step 2: Traverse graph to find related entities
256        entity_scores = self.traverse_graph(seed_entities, self.max_hops)
257        
258        # Step 3: Collect documents associated with relevant entities
259        doc_graph_scores: Dict[str, float] = {}
260        
261        for entity_id, entity_score in entity_scores.items():
262            doc_ids = self.entity_document_mapping.get(entity_id, [])
263            for doc_id in doc_ids:
264                if doc_id not in doc_graph_scores or entity_score > doc_graph_scores[doc_id]:
265                    doc_graph_scores[doc_id] = entity_score
266        
267        logger.info(f"Found {len(doc_graph_scores)} documents via graph traversal")
268        
269        if not doc_graph_scores:
270            logger.warning("No documents found via graph, falling back to vector search")
271            if self.embedder and not query.embedding:
272                query.embedding = self.embedder.embed_text(query.text)
273            
274            return self.vector_store.search(
275                query_embedding=query.embedding,
276                top_k=query.top_k,
277                filters=query.filters
278            )
279        
280        # Step 4: Get vector scores if combining
281        if self.combine_vector_scores:
282            # Get embeddings for vector search
283            if self.embedder and not query.embedding:
284                query.embedding = self.embedder.embed_text(query.text)
285            
286            # Retrieve more documents for scoring
287            vector_results = self.vector_store.search(
288                query_embedding=query.embedding,
289                top_k=query.top_k * 3,  # Get more for better coverage
290                filters=query.filters
291            )
292            
293            # Create document ID to vector score mapping
294            doc_vector_scores: Dict[str, float] = {}
295            for result in vector_results:
296                doc_vector_scores[result.document.id] = result.score
297            
298            # Combine scores
299            combined_results: List[SearchResult] = []
300            
301            for doc_id, graph_score in doc_graph_scores.items():
302                # Get document from vector store
303                document = self.vector_store.get_document(doc_id)
304                if not document:
305                    continue
306                
307                # Get vector score (0 if not found)
308                vector_score = doc_vector_scores.get(doc_id, 0.0)
309                
310                # Combine scores
311                combined_score = (
312                    self.graph_weight * graph_score +
313                    self.vector_weight * vector_score
314                )
315                
316                combined_results.append(
317                    SearchResult(
318                        document=document,
319                        score=combined_score,
320                        rank=0  # Will be set later
321                    )
322                )
323        else:
324            # Use only graph scores
325            combined_results: List[SearchResult] = []
326            
327            for doc_id, graph_score in doc_graph_scores.items():
328                document = self.vector_store.get_document(doc_id)
329                if not document:
330                    continue
331                
332                combined_results.append(
333                    SearchResult(
334                        document=document,
335                        score=graph_score,
336                        rank=0
337                    )
338                )
339        
340        # Sort by score and limit to top_k
341        combined_results.sort(key=lambda x: x.score, reverse=True)
342        final_results = combined_results[:query.top_k]
343        
344        # Update ranks
345        for i, result in enumerate(final_results):
346            result.rank = i + 1
347        
348        logger.info(f"Graph retrieval complete: {len(final_results)} results")
349        
350        return final_results

Graph-based retrieval using entity relationships.

This retriever:

  1. Extracts entities from the query
  2. Finds matching entities in the knowledge graph
  3. Traverses the graph to find related entities
  4. Retrieves documents associated with relevant entities

Example:

# Create knowledge graph
graph = nx.DiGraph()
graph.add_edge("Python", "Machine Learning", relation="used_in", weight=0.9)
graph.add_edge("Machine Learning", "Neural Networks", relation="includes", weight=0.8)

# Create entity-document mapping
entity_docs = {
    "Python": ["doc1", "doc2"],
    "Machine Learning": ["doc3", "doc4"],
    "Neural Networks": ["doc5"]
}

# Create retriever
retriever = GraphRetriever(
    vector_store=store,
    knowledge_graph=graph,
    entity_document_mapping=entity_docs,
    embedder=embedder,
    max_hops=2
)

# Retrieve
query = RetrievalQuery(text="Python for deep learning", top_k=5)
results = retriever.retrieve(query)
GraphRetriever( vector_store: BaseVectorStore, knowledge_graph: Optional[networkx.classes.digraph.DiGraph] = None, entity_document_mapping: Optional[Dict[str, List[str]]] = None, embedder: Optional[Any] = None, max_hops: int = 2, min_relation_weight: float = 0.5, combine_vector_scores: bool = True, graph_weight: float = 0.4, vector_weight: float = 0.6)
 84    def __init__(
 85        self,
 86        vector_store: BaseVectorStore,
 87        knowledge_graph: Optional['nx.DiGraph'] = None,
 88        entity_document_mapping: Optional[Dict[str, List[str]]] = None,
 89        embedder: Optional[Any] = None,
 90        max_hops: int = 2,
 91        min_relation_weight: float = 0.5,
 92        combine_vector_scores: bool = True,
 93        graph_weight: float = 0.4,
 94        vector_weight: float = 0.6
 95    ):
 96        """
 97        Initialize graph retriever.
 98        
 99        Args:
100            vector_store: Vector store for document retrieval
101            knowledge_graph: NetworkX DiGraph with entity relationships
102            entity_document_mapping: Dict mapping entity IDs to document IDs
103            embedder: Embedding provider for query encoding
104            max_hops: Maximum number of hops in graph traversal
105            min_relation_weight: Minimum weight for relationship edges
106            combine_vector_scores: If True, combine graph and vector scores
107            graph_weight: Weight for graph-based relevance
108            vector_weight: Weight for vector similarity
109        """
110        if not NETWORKX_AVAILABLE:
111            raise ImportError(
112                "NetworkX is required for GraphRetriever. "
113                "Install with: pip install networkx"
114            )
115        
116        self.vector_store = vector_store
117        self.knowledge_graph = knowledge_graph or nx.DiGraph()
118        self.entity_document_mapping = entity_document_mapping or {}
119        self.embedder = embedder
120        self.max_hops = max_hops
121        self.min_relation_weight = min_relation_weight
122        self.combine_vector_scores = combine_vector_scores
123        self.graph_weight = graph_weight
124        self.vector_weight = vector_weight
125        
126        # Normalize weights
127        total_weight = graph_weight + vector_weight
128        if total_weight > 0:
129            self.graph_weight = graph_weight / total_weight
130            self.vector_weight = vector_weight / total_weight

Initialize graph retriever.

Args: vector_store: Vector store for document retrieval knowledge_graph: NetworkX DiGraph with entity relationships entity_document_mapping: Dict mapping entity IDs to document IDs embedder: Embedding provider for query encoding max_hops: Maximum number of hops in graph traversal min_relation_weight: Minimum weight for relationship edges combine_vector_scores: If True, combine graph and vector scores graph_weight: Weight for graph-based relevance vector_weight: Weight for vector similarity

vector_store
knowledge_graph
entity_document_mapping
embedder
max_hops
min_relation_weight
combine_vector_scores
graph_weight
vector_weight
def extract_entities(self, query_text: str) -> List[str]:
132    def extract_entities(self, query_text: str) -> List[str]:
133        """
134        Extract entities from query text.
135        
136        Simple implementation: looks for entity names in the knowledge graph.
137        For production, use NER models (spaCy, Hugging Face, etc.).
138        
139        Args:
140            query_text: Query text
141            
142        Returns:
143            List of entity IDs found in the query
144        """
145        query_lower = query_text.lower()
146        entities = []
147        
148        # Simple matching: check if entity names appear in query
149        for entity_id in self.knowledge_graph.nodes():
150            entity_name = str(entity_id).lower()
151            if entity_name in query_lower:
152                entities.append(entity_id)
153        
154        logger.info(f"Extracted {len(entities)} entities from query: {entities}")
155        return entities

Extract entities from query text.

Simple implementation: looks for entity names in the knowledge graph. For production, use NER models (spaCy, Hugging Face, etc.).

Args: query_text: Query text

Returns: List of entity IDs found in the query

def traverse_graph(self, seed_entities: List[str], max_hops: int) -> Dict[str, float]:
157    def traverse_graph(
158        self,
159        seed_entities: List[str],
160        max_hops: int
161    ) -> Dict[str, float]:
162        """
163        Traverse knowledge graph from seed entities.
164        
165        Args:
166            seed_entities: Starting entities
167            max_hops: Maximum number of hops
168            
169        Returns:
170            Dict mapping entity IDs to relevance scores (0-1)
171        """
172        entity_scores: Dict[str, float] = {}
173        
174        # Initialize seed entities with score 1.0
175        for entity in seed_entities:
176            if entity in self.knowledge_graph:
177                entity_scores[entity] = 1.0
178        
179        if not entity_scores:
180            return entity_scores
181        
182        # BFS traversal
183        visited: Set[str] = set()
184        current_level = [(e, 1.0, 0) for e in seed_entities]  # (entity, score, hop)
185        
186        while current_level:
187            next_level = []
188            
189            for entity_id, score, hop in current_level:
190                if entity_id in visited or hop >= max_hops:
191                    continue
192                
193                visited.add(entity_id)
194                
195                # Get neighbors
196                if entity_id not in self.knowledge_graph:
197                    continue
198                
199                for neighbor in self.knowledge_graph.neighbors(entity_id):
200                    # Get edge weight
201                    edge_data = self.knowledge_graph.get_edge_data(entity_id, neighbor)
202                    edge_weight = edge_data.get('weight', 1.0) if edge_data else 1.0
203                    
204                    # Skip weak relationships
205                    if edge_weight < self.min_relation_weight:
206                        continue
207                    
208                    # Calculate neighbor score (decay with hops)
209                    decay_factor = 0.7 ** (hop + 1)
210                    neighbor_score = score * edge_weight * decay_factor
211                    
212                    # Update score if better
213                    if neighbor not in entity_scores or neighbor_score > entity_scores[neighbor]:
214                        entity_scores[neighbor] = neighbor_score
215                        next_level.append((neighbor, neighbor_score, hop + 1))
216            
217            current_level = next_level
218        
219        logger.info(f"Graph traversal found {len(entity_scores)} relevant entities")
220        return entity_scores

Traverse knowledge graph from seed entities.

Args: seed_entities: Starting entities max_hops: Maximum number of hops

Returns: Dict mapping entity IDs to relevance scores (0-1)

def retrieve( self, query: RetrievalQuery, **kwargs) -> List[SearchResult]:
222    def retrieve(
223        self,
224        query: RetrievalQuery,
225        **kwargs
226    ) -> List[SearchResult]:
227        """
228        Perform graph-based retrieval.
229        
230        Args:
231            query: The retrieval query
232            **kwargs: Additional arguments
233            
234        Returns:
235            List of SearchResult objects ranked by combined graph + vector scores
236        """
237        if not query.text:
238            raise ValueError("GraphRetriever requires query.text")
239        
240        # Step 1: Extract entities from query
241        seed_entities = self.extract_entities(query.text)
242        
243        if not seed_entities:
244            logger.warning("No entities found in query, falling back to vector search")
245            # Fallback to pure vector search
246            if self.embedder and not query.embedding:
247                query.embedding = self.embedder.embed_text(query.text)
248            
249            return self.vector_store.search(
250                query_embedding=query.embedding,
251                top_k=query.top_k,
252                filters=query.filters
253            )
254        
255        # Step 2: Traverse graph to find related entities
256        entity_scores = self.traverse_graph(seed_entities, self.max_hops)
257        
258        # Step 3: Collect documents associated with relevant entities
259        doc_graph_scores: Dict[str, float] = {}
260        
261        for entity_id, entity_score in entity_scores.items():
262            doc_ids = self.entity_document_mapping.get(entity_id, [])
263            for doc_id in doc_ids:
264                if doc_id not in doc_graph_scores or entity_score > doc_graph_scores[doc_id]:
265                    doc_graph_scores[doc_id] = entity_score
266        
267        logger.info(f"Found {len(doc_graph_scores)} documents via graph traversal")
268        
269        if not doc_graph_scores:
270            logger.warning("No documents found via graph, falling back to vector search")
271            if self.embedder and not query.embedding:
272                query.embedding = self.embedder.embed_text(query.text)
273            
274            return self.vector_store.search(
275                query_embedding=query.embedding,
276                top_k=query.top_k,
277                filters=query.filters
278            )
279        
280        # Step 4: Get vector scores if combining
281        if self.combine_vector_scores:
282            # Get embeddings for vector search
283            if self.embedder and not query.embedding:
284                query.embedding = self.embedder.embed_text(query.text)
285            
286            # Retrieve more documents for scoring
287            vector_results = self.vector_store.search(
288                query_embedding=query.embedding,
289                top_k=query.top_k * 3,  # Get more for better coverage
290                filters=query.filters
291            )
292            
293            # Create document ID to vector score mapping
294            doc_vector_scores: Dict[str, float] = {}
295            for result in vector_results:
296                doc_vector_scores[result.document.id] = result.score
297            
298            # Combine scores
299            combined_results: List[SearchResult] = []
300            
301            for doc_id, graph_score in doc_graph_scores.items():
302                # Get document from vector store
303                document = self.vector_store.get_document(doc_id)
304                if not document:
305                    continue
306                
307                # Get vector score (0 if not found)
308                vector_score = doc_vector_scores.get(doc_id, 0.0)
309                
310                # Combine scores
311                combined_score = (
312                    self.graph_weight * graph_score +
313                    self.vector_weight * vector_score
314                )
315                
316                combined_results.append(
317                    SearchResult(
318                        document=document,
319                        score=combined_score,
320                        rank=0  # Will be set later
321                    )
322                )
323        else:
324            # Use only graph scores
325            combined_results: List[SearchResult] = []
326            
327            for doc_id, graph_score in doc_graph_scores.items():
328                document = self.vector_store.get_document(doc_id)
329                if not document:
330                    continue
331                
332                combined_results.append(
333                    SearchResult(
334                        document=document,
335                        score=graph_score,
336                        rank=0
337                    )
338                )
339        
340        # Sort by score and limit to top_k
341        combined_results.sort(key=lambda x: x.score, reverse=True)
342        final_results = combined_results[:query.top_k]
343        
344        # Update ranks
345        for i, result in enumerate(final_results):
346            result.rank = i + 1
347        
348        logger.info(f"Graph retrieval complete: {len(final_results)} results")
349        
350        return final_results

Perform graph-based retrieval.

Args: query: The retrieval query **kwargs: Additional arguments

Returns: List of SearchResult objects ranked by combined graph + vector scores

class SQLRetriever(gmf_forge_ai_data.BaseRetriever):
 38class SQLRetriever(BaseRetriever):
 39    """
 40    Retrieval from structured databases using SQL queries.
 41    
 42    This retriever:
 43    1. Converts natural language queries to SQL (via LLM or rule-based)
 44    2. Executes SQL against a database
 45    3. Converts results to Document objects
 46    4. Returns SearchResult objects with relevance scores
 47    
 48    Example:
 49        ```python
 50        import sqlite3
 51        
 52        # Create database connection
 53        conn = sqlite3.connect("products.db")
 54        
 55        # Define schema
 56        schema = SQLSchema(
 57            table_name="products",
 58            columns=[
 59                {"name": "id", "type": "INTEGER", "description": "Product ID"},
 60                {"name": "name", "type": "TEXT", "description": "Product name"},
 61                {"name": "price", "type": "REAL", "description": "Price in USD"},
 62                {"name": "category", "type": "TEXT", "description": "Product category"}
 63            ],
 64            primary_key="id",
 65            description="E-commerce product catalog"
 66        )
 67        
 68        # Create retriever
 69        retriever = SQLRetriever(
 70            db_connection=conn,
 71            schema=schema,
 72            text_to_sql_fn=my_text_to_sql_function,
 73            db_type="sqlite"
 74        )
 75        
 76        # Retrieve
 77        query = RetrievalQuery(text="products under $100 in electronics", top_k=10)
 78        results = retriever.retrieve(query)
 79        ```
 80    """
 81    
 82    def __init__(
 83        self,
 84        db_connection: Any,
 85        schema: SQLSchema,
 86        text_to_sql_fn: Optional[Callable[[str, SQLSchema], SQLQuery]] = None,
 87        db_type: str = "sqlite",
 88        content_columns: Optional[List[str]] = None,
 89        score_column: Optional[str] = None,
 90        default_score: float = 1.0
 91    ):
 92        """
 93        Initialize SQL retriever.
 94        
 95        Args:
 96            db_connection: Database connection object (e.g., sqlite3.Connection)
 97            schema: Database schema information
 98            text_to_sql_fn: Function to convert text to SQL (None = use simple rules)
 99            db_type: Database type ("sqlite", "postgresql", "mysql")
100            content_columns: Columns to use for document content (None = all columns)
101            score_column: Column to use for relevance score (None = use default_score)
102            default_score: Default score when no score_column specified
103        """
104        self.db_connection = db_connection
105        self.schema = schema
106        self.text_to_sql_fn = text_to_sql_fn or self._simple_text_to_sql
107        self.db_type = db_type
108        self.content_columns = content_columns
109        self.score_column = score_column
110        self.default_score = default_score
111    
112    def _simple_text_to_sql(self, query_text: str, schema: SQLSchema) -> SQLQuery:
113        """
114        Simple rule-based text-to-SQL conversion.
115        
116        For production, use LLM-based conversion or specialized libraries.
117        
118        Args:
119            query_text: Natural language query
120            schema: Database schema
121            
122        Returns:
123            SQLQuery object
124        """
125        # Extract keywords for filtering
126        query_lower = query_text.lower()
127        
128        # Build SELECT clause
129        if self.content_columns:
130            columns = ", ".join(self.content_columns)
131        else:
132            columns = "*"
133        
134        # Add score column if specified
135        if self.score_column:
136            columns = f"{columns}, {self.score_column}"
137        
138        # Build basic SELECT
139        sql = f"SELECT {columns} FROM {schema.table_name}"
140        
141        # Simple WHERE clause based on keywords
142        conditions = []
143        
144        # Look for numeric comparisons
145        if "under" in query_lower or "less than" in query_lower or "<" in query_lower:
146            # Try to find price column
147            price_cols = [c["name"] for c in schema.columns if "price" in c["name"].lower() or "cost" in c["name"].lower()]
148            if price_cols:
149                # Extract number
150                import re
151                numbers = re.findall(r'\d+', query_text)
152                if numbers:
153                    conditions.append(f"{price_cols[0]} < {numbers[0]}")
154        
155        if "over" in query_lower or "more than" in query_lower or ">" in query_lower:
156            price_cols = [c["name"] for c in schema.columns if "price" in c["name"].lower() or "cost" in c["name"].lower()]
157            if price_cols:
158                import re
159                numbers = re.findall(r'\d+', query_text)
160                if numbers:
161                    conditions.append(f"{price_cols[0]} > {numbers[0]}")
162        
163        # Look for text columns to search
164        text_cols = [c["name"] for c in schema.columns if c["type"].upper() in ["TEXT", "VARCHAR", "STRING"]]
165        
166        # Build LIKE conditions for text search
167        search_terms = []
168        # Remove common words
169        stop_words = {"in", "the", "a", "an", "under", "over", "less", "more", "than", "products", "items"}
170        words = [w for w in query_lower.split() if w not in stop_words and not w.isdigit()]
171        
172        for word in words:
173            if len(word) > 2:  # Skip very short words
174                search_terms.append(word)
175        
176        if search_terms and text_cols:
177            like_conditions = []
178            for col in text_cols:
179                for term in search_terms:
180                    like_conditions.append(f"{col} LIKE '%{term}%'")
181            
182            if like_conditions:
183                conditions.append(f"({' OR '.join(like_conditions)})")
184        
185        # Add WHERE clause if conditions exist
186        if conditions:
187            sql += " WHERE " + " AND ".join(conditions)
188        
189        # Add LIMIT
190        sql += " LIMIT 100"  # Safety limit
191        
192        return SQLQuery(
193            sql=sql,
194            explanation=f"Simple keyword-based SQL generation"
195        )
196    
197    def _execute_sql(self, sql_query: SQLQuery) -> List[Dict[str, Any]]:
198        """
199        Execute SQL query and return results.
200        
201        Args:
202            sql_query: SQL query to execute
203            
204        Returns:
205            List of result rows as dictionaries
206        """
207        cursor = self.db_connection.cursor()
208        
209        try:
210            logger.info(f"Executing SQL: {sql_query.sql}")
211            
212            if sql_query.parameters:
213                cursor.execute(sql_query.sql, sql_query.parameters)
214            else:
215                cursor.execute(sql_query.sql)
216            
217            # Get column names
218            if cursor.description:
219                columns = [desc[0] for desc in cursor.description]
220                
221                # Fetch results
222                rows = cursor.fetchall()
223                
224                # Convert to dictionaries
225                results = []
226                for row in rows:
227                    result_dict = dict(zip(columns, row))
228                    results.append(result_dict)
229                
230                logger.info(f"SQL returned {len(results)} rows")
231                return results
232            else:
233                return []
234                
235        except Exception as e:
236            logger.error(f"SQL execution failed: {e}")
237            raise
238        finally:
239            cursor.close()
240    
241    def _row_to_document(self, row: Dict[str, Any], index: int) -> Document:
242        """
243        Convert database row to Document object.
244        
245        Args:
246            row: Database row as dictionary
247            index: Row index for ID generation
248            
249        Returns:
250            Document object
251        """
252        # Extract score if available
253        score = self.default_score
254        if self.score_column and self.score_column in row:
255            score = float(row[self.score_column])
256        
257        # Build content from specified columns or all columns
258        if self.content_columns:
259            content_parts = []
260            for col in self.content_columns:
261                if col in row:
262                    content_parts.append(f"{col}: {row[col]}")
263            content = "\n".join(content_parts)
264        else:
265            # Use JSON representation
266            content = json.dumps(row, indent=2, default=str)
267        
268        # Use primary key if available, otherwise use index
269        doc_id = str(row.get(self.schema.primary_key, f"sql_result_{index}"))
270        
271        # Store all row data in metadata
272        metadata = {
273            "source": "sql_database",
274            "table": self.schema.table_name,
275            "row_data": row,
276            "sql_score": score
277        }
278        
279        return Document(
280            id=doc_id,
281            content=content,
282            metadata=metadata
283        )
284    
285    def retrieve(
286        self,
287        query: RetrievalQuery,
288        **kwargs
289    ) -> List[SearchResult]:
290        """
291        Perform SQL-based retrieval.
292        
293        Args:
294            query: The retrieval query
295            **kwargs: Additional arguments
296            
297        Returns:
298            List of SearchResult objects from database query
299        """
300        if not query.text:
301            raise ValueError("SQLRetriever requires query.text")
302        
303        # Convert text to SQL
304        sql_query = self.text_to_sql_fn(query.text, self.schema)
305        
306        logger.info(f"Generated SQL: {sql_query.sql}")
307        if sql_query.explanation:
308            logger.info(f"Explanation: {sql_query.explanation}")
309        
310        # Execute SQL
311        rows = self._execute_sql(sql_query)
312        
313        if not rows:
314            logger.warning("SQL query returned no results")
315            return []
316        
317        # Convert rows to Documents
318        documents = [
319            self._row_to_document(row, i)
320            for i, row in enumerate(rows)
321        ]
322        
323        # Create SearchResult objects
324        results = [
325            SearchResult(
326                document=doc,
327                score=doc.metadata.get("sql_score", self.default_score),
328                rank=i + 1
329            )
330            for i, doc in enumerate(documents)
331        ]
332        
333        # Limit to top_k
334        results = results[:query.top_k]
335        
336        # Update ranks
337        for i, result in enumerate(results):
338            result.rank = i + 1
339        
340        logger.info(f"SQL retrieval complete: {len(results)} results")
341        
342        return results

Retrieval from structured databases using SQL queries.

This retriever:

  1. Converts natural language queries to SQL (via LLM or rule-based)
  2. Executes SQL against a database
  3. Converts results to Document objects
  4. Returns SearchResult objects with relevance scores

Example:

import sqlite3

# Create database connection
conn = sqlite3.connect("products.db")

# Define schema
schema = SQLSchema(
    table_name="products",
    columns=[
        {"name": "id", "type": "INTEGER", "description": "Product ID"},
        {"name": "name", "type": "TEXT", "description": "Product name"},
        {"name": "price", "type": "REAL", "description": "Price in USD"},
        {"name": "category", "type": "TEXT", "description": "Product category"}
    ],
    primary_key="id",
    description="E-commerce product catalog"
)

# Create retriever
retriever = SQLRetriever(
    db_connection=conn,
    schema=schema,
    text_to_sql_fn=my_text_to_sql_function,
    db_type="sqlite"
)

# Retrieve
query = RetrievalQuery(text="products under $100 in electronics", top_k=10)
results = retriever.retrieve(query)
SQLRetriever( db_connection: Any, schema: SQLSchema, text_to_sql_fn: Optional[Callable[[str, SQLSchema], SQLQuery]] = None, db_type: str = 'sqlite', content_columns: Optional[List[str]] = None, score_column: Optional[str] = None, default_score: float = 1.0)
 82    def __init__(
 83        self,
 84        db_connection: Any,
 85        schema: SQLSchema,
 86        text_to_sql_fn: Optional[Callable[[str, SQLSchema], SQLQuery]] = None,
 87        db_type: str = "sqlite",
 88        content_columns: Optional[List[str]] = None,
 89        score_column: Optional[str] = None,
 90        default_score: float = 1.0
 91    ):
 92        """
 93        Initialize SQL retriever.
 94        
 95        Args:
 96            db_connection: Database connection object (e.g., sqlite3.Connection)
 97            schema: Database schema information
 98            text_to_sql_fn: Function to convert text to SQL (None = use simple rules)
 99            db_type: Database type ("sqlite", "postgresql", "mysql")
100            content_columns: Columns to use for document content (None = all columns)
101            score_column: Column to use for relevance score (None = use default_score)
102            default_score: Default score when no score_column specified
103        """
104        self.db_connection = db_connection
105        self.schema = schema
106        self.text_to_sql_fn = text_to_sql_fn or self._simple_text_to_sql
107        self.db_type = db_type
108        self.content_columns = content_columns
109        self.score_column = score_column
110        self.default_score = default_score

Initialize SQL retriever.

Args: db_connection: Database connection object (e.g., sqlite3.Connection) schema: Database schema information text_to_sql_fn: Function to convert text to SQL (None = use simple rules) db_type: Database type ("sqlite", "postgresql", "mysql") content_columns: Columns to use for document content (None = all columns) score_column: Column to use for relevance score (None = use default_score) default_score: Default score when no score_column specified

db_connection
schema
text_to_sql_fn
db_type
content_columns
score_column
default_score
def retrieve( self, query: RetrievalQuery, **kwargs) -> List[SearchResult]:
285    def retrieve(
286        self,
287        query: RetrievalQuery,
288        **kwargs
289    ) -> List[SearchResult]:
290        """
291        Perform SQL-based retrieval.
292        
293        Args:
294            query: The retrieval query
295            **kwargs: Additional arguments
296            
297        Returns:
298            List of SearchResult objects from database query
299        """
300        if not query.text:
301            raise ValueError("SQLRetriever requires query.text")
302        
303        # Convert text to SQL
304        sql_query = self.text_to_sql_fn(query.text, self.schema)
305        
306        logger.info(f"Generated SQL: {sql_query.sql}")
307        if sql_query.explanation:
308            logger.info(f"Explanation: {sql_query.explanation}")
309        
310        # Execute SQL
311        rows = self._execute_sql(sql_query)
312        
313        if not rows:
314            logger.warning("SQL query returned no results")
315            return []
316        
317        # Convert rows to Documents
318        documents = [
319            self._row_to_document(row, i)
320            for i, row in enumerate(rows)
321        ]
322        
323        # Create SearchResult objects
324        results = [
325            SearchResult(
326                document=doc,
327                score=doc.metadata.get("sql_score", self.default_score),
328                rank=i + 1
329            )
330            for i, doc in enumerate(documents)
331        ]
332        
333        # Limit to top_k
334        results = results[:query.top_k]
335        
336        # Update ranks
337        for i, result in enumerate(results):
338            result.rank = i + 1
339        
340        logger.info(f"SQL retrieval complete: {len(results)} results")
341        
342        return results

Perform SQL-based retrieval.

Args: query: The retrieval query **kwargs: Additional arguments

Returns: List of SearchResult objects from database query

@dataclass
class SQLSchema:
21@dataclass
22class SQLSchema:
23    """Represents a database schema."""
24    table_name: str
25    columns: List[Dict[str, str]]  # [{"name": "col1", "type": "int", "description": "..."}]
26    primary_key: Optional[str] = None
27    description: Optional[str] = None

Represents a database schema.

SQLSchema( table_name: str, columns: List[Dict[str, str]], primary_key: Optional[str] = None, description: Optional[str] = None)
table_name: str
columns: List[Dict[str, str]]
primary_key: Optional[str] = None
description: Optional[str] = None
@dataclass
class SQLQuery:
30@dataclass
31class SQLQuery:
32    """Represents a generated SQL query."""
33    sql: str
34    parameters: Dict[str, Any] = field(default_factory=dict)
35    explanation: Optional[str] = None

Represents a generated SQL query.

SQLQuery( sql: str, parameters: Dict[str, Any] = <factory>, explanation: Optional[str] = None)
sql: str
parameters: Dict[str, Any]
explanation: Optional[str] = None
class MultiIndexRetriever(gmf_forge_ai_data.BaseRetriever):
 31class MultiIndexRetriever(BaseRetriever):
 32    """
 33    Retrieve from multiple indices/sources and merge results.
 34    
 35    This retriever:
 36    1. Queries multiple independent retrievers in parallel (or sequentially)
 37    2. Merges results from all sources
 38    3. Applies source-specific weights and boosts
 39    4. Re-ranks using RRF or weighted scoring
 40    
 41    Use cases:
 42    - Multi-domain search (HR + Finance + IT)
 43    - Federated search across multiple databases
 44    - Combining different data sources (documents + SQL + graph)
 45    
 46    Example:
 47        ```python
 48        # Create retrievers for different sources
 49        hr_retriever = VectorRetriever(hr_store, embedder)
 50        finance_retriever = VectorRetriever(finance_store, embedder)
 51        it_retriever = VectorRetriever(it_store, embedder)
 52        
 53        # Create multi-index retriever
 54        retriever = MultiIndexRetriever(
 55            sources=[
 56                SourceConfig("HR", hr_retriever, weight=1.0),
 57                SourceConfig("Finance", finance_retriever, weight=1.5),
 58                SourceConfig("IT", it_retriever, weight=1.0)
 59            ],
 60            fusion_strategy="rrf"
 61        )
 62        
 63        # Retrieve across all sources
 64        query = RetrievalQuery(text="employee benefits policy", top_k=10)
 65        results = retriever.retrieve(query)
 66        ```
 67    """
 68    
 69    def __init__(
 70        self,
 71        sources: List[SourceConfig],
 72        fusion_strategy: str = "rrf",
 73        rrf_k: int = 60,
 74        normalize_scores: bool = True
 75    ):
 76        """
 77        Initialize multi-index retriever.
 78        
 79        Args:
 80            sources: List of SourceConfig objects defining retrievers
 81            fusion_strategy: Strategy for merging results:
 82                - "rrf": Reciprocal Rank Fusion
 83                - "weighted_average": Weighted average of normalized scores
 84                - "max_score": Take maximum score across sources
 85            rrf_k: Constant for RRF formula (default 60)
 86            normalize_scores: Normalize scores to [0,1] before fusion
 87        """
 88        self.sources = sources
 89        self.fusion_strategy = fusion_strategy
 90        self.rrf_k = rrf_k
 91        self.normalize_scores = normalize_scores
 92        
 93        # Validate sources
 94        if not sources:
 95            raise ValueError("At least one source must be provided")
 96        
 97        # Normalize weights
 98        total_weight = sum(s.weight for s in sources if s.enabled)
 99        if total_weight > 0:
100            for source in sources:
101                source.weight = source.weight / total_weight
102    
103    def _normalize_scores_for_source(
104        self,
105        results: List[SearchResult]
106    ) -> List[SearchResult]:
107        """
108        Normalize scores to [0, 1] range.
109        
110        Args:
111            results: Search results
112            
113        Returns:
114            Results with normalized scores
115        """
116        if not results:
117            return results
118        
119        scores = [r.score for r in results]
120        min_score = min(scores)
121        max_score = max(scores)
122        
123        if max_score - min_score == 0:
124            # All scores are the same
125            for result in results:
126                result.score = 1.0
127        else:
128            for result in results:
129                result.score = (result.score - min_score) / (max_score - min_score)
130        
131        return results
132    
133    def _rrf_fusion(
134        self,
135        source_results: Dict[str, List[SearchResult]]
136    ) -> List[SearchResult]:
137        """
138        Merge results using Reciprocal Rank Fusion.
139        
140        Args:
141            source_results: Dict mapping source names to their results
142            
143        Returns:
144            Merged and re-ranked results
145        """
146        # Track all unique documents and their RRF scores
147        doc_rrf_scores: Dict[str, float] = {}
148        doc_objects: Dict[str, SearchResult] = {}
149        
150        # Calculate RRF scores
151        for source_name, results in source_results.items():
152            # Get source config
153            source_config = next((s for s in self.sources if s.name == source_name), None)
154            if not source_config or not source_config.enabled:
155                continue
156            
157            weight = source_config.weight
158            boost = source_config.boost_factor
159            
160            for result in results:
161                doc_id = result.document.id
162                
163                # RRF score: weight / (k + rank)
164                rrf_score = (weight * boost) / (self.rrf_k + result.rank)
165                
166                if doc_id not in doc_rrf_scores:
167                    doc_rrf_scores[doc_id] = rrf_score
168                    doc_objects[doc_id] = result
169                    # Add source metadata
170                    result.document.metadata["retrieval_source"] = source_name
171                else:
172                    doc_rrf_scores[doc_id] += rrf_score
173                    # If document appears in multiple sources, mark it
174                    existing_source = doc_objects[doc_id].document.metadata.get("retrieval_source", "")
175                    if source_name not in existing_source:
176                        doc_objects[doc_id].document.metadata["retrieval_source"] = f"{existing_source}, {source_name}"
177        
178        # Create final results
179        merged_results = []
180        for doc_id, rrf_score in doc_rrf_scores.items():
181            result = doc_objects[doc_id]
182            result.score = rrf_score
183            merged_results.append(result)
184        
185        # Sort by RRF score
186        merged_results.sort(key=lambda x: x.score, reverse=True)
187        
188        return merged_results
189    
190    def _weighted_average_fusion(
191        self,
192        source_results: Dict[str, List[SearchResult]]
193    ) -> List[SearchResult]:
194        """
195        Merge results using weighted average of scores.
196        
197        Args:
198            source_results: Dict mapping source names to their results
199            
200        Returns:
201            Merged and re-ranked results
202        """
203        # Track all unique documents and their weighted scores
204        doc_weighted_scores: Dict[str, float] = {}
205        doc_score_counts: Dict[str, int] = {}
206        doc_objects: Dict[str, SearchResult] = {}
207        
208        # Calculate weighted scores
209        for source_name, results in source_results.items():
210            # Get source config
211            source_config = next((s for s in self.sources if s.name == source_name), None)
212            if not source_config or not source_config.enabled:
213                continue
214            
215            # Normalize scores for this source
216            if self.normalize_scores:
217                results = self._normalize_scores_for_source(results)
218            
219            weight = source_config.weight
220            boost = source_config.boost_factor
221            
222            for result in results:
223                doc_id = result.document.id
224                weighted_score = result.score * weight * boost
225                
226                if doc_id not in doc_weighted_scores:
227                    doc_weighted_scores[doc_id] = weighted_score
228                    doc_score_counts[doc_id] = 1
229                    doc_objects[doc_id] = result
230                    result.document.metadata["retrieval_source"] = source_name
231                else:
232                    doc_weighted_scores[doc_id] += weighted_score
233                    doc_score_counts[doc_id] += 1
234                    existing_source = doc_objects[doc_id].document.metadata.get("retrieval_source", "")
235                    if source_name not in existing_source:
236                        doc_objects[doc_id].document.metadata["retrieval_source"] = f"{existing_source}, {source_name}"
237        
238        # Create final results with averaged scores
239        merged_results = []
240        for doc_id, total_score in doc_weighted_scores.items():
241            result = doc_objects[doc_id]
242            # Average the weighted scores
243            result.score = total_score / doc_score_counts[doc_id]
244            merged_results.append(result)
245        
246        # Sort by score
247        merged_results.sort(key=lambda x: x.score, reverse=True)
248        
249        return merged_results
250    
251    def _max_score_fusion(
252        self,
253        source_results: Dict[str, List[SearchResult]]
254    ) -> List[SearchResult]:
255        """
256        Merge results using maximum score across sources.
257        
258        Args:
259            source_results: Dict mapping source names to their results
260            
261        Returns:
262            Merged and re-ranked results
263        """
264        # Track all unique documents and their max scores
265        doc_max_scores: Dict[str, float] = {}
266        doc_objects: Dict[str, SearchResult] = {}
267        
268        # Find max scores
269        for source_name, results in source_results.items():
270            # Get source config
271            source_config = next((s for s in self.sources if s.name == source_name), None)
272            if not source_config or not source_config.enabled:
273                continue
274            
275            # Normalize scores for this source
276            if self.normalize_scores:
277                results = self._normalize_scores_for_source(results)
278            
279            boost = source_config.boost_factor
280            
281            for result in results:
282                doc_id = result.document.id
283                boosted_score = result.score * boost
284                
285                if doc_id not in doc_max_scores or boosted_score > doc_max_scores[doc_id]:
286                    doc_max_scores[doc_id] = boosted_score
287                    doc_objects[doc_id] = result
288                    result.document.metadata["retrieval_source"] = source_name
289                else:
290                    # Update source metadata if from multiple sources
291                    existing_source = doc_objects[doc_id].document.metadata.get("retrieval_source", "")
292                    if source_name not in existing_source:
293                        doc_objects[doc_id].document.metadata["retrieval_source"] = f"{existing_source}, {source_name}"
294        
295        # Create final results
296        merged_results = []
297        for doc_id, max_score in doc_max_scores.items():
298            result = doc_objects[doc_id]
299            result.score = max_score
300            merged_results.append(result)
301        
302        # Sort by score
303        merged_results.sort(key=lambda x: x.score, reverse=True)
304        
305        return merged_results
306    
307    def retrieve(
308        self,
309        query: RetrievalQuery,
310        **kwargs
311    ) -> List[SearchResult]:
312        """
313        Perform multi-source retrieval.
314        
315        Args:
316            query: The retrieval query
317            **kwargs: Additional arguments passed to each retriever
318            
319        Returns:
320            List of SearchResult objects merged from all sources
321        """
322        # Retrieve from each source
323        source_results: Dict[str, List[SearchResult]] = {}
324        
325        for source in self.sources:
326            if not source.enabled:
327                logger.info(f"Skipping disabled source: {source.name}")
328                continue
329            
330            try:
331                logger.info(f"Retrieving from source: {source.name}")
332                results = source.retriever.retrieve(query, **kwargs)
333                source_results[source.name] = results
334                logger.info(f"Source {source.name} returned {len(results)} results")
335                
336            except Exception as e:
337                logger.warning(f"Retrieval from source {source.name} failed: {e}")
338                source_results[source.name] = []
339        
340        # Check if we got any results
341        total_results = sum(len(results) for results in source_results.values())
342        if total_results == 0:
343            logger.warning("No results from any source")
344            return []
345        
346        # Merge results based on fusion strategy
347        logger.info(f"Merging results using {self.fusion_strategy} strategy")
348        
349        if self.fusion_strategy == "rrf":
350            merged_results = self._rrf_fusion(source_results)
351        elif self.fusion_strategy == "weighted_average":
352            merged_results = self._weighted_average_fusion(source_results)
353        elif self.fusion_strategy == "max_score":
354            merged_results = self._max_score_fusion(source_results)
355        else:
356            raise ValueError(f"Unknown fusion strategy: {self.fusion_strategy}")
357        
358        # Limit to top_k
359        final_results = merged_results[:query.top_k]
360        
361        # Update ranks
362        for i, result in enumerate(final_results):
363            result.rank = i + 1
364        
365        logger.info(
366            f"Multi-index retrieval complete: {len(final_results)} results "
367            f"from {len(source_results)} sources"
368        )
369        
370        return final_results

Retrieve from multiple indices/sources and merge results.

This retriever:

  1. Queries multiple independent retrievers in parallel (or sequentially)
  2. Merges results from all sources
  3. Applies source-specific weights and boosts
  4. Re-ranks using RRF or weighted scoring

Use cases:

  • Multi-domain search (HR + Finance + IT)
  • Federated search across multiple databases
  • Combining different data sources (documents + SQL + graph)

Example:

# Create retrievers for different sources
hr_retriever = VectorRetriever(hr_store, embedder)
finance_retriever = VectorRetriever(finance_store, embedder)
it_retriever = VectorRetriever(it_store, embedder)

# Create multi-index retriever
retriever = MultiIndexRetriever(
    sources=[
        SourceConfig("HR", hr_retriever, weight=1.0),
        SourceConfig("Finance", finance_retriever, weight=1.5),
        SourceConfig("IT", it_retriever, weight=1.0)
    ],
    fusion_strategy="rrf"
)

# Retrieve across all sources
query = RetrievalQuery(text="employee benefits policy", top_k=10)
results = retriever.retrieve(query)
MultiIndexRetriever( sources: List[SourceConfig], fusion_strategy: str = 'rrf', rrf_k: int = 60, normalize_scores: bool = True)
 69    def __init__(
 70        self,
 71        sources: List[SourceConfig],
 72        fusion_strategy: str = "rrf",
 73        rrf_k: int = 60,
 74        normalize_scores: bool = True
 75    ):
 76        """
 77        Initialize multi-index retriever.
 78        
 79        Args:
 80            sources: List of SourceConfig objects defining retrievers
 81            fusion_strategy: Strategy for merging results:
 82                - "rrf": Reciprocal Rank Fusion
 83                - "weighted_average": Weighted average of normalized scores
 84                - "max_score": Take maximum score across sources
 85            rrf_k: Constant for RRF formula (default 60)
 86            normalize_scores: Normalize scores to [0,1] before fusion
 87        """
 88        self.sources = sources
 89        self.fusion_strategy = fusion_strategy
 90        self.rrf_k = rrf_k
 91        self.normalize_scores = normalize_scores
 92        
 93        # Validate sources
 94        if not sources:
 95            raise ValueError("At least one source must be provided")
 96        
 97        # Normalize weights
 98        total_weight = sum(s.weight for s in sources if s.enabled)
 99        if total_weight > 0:
100            for source in sources:
101                source.weight = source.weight / total_weight

Initialize multi-index retriever.

Args: sources: List of SourceConfig objects defining retrievers fusion_strategy: Strategy for merging results: - "rrf": Reciprocal Rank Fusion - "weighted_average": Weighted average of normalized scores - "max_score": Take maximum score across sources rrf_k: Constant for RRF formula (default 60) normalize_scores: Normalize scores to [0,1] before fusion

sources
fusion_strategy
rrf_k
normalize_scores
def retrieve( self, query: RetrievalQuery, **kwargs) -> List[SearchResult]:
307    def retrieve(
308        self,
309        query: RetrievalQuery,
310        **kwargs
311    ) -> List[SearchResult]:
312        """
313        Perform multi-source retrieval.
314        
315        Args:
316            query: The retrieval query
317            **kwargs: Additional arguments passed to each retriever
318            
319        Returns:
320            List of SearchResult objects merged from all sources
321        """
322        # Retrieve from each source
323        source_results: Dict[str, List[SearchResult]] = {}
324        
325        for source in self.sources:
326            if not source.enabled:
327                logger.info(f"Skipping disabled source: {source.name}")
328                continue
329            
330            try:
331                logger.info(f"Retrieving from source: {source.name}")
332                results = source.retriever.retrieve(query, **kwargs)
333                source_results[source.name] = results
334                logger.info(f"Source {source.name} returned {len(results)} results")
335                
336            except Exception as e:
337                logger.warning(f"Retrieval from source {source.name} failed: {e}")
338                source_results[source.name] = []
339        
340        # Check if we got any results
341        total_results = sum(len(results) for results in source_results.values())
342        if total_results == 0:
343            logger.warning("No results from any source")
344            return []
345        
346        # Merge results based on fusion strategy
347        logger.info(f"Merging results using {self.fusion_strategy} strategy")
348        
349        if self.fusion_strategy == "rrf":
350            merged_results = self._rrf_fusion(source_results)
351        elif self.fusion_strategy == "weighted_average":
352            merged_results = self._weighted_average_fusion(source_results)
353        elif self.fusion_strategy == "max_score":
354            merged_results = self._max_score_fusion(source_results)
355        else:
356            raise ValueError(f"Unknown fusion strategy: {self.fusion_strategy}")
357        
358        # Limit to top_k
359        final_results = merged_results[:query.top_k]
360        
361        # Update ranks
362        for i, result in enumerate(final_results):
363            result.rank = i + 1
364        
365        logger.info(
366            f"Multi-index retrieval complete: {len(final_results)} results "
367            f"from {len(source_results)} sources"
368        )
369        
370        return final_results

Perform multi-source retrieval.

Args: query: The retrieval query **kwargs: Additional arguments passed to each retriever

Returns: List of SearchResult objects merged from all sources

@dataclass
class SourceConfig:
21@dataclass
22class SourceConfig:
23    """Configuration for a retrieval source."""
24    name: str
25    retriever: BaseRetriever
26    weight: float = 1.0
27    boost_factor: float = 1.0
28    enabled: bool = True

Configuration for a retrieval source.

SourceConfig( name: str, retriever: BaseRetriever, weight: float = 1.0, boost_factor: float = 1.0, enabled: bool = True)
name: str
retriever: BaseRetriever
weight: float = 1.0
boost_factor: float = 1.0
enabled: bool = True
class BaseConnector(abc.ABC):
20class BaseConnector(ABC):
21    """
22    Abstract base class for all data source connectors.
23
24    Connectors handle one concern only: sourcing raw content and converting
25    it to Documents. They do NOT chunk, embed, or index — those steps follow.
26
27    Contract for all implementations:
28    - Every returned Document must have a non-empty ``id`` and ``content``.
29    - ``embedding`` is always ``None`` — the caller is responsible for embedding.
30    - Source-specific metadata (path, URL, container, etc.) must be stored in
31      ``document.metadata`` under consistent, documented keys.
32    - Files that cannot be read should be skipped with a printed warning rather
33      than crashing the entire load.
34    """
35
36    @abstractmethod
37    def load(self) -> List[Document]:
38        """
39        Load documents from the data source.
40
41        Returns:
42            List of Document objects with ``id``, ``content``, ``timestamp``,
43            and ``metadata`` populated. ``embedding`` is always ``None``.
44        """

Abstract base class for all data source connectors.

Connectors handle one concern only: sourcing raw content and converting it to Documents. They do NOT chunk, embed, or index — those steps follow.

Contract for all implementations:

  • Every returned Document must have a non-empty id and content.
  • embedding is always None — the caller is responsible for embedding.
  • Source-specific metadata (path, URL, container, etc.) must be stored in document.metadata under consistent, documented keys.
  • Files that cannot be read should be skipped with a printed warning rather than crashing the entire load.
@abstractmethod
def load(self) -> List[Document]:
36    @abstractmethod
37    def load(self) -> List[Document]:
38        """
39        Load documents from the data source.
40
41        Returns:
42            List of Document objects with ``id``, ``content``, ``timestamp``,
43            and ``metadata`` populated. ``embedding`` is always ``None``.
44        """

Load documents from the data source.

Returns: List of Document objects with id, content, timestamp, and metadata populated. embedding is always None.

class FilesystemConnector(gmf_forge_ai_data.BaseConnector):
 59class FilesystemConnector(BaseConnector):
 60    """
 61    Loads documents from a local directory into Document objects.
 62
 63    Scans a root directory (optionally recursively) and converts each matching
 64    file into a Document. The document ``id`` is a stable MD5 hash of the
 65    absolute file path, so re-runs produce consistent IDs for upsert workflows.
 66
 67    Supported formats:
 68    - Native text: .txt, .md, .rst, .csv, .json, .yaml, .yml, .py, .js, .ts,
 69      .java, .cpp, .c, .h, .cs, .html, .htm, .xml, .toml, .ini, .cfg, .env
 70    - PDF (.pdf) — requires ``pip install pypdf``
 71    - Word Document (.docx) — requires ``pip install python-docx``
 72
 73    Metadata keys set on every returned Document:
 74        ``source``       Absolute file path as a string
 75        ``file_name``    File name including extension
 76        ``extension``    Lowercase extension including dot (e.g. ``".md"``)
 77        ``size_bytes``   File size in bytes
 78        ``modified_at``  Last modification time in ISO 8601 format
 79
 80    Example:
 81        ```python
 82        from gmf_forge_ai_data.connectors import FilesystemConnector
 83
 84        connector = FilesystemConnector(
 85            root_path="/data/docs",
 86            extensions=[".txt", ".md", ".pdf"],
 87            recursive=True,
 88        )
 89        docs = connector.load()
 90        # docs is List[Document] — pass to a chunker next
 91        ```
 92    """
 93
 94    def __init__(
 95        self,
 96        root_path: Union[str, Path],
 97        extensions: Optional[List[str]] = None,
 98        recursive: bool = True,
 99        encoding: str = "utf-8-sig",
100        skip_empty: bool = True,
101    ):
102        """
103        Args:
104            root_path:   Root directory to scan.
105            extensions:  Explicit list of extensions to include (e.g.
106                         ``[".txt", ".md"]``). Include the leading dot.
107                         If ``None``, all supported formats are loaded.
108            recursive:   If ``True`` (default), scan subdirectories recursively.
109            encoding:    Text encoding for native text files (default ``"utf-8"``).
110            skip_empty:  If ``True`` (default), skip files that produce no
111                         text content after stripping whitespace.
112        """
113        self.root_path = Path(root_path).resolve()
114        self.extensions: Optional[Set[str]] = (
115            {ext.lower() for ext in extensions}
116            if extensions is not None
117            else None  # None = all supported formats
118        )
119        self.recursive = recursive
120        self.encoding = encoding
121        self.skip_empty = skip_empty
122        self._logger = BasicLogger(__name__)
123
124    def load(self) -> List[Document]:
125        """
126        Scan the root directory and return one Document per matching file.
127
128        Files that raise an error during reading are skipped with a warning
129        printed to stdout so the rest of the load continues uninterrupted.
130
131        Returns:
132            List of Document objects, one per successfully loaded file,
133            sorted by file path for deterministic ordering.
134
135        Raises:
136            FileNotFoundError: If ``root_path`` does not exist.
137            NotADirectoryError: If ``root_path`` is not a directory.
138        """
139        if not self.root_path.exists():
140            raise FileNotFoundError(
141                f"Root path does not exist: {self.root_path}"
142            )
143        if not self.root_path.is_dir():
144            raise NotADirectoryError(
145                f"Root path is not a directory: {self.root_path}"
146            )
147
148        pattern = "**/*" if self.recursive else "*"
149        all_files = sorted(p for p in self.root_path.glob(pattern) if p.is_file())
150
151        documents: List[Document] = []
152        for file_path in all_files:
153            ext = file_path.suffix.lower()
154            if not self._is_accepted(ext):
155                continue
156
157            try:
158                content = self._read_file(file_path, ext)
159            except Exception as e:
160                self._logger.warning("Skipping file", file=file_path.name, error=str(e))
161                continue
162
163            if self.skip_empty and not content.strip():
164                continue
165
166            stat = file_path.stat()
167            doc_id = "fs_" + hashlib.md5(str(file_path).encode()).hexdigest()[:12]
168            documents.append(Document(
169                id=doc_id,
170                content=content.strip(),
171                timestamp=datetime.fromtimestamp(stat.st_mtime),
172                metadata={
173                    "source": str(file_path),
174                    "file_name": file_path.name,
175                    "extension": ext,
176                    "size_bytes": stat.st_size,
177                    "modified_at": datetime.fromtimestamp(stat.st_mtime).isoformat(),
178                },
179            ))
180
181        return documents
182
183    # ── Private helpers ──────────────────────────────────────────────────────
184
185    def _is_accepted(self, ext: str) -> bool:
186        """Return True if this extension should be loaded."""
187        if self.extensions is not None:
188            return ext in self.extensions
189        # No filter specified: accept all natively supported + optional formats
190        return ext in _NATIVE_TEXT_EXTENSIONS or ext in {".pdf", ".docx"}
191
192    def _read_file(self, path: Path, ext: str) -> str:
193        if ext == ".pdf":
194            return _read_pdf(path)
195        if ext == ".docx":
196            return _read_docx(path)
197        return path.read_text(encoding=self.encoding, errors="replace")

Loads documents from a local directory into Document objects.

Scans a root directory (optionally recursively) and converts each matching file into a Document. The document id is a stable MD5 hash of the absolute file path, so re-runs produce consistent IDs for upsert workflows.

Supported formats:

  • Native text: .txt, .md, .rst, .csv, .json, .yaml, .yml, .py, .js, .ts, .java, .cpp, .c, .h, .cs, .html, .htm, .xml, .toml, .ini, .cfg, .env
  • PDF (.pdf) — requires pip install pypdf
  • Word Document (.docx) — requires pip install python-docx

Metadata keys set on every returned Document: source Absolute file path as a string file_name File name including extension extension Lowercase extension including dot (e.g. ".md") size_bytes File size in bytes modified_at Last modification time in ISO 8601 format

Example:

from gmf_forge_ai_data.connectors import FilesystemConnector

connector = FilesystemConnector(
    root_path="/data/docs",
    extensions=[".txt", ".md", ".pdf"],
    recursive=True,
)
docs = connector.load()
# docs is List[Document] — pass to a chunker next
FilesystemConnector( root_path: Union[str, pathlib._local.Path], extensions: Optional[List[str]] = None, recursive: bool = True, encoding: str = 'utf-8-sig', skip_empty: bool = True)
 94    def __init__(
 95        self,
 96        root_path: Union[str, Path],
 97        extensions: Optional[List[str]] = None,
 98        recursive: bool = True,
 99        encoding: str = "utf-8-sig",
100        skip_empty: bool = True,
101    ):
102        """
103        Args:
104            root_path:   Root directory to scan.
105            extensions:  Explicit list of extensions to include (e.g.
106                         ``[".txt", ".md"]``). Include the leading dot.
107                         If ``None``, all supported formats are loaded.
108            recursive:   If ``True`` (default), scan subdirectories recursively.
109            encoding:    Text encoding for native text files (default ``"utf-8"``).
110            skip_empty:  If ``True`` (default), skip files that produce no
111                         text content after stripping whitespace.
112        """
113        self.root_path = Path(root_path).resolve()
114        self.extensions: Optional[Set[str]] = (
115            {ext.lower() for ext in extensions}
116            if extensions is not None
117            else None  # None = all supported formats
118        )
119        self.recursive = recursive
120        self.encoding = encoding
121        self.skip_empty = skip_empty
122        self._logger = BasicLogger(__name__)

Args: root_path: Root directory to scan. extensions: Explicit list of extensions to include (e.g. [".txt", ".md"]). Include the leading dot. If None, all supported formats are loaded. recursive: If True (default), scan subdirectories recursively. encoding: Text encoding for native text files (default "utf-8"). skip_empty: If True (default), skip files that produce no text content after stripping whitespace.

root_path
extensions: Optional[Set[str]]
recursive
encoding
skip_empty
def load(self) -> List[Document]:
124    def load(self) -> List[Document]:
125        """
126        Scan the root directory and return one Document per matching file.
127
128        Files that raise an error during reading are skipped with a warning
129        printed to stdout so the rest of the load continues uninterrupted.
130
131        Returns:
132            List of Document objects, one per successfully loaded file,
133            sorted by file path for deterministic ordering.
134
135        Raises:
136            FileNotFoundError: If ``root_path`` does not exist.
137            NotADirectoryError: If ``root_path`` is not a directory.
138        """
139        if not self.root_path.exists():
140            raise FileNotFoundError(
141                f"Root path does not exist: {self.root_path}"
142            )
143        if not self.root_path.is_dir():
144            raise NotADirectoryError(
145                f"Root path is not a directory: {self.root_path}"
146            )
147
148        pattern = "**/*" if self.recursive else "*"
149        all_files = sorted(p for p in self.root_path.glob(pattern) if p.is_file())
150
151        documents: List[Document] = []
152        for file_path in all_files:
153            ext = file_path.suffix.lower()
154            if not self._is_accepted(ext):
155                continue
156
157            try:
158                content = self._read_file(file_path, ext)
159            except Exception as e:
160                self._logger.warning("Skipping file", file=file_path.name, error=str(e))
161                continue
162
163            if self.skip_empty and not content.strip():
164                continue
165
166            stat = file_path.stat()
167            doc_id = "fs_" + hashlib.md5(str(file_path).encode()).hexdigest()[:12]
168            documents.append(Document(
169                id=doc_id,
170                content=content.strip(),
171                timestamp=datetime.fromtimestamp(stat.st_mtime),
172                metadata={
173                    "source": str(file_path),
174                    "file_name": file_path.name,
175                    "extension": ext,
176                    "size_bytes": stat.st_size,
177                    "modified_at": datetime.fromtimestamp(stat.st_mtime).isoformat(),
178                },
179            ))
180
181        return documents

Scan the root directory and return one Document per matching file.

Files that raise an error during reading are skipped with a warning printed to stdout so the rest of the load continues uninterrupted.

Returns: List of Document objects, one per successfully loaded file, sorted by file path for deterministic ordering.

Raises: FileNotFoundError: If root_path does not exist. NotADirectoryError: If root_path is not a directory.

class SharePointConnector(gmf_forge_ai_data.BaseConnector):
 34class SharePointConnector(BaseConnector):
 35    """
 36    Loads files from a SharePoint document library via the Microsoft Graph API.
 37
 38    Authenticates using OAuth2 client credentials — no user interaction or
 39    browser required. Recursively lists all files under ``folder_path``,
 40    downloads their content, and extracts text. Unsupported binary formats
 41    are skipped automatically.
 42
 43    Supported file types:
 44    - Native text: .txt, .md, .rst, .csv, .json, .yaml, .yml, .py, .html, .xml
 45    - PDF (.pdf) — requires ``pip install pypdf``
 46    - Word Document (.docx) — requires ``pip install python-docx``
 47
 48    Metadata keys set on every returned Document:
 49        ``source``         Microsoft Graph download URL
 50        ``file_name``      File name including extension
 51        ``extension``      Lowercase extension including dot
 52        ``size_bytes``     File size in bytes
 53        ``modified_at``    Last modification time (ISO 8601 string)
 54        ``sharepoint_id``  SharePoint item ID
 55
 56    Example:
 57        ```python
 58        from gmf_forge_ai_data.connectors import SharePointConnector
 59
 60        connector = SharePointConnector(
 61            tenant_id="your-tenant-id",
 62            client_id="your-client-id",
 63            client_secret="your-client-secret",
 64            site_id="your-site-id",
 65            folder_path="/Shared Documents/KnowledgeBase",
 66        )
 67        docs = connector.load()
 68        ```
 69    """
 70
 71    _GRAPH_BASE = "https://graph.microsoft.com/v1.0"
 72    _TOKEN_URL = "https://login.microsoftonline.com/{tenant_id}/oauth2/v2.0/token"
 73    _SUPPORTED_EXTENSIONS = {
 74        ".txt", ".md", ".rst", ".csv", ".json", ".yaml", ".yml",
 75        ".py", ".html", ".htm", ".xml", ".pdf", ".docx",
 76    }
 77
 78    def __init__(
 79        self,
 80        tenant_id: str,
 81        client_id: str,
 82        client_secret: str,
 83        site_id: str,
 84        folder_path: str = "/",
 85        drive_id: Optional[str] = None,
 86        ssl_cert_path: Optional[str] = None,
 87    ):
 88        """
 89        Args:
 90            tenant_id:      Azure AD tenant ID.
 91            client_id:      App registration (service principal) client ID.
 92            client_secret:  App registration client secret.
 93            site_id:        SharePoint site ID. Use ``"root"`` for the tenant
 94                            root site or provide the full GUID from the Graph API.
 95            folder_path:    Path inside the drive to load from. Use ``"/"``
 96                            (default) for the entire drive root.
 97            drive_id:       Drive ID or ``None`` (default) to use the site's
 98                            default document library (``root`` drive).
 99            ssl_cert_path:  Optional path to a CA bundle PEM file for
100                            environments with corporate SSL inspection.
101        """
102        self.tenant_id = tenant_id
103        self.client_id = client_id
104        self.client_secret = client_secret
105        self.site_id = site_id
106        self.folder_path = folder_path.rstrip("/") or "/"
107        self.drive_id = drive_id or "root"
108        self.ssl_cert_path = ssl_cert_path
109        self._token: Optional[str] = None
110        self._logger = BasicLogger(__name__)
111
112    def load(self) -> List[Document]:
113        """
114        Authenticate and load all supported files from the configured folder.
115
116        Returns:
117            List of Document objects, one per successfully loaded file.
118
119        Raises:
120            ImportError: If the ``requests`` package is not installed.
121            requests.HTTPError: If authentication or a Graph API call fails.
122        """
123        try:
124            import requests  # type: ignore
125        except ImportError:
126            raise ImportError(
127                "requests is required for SharePointConnector. "
128                "Install it with: pip install requests"
129            )
130
131        self._token = self._acquire_token(requests)
132        items = self._list_items(requests, self.folder_path)
133
134        documents: List[Document] = []
135        for item in items:
136            name: str = item.get("name", "")
137            ext = ("." + name.rsplit(".", 1)[-1]).lower() if "." in name else ""
138            if ext not in self._SUPPORTED_EXTENSIONS:
139                continue
140
141            try:
142                content = self._download_item(requests, item, ext)
143            except Exception as e:
144                self._logger.warning("Skipping file", file=name, error=str(e))
145                continue
146
147            if not content.strip():
148                continue
149
150            modified = item.get("lastModifiedDateTime", "")
151            size = item.get("size", 0)
152            download_url = item.get("@microsoft.graph.downloadUrl", "")
153            item_id = item.get("id", "")
154            doc_id = "sp_" + hashlib.md5(item_id.encode()).hexdigest()[:12]
155
156            documents.append(Document(
157                id=doc_id,
158                content=content.strip(),
159                timestamp=(
160                    datetime.fromisoformat(modified.replace("Z", "+00:00"))
161                    if modified else datetime.now()
162                ),
163                metadata={
164                    "source": download_url,
165                    "file_name": name,
166                    "extension": ext,
167                    "size_bytes": size,
168                    "modified_at": modified,
169                    "sharepoint_id": item_id,
170                },
171            ))
172
173        return documents
174
175    # ── Private helpers ──────────────────────────────────────────────────────
176
177    def _acquire_token(self, requests) -> str:
178        """Acquire an OAuth2 access token via client credentials flow."""
179        url = self._TOKEN_URL.format(tenant_id=self.tenant_id)
180        resp = requests.post(
181            url,
182            data={
183                "grant_type": "client_credentials",
184                "client_id": self.client_id,
185                "client_secret": self.client_secret,
186                "scope": "https://graph.microsoft.com/.default",
187            },
188            verify=self.ssl_cert_path or True,
189            timeout=30,
190        )
191        resp.raise_for_status()
192        return resp.json()["access_token"]
193
194    def _headers(self) -> Dict[str, str]:
195        return {"Authorization": f"Bearer {self._token}"}
196
197    def _list_items(self, requests, folder_path: str) -> List[Dict]:
198        """Recursively list all file items under folder_path."""
199        # Graph API uses /drive/ (singular) for the default document library
200        # and /drives/{id}/ when a specific drive ID is given.
201        drive_segment = (
202            "drive" if self.drive_id == "root"
203            else f"drives/{self.drive_id}"
204        )
205        if folder_path == "/":
206            url = (
207                f"{self._GRAPH_BASE}/sites/{self.site_id}"
208                f"/{drive_segment}/root/children"
209            )
210        else:
211            encoded = folder_path.lstrip("/")
212            url = (
213                f"{self._GRAPH_BASE}/sites/{self.site_id}"
214                f"/{drive_segment}/root:/{encoded}:/children"
215            )
216
217        items: List[Dict] = []
218        while url:
219            resp = requests.get(
220                url,
221                headers=self._headers(),
222                verify=self.ssl_cert_path or True,
223                timeout=30,
224            )
225            resp.raise_for_status()
226            data = resp.json()
227            for item in data.get("value", []):
228                if "folder" in item:
229                    # Recurse into sub-folders
230                    child_path = folder_path.rstrip("/") + "/" + item["name"]
231                    self._logger.info("Scanning subfolder", path=child_path)
232                    try:
233                        items.extend(self._list_items(requests, child_path))
234                    except Exception as e:
235                        self._logger.warning("Skipping subfolder", path=child_path, error=str(e))
236                else:
237                    items.append(item)
238            url = data.get("@odata.nextLink")  # follow pagination if present
239
240        return items
241
242    def _download_item(self, requests, item: Dict, ext: str) -> str:
243        """Download a file item's raw bytes and decode to text."""
244        download_url = item.get("@microsoft.graph.downloadUrl")
245        if not download_url:
246            raise ValueError(f"No download URL for item: {item.get('name')}")
247        resp = requests.get(
248            download_url,
249            verify=self.ssl_cert_path or True,
250            timeout=60,
251        )
252        resp.raise_for_status()
253        raw: bytes = resp.content
254        if ext == ".pdf":
255            return self._extract_pdf(raw)
256        if ext == ".docx":
257            return self._extract_docx(raw)
258        return raw.decode("utf-8-sig", errors="replace")
259
260    @staticmethod
261    def _extract_pdf(raw: bytes) -> str:
262        try:
263            import pypdf  # type: ignore
264        except ImportError:
265            raise ImportError(
266                "pypdf required for PDF support: pip install pypdf"
267            )
268        import logging
269        logging.getLogger("pypdf").setLevel(logging.ERROR)
270        reader = pypdf.PdfReader(io.BytesIO(raw))
271        return "\n".join(page.extract_text() or "" for page in reader.pages)
272
273    @staticmethod
274    def _extract_docx(raw: bytes) -> str:
275        try:
276            import docx  # type: ignore
277        except ImportError:
278            raise ImportError(
279                "python-docx required for DOCX support: pip install python-docx"
280            )
281        doc = docx.Document(io.BytesIO(raw))
282        return "\n".join(p.text for p in doc.paragraphs if p.text.strip())

Loads files from a SharePoint document library via the Microsoft Graph API.

Authenticates using OAuth2 client credentials — no user interaction or browser required. Recursively lists all files under folder_path, downloads their content, and extracts text. Unsupported binary formats are skipped automatically.

Supported file types:

  • Native text: .txt, .md, .rst, .csv, .json, .yaml, .yml, .py, .html, .xml
  • PDF (.pdf) — requires pip install pypdf
  • Word Document (.docx) — requires pip install python-docx

Metadata keys set on every returned Document: source Microsoft Graph download URL file_name File name including extension extension Lowercase extension including dot size_bytes File size in bytes modified_at Last modification time (ISO 8601 string) sharepoint_id SharePoint item ID

Example:

from gmf_forge_ai_data.connectors import SharePointConnector

connector = SharePointConnector(
    tenant_id="your-tenant-id",
    client_id="your-client-id",
    client_secret="your-client-secret",
    site_id="your-site-id",
    folder_path="/Shared Documents/KnowledgeBase",
)
docs = connector.load()
SharePointConnector( tenant_id: str, client_id: str, client_secret: str, site_id: str, folder_path: str = '/', drive_id: Optional[str] = None, ssl_cert_path: Optional[str] = None)
 78    def __init__(
 79        self,
 80        tenant_id: str,
 81        client_id: str,
 82        client_secret: str,
 83        site_id: str,
 84        folder_path: str = "/",
 85        drive_id: Optional[str] = None,
 86        ssl_cert_path: Optional[str] = None,
 87    ):
 88        """
 89        Args:
 90            tenant_id:      Azure AD tenant ID.
 91            client_id:      App registration (service principal) client ID.
 92            client_secret:  App registration client secret.
 93            site_id:        SharePoint site ID. Use ``"root"`` for the tenant
 94                            root site or provide the full GUID from the Graph API.
 95            folder_path:    Path inside the drive to load from. Use ``"/"``
 96                            (default) for the entire drive root.
 97            drive_id:       Drive ID or ``None`` (default) to use the site's
 98                            default document library (``root`` drive).
 99            ssl_cert_path:  Optional path to a CA bundle PEM file for
100                            environments with corporate SSL inspection.
101        """
102        self.tenant_id = tenant_id
103        self.client_id = client_id
104        self.client_secret = client_secret
105        self.site_id = site_id
106        self.folder_path = folder_path.rstrip("/") or "/"
107        self.drive_id = drive_id or "root"
108        self.ssl_cert_path = ssl_cert_path
109        self._token: Optional[str] = None
110        self._logger = BasicLogger(__name__)

Args: tenant_id: Azure AD tenant ID. client_id: App registration (service principal) client ID. client_secret: App registration client secret. site_id: SharePoint site ID. Use "root" for the tenant root site or provide the full GUID from the Graph API. folder_path: Path inside the drive to load from. Use "/" (default) for the entire drive root. drive_id: Drive ID or None (default) to use the site's default document library (root drive). ssl_cert_path: Optional path to a CA bundle PEM file for environments with corporate SSL inspection.

tenant_id
client_id
client_secret
site_id
folder_path
drive_id
ssl_cert_path
def load(self) -> List[Document]:
112    def load(self) -> List[Document]:
113        """
114        Authenticate and load all supported files from the configured folder.
115
116        Returns:
117            List of Document objects, one per successfully loaded file.
118
119        Raises:
120            ImportError: If the ``requests`` package is not installed.
121            requests.HTTPError: If authentication or a Graph API call fails.
122        """
123        try:
124            import requests  # type: ignore
125        except ImportError:
126            raise ImportError(
127                "requests is required for SharePointConnector. "
128                "Install it with: pip install requests"
129            )
130
131        self._token = self._acquire_token(requests)
132        items = self._list_items(requests, self.folder_path)
133
134        documents: List[Document] = []
135        for item in items:
136            name: str = item.get("name", "")
137            ext = ("." + name.rsplit(".", 1)[-1]).lower() if "." in name else ""
138            if ext not in self._SUPPORTED_EXTENSIONS:
139                continue
140
141            try:
142                content = self._download_item(requests, item, ext)
143            except Exception as e:
144                self._logger.warning("Skipping file", file=name, error=str(e))
145                continue
146
147            if not content.strip():
148                continue
149
150            modified = item.get("lastModifiedDateTime", "")
151            size = item.get("size", 0)
152            download_url = item.get("@microsoft.graph.downloadUrl", "")
153            item_id = item.get("id", "")
154            doc_id = "sp_" + hashlib.md5(item_id.encode()).hexdigest()[:12]
155
156            documents.append(Document(
157                id=doc_id,
158                content=content.strip(),
159                timestamp=(
160                    datetime.fromisoformat(modified.replace("Z", "+00:00"))
161                    if modified else datetime.now()
162                ),
163                metadata={
164                    "source": download_url,
165                    "file_name": name,
166                    "extension": ext,
167                    "size_bytes": size,
168                    "modified_at": modified,
169                    "sharepoint_id": item_id,
170                },
171            ))
172
173        return documents

Authenticate and load all supported files from the configured folder.

Returns: List of Document objects, one per successfully loaded file.

Raises: ImportError: If the requests package is not installed. requests.HTTPError: If authentication or a Graph API call fails.

class BlobStorageConnector(gmf_forge_ai_data.BaseConnector):
 35class BlobStorageConnector(BaseConnector):
 36    """
 37    Loads blobs from an Azure Blob Storage container into Document objects.
 38
 39    Lists all blobs in the container (optionally filtered by a path prefix),
 40    downloads their content, and extracts text. Blobs whose extension is not
 41    in the supported set are skipped automatically.
 42
 43    Supported file types:
 44    - Native text: .txt, .md, .rst, .csv, .json, .yaml, .yml, .py, .html, .xml
 45    - PDF (.pdf) — requires ``pip install pypdf``
 46    - Word Document (.docx) — requires ``pip install python-docx``
 47
 48    Metadata keys set on every returned Document:
 49        ``source``       Full blob URL
 50        ``file_name``    Last path segment of the blob name
 51        ``blob_name``    Full blob path inside the container
 52        ``extension``    Lowercase extension including dot
 53        ``size_bytes``   Blob content length in bytes
 54        ``modified_at``  Last modified time in ISO 8601 format
 55        ``container``    Container name
 56
 57    Example::
 58
 59        from gmf_forge_ai_data.connectors import BlobStorageConnector
 60
 61        connector = BlobStorageConnector(
 62            account_name="myaccount",
 63            access_key="your-storage-account-access-key",
 64            container_name="documents",
 65            prefix="knowledge-base/",
 66        )
 67        docs = connector.load()
 68    """
 69
 70    def __init__(
 71        self,
 72        account_name: str,
 73        access_key: str,
 74        container_name: str,
 75        prefix: str = "",
 76        ssl_cert_path: Optional[str] = None,
 77    ):
 78        """
 79        Args:
 80            account_name:  Storage account name.
 81            access_key:    Storage account access key.
 82            container_name: Name of the blob container to read from.
 83            prefix:        Optional blob name prefix for filtering, acting
 84                           like a folder path (e.g. ``"knowledge-base/"``).
 85                           Pass ``""`` (default) to list the entire container.
 86            ssl_cert_path: Optional path to a CA bundle PEM file for
 87                           environments with corporate SSL inspection.
 88        """
 89        self.account_name = account_name
 90        self.access_key = access_key
 91        self.container_name = container_name
 92        self.connection_string = (
 93            f"DefaultEndpointsProtocol=https;"
 94            f"AccountName={account_name};"
 95            f"AccountKey={access_key};"
 96            f"EndpointSuffix=core.windows.net"
 97        )
 98        self.account_url = f"https://{account_name}.blob.core.windows.net"
 99        self.prefix = prefix
100        self.ssl_cert_path = ssl_cert_path
101        self._logger = BasicLogger(__name__)
102
103    def load(self) -> List[Document]:
104        """
105        List and download all supported blobs in the container under ``prefix``.
106
107        Returns:
108            List of Document objects, one per successfully loaded blob.
109
110        Raises:
111            ImportError: If the ``azure-storage-blob`` package is not installed.
112        """
113        try:
114            from azure.storage.blob import BlobServiceClient  # type: ignore
115        except ImportError:
116            raise ImportError(
117                "azure-storage-blob is required for BlobStorageConnector. "
118                "Install it with: pip install azure-storage-blob"
119            )
120
121        client = BlobServiceClient.from_connection_string(self.connection_string)
122        container_client = client.get_container_client(self.container_name)
123
124        documents: List[Document] = []
125        for blob in container_client.list_blobs(name_starts_with=self.prefix or None):
126            name: str = blob.name
127            ext = ("." + name.rsplit(".", 1)[-1]).lower() if "." in name else ""
128            if ext not in _SUPPORTED_EXTENSIONS:
129                continue
130
131            try:
132                content = self._download_blob(container_client, name, ext)
133            except Exception as e:
134                self._logger.warning("Skipping blob", blob=name, error=str(e))
135                continue
136
137            if not content.strip():
138                continue
139
140            modified = blob.last_modified
141            size = blob.size or 0
142            blob_url = f"{self.account_url}/{self.container_name}/{name}"
143            doc_id = "blob_" + hashlib.md5(blob_url.encode()).hexdigest()[:12]
144            file_name = name.split("/")[-1]
145
146            documents.append(Document(
147                id=doc_id,
148                content=content.strip(),
149                timestamp=(
150                    modified if isinstance(modified, datetime) else datetime.now()
151                ),
152                metadata={
153                    "source": blob_url,
154                    "file_name": file_name,
155                    "blob_name": name,
156                    "extension": ext,
157                    "size_bytes": size,
158                    "modified_at": modified.isoformat() if modified else "",
159                    "container": self.container_name,
160                },
161            ))
162
163        return documents
164
165    # ── Private helpers ──────────────────────────────────────────────────────
166
167    def _download_blob(self, container_client, name: str, ext: str) -> str:
168        raw: bytes = container_client.download_blob(name).readall()
169        if ext == ".pdf":
170            return self._extract_pdf(raw)
171        if ext == ".docx":
172            return self._extract_docx(raw)
173        return raw.decode("utf-8-sig", errors="replace")
174
175    @staticmethod
176    def _extract_pdf(raw: bytes) -> str:
177        try:
178            import pypdf  # type: ignore
179        except ImportError:
180            raise ImportError(
181                "pypdf required for PDF support: pip install pypdf"
182            )
183        import logging
184        logging.getLogger("pypdf").setLevel(logging.ERROR)
185        reader = pypdf.PdfReader(io.BytesIO(raw))
186        return "\n".join(page.extract_text() or "" for page in reader.pages)
187
188    @staticmethod
189    def _extract_docx(raw: bytes) -> str:
190        try:
191            import docx  # type: ignore
192        except ImportError:
193            raise ImportError(
194                "python-docx required for DOCX support: pip install python-docx"
195            )
196        doc = docx.Document(io.BytesIO(raw))
197        return "\n".join(p.text for p in doc.paragraphs if p.text.strip())

Loads blobs from an Azure Blob Storage container into Document objects.

Lists all blobs in the container (optionally filtered by a path prefix), downloads their content, and extracts text. Blobs whose extension is not in the supported set are skipped automatically.

Supported file types:

  • Native text: .txt, .md, .rst, .csv, .json, .yaml, .yml, .py, .html, .xml
  • PDF (.pdf) — requires pip install pypdf
  • Word Document (.docx) — requires pip install python-docx

Metadata keys set on every returned Document: source Full blob URL file_name Last path segment of the blob name blob_name Full blob path inside the container extension Lowercase extension including dot size_bytes Blob content length in bytes modified_at Last modified time in ISO 8601 format container Container name

Example::

from gmf_forge_ai_data.connectors import BlobStorageConnector

connector = BlobStorageConnector(
    account_name="myaccount",
    access_key="your-storage-account-access-key",
    container_name="documents",
    prefix="knowledge-base/",
)
docs = connector.load()
BlobStorageConnector( account_name: str, access_key: str, container_name: str, prefix: str = '', ssl_cert_path: Optional[str] = None)
 70    def __init__(
 71        self,
 72        account_name: str,
 73        access_key: str,
 74        container_name: str,
 75        prefix: str = "",
 76        ssl_cert_path: Optional[str] = None,
 77    ):
 78        """
 79        Args:
 80            account_name:  Storage account name.
 81            access_key:    Storage account access key.
 82            container_name: Name of the blob container to read from.
 83            prefix:        Optional blob name prefix for filtering, acting
 84                           like a folder path (e.g. ``"knowledge-base/"``).
 85                           Pass ``""`` (default) to list the entire container.
 86            ssl_cert_path: Optional path to a CA bundle PEM file for
 87                           environments with corporate SSL inspection.
 88        """
 89        self.account_name = account_name
 90        self.access_key = access_key
 91        self.container_name = container_name
 92        self.connection_string = (
 93            f"DefaultEndpointsProtocol=https;"
 94            f"AccountName={account_name};"
 95            f"AccountKey={access_key};"
 96            f"EndpointSuffix=core.windows.net"
 97        )
 98        self.account_url = f"https://{account_name}.blob.core.windows.net"
 99        self.prefix = prefix
100        self.ssl_cert_path = ssl_cert_path
101        self._logger = BasicLogger(__name__)

Args: account_name: Storage account name. access_key: Storage account access key. container_name: Name of the blob container to read from. prefix: Optional blob name prefix for filtering, acting like a folder path (e.g. "knowledge-base/"). Pass "" (default) to list the entire container. ssl_cert_path: Optional path to a CA bundle PEM file for environments with corporate SSL inspection.

account_name
access_key
container_name
connection_string
account_url
prefix
ssl_cert_path
def load(self) -> List[Document]:
103    def load(self) -> List[Document]:
104        """
105        List and download all supported blobs in the container under ``prefix``.
106
107        Returns:
108            List of Document objects, one per successfully loaded blob.
109
110        Raises:
111            ImportError: If the ``azure-storage-blob`` package is not installed.
112        """
113        try:
114            from azure.storage.blob import BlobServiceClient  # type: ignore
115        except ImportError:
116            raise ImportError(
117                "azure-storage-blob is required for BlobStorageConnector. "
118                "Install it with: pip install azure-storage-blob"
119            )
120
121        client = BlobServiceClient.from_connection_string(self.connection_string)
122        container_client = client.get_container_client(self.container_name)
123
124        documents: List[Document] = []
125        for blob in container_client.list_blobs(name_starts_with=self.prefix or None):
126            name: str = blob.name
127            ext = ("." + name.rsplit(".", 1)[-1]).lower() if "." in name else ""
128            if ext not in _SUPPORTED_EXTENSIONS:
129                continue
130
131            try:
132                content = self._download_blob(container_client, name, ext)
133            except Exception as e:
134                self._logger.warning("Skipping blob", blob=name, error=str(e))
135                continue
136
137            if not content.strip():
138                continue
139
140            modified = blob.last_modified
141            size = blob.size or 0
142            blob_url = f"{self.account_url}/{self.container_name}/{name}"
143            doc_id = "blob_" + hashlib.md5(blob_url.encode()).hexdigest()[:12]
144            file_name = name.split("/")[-1]
145
146            documents.append(Document(
147                id=doc_id,
148                content=content.strip(),
149                timestamp=(
150                    modified if isinstance(modified, datetime) else datetime.now()
151                ),
152                metadata={
153                    "source": blob_url,
154                    "file_name": file_name,
155                    "blob_name": name,
156                    "extension": ext,
157                    "size_bytes": size,
158                    "modified_at": modified.isoformat() if modified else "",
159                    "container": self.container_name,
160                },
161            ))
162
163        return documents

List and download all supported blobs in the container under prefix.

Returns: List of Document objects, one per successfully loaded blob.

Raises: ImportError: If the azure-storage-blob package is not installed.

class BaseIndexBuilder(abc.ABC):
47class BaseIndexBuilder(ABC):
48    """
49    Abstract base class for index builders.
50
51    Each backend (Azure AI Search, Cosmos DB, MongoDB) provides a concrete
52    subclass that exposes backend-specific tuning parameters while sharing
53    the same management interface.
54    """
55
56    # ------------------------------------------------------------------ #
57    # Core lifecycle                                                       #
58    # ------------------------------------------------------------------ #
59
60    @abstractmethod
61    def create_index(self) -> None:
62        """Create the index if it does not already exist.
63
64        Safe to call multiple times — must be a no-op when the index exists.
65        Use this for idempotent provisioning (CI/CD pipelines, first-run
66        setup scripts).
67        """
68
69    @abstractmethod
70    def create_or_replace_index(self) -> None:
71        """Delete the index if it exists, then create it fresh.
72
73        Use this when you need to apply schema changes that cannot be done
74        via an in-place update (e.g. changing HNSW parameters or adding a
75        new vector field).
76
77        Warning: All documents are lost.  Only use in dev/staging or after a
78        full re-ingestion has been planned.
79        """
80
81    @abstractmethod
82    def delete_index(self) -> None:
83        """Permanently delete the index and all its documents.
84
85        Raises:
86            RuntimeError: If the index does not exist.
87        """
88
89    @abstractmethod
90    def index_exists(self) -> bool:
91        """Return True if the index currently exists, False otherwise."""
92
93    @abstractmethod
94    def list_indexes(self) -> List[str]:
95        """Return the names of all indexes on this backend/service."""

Abstract base class for index builders.

Each backend (Azure AI Search, Cosmos DB, MongoDB) provides a concrete subclass that exposes backend-specific tuning parameters while sharing the same management interface.

@abstractmethod
def create_index(self) -> None:
60    @abstractmethod
61    def create_index(self) -> None:
62        """Create the index if it does not already exist.
63
64        Safe to call multiple times — must be a no-op when the index exists.
65        Use this for idempotent provisioning (CI/CD pipelines, first-run
66        setup scripts).
67        """

Create the index if it does not already exist.

Safe to call multiple times — must be a no-op when the index exists. Use this for idempotent provisioning (CI/CD pipelines, first-run setup scripts).

@abstractmethod
def create_or_replace_index(self) -> None:
69    @abstractmethod
70    def create_or_replace_index(self) -> None:
71        """Delete the index if it exists, then create it fresh.
72
73        Use this when you need to apply schema changes that cannot be done
74        via an in-place update (e.g. changing HNSW parameters or adding a
75        new vector field).
76
77        Warning: All documents are lost.  Only use in dev/staging or after a
78        full re-ingestion has been planned.
79        """

Delete the index if it exists, then create it fresh.

Use this when you need to apply schema changes that cannot be done via an in-place update (e.g. changing HNSW parameters or adding a new vector field).

Warning: All documents are lost. Only use in dev/staging or after a full re-ingestion has been planned.

@abstractmethod
def delete_index(self) -> None:
81    @abstractmethod
82    def delete_index(self) -> None:
83        """Permanently delete the index and all its documents.
84
85        Raises:
86            RuntimeError: If the index does not exist.
87        """

Permanently delete the index and all its documents.

Raises: RuntimeError: If the index does not exist.

@abstractmethod
def index_exists(self) -> bool:
89    @abstractmethod
90    def index_exists(self) -> bool:
91        """Return True if the index currently exists, False otherwise."""

Return True if the index currently exists, False otherwise.

@abstractmethod
def list_indexes(self) -> List[str]:
93    @abstractmethod
94    def list_indexes(self) -> List[str]:
95        """Return the names of all indexes on this backend/service."""

Return the names of all indexes on this backend/service.

class AzureAISearchIndexBuilder(gmf_forge_ai_data.BaseIndexBuilder):
 91class AzureAISearchIndexBuilder(BaseIndexBuilder):
 92    """
 93    Builds and manages Azure AI Search indexes with full developer control.
 94
 95    The builder owns *schema* concerns only.  Document operations (add,
 96    search, delete) belong to ``AzureAISearchVectorStore``.
 97
 98    Parameters
 99    ----------
100    endpoint:
101        Azure AI Search service endpoint URL.
102    api_key:
103        Azure AI Search admin API key. Use for local development or
104        when managed identity is not available.
105    token_provider:
106        Zero-argument callable that returns a bearer token string.
107        Use for managed identity / workload identity scenarios.
108        The callable must request the **Azure AI Search** scope::
109
110            from azure.identity import DefaultAzureCredential, get_bearer_token_provider
111            token_provider = get_bearer_token_provider(
112                DefaultAzureCredential(),
113                "https://search.azure.com/.default"
114            )
115
116        Note: this scope is different from Azure OpenAI / Cognitive Services
117        (``https://cognitiveservices.azure.com/.default``) — each service
118        requires its own token_provider.
119    index_name:
120        Name of the index to create / manage.
121    embedding_dimension:
122        Number of dimensions in the embedding vectors (must match the
123        embedding model — e.g. 1536 for text-embedding-ada-002, 3072 for
124        text-embedding-3-large).
125    document_type:
126        Optional Document subclass.  When provided, all dataclass fields
127        not in the base Document are automatically added as indexed fields
128        (filterable, sortable, facetable where appropriate).
129    hnsw_m:
130        Number of bi-directional links created per node.  Higher = better
131        recall but more memory.  Typical range 4–16.  Default: 4.
132    hnsw_ef_construction:
133        Size of the candidate list during index construction.  Higher =
134        better recall, slower build time.  Typical range 100–800.
135        Default: 400.
136    hnsw_ef_search:
137        Size of the candidate list during search.  Higher = better recall,
138        slower queries.  Typical range 100–1000.  Default: 500.
139    metric:
140        Similarity metric.  One of ``"cosine"``, ``"euclidean"``,
141        ``"dotProduct"``.  Default: ``"cosine"``.
142    ssl_cert_path:
143        Optional path to a PEM certificate bundle for corporate SSL
144        inspection proxies.  Sets ``REQUESTS_CA_BUNDLE`` and
145        ``SSL_CERT_FILE`` environment variables before building the client.
146    semantic_config:
147        Optional semantic search configuration.  When provided the index is
148        provisioned with a ``SemanticSearch`` configuration that enables
149        Azure AI semantic reranking (``BoostedRerankerScore``).
150
151        Expected keys:
152
153        - ``name`` (str) — semantic config name (default
154          ``"default-semantic-config"``)
155        - ``title_field`` (str, optional) — field used as the document title
156        - ``content_fields`` (list[str]) — primary body content fields
157        - ``keyword_fields`` (list[str], optional) — keyword/facet fields
158
159        Example::
160
161            {
162                "name": "policyhub-semantic-config",
163                "title_field": "document_name",
164                "content_fields": ["content"],
165                "keyword_fields": ["language", "locale", "source"],
166            }
167    """
168
169    def __init__(
170        self,
171        endpoint: str,
172        index_name: str,
173        api_key: Optional[str] = None,
174        token_provider: Optional[Callable[[], str]] = None,
175        embedding_dimension: int = 1536,
176        document_type: Type[Document] = Document,
177        hnsw_m: int = 4,
178        hnsw_ef_construction: int = 400,
179        hnsw_ef_search: int = 500,
180        metric: str = "cosine",
181        ssl_cert_path: Optional[str] = None,
182        semantic_config: Optional[dict] = None,
183    ) -> None:
184        self.index_name = index_name
185        self.embedding_dimension = embedding_dimension
186        self.document_type = document_type
187        self.hnsw_m = hnsw_m
188        self.hnsw_ef_construction = hnsw_ef_construction
189        self.hnsw_ef_search = hnsw_ef_search
190        self.metric = metric
191        self.semantic_config = semantic_config
192
193        if not api_key and not token_provider:
194            raise ValueError(
195                "Either api_key or token_provider must be supplied to AzureAISearchIndexBuilder."
196            )
197
198        if ssl_cert_path:
199            import os as _os
200            _os.environ.setdefault("REQUESTS_CA_BUNDLE", ssl_cert_path)
201            _os.environ.setdefault("SSL_CERT_FILE", ssl_cert_path)
202
203        if token_provider:
204            credential = _TokenProviderCredential(token_provider)
205        else:
206            credential = AzureKeyCredential(api_key)
207        self._index_client = SearchIndexClient(
208            endpoint=endpoint,
209            credential=credential,
210        )
211
212    # ------------------------------------------------------------------ #
213    # BaseIndexBuilder interface                                           #
214    # ------------------------------------------------------------------ #
215
216    def create_index(self) -> None:
217        """Create the index if it does not already exist (idempotent)."""
218        if self.index_exists():
219            logger.info("Index already exists — skipping creation", index=self.index_name)
220            return
221        self._create(self.index_name)
222        logger.info("Index created successfully", index=self.index_name)
223
224    def create_or_replace_index(self) -> None:
225        """Delete the existing index (if any) then create it fresh.
226
227        Warning: All documents are permanently lost.
228        """
229        if self.index_exists():
230            self._index_client.delete_index(self.index_name)
231            logger.info("Index deleted for replacement", index=self.index_name)
232        self._create(self.index_name)
233        logger.info("Index created (replaced)", index=self.index_name)
234
235    def delete_index(self) -> None:
236        """Permanently delete the index and all its documents.
237
238        Raises:
239            RuntimeError: If the index does not exist.
240        """
241        if not self.index_exists():
242            raise RuntimeError(
243                f"Cannot delete index '{self.index_name}': it does not exist."
244            )
245        self._index_client.delete_index(self.index_name)
246        logger.info("Index deleted", index=self.index_name)
247
248    def index_exists(self) -> bool:
249        """Return True if the index currently exists."""
250        try:
251            self._index_client.get_index(self.index_name)
252            return True
253        except ResourceNotFoundError:
254            return False
255        except Exception:
256            # Treat any other error as non-existence to keep callers safe
257            return False
258
259    def list_indexes(self) -> List[str]:
260        """Return the names of all indexes on this Azure AI Search service."""
261        return [idx.name for idx in self._index_client.list_indexes()]
262
263    # ------------------------------------------------------------------ #
264    # Internal helpers                                                     #
265    # ------------------------------------------------------------------ #
266
267    def _build_fields(self) -> list:
268        """Build the Azure Search field list from base + document_type fields."""
269        fields = [
270            SimpleField(
271                name="id",
272                type=SearchFieldDataType.String,
273                key=True,
274                filterable=True,
275            ),
276            SearchableField(
277                name="content",
278                type=SearchFieldDataType.String,
279                searchable=True,
280            ),
281            SearchField(
282                name="embedding",
283                type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
284                searchable=True,
285                vector_search_dimensions=self.embedding_dimension,
286                vector_search_profile_name="default-vector-profile",
287            ),
288            SimpleField(
289                name="timestamp",
290                type=SearchFieldDataType.DateTimeOffset,
291                filterable=True,
292                sortable=True,
293            ),
294            # Stores serialised metadata dict and any non-indexed custom fields
295            SimpleField(
296                name="document_data",
297                type=SearchFieldDataType.String,
298                filterable=False,
299            ),
300        ]
301
302        # Infer custom fields from the document_type dataclass
303        if dataclasses.is_dataclass(self.document_type):
304            base_field_names = {"id", "content", "embedding", "timestamp", "metadata"}
305            for field in dataclasses.fields(self.document_type):
306                if field.name in base_field_names:
307                    continue
308                azure_type = self._map_python_type(field.type)
309                if azure_type is None:
310                    continue
311                scalar_types = {
312                    SearchFieldDataType.String,
313                    SearchFieldDataType.Int32,
314                    SearchFieldDataType.Int64,
315                    SearchFieldDataType.Double,
316                    SearchFieldDataType.DateTimeOffset,
317                    SearchFieldDataType.Boolean,
318                }
319                # Per-field overrides from dataclass field metadata.
320                # Falls back to the original defaults when not specified, so
321                # existing document types without metadata are unaffected.
322                meta = field.metadata
323                is_searchable = meta.get("searchable", False)
324                is_filterable = meta.get("filterable", True)
325                is_sortable = meta.get("sortable", azure_type in scalar_types)
326                is_facetable = meta.get(
327                    "facetable",
328                    azure_type in {SearchFieldDataType.String, SearchFieldDataType.Boolean},
329                )
330                if is_searchable:
331                    fields.append(
332                        SearchableField(
333                            name=field.name,
334                            filterable=is_filterable,
335                            sortable=is_sortable,
336                            facetable=is_facetable,
337                        )
338                    )
339                else:
340                    fields.append(
341                        SimpleField(
342                            name=field.name,
343                            type=azure_type,
344                            filterable=is_filterable,
345                            sortable=is_sortable,
346                            facetable=is_facetable,
347                        )
348                    )
349                logger.info(
350                    "Added indexed field",
351                    field=field.name,
352                    azure_type=str(azure_type),
353                    searchable=is_searchable,
354                )
355
356        return fields
357
358    def _create(self, index_name: str) -> None:
359        """Internal: build and submit the index definition to Azure."""
360        fields = self._build_fields()
361
362        vector_search = VectorSearch(
363            algorithms=[
364                HnswAlgorithmConfiguration(
365                    name="default-hnsw",
366                    parameters={
367                        "m": self.hnsw_m,
368                        "efConstruction": self.hnsw_ef_construction,
369                        "efSearch": self.hnsw_ef_search,
370                        "metric": self.metric,
371                    },
372                )
373            ],
374            profiles=[
375                VectorSearchProfile(
376                    name="default-vector-profile",
377                    algorithm_configuration_name="default-hnsw",
378                )
379            ],
380        )
381
382        semantic_search = None
383        if self.semantic_config:
384            sc = self.semantic_config
385            title_field = (
386                SemanticField(field_name=sc["title_field"])
387                if sc.get("title_field") else None
388            )
389            semantic_search = SemanticSearch(
390                configurations=[
391                    SemanticConfiguration(
392                        name=sc.get("name", "default-semantic-config"),
393                        prioritized_fields=SemanticPrioritizedFields(
394                            title_field=title_field,
395                            content_fields=[
396                                SemanticField(field_name=f)
397                                for f in sc.get("content_fields", [])
398                            ],
399                            keywords_fields=[
400                                SemanticField(field_name=f)
401                                for f in sc.get("keyword_fields", [])
402                            ],
403                        ),
404                    )
405                ]
406            )
407
408        index = SearchIndex(
409            name=index_name,
410            fields=fields,
411            vector_search=vector_search,
412            semantic_search=semantic_search,
413        )
414
415        self._index_client.create_index(index)
416        logger.info(
417            "Azure AI Search index provisioned",
418            index=index_name,
419            dim=self.embedding_dimension,
420            metric=self.metric,
421            hnsw_m=self.hnsw_m,
422            ef_construction=self.hnsw_ef_construction,
423            ef_search=self.hnsw_ef_search,
424            fields=len(fields),
425        )
426
427    @staticmethod
428    def _map_python_type(python_type) -> Optional[SearchFieldDataType]:
429        """Map a Python / dataclass field type to an Azure Search field type."""
430        _map = {
431            str: SearchFieldDataType.String,
432            int: SearchFieldDataType.Int64,
433            float: SearchFieldDataType.Double,
434            bool: SearchFieldDataType.Boolean,
435            datetime: SearchFieldDataType.DateTimeOffset,
436        }
437
438        # Handle Optional[X]  →  extract X
439        if hasattr(python_type, "__origin__"):
440            args = getattr(python_type, "__args__", ())
441            for arg in args:
442                if arg is type(None):
443                    continue
444                return _map.get(arg, SearchFieldDataType.String)
445
446        # Handle string annotations
447        if isinstance(python_type, str):
448            s = python_type.lower()
449            if "datetime" in s:
450                return SearchFieldDataType.DateTimeOffset
451            if "int" in s:
452                return SearchFieldDataType.Int64
453            if "float" in s or "double" in s:
454                return SearchFieldDataType.Double
455            if "bool" in s:
456                return SearchFieldDataType.Boolean
457            return SearchFieldDataType.String
458
459        return _map.get(python_type, SearchFieldDataType.String)

Builds and manages Azure AI Search indexes with full developer control.

The builder owns schema concerns only. Document operations (add, search, delete) belong to AzureAISearchVectorStore.

Parameters

endpoint: Azure AI Search service endpoint URL. api_key: Azure AI Search admin API key. Use for local development or when managed identity is not available. token_provider: Zero-argument callable that returns a bearer token string. Use for managed identity / workload identity scenarios. The callable must request the Azure AI Search scope::

    from azure.identity import DefaultAzureCredential, get_bearer_token_provider
    token_provider = get_bearer_token_provider(
        DefaultAzureCredential(),
        "https://search.azure.com/.default"
    )

Note: this scope is different from Azure OpenAI / Cognitive Services
(``https://cognitiveservices.azure.com/.default``) — each service
requires its own token_provider.

index_name: Name of the index to create / manage. embedding_dimension: Number of dimensions in the embedding vectors (must match the embedding model — e.g. 1536 for text-embedding-ada-002, 3072 for text-embedding-3-large). document_type: Optional Document subclass. When provided, all dataclass fields not in the base Document are automatically added as indexed fields (filterable, sortable, facetable where appropriate). hnsw_m: Number of bi-directional links created per node. Higher = better recall but more memory. Typical range 4–16. Default: 4. hnsw_ef_construction: Size of the candidate list during index construction. Higher = better recall, slower build time. Typical range 100–800. Default: 400. hnsw_ef_search: Size of the candidate list during search. Higher = better recall, slower queries. Typical range 100–1000. Default: 500. metric: Similarity metric. One of "cosine", "euclidean", "dotProduct". Default: "cosine". ssl_cert_path: Optional path to a PEM certificate bundle for corporate SSL inspection proxies. Sets REQUESTS_CA_BUNDLE and SSL_CERT_FILE environment variables before building the client. semantic_config: Optional semantic search configuration. When provided the index is provisioned with a SemanticSearch configuration that enables Azure AI semantic reranking (BoostedRerankerScore).

Expected keys:

- ``name`` (str) — semantic config name (default
  ``"default-semantic-config"``)
- ``title_field`` (str, optional) — field used as the document title
- ``content_fields`` (list[str]) — primary body content fields
- ``keyword_fields`` (list[str], optional) — keyword/facet fields

Example::

    {
        "name": "policyhub-semantic-config",
        "title_field": "document_name",
        "content_fields": ["content"],
        "keyword_fields": ["language", "locale", "source"],
    }
AzureAISearchIndexBuilder( endpoint: str, index_name: str, api_key: Optional[str] = None, token_provider: Optional[Callable[[], str]] = None, embedding_dimension: int = 1536, document_type: Type[Document] = <class 'Document'>, hnsw_m: int = 4, hnsw_ef_construction: int = 400, hnsw_ef_search: int = 500, metric: str = 'cosine', ssl_cert_path: Optional[str] = None, semantic_config: Optional[dict] = None)
169    def __init__(
170        self,
171        endpoint: str,
172        index_name: str,
173        api_key: Optional[str] = None,
174        token_provider: Optional[Callable[[], str]] = None,
175        embedding_dimension: int = 1536,
176        document_type: Type[Document] = Document,
177        hnsw_m: int = 4,
178        hnsw_ef_construction: int = 400,
179        hnsw_ef_search: int = 500,
180        metric: str = "cosine",
181        ssl_cert_path: Optional[str] = None,
182        semantic_config: Optional[dict] = None,
183    ) -> None:
184        self.index_name = index_name
185        self.embedding_dimension = embedding_dimension
186        self.document_type = document_type
187        self.hnsw_m = hnsw_m
188        self.hnsw_ef_construction = hnsw_ef_construction
189        self.hnsw_ef_search = hnsw_ef_search
190        self.metric = metric
191        self.semantic_config = semantic_config
192
193        if not api_key and not token_provider:
194            raise ValueError(
195                "Either api_key or token_provider must be supplied to AzureAISearchIndexBuilder."
196            )
197
198        if ssl_cert_path:
199            import os as _os
200            _os.environ.setdefault("REQUESTS_CA_BUNDLE", ssl_cert_path)
201            _os.environ.setdefault("SSL_CERT_FILE", ssl_cert_path)
202
203        if token_provider:
204            credential = _TokenProviderCredential(token_provider)
205        else:
206            credential = AzureKeyCredential(api_key)
207        self._index_client = SearchIndexClient(
208            endpoint=endpoint,
209            credential=credential,
210        )
index_name
embedding_dimension
document_type
hnsw_m
hnsw_ef_construction
metric
semantic_config
def create_index(self) -> None:
216    def create_index(self) -> None:
217        """Create the index if it does not already exist (idempotent)."""
218        if self.index_exists():
219            logger.info("Index already exists — skipping creation", index=self.index_name)
220            return
221        self._create(self.index_name)
222        logger.info("Index created successfully", index=self.index_name)

Create the index if it does not already exist (idempotent).

def create_or_replace_index(self) -> None:
224    def create_or_replace_index(self) -> None:
225        """Delete the existing index (if any) then create it fresh.
226
227        Warning: All documents are permanently lost.
228        """
229        if self.index_exists():
230            self._index_client.delete_index(self.index_name)
231            logger.info("Index deleted for replacement", index=self.index_name)
232        self._create(self.index_name)
233        logger.info("Index created (replaced)", index=self.index_name)

Delete the existing index (if any) then create it fresh.

Warning: All documents are permanently lost.

def delete_index(self) -> None:
235    def delete_index(self) -> None:
236        """Permanently delete the index and all its documents.
237
238        Raises:
239            RuntimeError: If the index does not exist.
240        """
241        if not self.index_exists():
242            raise RuntimeError(
243                f"Cannot delete index '{self.index_name}': it does not exist."
244            )
245        self._index_client.delete_index(self.index_name)
246        logger.info("Index deleted", index=self.index_name)

Permanently delete the index and all its documents.

Raises: RuntimeError: If the index does not exist.

def index_exists(self) -> bool:
248    def index_exists(self) -> bool:
249        """Return True if the index currently exists."""
250        try:
251            self._index_client.get_index(self.index_name)
252            return True
253        except ResourceNotFoundError:
254            return False
255        except Exception:
256            # Treat any other error as non-existence to keep callers safe
257            return False

Return True if the index currently exists.

def list_indexes(self) -> List[str]:
259    def list_indexes(self) -> List[str]:
260        """Return the names of all indexes on this Azure AI Search service."""
261        return [idx.name for idx in self._index_client.list_indexes()]

Return the names of all indexes on this Azure AI Search service.

class CosmosDBIndexBuilder(gmf_forge_ai_data.BaseIndexBuilder):
 67class CosmosDBIndexBuilder(BaseIndexBuilder):
 68    """
 69    Builds and manages Cosmos DB databases and containers for vector search.
 70
 71    The builder owns *schema / provisioning* concerns only.  Document
 72    operations (add, search, delete) belong to ``AzureCosmosDBVectorStore``.
 73
 74    Parameters
 75    ----------
 76    endpoint:
 77        Cosmos DB account endpoint URL.
 78    api_key:
 79        Cosmos DB account primary or secondary key.
 80    database_name:
 81        Name of the Cosmos DB database to create / manage.
 82    container_name:
 83        Name of the container to create / manage.
 84    embedding_dimension:
 85        Number of dimensions in the embedding vectors.
 86    distance_function:
 87        Vector similarity function.  ``"cosine"`` (default), ``"euclidean"``,
 88        or ``"dotproduct"``.
 89    vector_index_type:
 90        Index structure for vector search.  ``"quantizedFlat"`` (default,
 91        lower memory) or ``"diskANN"`` (higher recall on large datasets).
 92    partition_key:
 93        Cosmos DB partition key path.  Default: ``"/id"``.
 94    throughput:
 95        Manual RU/s throughput for the container.  ``None`` uses the Cosmos
 96        DB account default.  Ignored if the container already exists.
 97    ssl_cert_path:
 98        Optional path to a PEM certificate bundle for corporate SSL
 99        inspection proxies.
100    """
101
102    def __init__(
103        self,
104        endpoint: str,
105        api_key: str,
106        database_name: str,
107        container_name: str,
108        embedding_dimension: int = 1536,
109        distance_function: DistanceFunction = "cosine",
110        vector_index_type: VectorIndexType = "quantizedFlat",
111        partition_key: str = "/id",
112        throughput: Optional[int] = None,
113        ssl_cert_path: Optional[str] = None,
114    ) -> None:
115        self.database_name = database_name
116        self.container_name = container_name
117        self.embedding_dimension = embedding_dimension
118        self.distance_function = distance_function
119        self.vector_index_type = vector_index_type
120        self.partition_key = partition_key
121        self.throughput = throughput
122        self._ssl_cert_path = ssl_cert_path
123        self._endpoint = endpoint
124        self._api_key = api_key
125
126        self._client = self._build_client(endpoint, api_key, ssl_cert_path)
127
128    # ------------------------------------------------------------------ #
129    # BaseIndexBuilder interface                                           #
130    # ------------------------------------------------------------------ #
131
132    def create_index(self) -> None:
133        """Create the Cosmos DB database and container if they don't exist.
134
135        Safe to call multiple times — no-op if both already exist.
136        """
137        self._ensure_database()
138        if self.index_exists():
139            logger.info(
140                "Cosmos DB container already exists — skipping creation",
141                database=self.database_name,
142                container=self.container_name,
143            )
144            return
145        self._create_container()
146        logger.info(
147            "Cosmos DB container created",
148            database=self.database_name,
149            container=self.container_name,
150            dim=self.embedding_dimension,
151            distance_function=self.distance_function,
152            vector_index_type=self.vector_index_type,
153        )
154
155    def create_or_replace_index(self) -> None:
156        """Delete the container if it exists then create it fresh.
157
158        Warning: All documents are permanently lost.
159        """
160        self._ensure_database()
161        if self.index_exists():
162            db = self._client.get_database_client(self.database_name)
163            db.delete_container(self.container_name)
164            logger.info(
165                "Cosmos DB container deleted for replacement",
166                database=self.database_name,
167                container=self.container_name,
168            )
169        self._create_container()
170        logger.info(
171            "Cosmos DB container created (replaced)",
172            database=self.database_name,
173            container=self.container_name,
174        )
175
176    def delete_index(self) -> None:
177        """Permanently delete the container and all its documents.
178
179        Raises:
180            RuntimeError: If the container does not exist.
181        """
182        if not self.index_exists():
183            raise RuntimeError(
184                f"Cannot delete container '{self.database_name}/{self.container_name}': "
185                "it does not exist."
186            )
187        db = self._client.get_database_client(self.database_name)
188        db.delete_container(self.container_name)
189        logger.info(
190            "Cosmos DB container deleted",
191            database=self.database_name,
192            container=self.container_name,
193        )
194
195    def index_exists(self) -> bool:
196        """Return True if the container currently exists."""
197        try:
198            db = self._client.get_database_client(self.database_name)
199            db.get_container_client(self.container_name).read()
200            return True
201        except Exception:
202            return False
203
204    def list_indexes(self) -> List[str]:
205        """Return the names of all containers in the database."""
206        try:
207            db = self._client.get_database_client(self.database_name)
208            return [c["id"] for c in db.list_containers()]
209        except Exception:
210            return []
211
212    # ------------------------------------------------------------------ #
213    # Internal helpers                                                     #
214    # ------------------------------------------------------------------ #
215
216    @staticmethod
217    def _build_client(endpoint: str, api_key: str, ssl_cert_path: Optional[str]):
218        """Build a CosmosClient, optionally with a corporate SSL bundle."""
219        from azure.cosmos import CosmosClient
220        kwargs = {"url": endpoint, "credential": api_key}
221        if ssl_cert_path:
222            import ssl, os
223            os.environ.setdefault("REQUESTS_CA_BUNDLE", ssl_cert_path)
224            os.environ.setdefault("SSL_CERT_FILE", ssl_cert_path)
225        return CosmosClient(**kwargs)
226
227    def _ensure_database(self) -> None:
228        """Create the database if it does not already exist."""
229        self._client.create_database_if_not_exists(self.database_name)
230
231    def _create_container(self) -> None:
232        """Create the container with vector embedding and indexing policies."""
233        from azure.cosmos import PartitionKey
234        from azure.cosmos.exceptions import CosmosHttpResponseError
235
236        db = self._client.get_database_client(self.database_name)
237
238        vector_embedding_policy = {
239            "vectorEmbeddings": [
240                {
241                    "path": "/embedding",
242                    "dataType": "float32",
243                    "distanceFunction": self.distance_function,
244                    "dimensions": self.embedding_dimension,
245                }
246            ]
247        }
248
249        indexing_policy = {
250            "includedPaths": [{"path": "/*"}],
251            "excludedPaths": [{"path": "/embedding/*"}],
252            "vectorIndexes": [
253                {"path": "/embedding", "type": self.vector_index_type}
254            ],
255        }
256
257        kwargs = dict(
258            id=self.container_name,
259            partition_key=PartitionKey(path=self.partition_key),
260            vector_embedding_policy=vector_embedding_policy,
261            indexing_policy=indexing_policy,
262        )
263        if self.throughput is not None:
264            kwargs["offer_throughput"] = self.throughput
265
266        try:
267            db.create_container(**kwargs)
268        except CosmosHttpResponseError as exc:
269            if "Vector Policy" in str(exc) or "capability" in str(exc):
270                raise RuntimeError(
271                    "Vector Search capability is not enabled on this Cosmos DB account. "
272                    "Enable via: az cosmosdb update --resource-group <RG> "
273                    "--name <ACCOUNT> --capabilities EnableNoSQLVectorSearch"
274                ) from exc
275            raise

Builds and manages Cosmos DB databases and containers for vector search.

The builder owns schema / provisioning concerns only. Document operations (add, search, delete) belong to AzureCosmosDBVectorStore.

Parameters

endpoint: Cosmos DB account endpoint URL. api_key: Cosmos DB account primary or secondary key. database_name: Name of the Cosmos DB database to create / manage. container_name: Name of the container to create / manage. embedding_dimension: Number of dimensions in the embedding vectors. distance_function: Vector similarity function. "cosine" (default), "euclidean", or "dotproduct". vector_index_type: Index structure for vector search. "quantizedFlat" (default, lower memory) or "diskANN" (higher recall on large datasets). partition_key: Cosmos DB partition key path. Default: "/id". throughput: Manual RU/s throughput for the container. None uses the Cosmos DB account default. Ignored if the container already exists. ssl_cert_path: Optional path to a PEM certificate bundle for corporate SSL inspection proxies.

CosmosDBIndexBuilder( endpoint: str, api_key: str, database_name: str, container_name: str, embedding_dimension: int = 1536, distance_function: Literal['cosine', 'euclidean', 'dotproduct'] = 'cosine', vector_index_type: Literal['quantizedFlat', 'diskANN'] = 'quantizedFlat', partition_key: str = '/id', throughput: Optional[int] = None, ssl_cert_path: Optional[str] = None)
102    def __init__(
103        self,
104        endpoint: str,
105        api_key: str,
106        database_name: str,
107        container_name: str,
108        embedding_dimension: int = 1536,
109        distance_function: DistanceFunction = "cosine",
110        vector_index_type: VectorIndexType = "quantizedFlat",
111        partition_key: str = "/id",
112        throughput: Optional[int] = None,
113        ssl_cert_path: Optional[str] = None,
114    ) -> None:
115        self.database_name = database_name
116        self.container_name = container_name
117        self.embedding_dimension = embedding_dimension
118        self.distance_function = distance_function
119        self.vector_index_type = vector_index_type
120        self.partition_key = partition_key
121        self.throughput = throughput
122        self._ssl_cert_path = ssl_cert_path
123        self._endpoint = endpoint
124        self._api_key = api_key
125
126        self._client = self._build_client(endpoint, api_key, ssl_cert_path)
database_name
container_name
embedding_dimension
distance_function
vector_index_type
partition_key
throughput
def create_index(self) -> None:
132    def create_index(self) -> None:
133        """Create the Cosmos DB database and container if they don't exist.
134
135        Safe to call multiple times — no-op if both already exist.
136        """
137        self._ensure_database()
138        if self.index_exists():
139            logger.info(
140                "Cosmos DB container already exists — skipping creation",
141                database=self.database_name,
142                container=self.container_name,
143            )
144            return
145        self._create_container()
146        logger.info(
147            "Cosmos DB container created",
148            database=self.database_name,
149            container=self.container_name,
150            dim=self.embedding_dimension,
151            distance_function=self.distance_function,
152            vector_index_type=self.vector_index_type,
153        )

Create the Cosmos DB database and container if they don't exist.

Safe to call multiple times — no-op if both already exist.

def create_or_replace_index(self) -> None:
155    def create_or_replace_index(self) -> None:
156        """Delete the container if it exists then create it fresh.
157
158        Warning: All documents are permanently lost.
159        """
160        self._ensure_database()
161        if self.index_exists():
162            db = self._client.get_database_client(self.database_name)
163            db.delete_container(self.container_name)
164            logger.info(
165                "Cosmos DB container deleted for replacement",
166                database=self.database_name,
167                container=self.container_name,
168            )
169        self._create_container()
170        logger.info(
171            "Cosmos DB container created (replaced)",
172            database=self.database_name,
173            container=self.container_name,
174        )

Delete the container if it exists then create it fresh.

Warning: All documents are permanently lost.

def delete_index(self) -> None:
176    def delete_index(self) -> None:
177        """Permanently delete the container and all its documents.
178
179        Raises:
180            RuntimeError: If the container does not exist.
181        """
182        if not self.index_exists():
183            raise RuntimeError(
184                f"Cannot delete container '{self.database_name}/{self.container_name}': "
185                "it does not exist."
186            )
187        db = self._client.get_database_client(self.database_name)
188        db.delete_container(self.container_name)
189        logger.info(
190            "Cosmos DB container deleted",
191            database=self.database_name,
192            container=self.container_name,
193        )

Permanently delete the container and all its documents.

Raises: RuntimeError: If the container does not exist.

def index_exists(self) -> bool:
195    def index_exists(self) -> bool:
196        """Return True if the container currently exists."""
197        try:
198            db = self._client.get_database_client(self.database_name)
199            db.get_container_client(self.container_name).read()
200            return True
201        except Exception:
202            return False

Return True if the container currently exists.

def list_indexes(self) -> List[str]:
204    def list_indexes(self) -> List[str]:
205        """Return the names of all containers in the database."""
206        try:
207            db = self._client.get_database_client(self.database_name)
208            return [c["id"] for c in db.list_containers()]
209        except Exception:
210            return []

Return the names of all containers in the database.

class MongoDBIndexBuilder(gmf_forge_ai_data.BaseIndexBuilder):
 63class MongoDBIndexBuilder(BaseIndexBuilder):
 64    """
 65    Builds and manages Atlas Vector Search and text indexes for a MongoDB
 66    collection.
 67
 68    The builder owns *schema / provisioning* concerns only.  Document
 69    operations belong to ``MongoDBVectorStore``.
 70
 71    Parameters
 72    ----------
 73    connection_string:
 74        MongoDB Atlas connection string, e.g.
 75        ``"mongodb+srv://user:pass@cluster.mongodb.net/"``.
 76    database_name:
 77        Name of the MongoDB database.
 78    collection_name:
 79        Name of the collection to index.
 80    embedding_dimension:
 81        Number of dimensions in the embedding vectors.
 82    document_type:
 83        Document dataclass whose *extra* fields will be added as Atlas
 84        filter fields (fields other than ``id``, ``content``, ``embedding``,
 85        ``timestamp``, and ``metadata``).
 86    vector_index_name:
 87        Name of the Atlas Vector Search index.  Must match the
 88        ``vector_index_name`` used when constructing ``MongoDBVectorStore``.
 89        Default: ``"vector_index"``.
 90    similarity:
 91        Similarity metric for the vector index.  ``"cosine"`` (default),
 92        ``"euclidean"``, or ``"dotProduct"``.
 93    extra_filter_paths:
 94        Additional document paths to register as Atlas filter fields beyond
 95        those inferred from *document_type*.  Useful for arbitrary metadata
 96        fields stored in the ``metadata`` sub-document.
 97    text_index_fields:
 98        Fields to include in the MongoDB ``$text`` full-text index.
 99        Default: ``["content"]``.
100    ssl_cert_path:
101        Path to a CA certificate bundle (PEM) for TLS verification in
102        corporate environments with custom certificate authorities.
103    """
104
105    _BASE_KEYS = frozenset({"id", "content", "embedding", "timestamp", "metadata"})
106
107    def __init__(
108        self,
109        connection_string: str,
110        database_name: str,
111        collection_name: str,
112        embedding_dimension: int = 1536,
113        document_type: Type[Document] = Document,
114        vector_index_name: str = "vector_index",
115        similarity: Similarity = "cosine",
116        extra_filter_paths: Optional[List[str]] = None,
117        text_index_fields: Optional[List[str]] = None,
118        ssl_cert_path: Optional[str] = None,
119    ) -> None:
120        try:
121            import pymongo  # noqa: F401
122        except ImportError as exc:
123            raise ImportError(
124                "pymongo is required for MongoDBIndexBuilder. "
125                "Install it with:  pip install pymongo"
126            ) from exc
127
128        import pymongo
129
130        self.database_name = database_name
131        self.collection_name = collection_name
132        self.embedding_dimension = embedding_dimension
133        self.document_type = document_type
134        self.vector_index_name = vector_index_name
135        self.similarity = similarity
136        self.extra_filter_paths: List[str] = extra_filter_paths or []
137        self.text_index_fields: List[str] = text_index_fields or ["content"]
138
139        client_kwargs: Dict[str, Any] = {}
140        if ssl_cert_path:
141            client_kwargs["tlsCAFile"] = ssl_cert_path
142
143        self._client = pymongo.MongoClient(connection_string, **client_kwargs)
144        self._db = self._client[database_name]
145        self._collection = self._db[collection_name]
146
147    # ------------------------------------------------------------------ #
148    # BaseIndexBuilder interface                                           #
149    # ------------------------------------------------------------------ #
150
151    def create_index(self) -> None:
152        """Create the Atlas Vector Search index and the text index if they
153        don't already exist.
154
155        Safe to call multiple times — each component is idempotent.
156        """
157        self._create_vector_index(replace=False)
158        self._create_text_index()
159
160    def create_or_replace_index(self) -> None:
161        """Drop the Atlas Vector Search index if it exists, then create it
162        fresh alongside the text index.
163
164        The text index is not recreated if it already exists (text indexes
165        are schema-agnostic and need no replacement).
166
167        Warning: Existing vector index data is lost.
168        """
169        import time
170
171        if self.index_exists():
172            self._collection.drop_search_index(self.vector_index_name)
173            logger.info(
174                "Atlas Vector Search index dropped for replacement",
175                index=self.vector_index_name,
176            )
177            # Atlas drops are asynchronous — poll until the index is gone
178            # before submitting the creation request with the same name.
179            deadline = time.monotonic() + 60
180            while time.monotonic() < deadline:
181                if self.vector_index_name not in [
182                    idx["name"] for idx in self._collection.list_search_indexes()
183                ]:
184                    break
185                time.sleep(2)
186            else:
187                raise RuntimeError(
188                    f"Timed out waiting for Atlas to finish dropping index "
189                    f"'{self.vector_index_name}'. Try again in a moment."
190                )
191        self._create_vector_index(replace=True)
192        self._create_text_index()
193
194    def delete_index(self) -> None:
195        """Drop the Atlas Vector Search index.
196
197        The MongoDB text index (``content_text``) is left in place because
198        it is independent of vector dimensionality.
199
200        Raises:
201            RuntimeError: If the vector index does not exist.
202        """
203        if not self.index_exists():
204            raise RuntimeError(
205                f"Cannot delete vector index '{self.vector_index_name}' on "
206                f"'{self.database_name}.{self.collection_name}': it does not exist."
207            )
208        self._collection.drop_search_index(self.vector_index_name)
209        logger.info(
210            "Atlas Vector Search index deleted",
211            index=self.vector_index_name,
212            database=self.database_name,
213            collection=self.collection_name,
214        )
215
216    def index_exists(self) -> bool:
217        """Return True if the Atlas Vector Search index currently exists."""
218        existing = [idx["name"] for idx in self._collection.list_search_indexes()]
219        return self.vector_index_name in existing
220
221    def list_indexes(self) -> List[str]:
222        """Return the names of all Atlas Vector Search indexes on the collection."""
223        return [idx["name"] for idx in self._collection.list_search_indexes()]
224
225    # ------------------------------------------------------------------ #
226    # Additional helpers                                                   #
227    # ------------------------------------------------------------------ #
228
229    def list_text_indexes(self) -> List[str]:
230        """Return the names of all standard MongoDB indexes on the collection."""
231        return [idx["name"] for idx in self._collection.list_indexes()]
232
233    # ------------------------------------------------------------------ #
234    # Internal helpers                                                     #
235    # ------------------------------------------------------------------ #
236
237    def _build_filter_fields(self) -> List[Dict[str, str]]:
238        """Build the list of Atlas filter field definitions.
239
240        Includes custom fields inferred from *document_type* plus any
241        *extra_filter_paths* provided at construction.
242        """
243        paths = set(self.extra_filter_paths)
244
245        if dataclasses.is_dataclass(self.document_type):
246            for field in dataclasses.fields(self.document_type):
247                if field.name not in self._BASE_KEYS:
248                    paths.add(field.name)
249
250        return [{"type": "filter", "path": p} for p in sorted(paths)]
251
252    def _create_vector_index(self, replace: bool = False) -> None:
253        """Submit the Atlas Vector Search index creation request."""
254        if not replace and self.index_exists():
255            logger.info(
256                "Atlas Vector Search index already exists — skipping",
257                index=self.vector_index_name,
258                database=self.database_name,
259                collection=self.collection_name,
260            )
261            return
262
263        filter_fields = self._build_filter_fields()
264
265        index_spec: Dict[str, Any] = {
266            "name": self.vector_index_name,
267            "type": "vectorSearch",
268            "definition": {
269                "fields": [
270                    {
271                        "type": "vector",
272                        "path": "embedding",
273                        "numDimensions": self.embedding_dimension,
274                        "similarity": self.similarity,
275                    },
276                    *filter_fields,
277                ]
278            },
279        }
280
281        self._collection.create_search_index(index_spec)
282        logger.info(
283            "Atlas Vector Search index created",
284            index=self.vector_index_name,
285            database=self.database_name,
286            collection=self.collection_name,
287            dim=self.embedding_dimension,
288            similarity=self.similarity,
289            filter_fields=len(filter_fields),
290        )
291
292    def _create_text_index(self) -> None:
293        """Create a MongoDB ``$text`` index if it does not already exist."""
294        existing = [idx["name"] for idx in self._collection.list_indexes()]
295        if "content_text" in existing:
296            logger.info(
297                "Text index already exists — skipping",
298                index="content_text",
299                database=self.database_name,
300                collection=self.collection_name,
301            )
302            return
303
304        keys = [(field, "text") for field in self.text_index_fields]
305        self._collection.create_index(keys, name="content_text")
306        logger.info(
307            "Text index created",
308            index="content_text",
309            database=self.database_name,
310            collection=self.collection_name,
311            fields=self.text_index_fields,
312        )

Builds and manages Atlas Vector Search and text indexes for a MongoDB collection.

The builder owns schema / provisioning concerns only. Document operations belong to MongoDBVectorStore.

Parameters

connection_string: MongoDB Atlas connection string, e.g. "mongodb+srv://user:pass@cluster.mongodb.net/". database_name: Name of the MongoDB database. collection_name: Name of the collection to index. embedding_dimension: Number of dimensions in the embedding vectors. document_type: Document dataclass whose extra fields will be added as Atlas filter fields (fields other than id, content, embedding, timestamp, and metadata). vector_index_name: Name of the Atlas Vector Search index. Must match the vector_index_name used when constructing MongoDBVectorStore. Default: "vector_index". similarity: Similarity metric for the vector index. "cosine" (default), "euclidean", or "dotProduct". extra_filter_paths: Additional document paths to register as Atlas filter fields beyond those inferred from document_type. Useful for arbitrary metadata fields stored in the metadata sub-document. text_index_fields: Fields to include in the MongoDB $text full-text index. Default: ["content"]. ssl_cert_path: Path to a CA certificate bundle (PEM) for TLS verification in corporate environments with custom certificate authorities.

MongoDBIndexBuilder( connection_string: str, database_name: str, collection_name: str, embedding_dimension: int = 1536, document_type: Type[Document] = <class 'Document'>, vector_index_name: str = 'vector_index', similarity: Literal['cosine', 'euclidean', 'dotProduct'] = 'cosine', extra_filter_paths: Optional[List[str]] = None, text_index_fields: Optional[List[str]] = None, ssl_cert_path: Optional[str] = None)
107    def __init__(
108        self,
109        connection_string: str,
110        database_name: str,
111        collection_name: str,
112        embedding_dimension: int = 1536,
113        document_type: Type[Document] = Document,
114        vector_index_name: str = "vector_index",
115        similarity: Similarity = "cosine",
116        extra_filter_paths: Optional[List[str]] = None,
117        text_index_fields: Optional[List[str]] = None,
118        ssl_cert_path: Optional[str] = None,
119    ) -> None:
120        try:
121            import pymongo  # noqa: F401
122        except ImportError as exc:
123            raise ImportError(
124                "pymongo is required for MongoDBIndexBuilder. "
125                "Install it with:  pip install pymongo"
126            ) from exc
127
128        import pymongo
129
130        self.database_name = database_name
131        self.collection_name = collection_name
132        self.embedding_dimension = embedding_dimension
133        self.document_type = document_type
134        self.vector_index_name = vector_index_name
135        self.similarity = similarity
136        self.extra_filter_paths: List[str] = extra_filter_paths or []
137        self.text_index_fields: List[str] = text_index_fields or ["content"]
138
139        client_kwargs: Dict[str, Any] = {}
140        if ssl_cert_path:
141            client_kwargs["tlsCAFile"] = ssl_cert_path
142
143        self._client = pymongo.MongoClient(connection_string, **client_kwargs)
144        self._db = self._client[database_name]
145        self._collection = self._db[collection_name]
database_name
collection_name
embedding_dimension
document_type
vector_index_name
similarity
extra_filter_paths: List[str]
text_index_fields: List[str]
def create_index(self) -> None:
151    def create_index(self) -> None:
152        """Create the Atlas Vector Search index and the text index if they
153        don't already exist.
154
155        Safe to call multiple times — each component is idempotent.
156        """
157        self._create_vector_index(replace=False)
158        self._create_text_index()

Create the Atlas Vector Search index and the text index if they don't already exist.

Safe to call multiple times — each component is idempotent.

def create_or_replace_index(self) -> None:
160    def create_or_replace_index(self) -> None:
161        """Drop the Atlas Vector Search index if it exists, then create it
162        fresh alongside the text index.
163
164        The text index is not recreated if it already exists (text indexes
165        are schema-agnostic and need no replacement).
166
167        Warning: Existing vector index data is lost.
168        """
169        import time
170
171        if self.index_exists():
172            self._collection.drop_search_index(self.vector_index_name)
173            logger.info(
174                "Atlas Vector Search index dropped for replacement",
175                index=self.vector_index_name,
176            )
177            # Atlas drops are asynchronous — poll until the index is gone
178            # before submitting the creation request with the same name.
179            deadline = time.monotonic() + 60
180            while time.monotonic() < deadline:
181                if self.vector_index_name not in [
182                    idx["name"] for idx in self._collection.list_search_indexes()
183                ]:
184                    break
185                time.sleep(2)
186            else:
187                raise RuntimeError(
188                    f"Timed out waiting for Atlas to finish dropping index "
189                    f"'{self.vector_index_name}'. Try again in a moment."
190                )
191        self._create_vector_index(replace=True)
192        self._create_text_index()

Drop the Atlas Vector Search index if it exists, then create it fresh alongside the text index.

The text index is not recreated if it already exists (text indexes are schema-agnostic and need no replacement).

Warning: Existing vector index data is lost.

def delete_index(self) -> None:
194    def delete_index(self) -> None:
195        """Drop the Atlas Vector Search index.
196
197        The MongoDB text index (``content_text``) is left in place because
198        it is independent of vector dimensionality.
199
200        Raises:
201            RuntimeError: If the vector index does not exist.
202        """
203        if not self.index_exists():
204            raise RuntimeError(
205                f"Cannot delete vector index '{self.vector_index_name}' on "
206                f"'{self.database_name}.{self.collection_name}': it does not exist."
207            )
208        self._collection.drop_search_index(self.vector_index_name)
209        logger.info(
210            "Atlas Vector Search index deleted",
211            index=self.vector_index_name,
212            database=self.database_name,
213            collection=self.collection_name,
214        )

Drop the Atlas Vector Search index.

The MongoDB text index (content_text) is left in place because it is independent of vector dimensionality.

Raises: RuntimeError: If the vector index does not exist.

def index_exists(self) -> bool:
216    def index_exists(self) -> bool:
217        """Return True if the Atlas Vector Search index currently exists."""
218        existing = [idx["name"] for idx in self._collection.list_search_indexes()]
219        return self.vector_index_name in existing

Return True if the Atlas Vector Search index currently exists.

def list_indexes(self) -> List[str]:
221    def list_indexes(self) -> List[str]:
222        """Return the names of all Atlas Vector Search indexes on the collection."""
223        return [idx["name"] for idx in self._collection.list_search_indexes()]

Return the names of all Atlas Vector Search indexes on the collection.

def list_text_indexes(self) -> List[str]:
229    def list_text_indexes(self) -> List[str]:
230        """Return the names of all standard MongoDB indexes on the collection."""
231        return [idx["name"] for idx in self._collection.list_indexes()]

Return the names of all standard MongoDB indexes on the collection.

class RelevanceFilter:
20class RelevanceFilter:
21    """
22    Filters retrieved documents by minimum similarity score.
23
24    Vector stores return results ordered by score but never automatically
25    remove low-confidence matches. The filter enforces a quality floor so
26    that weakly related documents are excluded before they reach the LLM.
27
28    The threshold depends on the embedding model and index:
29    - text-embedding-ada-002 / cosine: 0.75–0.85 is typical
30    - Lower threshold (0.6): broad recall, more noise
31    - Higher threshold (0.9): tight precision, fewer results
32
33    Example:
34        ```python
35        from gmf_forge_ai_data.context import RelevanceFilter
36
37        filter_ = RelevanceFilter(min_score=0.80)
38        kept = filter_.filter(retrieved_results)
39        # Only results with score >= 0.80 are returned
40        ```
41    """
42
43    def __init__(self, min_score: float = 0.75):
44        """
45        Args:
46            min_score: Minimum similarity score in [0, 1] to retain a result.
47                       Results with score < min_score are discarded.
48        """
49        if not 0.0 <= min_score <= 1.0:
50            raise ValueError(f"min_score must be in [0, 1], got {min_score}")
51        self.min_score = min_score
52
53    def filter(self, results: List[SearchResult]) -> List[SearchResult]:
54        """
55        Remove results below the minimum score threshold.
56
57        Rank values are NOT reassigned — they preserve the original retrieval
58        rank for traceability. Use a downstream component or sort afterward
59        if contiguous ranks are required.
60
61        Args:
62            results: Retrieved SearchResult list, typically sorted by score.
63
64        Returns:
65            Subset of results where score >= min_score, in original order.
66        """
67        return [r for r in results if r.score >= self.min_score]

Filters retrieved documents by minimum similarity score.

Vector stores return results ordered by score but never automatically remove low-confidence matches. The filter enforces a quality floor so that weakly related documents are excluded before they reach the LLM.

The threshold depends on the embedding model and index:

  • text-embedding-ada-002 / cosine: 0.75–0.85 is typical
  • Lower threshold (0.6): broad recall, more noise
  • Higher threshold (0.9): tight precision, fewer results

Example:

from gmf_forge_ai_data.context import RelevanceFilter

filter_ = RelevanceFilter(min_score=0.80)
kept = filter_.filter(retrieved_results)
# Only results with score >= 0.80 are returned
RelevanceFilter(min_score: float = 0.75)
43    def __init__(self, min_score: float = 0.75):
44        """
45        Args:
46            min_score: Minimum similarity score in [0, 1] to retain a result.
47                       Results with score < min_score are discarded.
48        """
49        if not 0.0 <= min_score <= 1.0:
50            raise ValueError(f"min_score must be in [0, 1], got {min_score}")
51        self.min_score = min_score

Args: min_score: Minimum similarity score in [0, 1] to retain a result. Results with score < min_score are discarded.

min_score
def filter( self, results: List[SearchResult]) -> List[SearchResult]:
53    def filter(self, results: List[SearchResult]) -> List[SearchResult]:
54        """
55        Remove results below the minimum score threshold.
56
57        Rank values are NOT reassigned — they preserve the original retrieval
58        rank for traceability. Use a downstream component or sort afterward
59        if contiguous ranks are required.
60
61        Args:
62            results: Retrieved SearchResult list, typically sorted by score.
63
64        Returns:
65            Subset of results where score >= min_score, in original order.
66        """
67        return [r for r in results if r.score >= self.min_score]

Remove results below the minimum score threshold.

Rank values are NOT reassigned — they preserve the original retrieval rank for traceability. Use a downstream component or sort afterward if contiguous ranks are required.

Args: results: Retrieved SearchResult list, typically sorted by score.

Returns: Subset of results where score >= min_score, in original order.

class ContextDeduplicator:
 20class ContextDeduplicator:
 21    """
 22    Removes near-duplicate passages from a retrieved result list.
 23
 24    Two documents are considered near-duplicates if their character trigram
 25    Jaccard similarity exceeds `similarity_threshold`. The higher-ranked
 26    (lower `rank` value / earlier in list) document is always kept.
 27
 28    Near-duplicates arise from:
 29    - Overlapping chunk windows (sliding-window chunking creates 30–50% overlap)
 30    - The same passage indexed in multiple documents
 31    - Repeated LLM-generated content in a knowledge base
 32
 33    Example:
 34        ```python
 35        from gmf_forge_ai_data.context import ContextDeduplicator
 36
 37        deduper = ContextDeduplicator(similarity_threshold=0.85)
 38        unique = deduper.deduplicate(retrieved_results)
 39        # Near-identical passages are removed, keeping the higher-ranked one
 40        ```
 41    """
 42
 43    def __init__(self, similarity_threshold: float = 0.85, ngram_size: int = 3):
 44        """
 45        Args:
 46            similarity_threshold: Jaccard similarity above which two documents
 47                                  are treated as duplicates. Range [0, 1].
 48                                  0.85 catches heavy overlaps; lower values
 49                                  (0.6–0.7) catch paraphrase-level duplicates.
 50            ngram_size:           Character n-gram size for fingerprinting.
 51                                  3 (trigrams) is a good default.
 52        """
 53        if not 0.0 <= similarity_threshold <= 1.0:
 54            raise ValueError(
 55                f"similarity_threshold must be in [0, 1], got {similarity_threshold}"
 56            )
 57        self.similarity_threshold = similarity_threshold
 58        self.ngram_size = ngram_size
 59
 60    def deduplicate(self, results: List[SearchResult]) -> List[SearchResult]:
 61        """
 62        Remove near-duplicate passages, keeping the highest-ranked copy.
 63
 64        Processes results in order (index 0 = highest rank). For each result,
 65        checks if it is a near-duplicate of any already-kept result. If so,
 66        it is discarded; otherwise it is added to the output.
 67
 68        Args:
 69            results: List of SearchResult objects, ordered by relevance rank.
 70
 71        Returns:
 72            Deduplicated list in the same relative order.
 73        """
 74        kept: List[SearchResult] = []
 75        kept_fingerprints: List[Set[str]] = []
 76
 77        for result in results:
 78            fp = self._fingerprint(result.document.content)
 79            if not any(
 80                self._jaccard(fp, existing) >= self.similarity_threshold
 81                for existing in kept_fingerprints
 82            ):
 83                kept.append(result)
 84                kept_fingerprints.append(fp)
 85
 86        return kept
 87
 88    def _fingerprint(self, text: str) -> Set[str]:
 89        """Build a character n-gram set for the text."""
 90        n = self.ngram_size
 91        text = text.lower()
 92        if len(text) < n:
 93            return {text}
 94        return {text[i:i + n] for i in range(len(text) - n + 1)}
 95
 96    @staticmethod
 97    def _jaccard(a: Set[str], b: Set[str]) -> float:
 98        if not a or not b:
 99            return 0.0
100        return len(a & b) / len(a | b)

Removes near-duplicate passages from a retrieved result list.

Two documents are considered near-duplicates if their character trigram Jaccard similarity exceeds similarity_threshold. The higher-ranked (lower rank value / earlier in list) document is always kept.

Near-duplicates arise from:

  • Overlapping chunk windows (sliding-window chunking creates 30–50% overlap)
  • The same passage indexed in multiple documents
  • Repeated LLM-generated content in a knowledge base

Example:

from gmf_forge_ai_data.context import ContextDeduplicator

deduper = ContextDeduplicator(similarity_threshold=0.85)
unique = deduper.deduplicate(retrieved_results)
# Near-identical passages are removed, keeping the higher-ranked one
ContextDeduplicator(similarity_threshold: float = 0.85, ngram_size: int = 3)
43    def __init__(self, similarity_threshold: float = 0.85, ngram_size: int = 3):
44        """
45        Args:
46            similarity_threshold: Jaccard similarity above which two documents
47                                  are treated as duplicates. Range [0, 1].
48                                  0.85 catches heavy overlaps; lower values
49                                  (0.6–0.7) catch paraphrase-level duplicates.
50            ngram_size:           Character n-gram size for fingerprinting.
51                                  3 (trigrams) is a good default.
52        """
53        if not 0.0 <= similarity_threshold <= 1.0:
54            raise ValueError(
55                f"similarity_threshold must be in [0, 1], got {similarity_threshold}"
56            )
57        self.similarity_threshold = similarity_threshold
58        self.ngram_size = ngram_size

Args: similarity_threshold: Jaccard similarity above which two documents are treated as duplicates. Range [0, 1]. 0.85 catches heavy overlaps; lower values (0.6–0.7) catch paraphrase-level duplicates. ngram_size: Character n-gram size for fingerprinting. 3 (trigrams) is a good default.

similarity_threshold
ngram_size
def deduplicate( self, results: List[SearchResult]) -> List[SearchResult]:
60    def deduplicate(self, results: List[SearchResult]) -> List[SearchResult]:
61        """
62        Remove near-duplicate passages, keeping the highest-ranked copy.
63
64        Processes results in order (index 0 = highest rank). For each result,
65        checks if it is a near-duplicate of any already-kept result. If so,
66        it is discarded; otherwise it is added to the output.
67
68        Args:
69            results: List of SearchResult objects, ordered by relevance rank.
70
71        Returns:
72            Deduplicated list in the same relative order.
73        """
74        kept: List[SearchResult] = []
75        kept_fingerprints: List[Set[str]] = []
76
77        for result in results:
78            fp = self._fingerprint(result.document.content)
79            if not any(
80                self._jaccard(fp, existing) >= self.similarity_threshold
81                for existing in kept_fingerprints
82            ):
83                kept.append(result)
84                kept_fingerprints.append(fp)
85
86        return kept

Remove near-duplicate passages, keeping the highest-ranked copy.

Processes results in order (index 0 = highest rank). For each result, checks if it is a near-duplicate of any already-kept result. If so, it is discarded; otherwise it is added to the output.

Args: results: List of SearchResult objects, ordered by relevance rank.

Returns: Deduplicated list in the same relative order.

class ContextReranker:
 22class ContextReranker:
 23    """
 24    Reranks retrieved documents by LLM relevance scoring.
 25
 26    Sends the query and all retrieved document passages to the LLM and asks it
 27    to order them from most to least relevant. Returns a new list of
 28    SearchResult objects in the LLM-determined order, with scores reassigned
 29    to reflect the new ranking.
 30
 31    The reranker is most valuable when:
 32    - Vector similarity scores are clustered closely (hard to choose top-k)
 33    - The query requires reasoning rather than lexical overlap
 34    - Final context window is small and every slot matters
 35
 36    Example:
 37        ```python
 38        from gmf_forge_ai_data.context import ContextReranker
 39
 40        reranker = ContextReranker(llm_gateway)
 41        reranked = await reranker.rerank(
 42            query="What penalties apply for antitrust violations?",
 43            results=retrieved_results,
 44            top_k=3,
 45        )
 46        # reranked[0] is now the most relevant doc by LLM judgement
 47        ```
 48    """
 49
 50    _RERANK_PROMPT = (
 51        "You are a document relevance ranking assistant for a RAG pipeline.\n\n"
 52        "Query: {query}\n\n"
 53        "Below are {count} retrieved documents.\n"
 54        "Return ONLY a comma-separated list of document indices ordered from MOST "
 55        "to LEAST relevant to the query. Include every index exactly once.\n"
 56        "Example for 4 documents: 2, 0, 3, 1\n\n"
 57        "Documents:\n{documents}\n\n"
 58        "Ranked indices (most to least relevant):"
 59    )
 60
 61    def __init__(self, llm_gateway: UnifiedLLMGateway, temperature: float = 0.0):
 62        """
 63        Args:
 64            llm_gateway: LLM gateway for relevance assessment.
 65            temperature: Sampling temperature (default 0.0 for deterministic
 66                         ranking). Keep low — ranking should be consistent.
 67        """
 68        self.llm_gateway = llm_gateway
 69        self.temperature = temperature
 70
 71    async def rerank(
 72        self,
 73        query: str,
 74        results: List[SearchResult],
 75        top_k: Optional[int] = None,
 76    ) -> List[SearchResult]:
 77        """
 78        Rerank retrieved documents by LLM relevance to the query.
 79
 80        If the LLM returns an unparseable response, the original order is
 81        preserved so the pipeline does not break.
 82
 83        Args:
 84            query:  The original user query context for relevance judgement.
 85            results: List of SearchResult objects from retrieval.
 86            top_k:  If provided, return only the top-k results after reranking.
 87
 88        Returns:
 89            SearchResult list in new LLM-determined relevance order, with
 90            scores reassigned as 1.0, (n-1)/n, (n-2)/n, ... reflecting rank.
 91        """
 92        if not results:
 93            return results
 94
 95        documents_text = "\n\n".join(
 96            f"[{i}] {r.document.content}"
 97            for i, r in enumerate(results)
 98        )
 99
100        prompt = self._RERANK_PROMPT.format(
101            query=query,
102            count=len(results),
103            documents=documents_text,
104        )
105
106        response = await self.llm_gateway.complete(
107            prompt=prompt,
108            temperature=self.temperature,
109            max_tokens=100,
110        )
111
112        ranked_indices = self._parse_ranked_indices(response.content, len(results))
113
114        n = len(ranked_indices)
115        reranked: List[SearchResult] = []
116        for new_rank, idx in enumerate(ranked_indices):
117            original = results[idx]
118            score = round(1.0 - (new_rank / n), 4) if n > 1 else 1.0
119            reranked.append(SearchResult(
120                document=original.document,
121                score=score,
122                rank=new_rank,
123            ))
124
125        return reranked[:top_k] if top_k is not None else reranked
126
127    @staticmethod
128    def _parse_ranked_indices(text: str, count: int) -> List[int]:
129        """
130        Parse a comma-separated index list from LLM output.
131
132        Falls back to the original order [0, 1, 2, ...] if the response
133        cannot be parsed or contains out-of-range indices.
134        """
135        try:
136            parts = [p.strip() for p in text.strip().split(",")]
137            indices = [int(p) for p in parts if p.isdigit()]
138            # Validate: all indices in range, no duplicates
139            if sorted(indices) == list(range(count)):
140                return indices
141            # Partial list: include missing indices at the end
142            seen = set(indices)
143            for i in range(count):
144                if i not in seen:
145                    indices.append(i)
146            return indices
147        except (ValueError, AttributeError):
148            return list(range(count))

Reranks retrieved documents by LLM relevance scoring.

Sends the query and all retrieved document passages to the LLM and asks it to order them from most to least relevant. Returns a new list of SearchResult objects in the LLM-determined order, with scores reassigned to reflect the new ranking.

The reranker is most valuable when:

  • Vector similarity scores are clustered closely (hard to choose top-k)
  • The query requires reasoning rather than lexical overlap
  • Final context window is small and every slot matters

Example:

from gmf_forge_ai_data.context import ContextReranker

reranker = ContextReranker(llm_gateway)
reranked = await reranker.rerank(
    query="What penalties apply for antitrust violations?",
    results=retrieved_results,
    top_k=3,
)
# reranked[0] is now the most relevant doc by LLM judgement
ContextReranker( llm_gateway: gmf_forge_ai_shared_core.llm_gateway.UnifiedLLMGateway, temperature: float = 0.0)
61    def __init__(self, llm_gateway: UnifiedLLMGateway, temperature: float = 0.0):
62        """
63        Args:
64            llm_gateway: LLM gateway for relevance assessment.
65            temperature: Sampling temperature (default 0.0 for deterministic
66                         ranking). Keep low — ranking should be consistent.
67        """
68        self.llm_gateway = llm_gateway
69        self.temperature = temperature

Args: llm_gateway: LLM gateway for relevance assessment. temperature: Sampling temperature (default 0.0 for deterministic ranking). Keep low — ranking should be consistent.

llm_gateway
temperature
async def rerank( self, query: str, results: List[SearchResult], top_k: Optional[int] = None) -> List[SearchResult]:
 71    async def rerank(
 72        self,
 73        query: str,
 74        results: List[SearchResult],
 75        top_k: Optional[int] = None,
 76    ) -> List[SearchResult]:
 77        """
 78        Rerank retrieved documents by LLM relevance to the query.
 79
 80        If the LLM returns an unparseable response, the original order is
 81        preserved so the pipeline does not break.
 82
 83        Args:
 84            query:  The original user query context for relevance judgement.
 85            results: List of SearchResult objects from retrieval.
 86            top_k:  If provided, return only the top-k results after reranking.
 87
 88        Returns:
 89            SearchResult list in new LLM-determined relevance order, with
 90            scores reassigned as 1.0, (n-1)/n, (n-2)/n, ... reflecting rank.
 91        """
 92        if not results:
 93            return results
 94
 95        documents_text = "\n\n".join(
 96            f"[{i}] {r.document.content}"
 97            for i, r in enumerate(results)
 98        )
 99
100        prompt = self._RERANK_PROMPT.format(
101            query=query,
102            count=len(results),
103            documents=documents_text,
104        )
105
106        response = await self.llm_gateway.complete(
107            prompt=prompt,
108            temperature=self.temperature,
109            max_tokens=100,
110        )
111
112        ranked_indices = self._parse_ranked_indices(response.content, len(results))
113
114        n = len(ranked_indices)
115        reranked: List[SearchResult] = []
116        for new_rank, idx in enumerate(ranked_indices):
117            original = results[idx]
118            score = round(1.0 - (new_rank / n), 4) if n > 1 else 1.0
119            reranked.append(SearchResult(
120                document=original.document,
121                score=score,
122                rank=new_rank,
123            ))
124
125        return reranked[:top_k] if top_k is not None else reranked

Rerank retrieved documents by LLM relevance to the query.

If the LLM returns an unparseable response, the original order is preserved so the pipeline does not break.

Args: query: The original user query context for relevance judgement. results: List of SearchResult objects from retrieval. top_k: If provided, return only the top-k results after reranking.

Returns: SearchResult list in new LLM-determined relevance order, with scores reassigned as 1.0, (n-1)/n, (n-2)/n, ... reflecting rank.

class ContextCompressor:
 24class ContextCompressor:
 25    """
 26    Compresses retrieved passages to only the query-relevant sentences.
 27
 28    Calls the LLM once per retrieved document, asking it to extract only the
 29    sentences that directly help answer the query. The original SearchResult
 30    is preserved — only the `content` field is replaced with the compressed
 31    version. Score and rank are unchanged.
 32
 33    When to use:
 34    - Chunks are long (>500 tokens) and only partially relevant
 35    - Context window budget is tight
 36    - You want to reduce LLM distraction from off-topic content in chunks
 37
 38    Example:
 39        ```python
 40        from gmf_forge_ai_data.context import ContextCompressor
 41
 42        compressor = ContextCompressor(llm_gateway)
 43        compressed = await compressor.compress(
 44            query="What is self-attention in transformers?",
 45            results=reranked_results,
 46        )
 47        # Each result.document.content now contains only the relevant sentences
 48        ```
 49    """
 50
 51    _COMPRESS_PROMPT = (
 52        "You are a document compression assistant for a RAG pipeline.\n\n"
 53        "Extract ONLY the sentences from the passage below that are directly "
 54        "relevant to answering the query. Drop redundant, tangential, or clearly "
 55        "off-topic sentences.\n"
 56        "Return ONLY the extracted text — no explanation, no bullet points, no markdown.\n"
 57        "If the entire passage is already focused and relevant, return it unchanged.\n\n"
 58        "STRICT RULES:\n"
 59        "- Copy sentences EXACTLY as they appear in the passage.\n"
 60        "- Do NOT paraphrase, summarise, reword, or generate any new text.\n"
 61        "- Only output sentences that exist verbatim in the passage.\n\n"
 62        "Query: {query}\n\n"
 63        "Passage:\n{content}\n\n"
 64        "Relevant sentences from the passage (verbatim):"
 65    )
 66
 67    def __init__(self, llm_gateway: UnifiedLLMGateway, temperature: float = 0.0):
 68        """
 69        Args:
 70            llm_gateway: LLM gateway for compression.
 71            temperature: Sampling temperature (default 0.0 for deterministic
 72                         extraction). Keep low to avoid hallucination.
 73        """
 74        self.llm_gateway = llm_gateway
 75        self.temperature = temperature
 76
 77    async def compress(
 78        self,
 79        query: str,
 80        results: List[SearchResult],
 81        min_length: int = 20,
 82    ) -> List[SearchResult]:
 83        """
 84        Compress each retrieved passage to query-relevant content only.
 85
 86        Calls the LLM once per result. If the LLM returns content shorter
 87        than `min_length` characters (likely an error), the original passage
 88        is kept unchanged.
 89
 90        Args:
 91            query:      The user query to guide extraction.
 92            results:    List of SearchResult objects to compress.
 93            min_length: Minimum character length for a valid compressed result.
 94                        If compression produces less, the original is kept.
 95
 96        Returns:
 97            New list of SearchResult objects with compressed document content.
 98        """
 99        compressed: List[SearchResult] = []
100        for result in results:
101            compressed_content = await self._compress_one(
102                query, result.document.content, min_length
103            )
104            new_doc = copy.copy(result.document)
105            new_doc.content = compressed_content
106            compressed.append(SearchResult(
107                document=new_doc,
108                score=result.score,
109                rank=result.rank,
110            ))
111        return compressed
112
113    async def _compress_one(
114        self, query: str, content: str, min_length: int
115    ) -> str:
116        prompt = self._COMPRESS_PROMPT.format(query=query, content=content)
117        response = await self.llm_gateway.complete(
118            prompt=prompt,
119            temperature=self.temperature,
120            max_tokens=500,
121        )
122        extracted = response.content.strip().strip('"').strip("'")
123        return extracted if len(extracted) >= min_length else content

Compresses retrieved passages to only the query-relevant sentences.

Calls the LLM once per retrieved document, asking it to extract only the sentences that directly help answer the query. The original SearchResult is preserved — only the content field is replaced with the compressed version. Score and rank are unchanged.

When to use:

  • Chunks are long (>500 tokens) and only partially relevant
  • Context window budget is tight
  • You want to reduce LLM distraction from off-topic content in chunks

Example:

from gmf_forge_ai_data.context import ContextCompressor

compressor = ContextCompressor(llm_gateway)
compressed = await compressor.compress(
    query="What is self-attention in transformers?",
    results=reranked_results,
)
# Each result.document.content now contains only the relevant sentences
ContextCompressor( llm_gateway: gmf_forge_ai_shared_core.llm_gateway.UnifiedLLMGateway, temperature: float = 0.0)
67    def __init__(self, llm_gateway: UnifiedLLMGateway, temperature: float = 0.0):
68        """
69        Args:
70            llm_gateway: LLM gateway for compression.
71            temperature: Sampling temperature (default 0.0 for deterministic
72                         extraction). Keep low to avoid hallucination.
73        """
74        self.llm_gateway = llm_gateway
75        self.temperature = temperature

Args: llm_gateway: LLM gateway for compression. temperature: Sampling temperature (default 0.0 for deterministic extraction). Keep low to avoid hallucination.

llm_gateway
temperature
async def compress( self, query: str, results: List[SearchResult], min_length: int = 20) -> List[SearchResult]:
 77    async def compress(
 78        self,
 79        query: str,
 80        results: List[SearchResult],
 81        min_length: int = 20,
 82    ) -> List[SearchResult]:
 83        """
 84        Compress each retrieved passage to query-relevant content only.
 85
 86        Calls the LLM once per result. If the LLM returns content shorter
 87        than `min_length` characters (likely an error), the original passage
 88        is kept unchanged.
 89
 90        Args:
 91            query:      The user query to guide extraction.
 92            results:    List of SearchResult objects to compress.
 93            min_length: Minimum character length for a valid compressed result.
 94                        If compression produces less, the original is kept.
 95
 96        Returns:
 97            New list of SearchResult objects with compressed document content.
 98        """
 99        compressed: List[SearchResult] = []
100        for result in results:
101            compressed_content = await self._compress_one(
102                query, result.document.content, min_length
103            )
104            new_doc = copy.copy(result.document)
105            new_doc.content = compressed_content
106            compressed.append(SearchResult(
107                document=new_doc,
108                score=result.score,
109                rank=result.rank,
110            ))
111        return compressed

Compress each retrieved passage to query-relevant content only.

Calls the LLM once per result. If the LLM returns content shorter than min_length characters (likely an error), the original passage is kept unchanged.

Args: query: The user query to guide extraction. results: List of SearchResult objects to compress. min_length: Minimum character length for a valid compressed result. If compression produces less, the original is kept.

Returns: New list of SearchResult objects with compressed document content.

class ContextWindowManager:
 42class ContextWindowManager:
 43    """
 44    Fits retrieved documents into a maximum token budget.
 45
 46    Works in two passes:
 47    1. Greedily adds full documents until the budget is exhausted.
 48    2. If `allow_truncation=True`, the last document that didn't fit is
 49       partially included using as many characters as the remaining budget
 50       allows.
 51
 52    Token estimation uses 4 characters ≈ 1 token (standard rule-of-thumb
 53    for English text with GPT tokenizers).
 54
 55    Example:
 56        ```python
 57        from gmf_forge_ai_data.context import ContextWindowManager
 58
 59        manager = ContextWindowManager(max_tokens=2000)
 60        window = manager.fit(reranked_and_compressed_results)
 61
 62        print(f"Using {window.total_tokens} / {window.budget} tokens")
 63        print(f"Docs included: {len(window.results)}, truncated: {window.truncated}")
 64
 65        # Build the prompt context string
 66        context_str = "\\n\\n".join(r.document.content for r in window.results)
 67        ```
 68    """
 69
 70    _CHARS_PER_TOKEN: float = 4.0
 71
 72    def __init__(
 73        self,
 74        max_tokens: int = 3000,
 75        allow_truncation: bool = True,
 76    ):
 77        """
 78        Args:
 79            max_tokens:        Maximum number of tokens for all document content
 80                               combined. Does not include the prompt template
 81                               overhead — subtract your template size from the
 82                               model's context limit before setting this.
 83            allow_truncation:  If True (default), the last document that would
 84                               overflow is truncated to fit remaining space.
 85                               If False, it is dropped entirely.
 86        """
 87        if max_tokens <= 0:
 88            raise ValueError(f"max_tokens must be positive, got {max_tokens}")
 89        self.max_tokens = max_tokens
 90        self.allow_truncation = allow_truncation
 91
 92    def fit(self, results: List[SearchResult]) -> WindowedContext:
 93        """
 94        Select and optionally truncate documents to fit within the token budget.
 95
 96        Args:
 97            results: List of SearchResult objects, ordered by priority
 98                     (highest priority first — e.g. after reranking).
 99
100        Returns:
101            WindowedContext with the selected results and budget metadata.
102        """
103        kept: List[SearchResult] = []
104        tokens_used = 0
105        truncated = False
106        dropped = 0
107
108        for result in results:
109            doc_tokens = self._estimate_tokens(result.document.content)
110            remaining = self.max_tokens - tokens_used
111
112            if doc_tokens <= remaining:
113                kept.append(result)
114                tokens_used += doc_tokens
115            elif self.allow_truncation and remaining > 0 and not truncated:
116                # Partially include this document to fill remaining budget
117                max_chars = int(remaining * self._CHARS_PER_TOKEN)
118                truncated_content = result.document.content[:max_chars]
119                new_doc = copy.copy(result.document)
120                new_doc.content = truncated_content
121                kept.append(SearchResult(
122                    document=new_doc,
123                    score=result.score,
124                    rank=result.rank,
125                ))
126                tokens_used += self._estimate_tokens(truncated_content)
127                truncated = True
128            else:
129                dropped += 1
130
131        return WindowedContext(
132            results=kept,
133            total_tokens=tokens_used,
134            budget=self.max_tokens,
135            truncated=truncated,
136            dropped=dropped,
137        )
138
139    def _estimate_tokens(self, text: str) -> int:
140        """Estimate token count using 4 characters-per-token heuristic."""
141        return max(1, round(len(text) / self._CHARS_PER_TOKEN))

Fits retrieved documents into a maximum token budget.

Works in two passes:

  1. Greedily adds full documents until the budget is exhausted.
  2. If allow_truncation=True, the last document that didn't fit is partially included using as many characters as the remaining budget allows.

Token estimation uses 4 characters ≈ 1 token (standard rule-of-thumb for English text with GPT tokenizers).

Example:

from gmf_forge_ai_data.context import ContextWindowManager

manager = ContextWindowManager(max_tokens=2000)
window = manager.fit(reranked_and_compressed_results)

print(f"Using {window.total_tokens} / {window.budget} tokens")
print(f"Docs included: {len(window.results)}, truncated: {window.truncated}")

# Build the prompt context string
context_str = "\n\n".join(r.document.content for r in window.results)
ContextWindowManager(max_tokens: int = 3000, allow_truncation: bool = True)
72    def __init__(
73        self,
74        max_tokens: int = 3000,
75        allow_truncation: bool = True,
76    ):
77        """
78        Args:
79            max_tokens:        Maximum number of tokens for all document content
80                               combined. Does not include the prompt template
81                               overhead — subtract your template size from the
82                               model's context limit before setting this.
83            allow_truncation:  If True (default), the last document that would
84                               overflow is truncated to fit remaining space.
85                               If False, it is dropped entirely.
86        """
87        if max_tokens <= 0:
88            raise ValueError(f"max_tokens must be positive, got {max_tokens}")
89        self.max_tokens = max_tokens
90        self.allow_truncation = allow_truncation

Args: max_tokens: Maximum number of tokens for all document content combined. Does not include the prompt template overhead — subtract your template size from the model's context limit before setting this. allow_truncation: If True (default), the last document that would overflow is truncated to fit remaining space. If False, it is dropped entirely.

max_tokens
allow_truncation
def fit( self, results: List[SearchResult]) -> WindowedContext:
 92    def fit(self, results: List[SearchResult]) -> WindowedContext:
 93        """
 94        Select and optionally truncate documents to fit within the token budget.
 95
 96        Args:
 97            results: List of SearchResult objects, ordered by priority
 98                     (highest priority first — e.g. after reranking).
 99
100        Returns:
101            WindowedContext with the selected results and budget metadata.
102        """
103        kept: List[SearchResult] = []
104        tokens_used = 0
105        truncated = False
106        dropped = 0
107
108        for result in results:
109            doc_tokens = self._estimate_tokens(result.document.content)
110            remaining = self.max_tokens - tokens_used
111
112            if doc_tokens <= remaining:
113                kept.append(result)
114                tokens_used += doc_tokens
115            elif self.allow_truncation and remaining > 0 and not truncated:
116                # Partially include this document to fill remaining budget
117                max_chars = int(remaining * self._CHARS_PER_TOKEN)
118                truncated_content = result.document.content[:max_chars]
119                new_doc = copy.copy(result.document)
120                new_doc.content = truncated_content
121                kept.append(SearchResult(
122                    document=new_doc,
123                    score=result.score,
124                    rank=result.rank,
125                ))
126                tokens_used += self._estimate_tokens(truncated_content)
127                truncated = True
128            else:
129                dropped += 1
130
131        return WindowedContext(
132            results=kept,
133            total_tokens=tokens_used,
134            budget=self.max_tokens,
135            truncated=truncated,
136            dropped=dropped,
137        )

Select and optionally truncate documents to fit within the token budget.

Args: results: List of SearchResult objects, ordered by priority (highest priority first — e.g. after reranking).

Returns: WindowedContext with the selected results and budget metadata.

@dataclass
class WindowedContext:
23@dataclass
24class WindowedContext:
25    """
26    Output of ContextWindowManager.fit().
27
28    Attributes:
29        results:      Documents that fit within the token budget, in order.
30        total_tokens: Estimated token count of all included content.
31        budget:       The configured maximum token budget.
32        truncated:    True if the last document's content was truncated to fit.
33        dropped:      Number of documents that did not fit at all.
34    """
35    results: List[SearchResult]
36    total_tokens: int
37    budget: int
38    truncated: bool = False
39    dropped: int = 0

Output of ContextWindowManager.fit().

Attributes: results: Documents that fit within the token budget, in order. total_tokens: Estimated token count of all included content. budget: The configured maximum token budget. truncated: True if the last document's content was truncated to fit. dropped: Number of documents that did not fit at all.

WindowedContext( results: List[SearchResult], total_tokens: int, budget: int, truncated: bool = False, dropped: int = 0)
results: List[SearchResult]
total_tokens: int
budget: int
truncated: bool = False
dropped: int = 0
class QueryDecomposer:
 32class QueryDecomposer:
 33    """
 34    Decomposes complex multi-part queries into focused sub-queries using LLM.
 35
 36    Multi-part queries ("What are the antitrust laws and what cases were filed in 2024?")
 37    are split into individual queries for better retrieval precision per component.
 38    The sub-queries can then be run in parallel with a retriever and results merged,
 39    similar to query expansion but targeting distinct question atoms rather than synonyms.
 40
 41    Example:
 42        ```python
 43        from gmf_forge_ai_data.query import QueryDecomposer
 44        from gmf_forge_ai_shared_core.llm_gateway import UnifiedLLMGateway
 45
 46        gateway = UnifiedLLMGateway(default_provider=azure_provider)
 47        decomposer = QueryDecomposer(gateway)
 48
 49        result = await decomposer.decompose(
 50            "What are the antitrust laws and what cases were filed in 2024?"
 51        )
 52        # result.sub_queries = [
 53        #   "What are the antitrust laws?",
 54        #   "What cases were filed in 2024?",
 55        # ]
 56        ```
 57    """
 58
 59    _DECOMPOSE_PROMPT = (
 60        "You are a query decomposition assistant for a retrieval system.\n\n"
 61        "Break the following complex query into {max_sub_queries} or fewer "
 62        "focused sub-queries.\n"
 63        "Each sub-query must be self-contained and independently answerable.\n"
 64        "Return ONLY a numbered list, one sub-query per line. "
 65        "Do not add explanations.\n\n"
 66        "Query: {query}\n\n"
 67        "Sub-queries:"
 68    )
 69
 70    def __init__(self, llm_gateway: UnifiedLLMGateway, temperature: float = 0.0):
 71        """
 72        Args:
 73            llm_gateway: LLM gateway used for intelligent decomposition.
 74            temperature: Sampling temperature passed to the LLM (default 0.0 for
 75                         deterministic decomposition). Raise slightly (e.g. 0.2)
 76                         to get more varied sub-query boundaries.
 77        """
 78        self.llm_gateway = llm_gateway
 79        self.temperature = temperature
 80
 81    async def decompose(
 82        self,
 83        query: str,
 84        max_sub_queries: int = 3,
 85    ) -> DecomposedQuery:
 86        """
 87        Decompose a complex query into focused sub-queries using LLM.
 88
 89        Args:
 90            query:           The complex query to break apart.
 91            max_sub_queries: Maximum number of sub-queries to produce.
 92
 93        Returns:
 94            DecomposedQuery containing the original and list of sub-queries.
 95        """
 96        prompt = self._DECOMPOSE_PROMPT.format(
 97            query=query,
 98            max_sub_queries=max_sub_queries,
 99        )
100
101        response = await self.llm_gateway.complete(
102            prompt=prompt,
103            temperature=self.temperature,
104            max_tokens=300,
105        )
106
107        sub_queries = self._parse_numbered_list(response.content)
108
109        if not sub_queries:
110            return DecomposedQuery(
111                original=query,
112                sub_queries=[query],
113                reasoning=response.content,
114            )
115
116        return DecomposedQuery(
117            original=query,
118            sub_queries=sub_queries[:max_sub_queries],
119            reasoning=response.content,
120        )
121
122    @staticmethod
123    def _parse_numbered_list(text: str) -> List[str]:
124        """Parse '1. item\\n2. item', '1) item', '- item', '* item' from LLM output."""
125        lines = text.strip().split("\n")
126        results: List[str] = []
127        for line in lines:
128            match = re.match(r"^\s*(?:\d+[.)]\s*|[-*]\s*)(.+)", line)
129            if match:
130                results.append(match.group(1).strip())
131        return results

Decomposes complex multi-part queries into focused sub-queries using LLM.

Multi-part queries ("What are the antitrust laws and what cases were filed in 2024?") are split into individual queries for better retrieval precision per component. The sub-queries can then be run in parallel with a retriever and results merged, similar to query expansion but targeting distinct question atoms rather than synonyms.

Example:

from gmf_forge_ai_data.query import QueryDecomposer
from gmf_forge_ai_shared_core.llm_gateway import UnifiedLLMGateway

gateway = UnifiedLLMGateway(default_provider=azure_provider)
decomposer = QueryDecomposer(gateway)

result = await decomposer.decompose(
    "What are the antitrust laws and what cases were filed in 2024?"
)
# result.sub_queries = [
#   "What are the antitrust laws?",
#   "What cases were filed in 2024?",
# ]
QueryDecomposer( llm_gateway: gmf_forge_ai_shared_core.llm_gateway.UnifiedLLMGateway, temperature: float = 0.0)
70    def __init__(self, llm_gateway: UnifiedLLMGateway, temperature: float = 0.0):
71        """
72        Args:
73            llm_gateway: LLM gateway used for intelligent decomposition.
74            temperature: Sampling temperature passed to the LLM (default 0.0 for
75                         deterministic decomposition). Raise slightly (e.g. 0.2)
76                         to get more varied sub-query boundaries.
77        """
78        self.llm_gateway = llm_gateway
79        self.temperature = temperature

Args: llm_gateway: LLM gateway used for intelligent decomposition. temperature: Sampling temperature passed to the LLM (default 0.0 for deterministic decomposition). Raise slightly (e.g. 0.2) to get more varied sub-query boundaries.

llm_gateway
temperature
async def decompose( self, query: str, max_sub_queries: int = 3) -> DecomposedQuery:
 81    async def decompose(
 82        self,
 83        query: str,
 84        max_sub_queries: int = 3,
 85    ) -> DecomposedQuery:
 86        """
 87        Decompose a complex query into focused sub-queries using LLM.
 88
 89        Args:
 90            query:           The complex query to break apart.
 91            max_sub_queries: Maximum number of sub-queries to produce.
 92
 93        Returns:
 94            DecomposedQuery containing the original and list of sub-queries.
 95        """
 96        prompt = self._DECOMPOSE_PROMPT.format(
 97            query=query,
 98            max_sub_queries=max_sub_queries,
 99        )
100
101        response = await self.llm_gateway.complete(
102            prompt=prompt,
103            temperature=self.temperature,
104            max_tokens=300,
105        )
106
107        sub_queries = self._parse_numbered_list(response.content)
108
109        if not sub_queries:
110            return DecomposedQuery(
111                original=query,
112                sub_queries=[query],
113                reasoning=response.content,
114            )
115
116        return DecomposedQuery(
117            original=query,
118            sub_queries=sub_queries[:max_sub_queries],
119            reasoning=response.content,
120        )

Decompose a complex query into focused sub-queries using LLM.

Args: query: The complex query to break apart. max_sub_queries: Maximum number of sub-queries to produce.

Returns: DecomposedQuery containing the original and list of sub-queries.

@dataclass
class DecomposedQuery:
17@dataclass
18class DecomposedQuery:
19    """
20    Result of query decomposition.
21
22    Attributes:
23        original:    The original complex query string.
24        sub_queries: List of focused sub-queries derived from the original.
25        reasoning:   Raw LLM response explaining the decomposition.
26    """
27    original: str
28    sub_queries: List[str]
29    reasoning: Optional[str] = None

Result of query decomposition.

Attributes: original: The original complex query string. sub_queries: List of focused sub-queries derived from the original. reasoning: Raw LLM response explaining the decomposition.

DecomposedQuery( original: str, sub_queries: List[str], reasoning: Optional[str] = None)
original: str
sub_queries: List[str]
reasoning: Optional[str] = None
class QueryRouter:
 37class QueryRouter:
 38    """
 39    Routes queries to the appropriate retriever or index using LLM.
 40
 41    Each route has a name and a plain-English description of its content.
 42    The LLM selects the best-matching route for each incoming query.
 43
 44    Typical use in a multi-index RAG system: create one route per Azure AI
 45    Search index and let the router automatically direct queries without
 46    searching all indexes every time.
 47
 48    Example:
 49        ```python
 50        from gmf_forge_ai_data.query import QueryRouter
 51
 52        routes = {
 53            "legal_documents":   "Legal cases, court decisions, jurisdiction, antitrust, patent",
 54            "products":          "Products, prices, inventory, electronics, furniture, camera",
 55            "financial_reports": "Earnings, revenue, fiscal year, company financials, SEC filings",
 56            "ai_ml_knowledge":   "Machine learning, AI, neural networks, deep learning, NLP",
 57        }
 58
 59        router = QueryRouter(routes=routes, llm_gateway=gateway)
 60        decision = await router.route("What antitrust cases were filed in 2024?")
 61        # decision.target = "legal_documents"
 62        # decision.confidence = 0.9
 63        ```
 64    """
 65
 66    _ROUTE_PROMPT = (
 67        "You are a query routing assistant for a multi-domain retrieval system.\n\n"
 68        "Available indexes and what they contain:\n"
 69        "{routes_description}\n\n"
 70        "Given the user query below, output ONLY the name of the single best index "
 71        "to search. Do not add any explanation or punctuation.\n\n"
 72        "Query: {query}\n\n"
 73        "Best index:"
 74    )
 75
 76    def __init__(
 77        self,
 78        routes: Dict[str, str],
 79        llm_gateway: UnifiedLLMGateway,
 80        temperature: float = 0.0,
 81    ):
 82        """
 83        Args:
 84            routes:      Dict mapping route name → plain-English description of content.
 85            llm_gateway: LLM gateway for intelligent routing.
 86            temperature: Sampling temperature passed to the LLM (default 0.0 for
 87                         deterministic routing). Keep low — routing should be consistent.
 88        """
 89        self.routes = routes
 90        self.llm_gateway = llm_gateway
 91        self.temperature = temperature
 92
 93    async def route(self, query: str) -> RouteDecision:
 94        """
 95        Route a query to the best-matching index using LLM.
 96
 97        Args:
 98            query: The user query to route.
 99
100        Returns:
101            RouteDecision with the chosen target and confidence score.
102
103        Raises:
104            ValueError: If the LLM returns an unknown route name.
105        """
106        routes_description = "\n".join(
107            f"- {name}: {desc}" for name, desc in self.routes.items()
108        )
109        prompt = self._ROUTE_PROMPT.format(
110            routes_description=routes_description,
111            query=query,
112        )
113
114        response = await self.llm_gateway.complete(
115            prompt=prompt,
116            temperature=self.temperature,
117            max_tokens=50,
118        )
119
120        target = response.content.strip().strip('"').strip("'")
121
122        if target not in self.routes:
123            raise ValueError(
124                f"LLM returned unknown route '{target}'. "
125                f"Valid routes: {list(self.routes.keys())}"
126            )
127
128        alternatives = [(name, 0.0) for name in self.routes if name != target]
129
130        return RouteDecision(
131            query=query,
132            target=target,
133            confidence=0.9,
134            reasoning=response.content,
135            alternatives=alternatives,
136        )

Routes queries to the appropriate retriever or index using LLM.

Each route has a name and a plain-English description of its content. The LLM selects the best-matching route for each incoming query.

Typical use in a multi-index RAG system: create one route per Azure AI Search index and let the router automatically direct queries without searching all indexes every time.

Example:

from gmf_forge_ai_data.query import QueryRouter

routes = {
    "legal_documents":   "Legal cases, court decisions, jurisdiction, antitrust, patent",
    "products":          "Products, prices, inventory, electronics, furniture, camera",
    "financial_reports": "Earnings, revenue, fiscal year, company financials, SEC filings",
    "ai_ml_knowledge":   "Machine learning, AI, neural networks, deep learning, NLP",
}

router = QueryRouter(routes=routes, llm_gateway=gateway)
decision = await router.route("What antitrust cases were filed in 2024?")
# decision.target = "legal_documents"
# decision.confidence = 0.9
QueryRouter( routes: Dict[str, str], llm_gateway: gmf_forge_ai_shared_core.llm_gateway.UnifiedLLMGateway, temperature: float = 0.0)
76    def __init__(
77        self,
78        routes: Dict[str, str],
79        llm_gateway: UnifiedLLMGateway,
80        temperature: float = 0.0,
81    ):
82        """
83        Args:
84            routes:      Dict mapping route name → plain-English description of content.
85            llm_gateway: LLM gateway for intelligent routing.
86            temperature: Sampling temperature passed to the LLM (default 0.0 for
87                         deterministic routing). Keep low — routing should be consistent.
88        """
89        self.routes = routes
90        self.llm_gateway = llm_gateway
91        self.temperature = temperature

Args: routes: Dict mapping route name → plain-English description of content. llm_gateway: LLM gateway for intelligent routing. temperature: Sampling temperature passed to the LLM (default 0.0 for deterministic routing). Keep low — routing should be consistent.

routes
llm_gateway
temperature
async def route(self, query: str) -> RouteDecision:
 93    async def route(self, query: str) -> RouteDecision:
 94        """
 95        Route a query to the best-matching index using LLM.
 96
 97        Args:
 98            query: The user query to route.
 99
100        Returns:
101            RouteDecision with the chosen target and confidence score.
102
103        Raises:
104            ValueError: If the LLM returns an unknown route name.
105        """
106        routes_description = "\n".join(
107            f"- {name}: {desc}" for name, desc in self.routes.items()
108        )
109        prompt = self._ROUTE_PROMPT.format(
110            routes_description=routes_description,
111            query=query,
112        )
113
114        response = await self.llm_gateway.complete(
115            prompt=prompt,
116            temperature=self.temperature,
117            max_tokens=50,
118        )
119
120        target = response.content.strip().strip('"').strip("'")
121
122        if target not in self.routes:
123            raise ValueError(
124                f"LLM returned unknown route '{target}'. "
125                f"Valid routes: {list(self.routes.keys())}"
126            )
127
128        alternatives = [(name, 0.0) for name in self.routes if name != target]
129
130        return RouteDecision(
131            query=query,
132            target=target,
133            confidence=0.9,
134            reasoning=response.content,
135            alternatives=alternatives,
136        )

Route a query to the best-matching index using LLM.

Args: query: The user query to route.

Returns: RouteDecision with the chosen target and confidence score.

Raises: ValueError: If the LLM returns an unknown route name.

@dataclass
class RouteDecision:
18@dataclass
19class RouteDecision:
20    """
21    Result of query routing.
22
23    Attributes:
24        query:        The original query string.
25        target:       Name of the chosen route (retriever or index).
26        confidence:   Confidence score in [0, 1] for the chosen route.
27        reasoning:    Raw LLM output.
28        alternatives: Other routes with placeholder confidence scores.
29    """
30    query: str
31    target: str
32    confidence: float
33    reasoning: Optional[str] = None
34    alternatives: List[Tuple[str, float]] = field(default_factory=list)

Result of query routing.

Attributes: query: The original query string. target: Name of the chosen route (retriever or index). confidence: Confidence score in [0, 1] for the chosen route. reasoning: Raw LLM output. alternatives: Other routes with placeholder confidence scores.

RouteDecision( query: str, target: str, confidence: float, reasoning: Optional[str] = None, alternatives: List[Tuple[str, float]] = <factory>)
query: str
target: str
confidence: float
reasoning: Optional[str] = None
alternatives: List[Tuple[str, float]]
class QueryExpander:
 34class QueryExpander:
 35    """
 36    Generates query variations to improve retrieval recall using LLM.
 37
 38    Uses an LLM to produce semantically equivalent re-phrasings of the original
 39    query. Expanded queries are intended to run in parallel with the original query
 40    via separate retriever calls, then merged with Reciprocal Rank Fusion (RRF)
 41    using EnsembleRetriever for best results.
 42
 43    Example:
 44        ```python
 45        from gmf_forge_ai_data.query import QueryExpander
 46
 47        expander = QueryExpander(llm_gateway)
 48        result = await expander.expand("antitrust violations", num_expansions=3)
 49        # result.expansions = [
 50        #   "competition law breaches",
 51        #   "monopoly infringement cases",
 52        #   "anti-competitive conduct",
 53        # ]
 54        ```
 55    """
 56
 57    _EXPAND_PROMPT = (
 58        "You are a search query expansion assistant.\n\n"
 59        "Generate {num_expansions} alternative phrasings for the search query below.\n"
 60        "Use synonyms, related terms, and different wording that conveys the same intent.\n"
 61        "Return ONLY a numbered list, one variation per line. "
 62        "Do NOT repeat the original query.\n\n"
 63        "Original query: {query}\n\n"
 64        "Alternative phrasings:"
 65    )
 66
 67    def __init__(self, llm_gateway: UnifiedLLMGateway, temperature: float = 0.3):
 68        """
 69        Args:
 70            llm_gateway: LLM gateway for generating query variations.
 71            temperature: Sampling temperature passed to the LLM (default 0.3 for
 72                         creative variation). Raise toward 0.7 for more diverse
 73                         phrasings; lower toward 0.0 for tighter paraphrases.
 74        """
 75        self.llm_gateway = llm_gateway
 76        self.temperature = temperature
 77
 78    async def expand(
 79        self,
 80        query: str,
 81        num_expansions: int = 3,
 82    ) -> ExpandedQuery:
 83        """
 84        Expand a query into multiple variations using LLM.
 85
 86        Args:
 87            query:           The original query to expand.
 88            num_expansions:  Number of alternative phrasings to generate.
 89
 90        Returns:
 91            ExpandedQuery with original and list of variation strings.
 92        """
 93        prompt = self._EXPAND_PROMPT.format(
 94            query=query,
 95            num_expansions=num_expansions,
 96        )
 97
 98        response = await self.llm_gateway.complete(
 99            prompt=prompt,
100            temperature=self.temperature,
101            max_tokens=300,
102        )
103
104        expansions = self._parse_numbered_list(response.content)
105
106        return ExpandedQuery(
107            original=query,
108            expansions=expansions[:num_expansions],
109        )
110
111    @staticmethod
112    def _parse_numbered_list(text: str) -> List[str]:
113        """Parse '1. item\\n2. item', '1) item', '- item', '* item' from LLM output."""
114        lines = text.strip().split("\n")
115        results: List[str] = []
116        for line in lines:
117            match = re.match(r"^\s*(?:\d+[.)]\s*|[-*]\s*)(.+)", line)
118            if match:
119                results.append(match.group(1).strip())
120        return results

Generates query variations to improve retrieval recall using LLM.

Uses an LLM to produce semantically equivalent re-phrasings of the original query. Expanded queries are intended to run in parallel with the original query via separate retriever calls, then merged with Reciprocal Rank Fusion (RRF) using EnsembleRetriever for best results.

Example:

from gmf_forge_ai_data.query import QueryExpander

expander = QueryExpander(llm_gateway)
result = await expander.expand("antitrust violations", num_expansions=3)
# result.expansions = [
#   "competition law breaches",
#   "monopoly infringement cases",
#   "anti-competitive conduct",
# ]
QueryExpander( llm_gateway: gmf_forge_ai_shared_core.llm_gateway.UnifiedLLMGateway, temperature: float = 0.3)
67    def __init__(self, llm_gateway: UnifiedLLMGateway, temperature: float = 0.3):
68        """
69        Args:
70            llm_gateway: LLM gateway for generating query variations.
71            temperature: Sampling temperature passed to the LLM (default 0.3 for
72                         creative variation). Raise toward 0.7 for more diverse
73                         phrasings; lower toward 0.0 for tighter paraphrases.
74        """
75        self.llm_gateway = llm_gateway
76        self.temperature = temperature

Args: llm_gateway: LLM gateway for generating query variations. temperature: Sampling temperature passed to the LLM (default 0.3 for creative variation). Raise toward 0.7 for more diverse phrasings; lower toward 0.0 for tighter paraphrases.

llm_gateway
temperature
async def expand( self, query: str, num_expansions: int = 3) -> ExpandedQuery:
 78    async def expand(
 79        self,
 80        query: str,
 81        num_expansions: int = 3,
 82    ) -> ExpandedQuery:
 83        """
 84        Expand a query into multiple variations using LLM.
 85
 86        Args:
 87            query:           The original query to expand.
 88            num_expansions:  Number of alternative phrasings to generate.
 89
 90        Returns:
 91            ExpandedQuery with original and list of variation strings.
 92        """
 93        prompt = self._EXPAND_PROMPT.format(
 94            query=query,
 95            num_expansions=num_expansions,
 96        )
 97
 98        response = await self.llm_gateway.complete(
 99            prompt=prompt,
100            temperature=self.temperature,
101            max_tokens=300,
102        )
103
104        expansions = self._parse_numbered_list(response.content)
105
106        return ExpandedQuery(
107            original=query,
108            expansions=expansions[:num_expansions],
109        )

Expand a query into multiple variations using LLM.

Args: query: The original query to expand. num_expansions: Number of alternative phrasings to generate.

Returns: ExpandedQuery with original and list of variation strings.

@dataclass
class ExpandedQuery:
21@dataclass
22class ExpandedQuery:
23    """
24    Result of query expansion.
25
26    Attributes:
27        original:   The original query string (not included in expansions list).
28        expansions: Alternative phrasings — run alongside the original query.
29    """
30    original: str
31    expansions: List[str]

Result of query expansion.

Attributes: original: The original query string (not included in expansions list). expansions: Alternative phrasings — run alongside the original query.

ExpandedQuery(original: str, expansions: List[str])
original: str
expansions: List[str]
class QueryRewriter:
 31class QueryRewriter:
 32    """
 33    Improves query quality before retrieval using LLM.
 34
 35    Handles:
 36    - Grammar and spelling fixes
 37    - Replacement of vague terms with specific, domain-appropriate ones
 38    - Removal of conversational filler ("tell me about", "can you find")
 39    - Clarification of ambiguous intent using optional domain context
 40
 41    Example:
 42        ```python
 43        from gmf_forge_ai_data.query import QueryRewriter
 44
 45        rewriter = QueryRewriter(llm_gateway)
 46
 47        result = await rewriter.rewrite(
 48            "tell me the stuff about that apple patent thing",
 49            context="legal documents database"
 50        )
 51        # result.rewritten = "Apple Inc. patent infringement case details"
 52        # result.changes   = ["LLM rewrote: '...' → '...'"]
 53        ```
 54    """
 55
 56    _REWRITE_PROMPT = (
 57        "You are a search query optimization assistant for a document retrieval system.\n\n"
 58        "Rewrite the following query to make it more precise and effective for retrieval:\n"
 59        "- Fix grammar and spelling errors\n"
 60        "- Replace vague or colloquial terms with specific, domain-appropriate ones\n"
 61        "- Remove conversational filler (e.g., 'tell me about', 'can you find')\n"
 62        "- Preserve the original semantic intent\n"
 63        "- Return ONLY the rewritten query — no explanation, no extra text\n\n"
 64        "{context_line}"
 65        "Query: {query}\n\n"
 66        "Rewritten query:"
 67    )
 68
 69    def __init__(self, llm_gateway: UnifiedLLMGateway, temperature: float = 0.0):
 70        """
 71        Args:
 72            llm_gateway: LLM gateway for intelligent query rewriting.
 73            temperature: Sampling temperature passed to the LLM (default 0.0 for
 74                         deterministic rewrites). Keep low — rewriting should
 75                         produce consistent, reproducible output.
 76        """
 77        self.llm_gateway = llm_gateway
 78        self.temperature = temperature
 79
 80    async def rewrite(
 81        self,
 82        query: str,
 83        context: Optional[str] = None,
 84    ) -> RewrittenQuery:
 85        """
 86        Rewrite a query for better retrieval using LLM.
 87
 88        Args:
 89            query:   The original query to improve.
 90            context: Optional domain hint passed to the LLM
 91                     (e.g., "legal documents", "financial filings database").
 92
 93        Returns:
 94            RewrittenQuery with improved text and list of changes made.
 95        """
 96        context_line = f"Domain context: {context}\n\n" if context else ""
 97        prompt = self._REWRITE_PROMPT.format(
 98            query=query,
 99            context_line=context_line,
100        )
101
102        response = await self.llm_gateway.complete(
103            prompt=prompt,
104            temperature=self.temperature,
105            max_tokens=150,
106        )
107
108        rewritten = response.content.strip().strip('"').strip("'")
109
110        if not rewritten or rewritten.lower() == query.lower():
111            return RewrittenQuery(
112                original=query,
113                rewritten=query,
114                changes=["No rewrite needed"],
115            )
116
117        return RewrittenQuery(
118            original=query,
119            rewritten=rewritten,
120            changes=[f"LLM rewrote: '{query}' → '{rewritten}'"],
121        )

Improves query quality before retrieval using LLM.

Handles:

  • Grammar and spelling fixes
  • Replacement of vague terms with specific, domain-appropriate ones
  • Removal of conversational filler ("tell me about", "can you find")
  • Clarification of ambiguous intent using optional domain context

Example:

from gmf_forge_ai_data.query import QueryRewriter

rewriter = QueryRewriter(llm_gateway)

result = await rewriter.rewrite(
    "tell me the stuff about that apple patent thing",
    context="legal documents database"
)
# result.rewritten = "Apple Inc. patent infringement case details"
# result.changes   = ["LLM rewrote: '...' → '...'"]
QueryRewriter( llm_gateway: gmf_forge_ai_shared_core.llm_gateway.UnifiedLLMGateway, temperature: float = 0.0)
69    def __init__(self, llm_gateway: UnifiedLLMGateway, temperature: float = 0.0):
70        """
71        Args:
72            llm_gateway: LLM gateway for intelligent query rewriting.
73            temperature: Sampling temperature passed to the LLM (default 0.0 for
74                         deterministic rewrites). Keep low — rewriting should
75                         produce consistent, reproducible output.
76        """
77        self.llm_gateway = llm_gateway
78        self.temperature = temperature

Args: llm_gateway: LLM gateway for intelligent query rewriting. temperature: Sampling temperature passed to the LLM (default 0.0 for deterministic rewrites). Keep low — rewriting should produce consistent, reproducible output.

llm_gateway
temperature
async def rewrite( self, query: str, context: Optional[str] = None) -> RewrittenQuery:
 80    async def rewrite(
 81        self,
 82        query: str,
 83        context: Optional[str] = None,
 84    ) -> RewrittenQuery:
 85        """
 86        Rewrite a query for better retrieval using LLM.
 87
 88        Args:
 89            query:   The original query to improve.
 90            context: Optional domain hint passed to the LLM
 91                     (e.g., "legal documents", "financial filings database").
 92
 93        Returns:
 94            RewrittenQuery with improved text and list of changes made.
 95        """
 96        context_line = f"Domain context: {context}\n\n" if context else ""
 97        prompt = self._REWRITE_PROMPT.format(
 98            query=query,
 99            context_line=context_line,
100        )
101
102        response = await self.llm_gateway.complete(
103            prompt=prompt,
104            temperature=self.temperature,
105            max_tokens=150,
106        )
107
108        rewritten = response.content.strip().strip('"').strip("'")
109
110        if not rewritten or rewritten.lower() == query.lower():
111            return RewrittenQuery(
112                original=query,
113                rewritten=query,
114                changes=["No rewrite needed"],
115            )
116
117        return RewrittenQuery(
118            original=query,
119            rewritten=rewritten,
120            changes=[f"LLM rewrote: '{query}' → '{rewritten}'"],
121        )

Rewrite a query for better retrieval using LLM.

Args: query: The original query to improve. context: Optional domain hint passed to the LLM (e.g., "legal documents", "financial filings database").

Returns: RewrittenQuery with improved text and list of changes made.

@dataclass
class RewrittenQuery:
16@dataclass
17class RewrittenQuery:
18    """
19    Result of query rewriting.
20
21    Attributes:
22        original:  The original query string before rewriting.
23        rewritten: The improved query string after rewriting.
24        changes:   Human-readable list of transformations applied.
25    """
26    original: str
27    rewritten: str
28    changes: List[str] = field(default_factory=list)

Result of query rewriting.

Attributes: original: The original query string before rewriting. rewritten: The improved query string after rewriting. changes: Human-readable list of transformations applied.

RewrittenQuery(original: str, rewritten: str, changes: List[str] = <factory>)
original: str
rewritten: str
changes: List[str]
class HyDEGenerator:
 41class HyDEGenerator:
 42    """
 43    Hypothetical Document Embeddings (HyDE) generator.
 44
 45    Why this works:
 46    ---------------
 47    Short query strings ("antitrust cases 2024") and full answer passages live
 48    in very different regions of an embedding space. A hypothetical passage that
 49    ANSWERS the query occupies the same region as real answer documents, so
 50    cosine similarity between the HyDE embedding and indexed document embeddings
 51    is substantially higher than query-vs-document similarity.
 52
 53    Usage pattern:
 54    --------------
 55    1. Call generate_and_embed(query) → HypotheticalDocument (with embedding set).
 56    2. Feed the embedding into VectorRetriever via RetrievalQuery(embedding=...).
 57    3. Compare results against standard VectorRetriever on the same query.
 58
 59    Example:
 60        ```python
 61        from gmf_forge_ai_data.query import HyDEGenerator
 62        from gmf_forge_ai_data.retrieval import VectorRetriever, RetrievalQuery
 63
 64        hyde = HyDEGenerator(llm_gateway=gateway, embedder=embedder)
 65
 66        # Generate hypothetical doc and embed it
 67        hypo = await hyde.generate_and_embed(
 68            "What are the penalties for antitrust violations?",
 69            domain="legal documents"
 70        )
 71
 72        # Use HyDE embedding for retrieval
 73        query = RetrievalQuery(embedding=hypo.embedding, top_k=5)
 74        results = vector_retriever.retrieve(query)
 75        ```
 76    """
 77
 78    _HYDE_PROMPT = (
 79        "Write a concise, authoritative passage that directly answers the question below.\n"
 80        "Write it as if it were an excerpt from a reference document or knowledge base.\n"
 81        "{domain_line}"
 82        "Keep the passage under 150 words. "
 83        "Do not include meta-commentary or mention that this is hypothetical.\n\n"
 84        "Question: {query}\n\n"
 85        "Passage:"
 86    )
 87
 88    def __init__(
 89        self,
 90        llm_gateway: UnifiedLLMGateway,
 91        embedder: Optional[EmbeddingProvider] = None,
 92    ):
 93        """
 94        Initialize the HyDE generator.
 95
 96        Args:
 97            llm_gateway: LLM gateway used to generate the hypothetical document.
 98            embedder:    Embedding provider used to vectorize the hypothetical doc.
 99                         Required only for generate_and_embed(); optional for generate().
100        """
101        self.llm_gateway = llm_gateway
102        self.embedder = embedder
103
104    async def generate(
105        self,
106        query: str,
107        domain: Optional[str] = None,
108    ) -> HypotheticalDocument:
109        """
110        Generate a hypothetical document that would answer the query.
111
112        The returned HypotheticalDocument has embedding=None. Call
113        generate_and_embed() to also produce a vector in one step.
114
115        Args:
116            query:  The retrieval query to generate a passage for.
117            domain: Optional domain hint to guide the LLM style
118                    (e.g., "legal documents", "financial reports", "AI/ML knowledge base").
119
120        Returns:
121            HypotheticalDocument with hypothetical_doc text, embedding=None.
122        """
123        domain_line = f"Domain: {domain}\n" if domain else ""
124        prompt = self._HYDE_PROMPT.format(query=query, domain_line=domain_line)
125
126        response = await self.llm_gateway.complete(
127            prompt=prompt,
128            temperature=0.5,
129            max_tokens=200,
130        )
131
132        return HypotheticalDocument(
133            query=query,
134            hypothetical_doc=response.content.strip(),
135            domain=domain,
136        )
137
138    async def generate_and_embed(
139        self,
140        query: str,
141        domain: Optional[str] = None,
142    ) -> HypotheticalDocument:
143        """
144        Generate a hypothetical document and embed it in a single step.
145
146        Calls generate() then uses the configured embedder to vectorize the
147        resulting passage. The embedding can be passed directly to VectorRetriever
148        via RetrievalQuery(embedding=result.embedding, ...).
149
150        Args:
151            query:  The retrieval query.
152            domain: Optional domain hint for generation style.
153
154        Returns:
155            HypotheticalDocument with both hypothetical_doc and embedding populated.
156
157        Raises:
158            ValueError: If no embedder was provided at construction time.
159        """
160        if not self.embedder:
161            raise ValueError(
162                "An EmbeddingProvider is required for generate_and_embed(). "
163                "Pass embedder= to HyDEGenerator.__init__()."
164            )
165
166        result = await self.generate(query, domain)
167        result.embedding = self.embedder.embed_text(result.hypothetical_doc)
168        return result

Hypothetical Document Embeddings (HyDE) generator.

Why this works:

Short query strings ("antitrust cases 2024") and full answer passages live in very different regions of an embedding space. A hypothetical passage that ANSWERS the query occupies the same region as real answer documents, so cosine similarity between the HyDE embedding and indexed document embeddings is substantially higher than query-vs-document similarity.

Usage pattern:

  1. Call generate_and_embed(query) → HypotheticalDocument (with embedding set).
  2. Feed the embedding into VectorRetriever via RetrievalQuery(embedding=...).
  3. Compare results against standard VectorRetriever on the same query.

Example:

from gmf_forge_ai_data.query import HyDEGenerator
from gmf_forge_ai_data.retrieval import VectorRetriever, RetrievalQuery

hyde = HyDEGenerator(llm_gateway=gateway, embedder=embedder)

# Generate hypothetical doc and embed it
hypo = await hyde.generate_and_embed(
    "What are the penalties for antitrust violations?",
    domain="legal documents"
)

# Use HyDE embedding for retrieval
query = RetrievalQuery(embedding=hypo.embedding, top_k=5)
results = vector_retriever.retrieve(query)
HyDEGenerator( llm_gateway: gmf_forge_ai_shared_core.llm_gateway.UnifiedLLMGateway, embedder: Optional[EmbeddingProvider] = None)
 88    def __init__(
 89        self,
 90        llm_gateway: UnifiedLLMGateway,
 91        embedder: Optional[EmbeddingProvider] = None,
 92    ):
 93        """
 94        Initialize the HyDE generator.
 95
 96        Args:
 97            llm_gateway: LLM gateway used to generate the hypothetical document.
 98            embedder:    Embedding provider used to vectorize the hypothetical doc.
 99                         Required only for generate_and_embed(); optional for generate().
100        """
101        self.llm_gateway = llm_gateway
102        self.embedder = embedder

Initialize the HyDE generator.

Args: llm_gateway: LLM gateway used to generate the hypothetical document. embedder: Embedding provider used to vectorize the hypothetical doc. Required only for generate_and_embed(); optional for generate().

llm_gateway
embedder
async def generate( self, query: str, domain: Optional[str] = None) -> HypotheticalDocument:
104    async def generate(
105        self,
106        query: str,
107        domain: Optional[str] = None,
108    ) -> HypotheticalDocument:
109        """
110        Generate a hypothetical document that would answer the query.
111
112        The returned HypotheticalDocument has embedding=None. Call
113        generate_and_embed() to also produce a vector in one step.
114
115        Args:
116            query:  The retrieval query to generate a passage for.
117            domain: Optional domain hint to guide the LLM style
118                    (e.g., "legal documents", "financial reports", "AI/ML knowledge base").
119
120        Returns:
121            HypotheticalDocument with hypothetical_doc text, embedding=None.
122        """
123        domain_line = f"Domain: {domain}\n" if domain else ""
124        prompt = self._HYDE_PROMPT.format(query=query, domain_line=domain_line)
125
126        response = await self.llm_gateway.complete(
127            prompt=prompt,
128            temperature=0.5,
129            max_tokens=200,
130        )
131
132        return HypotheticalDocument(
133            query=query,
134            hypothetical_doc=response.content.strip(),
135            domain=domain,
136        )

Generate a hypothetical document that would answer the query.

The returned HypotheticalDocument has embedding=None. Call generate_and_embed() to also produce a vector in one step.

Args: query: The retrieval query to generate a passage for. domain: Optional domain hint to guide the LLM style (e.g., "legal documents", "financial reports", "AI/ML knowledge base").

Returns: HypotheticalDocument with hypothetical_doc text, embedding=None.

async def generate_and_embed( self, query: str, domain: Optional[str] = None) -> HypotheticalDocument:
138    async def generate_and_embed(
139        self,
140        query: str,
141        domain: Optional[str] = None,
142    ) -> HypotheticalDocument:
143        """
144        Generate a hypothetical document and embed it in a single step.
145
146        Calls generate() then uses the configured embedder to vectorize the
147        resulting passage. The embedding can be passed directly to VectorRetriever
148        via RetrievalQuery(embedding=result.embedding, ...).
149
150        Args:
151            query:  The retrieval query.
152            domain: Optional domain hint for generation style.
153
154        Returns:
155            HypotheticalDocument with both hypothetical_doc and embedding populated.
156
157        Raises:
158            ValueError: If no embedder was provided at construction time.
159        """
160        if not self.embedder:
161            raise ValueError(
162                "An EmbeddingProvider is required for generate_and_embed(). "
163                "Pass embedder= to HyDEGenerator.__init__()."
164            )
165
166        result = await self.generate(query, domain)
167        result.embedding = self.embedder.embed_text(result.hypothetical_doc)
168        return result

Generate a hypothetical document and embed it in a single step.

Calls generate() then uses the configured embedder to vectorize the resulting passage. The embedding can be passed directly to VectorRetriever via RetrievalQuery(embedding=result.embedding, ...).

Args: query: The retrieval query. domain: Optional domain hint for generation style.

Returns: HypotheticalDocument with both hypothetical_doc and embedding populated.

Raises: ValueError: If no embedder was provided at construction time.

@dataclass
class HypotheticalDocument:
24@dataclass
25class HypotheticalDocument:
26    """
27    Result of HyDE generation.
28
29    Attributes:
30        query:            The original retrieval query.
31        hypothetical_doc: LLM-generated passage that would answer the query.
32        embedding:        Vector embedding of hypothetical_doc (None until embedded).
33        domain:           Optional domain hint that was passed during generation.
34    """
35    query: str
36    hypothetical_doc: str
37    embedding: Optional[List[float]] = None
38    domain: Optional[str] = None

Result of HyDE generation.

Attributes: query: The original retrieval query. hypothetical_doc: LLM-generated passage that would answer the query. embedding: Vector embedding of hypothetical_doc (None until embedded). domain: Optional domain hint that was passed during generation.

HypotheticalDocument( query: str, hypothetical_doc: str, embedding: Optional[List[float]] = None, domain: Optional[str] = None)
query: str
hypothetical_doc: str
embedding: Optional[List[float]] = None
domain: Optional[str] = None
class DocumentIntelligenceLayout:
 89class DocumentIntelligenceLayout:
 90    """
 91    Wraps the Azure Document Intelligence ``prebuilt-layout`` model.
 92
 93    Converts PDF, DOCX, PPTX, XLSX, and image files into structured markdown
 94    that preserves headings, tables, lists, and page structure.  The markdown
 95    output is designed to be chunked with ``MarkdownChunker``.
 96
 97    Typical usage::
 98
 99        from gmf_forge_ai_data.layout import DocumentIntelligenceLayout
100        from gmf_forge_ai_data.chunkers import MarkdownChunker
101
102        # Managed identity (omit api_key) — required for Multiservices accounts
103        layout = DocumentIntelligenceLayout(
104            endpoint="https://my-resource.cognitiveservices.azure.com",
105        )
106
107        # API key — for standalone Document Intelligence resources
108        layout = DocumentIntelligenceLayout(
109            endpoint="https://my-resource.cognitiveservices.azure.com",
110            api_key="your-api-key",
111        )
112
113        # From a local file
114        result = layout.analyze_file("annual_report.pdf")
115
116        # From raw bytes (e.g. downloaded from BlobStorageConnector)
117        result = layout.analyze_bytes(blob_content, filename="report.pdf")
118
119        # From a URL
120        result = layout.analyze_url("https://example.com/policy.pdf")
121
122        # Chunk the markdown directly
123        chunker = MarkdownChunker(max_chunk_size=1500, min_header_level=2)
124        chunks  = chunker.chunk(result.markdown, metadata=result.metadata)
125
126    Args:
127        endpoint:    Azure Document Intelligence service endpoint URL.
128        api_key:     Azure Document Intelligence API key.  When provided,
129                     ``AzureKeyCredential`` is used.  When omitted (``None``),
130                     ``DefaultAzureCredential`` is used instead — required for
131                     Cognitive Services Multiservices accounts that do not
132                     expose API keys.
133        model_id:    Model to use (default: ``prebuilt-layout``).
134        api_version: REST API version (default: ``2024-11-30``).
135        logger:      Optional ``BasicLogger`` instance for structured logging.
136
137    Raises:
138        ImportError: If ``azure-ai-documentintelligence`` is not installed, or
139                     if ``api_key`` is omitted and ``azure-identity`` is not
140                     installed.
141        ValueError:  If ``endpoint`` is empty.
142    """
143
144    _DEFAULT_MODEL = "prebuilt-layout"
145    _DEFAULT_API_VERSION = "2024-11-30"
146
147    def __init__(
148        self,
149        endpoint: str,
150        api_key: Optional[str] = None,
151        model_id: str = _DEFAULT_MODEL,
152        api_version: str = _DEFAULT_API_VERSION,
153        logger: Optional[BasicLogger] = None,
154    ) -> None:
155        if not _SDK_AVAILABLE:
156            raise ImportError(
157                "azure-ai-documentintelligence is required: "
158                "pip install azure-ai-documentintelligence"
159            )
160        if not endpoint or not endpoint.strip():
161            raise ValueError("endpoint must not be empty")
162
163        self.endpoint = endpoint.rstrip("/")
164        self.model_id = model_id
165        self.api_version = api_version
166        self.logger = logger or _logger
167
168        if api_key:
169            # Standalone Document Intelligence resource with key-based auth enabled
170            credential = AzureKeyCredential(api_key)
171            auth_method = "api_key"
172        else:
173            # Cognitive Services Multiservices account — no API key exposed.
174            # DefaultAzureCredential resolves the auth chain automatically:
175            #   In Azure (AKS, VM, App Service): managed identity / workload identity
176            #   Locally: az login  or  VS Code Azure account extension
177            if not _IDENTITY_AVAILABLE:
178                raise ImportError(
179                    "azure-identity is required when api_key is not provided: "
180                    "pip install azure-identity"
181                )
182            credential = DefaultAzureCredential()
183            auth_method = "managed_identity"
184
185        self._client = DocumentIntelligenceClient(
186            endpoint=self.endpoint,
187            credential=credential,
188            api_version=self.api_version,
189        )
190
191        self.logger.info(
192            "DocumentIntelligenceLayout initialised",
193            endpoint=self.endpoint,
194            model_id=self.model_id,
195            api_version=self.api_version,
196            auth=auth_method,
197        )
198
199    # ------------------------------------------------------------------
200    # Public API
201    # ------------------------------------------------------------------
202
203    def analyze_file(self, file_path: str | Path) -> LayoutResult:
204        """
205        Analyse a local file and return its content as markdown.
206
207        Supported file types: PDF, DOCX, PPTX, XLSX, JPEG, PNG, BMP, TIFF, HEIF.
208
209        Args:
210            file_path: Path to the local document file.
211
212        Returns:
213            :class:`LayoutResult` with markdown content and metadata.
214
215        Raises:
216            FileNotFoundError: If the file does not exist.
217            ValueError:        If the file is empty.
218        """
219        path = Path(file_path)
220        if not path.exists():
221            raise FileNotFoundError(f"File not found: {file_path}")
222        if path.stat().st_size == 0:
223            raise ValueError(f"File is empty: {file_path}")
224
225        self.logger.info("Analysing file", file=str(path), model_id=self.model_id)
226
227        with open(path, "rb") as fh:
228            content = fh.read()
229
230        result = self._analyze_bytes_content(content)
231        result.metadata.update({
232            "source": str(path.resolve()),
233            "file_name": path.name,
234        })
235        return result
236
237    def analyze_bytes(self, content: bytes, filename: str = "") -> LayoutResult:
238        """
239        Analyse raw bytes and return content as markdown.
240
241        Use this when you already have the document bytes in memory — for
242        example, content downloaded via ``BlobStorageConnector`` or
243        ``SharePointConnector``.
244
245        Args:
246            content:  Raw document bytes.
247            filename: Original filename (used for metadata only, e.g. "report.pdf").
248
249        Returns:
250            :class:`LayoutResult` with markdown content and metadata.
251
252        Raises:
253            ValueError: If ``content`` is empty.
254        """
255        if not content:
256            raise ValueError("content must not be empty")
257
258        content_hash = hashlib.sha256(content).hexdigest()[:12]
259        source = f"bytes:{content_hash}"
260
261        self.logger.info(
262            "Analysing bytes",
263            size_bytes=len(content),
264            filename=filename or "(unnamed)",
265            model_id=self.model_id,
266        )
267
268        result = self._analyze_bytes_content(content)
269        result.metadata.update({
270            "source": source,
271            "file_name": filename,
272        })
273        return result
274
275    def analyze_url(self, url: str) -> LayoutResult:
276        """
277        Analyse a document at a publicly accessible URL.
278
279        The Azure Document Intelligence service fetches the document directly
280        from the URL — the bytes never pass through your application.
281
282        Args:
283            url: Publicly accessible URL pointing to a supported document.
284
285        Returns:
286            :class:`LayoutResult` with markdown content and metadata.
287
288        Raises:
289            ValueError: If ``url`` is empty.
290        """
291        if not url or not url.strip():
292            raise ValueError("url must not be empty")
293
294        self.logger.info("Analysing URL", url=url, model_id=self.model_id)
295
296        poller = self._client.begin_analyze_document(
297            self.model_id,
298            AnalyzeDocumentRequest(url_source=url),
299            output_content_format=DocumentContentFormat.MARKDOWN,
300            features=[DocumentAnalysisFeature.QUERY_FIELDS],
301        )
302        response = poller.result()
303
304        markdown = response.content or ""
305        page_count = len(response.pages) if response.pages else 0
306
307        self.logger.info(
308            "URL analysis complete",
309            url=url,
310            page_count=page_count,
311            markdown_length=len(markdown),
312        )
313
314        return LayoutResult(
315            markdown=markdown,
316            page_count=page_count,
317            metadata={
318                "source": url,
319                "file_name": url.split("/")[-1],
320                "model_id": self.model_id,
321                "analyzed_at": datetime.now(timezone.utc).isoformat(),
322            },
323        )
324
325    # ------------------------------------------------------------------
326    # Internal helpers
327    # ------------------------------------------------------------------
328
329    def _analyze_bytes_content(self, content: bytes) -> LayoutResult:
330        """Send raw bytes to the Document Intelligence service and return LayoutResult."""
331        poller = self._client.begin_analyze_document(
332            self.model_id,
333            AnalyzeDocumentRequest(bytes_source=content),
334            output_content_format=DocumentContentFormat.MARKDOWN,
335            features=[DocumentAnalysisFeature.QUERY_FIELDS],
336        )
337        response = poller.result()
338
339        markdown = response.content or ""
340        page_count = len(response.pages) if response.pages else 0
341
342        self.logger.info(
343            "Analysis complete",
344            model_id=self.model_id,
345            page_count=page_count,
346            markdown_length=len(markdown),
347        )
348
349        return LayoutResult(
350            markdown=markdown,
351            page_count=page_count,
352            metadata={
353                "model_id": self.model_id,
354                "analyzed_at": datetime.now(timezone.utc).isoformat(),
355            },
356        )

Wraps the Azure Document Intelligence prebuilt-layout model.

Converts PDF, DOCX, PPTX, XLSX, and image files into structured markdown that preserves headings, tables, lists, and page structure. The markdown output is designed to be chunked with MarkdownChunker.

Typical usage::

from gmf_forge_ai_data.layout import DocumentIntelligenceLayout
from gmf_forge_ai_data.chunkers import MarkdownChunker

# Managed identity (omit api_key) — required for Multiservices accounts
layout = DocumentIntelligenceLayout(
    endpoint="https://my-resource.cognitiveservices.azure.com",
)

# API key — for standalone Document Intelligence resources
layout = DocumentIntelligenceLayout(
    endpoint="https://my-resource.cognitiveservices.azure.com",
    api_key="your-api-key",
)

# From a local file
result = layout.analyze_file("annual_report.pdf")

# From raw bytes (e.g. downloaded from BlobStorageConnector)
result = layout.analyze_bytes(blob_content, filename="report.pdf")

# From a URL
result = layout.analyze_url("https://example.com/policy.pdf")

# Chunk the markdown directly
chunker = MarkdownChunker(max_chunk_size=1500, min_header_level=2)
chunks  = chunker.chunk(result.markdown, metadata=result.metadata)

Args: endpoint: Azure Document Intelligence service endpoint URL. api_key: Azure Document Intelligence API key. When provided, AzureKeyCredential is used. When omitted (None), DefaultAzureCredential is used instead — required for Cognitive Services Multiservices accounts that do not expose API keys. model_id: Model to use (default: prebuilt-layout). api_version: REST API version (default: 2024-11-30). logger: Optional BasicLogger instance for structured logging.

Raises: ImportError: If azure-ai-documentintelligence is not installed, or if api_key is omitted and azure-identity is not installed. ValueError: If endpoint is empty.

DocumentIntelligenceLayout( endpoint: str, api_key: Optional[str] = None, model_id: str = 'prebuilt-layout', api_version: str = '2024-11-30', logger: Optional[gmf_forge_ai_shared_core.observability.BasicLogger] = None)
147    def __init__(
148        self,
149        endpoint: str,
150        api_key: Optional[str] = None,
151        model_id: str = _DEFAULT_MODEL,
152        api_version: str = _DEFAULT_API_VERSION,
153        logger: Optional[BasicLogger] = None,
154    ) -> None:
155        if not _SDK_AVAILABLE:
156            raise ImportError(
157                "azure-ai-documentintelligence is required: "
158                "pip install azure-ai-documentintelligence"
159            )
160        if not endpoint or not endpoint.strip():
161            raise ValueError("endpoint must not be empty")
162
163        self.endpoint = endpoint.rstrip("/")
164        self.model_id = model_id
165        self.api_version = api_version
166        self.logger = logger or _logger
167
168        if api_key:
169            # Standalone Document Intelligence resource with key-based auth enabled
170            credential = AzureKeyCredential(api_key)
171            auth_method = "api_key"
172        else:
173            # Cognitive Services Multiservices account — no API key exposed.
174            # DefaultAzureCredential resolves the auth chain automatically:
175            #   In Azure (AKS, VM, App Service): managed identity / workload identity
176            #   Locally: az login  or  VS Code Azure account extension
177            if not _IDENTITY_AVAILABLE:
178                raise ImportError(
179                    "azure-identity is required when api_key is not provided: "
180                    "pip install azure-identity"
181                )
182            credential = DefaultAzureCredential()
183            auth_method = "managed_identity"
184
185        self._client = DocumentIntelligenceClient(
186            endpoint=self.endpoint,
187            credential=credential,
188            api_version=self.api_version,
189        )
190
191        self.logger.info(
192            "DocumentIntelligenceLayout initialised",
193            endpoint=self.endpoint,
194            model_id=self.model_id,
195            api_version=self.api_version,
196            auth=auth_method,
197        )
endpoint
model_id
api_version
logger
def analyze_file( self, file_path: str | pathlib._local.Path) -> LayoutResult:
203    def analyze_file(self, file_path: str | Path) -> LayoutResult:
204        """
205        Analyse a local file and return its content as markdown.
206
207        Supported file types: PDF, DOCX, PPTX, XLSX, JPEG, PNG, BMP, TIFF, HEIF.
208
209        Args:
210            file_path: Path to the local document file.
211
212        Returns:
213            :class:`LayoutResult` with markdown content and metadata.
214
215        Raises:
216            FileNotFoundError: If the file does not exist.
217            ValueError:        If the file is empty.
218        """
219        path = Path(file_path)
220        if not path.exists():
221            raise FileNotFoundError(f"File not found: {file_path}")
222        if path.stat().st_size == 0:
223            raise ValueError(f"File is empty: {file_path}")
224
225        self.logger.info("Analysing file", file=str(path), model_id=self.model_id)
226
227        with open(path, "rb") as fh:
228            content = fh.read()
229
230        result = self._analyze_bytes_content(content)
231        result.metadata.update({
232            "source": str(path.resolve()),
233            "file_name": path.name,
234        })
235        return result

Analyse a local file and return its content as markdown.

Supported file types: PDF, DOCX, PPTX, XLSX, JPEG, PNG, BMP, TIFF, HEIF.

Args: file_path: Path to the local document file.

Returns: LayoutResult with markdown content and metadata.

Raises: FileNotFoundError: If the file does not exist. ValueError: If the file is empty.

def analyze_bytes( self, content: bytes, filename: str = '') -> LayoutResult:
237    def analyze_bytes(self, content: bytes, filename: str = "") -> LayoutResult:
238        """
239        Analyse raw bytes and return content as markdown.
240
241        Use this when you already have the document bytes in memory — for
242        example, content downloaded via ``BlobStorageConnector`` or
243        ``SharePointConnector``.
244
245        Args:
246            content:  Raw document bytes.
247            filename: Original filename (used for metadata only, e.g. "report.pdf").
248
249        Returns:
250            :class:`LayoutResult` with markdown content and metadata.
251
252        Raises:
253            ValueError: If ``content`` is empty.
254        """
255        if not content:
256            raise ValueError("content must not be empty")
257
258        content_hash = hashlib.sha256(content).hexdigest()[:12]
259        source = f"bytes:{content_hash}"
260
261        self.logger.info(
262            "Analysing bytes",
263            size_bytes=len(content),
264            filename=filename or "(unnamed)",
265            model_id=self.model_id,
266        )
267
268        result = self._analyze_bytes_content(content)
269        result.metadata.update({
270            "source": source,
271            "file_name": filename,
272        })
273        return result

Analyse raw bytes and return content as markdown.

Use this when you already have the document bytes in memory — for example, content downloaded via BlobStorageConnector or SharePointConnector.

Args: content: Raw document bytes. filename: Original filename (used for metadata only, e.g. "report.pdf").

Returns: LayoutResult with markdown content and metadata.

Raises: ValueError: If content is empty.

def analyze_url( self, url: str) -> LayoutResult:
275    def analyze_url(self, url: str) -> LayoutResult:
276        """
277        Analyse a document at a publicly accessible URL.
278
279        The Azure Document Intelligence service fetches the document directly
280        from the URL — the bytes never pass through your application.
281
282        Args:
283            url: Publicly accessible URL pointing to a supported document.
284
285        Returns:
286            :class:`LayoutResult` with markdown content and metadata.
287
288        Raises:
289            ValueError: If ``url`` is empty.
290        """
291        if not url or not url.strip():
292            raise ValueError("url must not be empty")
293
294        self.logger.info("Analysing URL", url=url, model_id=self.model_id)
295
296        poller = self._client.begin_analyze_document(
297            self.model_id,
298            AnalyzeDocumentRequest(url_source=url),
299            output_content_format=DocumentContentFormat.MARKDOWN,
300            features=[DocumentAnalysisFeature.QUERY_FIELDS],
301        )
302        response = poller.result()
303
304        markdown = response.content or ""
305        page_count = len(response.pages) if response.pages else 0
306
307        self.logger.info(
308            "URL analysis complete",
309            url=url,
310            page_count=page_count,
311            markdown_length=len(markdown),
312        )
313
314        return LayoutResult(
315            markdown=markdown,
316            page_count=page_count,
317            metadata={
318                "source": url,
319                "file_name": url.split("/")[-1],
320                "model_id": self.model_id,
321                "analyzed_at": datetime.now(timezone.utc).isoformat(),
322            },
323        )

Analyse a document at a publicly accessible URL.

The Azure Document Intelligence service fetches the document directly from the URL — the bytes never pass through your application.

Args: url: Publicly accessible URL pointing to a supported document.

Returns: LayoutResult with markdown content and metadata.

Raises: ValueError: If url is empty.

@dataclass
class LayoutResult:
66@dataclass
67class LayoutResult:
68    """
69    Result of a Document Intelligence layout analysis.
70
71    Attributes:
72        markdown:    Full document content as markdown.  Headers, tables, lists,
73                     page breaks (``<!-- PageBreak -->``) and figure captions
74                     (``<!-- FigureCaption -->``) are preserved exactly as
75                     produced by Azure Document Intelligence.  Pass directly to
76                     ``MarkdownChunker`` for header-aware chunking.
77        page_count:  Number of pages in the analysed document.
78        metadata:    Source information and analysis details:
79                     ``source``       — file path, URL, or "bytes:<hash>"
80                     ``file_name``    — basename of the source file (if known)
81                     ``model_id``     — Document Intelligence model used
82                     ``analyzed_at`` — ISO-8601 UTC timestamp of the analysis
83    """
84    markdown: str
85    page_count: int
86    metadata: Dict[str, Any] = field(default_factory=dict)

Result of a Document Intelligence layout analysis.

Attributes: markdown: Full document content as markdown. Headers, tables, lists, page breaks (<!-- PageBreak -->) and figure captions (<!-- FigureCaption -->) are preserved exactly as produced by Azure Document Intelligence. Pass directly to MarkdownChunker for header-aware chunking. page_count: Number of pages in the analysed document. metadata: Source information and analysis details: source — file path, URL, or "bytes:" file_name — basename of the source file (if known) model_id — Document Intelligence model used analyzed_at — ISO-8601 UTC timestamp of the analysis

LayoutResult(markdown: str, page_count: int, metadata: Dict[str, Any] = <factory>)
markdown: str
page_count: int
metadata: Dict[str, Any]