gmf_forge_ai_data
GMF Forge AI Data Layer Package.
This package provides data layer components for RAG applications including:
- Embeddings (Azure OpenAI)
- Document processing and chunking
- Retrieval strategies
- Query optimization
- Context management
- Indexing utilities
- Data connectors
For detailed documentation, see the README.md file.
1""" 2GMF Forge AI Data Layer Package. 3 4This package provides data layer components for RAG applications including: 5- Embeddings (Azure OpenAI) 6- Document processing and chunking 7- Retrieval strategies 8- Query optimization 9- Context management 10- Indexing utilities 11- Data connectors 12 13For detailed documentation, see the README.md file. 14""" 15 16from gmf_forge_ai_data.version import __version__ 17 18# Embeddings module is available 19from .embeddings import ( 20 EmbeddingProvider, 21 AzureOpenAIEmbeddings, 22 BatchEmbeddings, 23) 24 25# Chunking module is available 26from .chunkers import ( 27 Chunk, 28 BaseChunker, 29 FixedSizeChunker, 30 SemanticChunker, 31 RecursiveChunker, 32 SentenceChunker, 33 MarkdownChunker, 34 CodeChunker, 35) 36 37# Vector stores module is available 38from .vector_stores import ( 39 BaseVectorStore, 40 Document, 41 SearchResult, 42 InMemoryVectorStore, 43 AzureAISearchVectorStore, 44 AzureCosmosDBVectorStore, 45 MongoDBVectorStore, 46) 47 48# Retrieval strategies module is available 49from .retrieval import ( 50 BaseRetriever, 51 RetrievalQuery, 52 VectorRetriever, 53 KeywordRetriever, 54 HybridRetriever, 55 MMRRetriever, 56 ParentDocumentRetriever, 57 EnsembleRetriever, 58 HierarchicalRetriever, 59 GraphRetriever, 60 SQLRetriever, 61 SQLSchema, 62 SQLQuery, 63 MultiIndexRetriever, 64 SourceConfig, 65) 66 67# Data connectors module is available 68from .connectors import ( 69 BaseConnector, 70 FilesystemConnector, 71 SharePointConnector, 72 BlobStorageConnector, 73) 74 75# Indexing builders module is available 76from .indexing import ( 77 BaseIndexBuilder, 78 AzureAISearchIndexBuilder, 79 CosmosDBIndexBuilder, 80 MongoDBIndexBuilder, 81) 82 83# Context processing module is available 84from .context import ( 85 RelevanceFilter, 86 ContextDeduplicator, 87 ContextReranker, 88 ContextCompressor, 89 ContextWindowManager, 90 WindowedContext, 91) 92 93# Query optimization module is available 94from .query import ( 95 QueryDecomposer, 96 DecomposedQuery, 97 QueryRouter, 98 RouteDecision, 99 QueryExpander, 100 ExpandedQuery, 101 QueryRewriter, 102 RewrittenQuery, 103 HyDEGenerator, 104 HypotheticalDocument, 105) 106 107# Document Intelligence layout module is available 108from .layout import ( 109 DocumentIntelligenceLayout, 110 LayoutResult, 111) 112 113__all__ = [ 114 "__version__", 115 # Embeddings 116 "EmbeddingProvider", 117 "AzureOpenAIEmbeddings", 118 "BatchEmbeddings", 119 # Chunking 120 "Chunk", 121 "BaseChunker", 122 "FixedSizeChunker", 123 "SemanticChunker", 124 "RecursiveChunker", 125 "SentenceChunker", 126 "MarkdownChunker", 127 "CodeChunker", 128 # Vector Stores 129 "BaseVectorStore", 130 "Document", 131 "SearchResult", 132 "InMemoryVectorStore", 133 "AzureAISearchVectorStore", 134 "AzureCosmosDBVectorStore", 135 "MongoDBVectorStore", 136 # Retrieval Strategies 137 "BaseRetriever", 138 "RetrievalQuery", 139 "VectorRetriever", 140 "KeywordRetriever", 141 "HybridRetriever", 142 "MMRRetriever", 143 "ParentDocumentRetriever", 144 "EnsembleRetriever", 145 "HierarchicalRetriever", 146 "GraphRetriever", 147 "SQLRetriever", 148 "SQLSchema", 149 "SQLQuery", 150 "MultiIndexRetriever", 151 "SourceConfig", 152 # Data Connectors 153 "BaseConnector", 154 "FilesystemConnector", 155 "SharePointConnector", 156 "BlobStorageConnector", 157 # Indexing Builders 158 "BaseIndexBuilder", 159 "AzureAISearchIndexBuilder", 160 "CosmosDBIndexBuilder", 161 "MongoDBIndexBuilder", 162 # Context Processing 163 "RelevanceFilter", 164 "ContextDeduplicator", 165 "ContextReranker", 166 "ContextCompressor", 167 "ContextWindowManager", 168 "WindowedContext", 169 # Query Optimization 170 "QueryDecomposer", 171 "DecomposedQuery", 172 "QueryRouter", 173 "RouteDecision", 174 "QueryExpander", 175 "ExpandedQuery", 176 "QueryRewriter", 177 "RewrittenQuery", 178 "HyDEGenerator", 179 "HypotheticalDocument", 180 # Document Intelligence Layout 181 "DocumentIntelligenceLayout", 182 "LayoutResult", 183]
13class EmbeddingProvider(ABC): 14 """ 15 Abstract base class for embedding providers. 16 17 All embedding implementations (Azure OpenAI, OpenAI, Cohere, etc.) should 18 inherit from this class and implement its abstract methods. 19 """ 20 21 @abstractmethod 22 def embed_text(self, text: str) -> List[float]: 23 """ 24 Generate embeddings for a single text string. 25 26 Args: 27 text: The input text to embed 28 29 Returns: 30 A list of floats representing the embedding vector 31 32 Raises: 33 ValueError: If text is empty or None 34 Exception: Provider-specific errors (rate limits, auth, etc.) 35 """ 36 pass 37 38 @abstractmethod 39 def embed_batch(self, texts: List[str]) -> List[List[float]]: 40 """ 41 Generate embeddings for multiple texts in a batch. 42 43 This method should be more efficient than calling embed_text() repeatedly 44 as it can leverage the provider's batch API capabilities. 45 46 Args: 47 texts: List of text strings to embed 48 49 Returns: 50 List of embedding vectors, one for each input text 51 52 Raises: 53 ValueError: If texts is empty or contains invalid entries 54 Exception: Provider-specific errors (rate limits, auth, etc.) 55 """ 56 pass 57 58 @abstractmethod 59 def get_embedding_dimension(self) -> int: 60 """ 61 Get the dimensionality of the embedding vectors. 62 63 Returns: 64 Integer dimension of the embedding space 65 (e.g., 1536 for text-embedding-3-large) 66 """ 67 pass 68 69 @abstractmethod 70 def get_model_name(self) -> str: 71 """ 72 Get the name/identifier of the embedding model being used. 73 74 Returns: 75 String identifier for the model (e.g., "text-embedding-3-large") 76 """ 77 pass 78 79 def validate_text(self, text: str) -> None: 80 """ 81 Validate input text before embedding. 82 83 Args: 84 text: The text to validate 85 86 Raises: 87 ValueError: If text is None, empty, or not a string 88 """ 89 if text is None: 90 raise ValueError("Text cannot be None") 91 if not isinstance(text, str): 92 raise ValueError(f"Text must be a string, got {type(text)}") 93 if not text.strip(): 94 raise ValueError("Text cannot be empty or whitespace only") 95 96 def validate_batch(self, texts: List[str]) -> None: 97 """ 98 Validate a batch of texts before embedding. 99 100 Args: 101 texts: List of texts to validate 102 103 Raises: 104 ValueError: If texts is invalid or contains invalid entries 105 """ 106 if texts is None: 107 raise ValueError("Texts list cannot be None") 108 if not isinstance(texts, list): 109 raise ValueError(f"Texts must be a list, got {type(texts)}") 110 if len(texts) == 0: 111 raise ValueError("Texts list cannot be empty") 112 113 for i, text in enumerate(texts): 114 try: 115 self.validate_text(text) 116 except ValueError as e: 117 raise ValueError(f"Invalid text at index {i}: {str(e)}")
Abstract base class for embedding providers.
All embedding implementations (Azure OpenAI, OpenAI, Cohere, etc.) should inherit from this class and implement its abstract methods.
21 @abstractmethod 22 def embed_text(self, text: str) -> List[float]: 23 """ 24 Generate embeddings for a single text string. 25 26 Args: 27 text: The input text to embed 28 29 Returns: 30 A list of floats representing the embedding vector 31 32 Raises: 33 ValueError: If text is empty or None 34 Exception: Provider-specific errors (rate limits, auth, etc.) 35 """ 36 pass
Generate embeddings for a single text string.
Args: text: The input text to embed
Returns: A list of floats representing the embedding vector
Raises: ValueError: If text is empty or None Exception: Provider-specific errors (rate limits, auth, etc.)
38 @abstractmethod 39 def embed_batch(self, texts: List[str]) -> List[List[float]]: 40 """ 41 Generate embeddings for multiple texts in a batch. 42 43 This method should be more efficient than calling embed_text() repeatedly 44 as it can leverage the provider's batch API capabilities. 45 46 Args: 47 texts: List of text strings to embed 48 49 Returns: 50 List of embedding vectors, one for each input text 51 52 Raises: 53 ValueError: If texts is empty or contains invalid entries 54 Exception: Provider-specific errors (rate limits, auth, etc.) 55 """ 56 pass
Generate embeddings for multiple texts in a batch.
This method should be more efficient than calling embed_text() repeatedly as it can leverage the provider's batch API capabilities.
Args: texts: List of text strings to embed
Returns: List of embedding vectors, one for each input text
Raises: ValueError: If texts is empty or contains invalid entries Exception: Provider-specific errors (rate limits, auth, etc.)
58 @abstractmethod 59 def get_embedding_dimension(self) -> int: 60 """ 61 Get the dimensionality of the embedding vectors. 62 63 Returns: 64 Integer dimension of the embedding space 65 (e.g., 1536 for text-embedding-3-large) 66 """ 67 pass
Get the dimensionality of the embedding vectors.
Returns: Integer dimension of the embedding space (e.g., 1536 for text-embedding-3-large)
69 @abstractmethod 70 def get_model_name(self) -> str: 71 """ 72 Get the name/identifier of the embedding model being used. 73 74 Returns: 75 String identifier for the model (e.g., "text-embedding-3-large") 76 """ 77 pass
Get the name/identifier of the embedding model being used.
Returns: String identifier for the model (e.g., "text-embedding-3-large")
79 def validate_text(self, text: str) -> None: 80 """ 81 Validate input text before embedding. 82 83 Args: 84 text: The text to validate 85 86 Raises: 87 ValueError: If text is None, empty, or not a string 88 """ 89 if text is None: 90 raise ValueError("Text cannot be None") 91 if not isinstance(text, str): 92 raise ValueError(f"Text must be a string, got {type(text)}") 93 if not text.strip(): 94 raise ValueError("Text cannot be empty or whitespace only")
Validate input text before embedding.
Args: text: The text to validate
Raises: ValueError: If text is None, empty, or not a string
96 def validate_batch(self, texts: List[str]) -> None: 97 """ 98 Validate a batch of texts before embedding. 99 100 Args: 101 texts: List of texts to validate 102 103 Raises: 104 ValueError: If texts is invalid or contains invalid entries 105 """ 106 if texts is None: 107 raise ValueError("Texts list cannot be None") 108 if not isinstance(texts, list): 109 raise ValueError(f"Texts must be a list, got {type(texts)}") 110 if len(texts) == 0: 111 raise ValueError("Texts list cannot be empty") 112 113 for i, text in enumerate(texts): 114 try: 115 self.validate_text(text) 116 except ValueError as e: 117 raise ValueError(f"Invalid text at index {i}: {str(e)}")
Validate a batch of texts before embedding.
Args: texts: List of texts to validate
Raises: ValueError: If texts is invalid or contains invalid entries
32class AzureOpenAIEmbeddings(EmbeddingProvider): 33 """ 34 Azure OpenAI embedding provider using text-embedding-3-large. 35 36 This implementation provides: 37 - SSL certificate handling for internal Azure endpoints 38 - Automatic retry logic with exponential backoff 39 - Optional logging via shared-core BasicLogger 40 - Simple developer-friendly API 41 42 Example (API key): 43 >>> embedder = AzureOpenAIEmbeddings( 44 ... endpoint="https://your-resource.openai.azure.com", 45 ... api_key="your-api-key", 46 ... deployment_name="text-embedding-3-large" 47 ... ) 48 >>> vector = embedder.embed_text("Hello world") 49 >>> vectors = embedder.embed_batch(["Text 1", "Text 2"]) 50 51 Example (managed identity / token provider): 52 >>> from azure.identity import DefaultAzureCredential, get_bearer_token_provider 53 >>> token_provider = get_bearer_token_provider( 54 ... DefaultAzureCredential(), 55 ... "https://cognitiveservices.azure.com/.default" # Cognitive Services scope 56 ... ) 57 >>> embedder = AzureOpenAIEmbeddings( 58 ... endpoint="https://your-resource.openai.azure.com", 59 ... deployment_name="text-embedding-3-large", 60 ... token_provider=token_provider 61 ... ) 62 """ 63 64 def __init__( 65 self, 66 endpoint: str, 67 deployment_name: str, 68 api_key: Optional[str] = None, 69 token_provider: Optional[Callable[[], str]] = None, 70 model: str = "text-embedding-3-large", 71 api_version: str = "2024-02-01", 72 max_retries: int = 3, 73 retry_delay: float = 1.0, 74 ssl_cert_path: Optional[str] = None, 75 logger: Optional['BasicLogger'] = None, 76 ): 77 """ 78 Initialize Azure OpenAI embeddings provider. 79 80 Exactly one of ``api_key`` or ``token_provider`` must be supplied. 81 82 Args: 83 endpoint: Azure OpenAI endpoint URL 84 deployment_name: Name of the deployment in Azure 85 api_key: Azure OpenAI API key. Use for local development or 86 when managed identity is not available. 87 token_provider: Zero-argument callable that returns a bearer token 88 string. Use for managed identity / workload identity scenarios. 89 The callable must request the **Cognitive Services** scope:: 90 91 from azure.identity import DefaultAzureCredential, get_bearer_token_provider 92 token_provider = get_bearer_token_provider( 93 DefaultAzureCredential(), 94 "https://cognitiveservices.azure.com/.default" 95 ) 96 97 Note: this scope is different from Azure AI Search 98 (``https://search.azure.com/.default``) — each service 99 requires its own token_provider. 100 model: Model identifier (default: text-embedding-3-large) 101 api_version: Azure OpenAI API version (default: 2024-02-01) 102 max_retries: Maximum number of retry attempts (default: 3) 103 retry_delay: Initial delay between retries in seconds (default: 1.0) 104 ssl_cert_path: Optional path to SSL certificate bundle for corporate networks. 105 If None, uses default SSL verification. 106 logger: Optional BasicLogger instance for observability 107 """ 108 if not api_key and not token_provider: 109 raise ValueError( 110 "Either api_key or token_provider must be supplied to AzureOpenAIEmbeddings." 111 ) 112 self.endpoint = endpoint 113 self.deployment_name = deployment_name 114 self.model = model 115 self.max_retries = max_retries 116 self.retry_delay = retry_delay 117 self.logger = logger 118 self.ssl_cert_path = ssl_cert_path 119 120 # Log initialization details 121 if self.logger: 122 self.logger.info( 123 "Initializing AzureOpenAIEmbeddings", 124 endpoint=endpoint, 125 deployment=deployment_name, 126 model=model, 127 ssl_cert_provided=ssl_cert_path is not None 128 ) 129 130 # Handle SSL certificate if provided (matches shared-core pattern) 131 http_client = None 132 if ssl_cert_path and os.path.exists(ssl_cert_path): 133 try: 134 # Use httpx Client with verify parameter (same as shared-core AzureOpenAIProvider) 135 http_client = httpx.Client(verify=str(ssl_cert_path), timeout=30.0) 136 137 if self.logger: 138 self.logger.info(f"✓ Using SSL certificate: {ssl_cert_path}") 139 except Exception as e: 140 if self.logger: 141 self.logger.warning(f"Failed to configure SSL certificate: {e}") 142 elif ssl_cert_path: 143 if self.logger: 144 self.logger.warning(f"SSL certificate path provided but file not found: {ssl_cert_path}") 145 else: 146 if self.logger: 147 self.logger.debug("No SSL certificate provided, using default SSL verification") 148 149 # Initialize Azure OpenAI client 150 if token_provider: 151 self.client = AzureOpenAI( 152 azure_ad_token_provider=token_provider, 153 api_version=api_version, 154 azure_endpoint=endpoint, 155 http_client=http_client, 156 ) 157 else: 158 self.client = AzureOpenAI( 159 api_key=api_key, 160 api_version=api_version, 161 azure_endpoint=endpoint, 162 http_client=http_client, 163 ) 164 165 # Dimension is detected lazily from the first API response 166 self._dimension: Optional[int] = None 167 168 if self.logger: 169 self.logger.info( 170 f"✓ Initialized AzureOpenAIEmbeddings with model={model}, " 171 f"deployment={deployment_name}" 172 ) 173 174 def _call_with_retry(self, func, *args, **kwargs): 175 """ 176 Execute a function with exponential backoff retry logic. 177 178 Args: 179 func: Function to call 180 *args: Positional arguments for the function 181 **kwargs: Keyword arguments for the function 182 183 Returns: 184 Result of the function call 185 186 Raises: 187 Exception: If all retries are exhausted 188 """ 189 last_exception = None 190 delay = self.retry_delay 191 192 for attempt in range(self.max_retries): 193 try: 194 return func(*args, **kwargs) 195 except Exception as e: 196 last_exception = e 197 error_type = type(e).__name__ 198 error_msg = str(e) 199 200 if attempt < self.max_retries - 1: 201 if self.logger: 202 self.logger.warning( 203 f"Attempt {attempt + 1} failed", 204 error_type=error_type, 205 error_message=error_msg, 206 retry_in=f"{delay}s" 207 ) 208 time.sleep(delay) 209 delay *= 2 # Exponential backoff 210 else: 211 if self.logger: 212 self.logger.error( 213 f"All {self.max_retries} attempts failed", 214 error_type=error_type, 215 error_message=error_msg 216 ) 217 218 raise last_exception 219 220 def embed_text(self, text: str) -> List[float]: 221 """ 222 Generate embeddings for a single text string. 223 224 Args: 225 text: The input text to embed 226 227 Returns: 228 A list of floats representing the embedding vector 229 230 Raises: 231 ValueError: If text is invalid 232 Exception: Azure OpenAI API errors 233 """ 234 self.validate_text(text) 235 236 if self.logger: 237 self.logger.debug(f"Embedding single text (length={len(text)})") 238 239 def _embed(): 240 response: CreateEmbeddingResponse = self.client.embeddings.create( 241 input=text, 242 model=self.deployment_name, 243 ) 244 return response.data[0].embedding 245 246 embedding = self._call_with_retry(_embed) 247 248 if self._dimension is None: 249 self._dimension = len(embedding) 250 251 if self.logger: 252 self.logger.debug(f"Generated embedding with dimension={len(embedding)}") 253 254 return embedding 255 256 def embed_batch(self, texts: List[str]) -> List[List[float]]: 257 """ 258 Generate embeddings for multiple texts in a batch. 259 260 Azure OpenAI supports batch requests which are more efficient than 261 individual calls. This method automatically handles the batch API. 262 263 Args: 264 texts: List of text strings to embed 265 266 Returns: 267 List of embedding vectors, one for each input text 268 269 Raises: 270 ValueError: If texts is invalid 271 Exception: Azure OpenAI API errors 272 """ 273 self.validate_batch(texts) 274 275 if self.logger: 276 self.logger.info(f"Embedding batch of {len(texts)} texts") 277 278 def _embed(): 279 response: CreateEmbeddingResponse = self.client.embeddings.create( 280 input=texts, 281 model=self.deployment_name, 282 ) 283 # Ensure results are in the correct order 284 sorted_data = sorted(response.data, key=lambda x: x.index) 285 return [item.embedding for item in sorted_data] 286 287 embeddings = self._call_with_retry(_embed) 288 289 if self._dimension is None and embeddings: 290 self._dimension = len(embeddings[0]) 291 292 if self.logger: 293 self.logger.info( 294 f"Generated {len(embeddings)} embeddings successfully" 295 ) 296 297 return embeddings 298 299 def get_embedding_dimension(self) -> Optional[int]: 300 """ 301 Get the dimensionality of the embedding vectors. 302 303 Returns the actual dimension observed from the first API response. 304 Returns None if no embedding has been generated yet. 305 """ 306 return self._dimension 307 308 def get_model_name(self) -> str: 309 """ 310 Get the name of the embedding model. 311 312 Returns: 313 The model identifier 314 """ 315 return self.model
Azure OpenAI embedding provider using text-embedding-3-large.
This implementation provides:
- SSL certificate handling for internal Azure endpoints
- Automatic retry logic with exponential backoff
- Optional logging via shared-core BasicLogger
- Simple developer-friendly API
Example (API key):
embedder = AzureOpenAIEmbeddings( ... endpoint="https://your-resource.openai.azure.com", ... api_key="your-api-key", ... deployment_name="text-embedding-3-large" ... ) vector = embedder.embed_text("Hello world") vectors = embedder.embed_batch(["Text 1", "Text 2"])
Example (managed identity / token provider):
from azure.identity import DefaultAzureCredential, get_bearer_token_provider token_provider = get_bearer_token_provider( ... DefaultAzureCredential(), ... "https://cognitiveservices.azure.com/.default" # Cognitive Services scope ... ) embedder = AzureOpenAIEmbeddings( ... endpoint="https://your-resource.openai.azure.com", ... deployment_name="text-embedding-3-large", ... token_provider=token_provider ... )
64 def __init__( 65 self, 66 endpoint: str, 67 deployment_name: str, 68 api_key: Optional[str] = None, 69 token_provider: Optional[Callable[[], str]] = None, 70 model: str = "text-embedding-3-large", 71 api_version: str = "2024-02-01", 72 max_retries: int = 3, 73 retry_delay: float = 1.0, 74 ssl_cert_path: Optional[str] = None, 75 logger: Optional['BasicLogger'] = None, 76 ): 77 """ 78 Initialize Azure OpenAI embeddings provider. 79 80 Exactly one of ``api_key`` or ``token_provider`` must be supplied. 81 82 Args: 83 endpoint: Azure OpenAI endpoint URL 84 deployment_name: Name of the deployment in Azure 85 api_key: Azure OpenAI API key. Use for local development or 86 when managed identity is not available. 87 token_provider: Zero-argument callable that returns a bearer token 88 string. Use for managed identity / workload identity scenarios. 89 The callable must request the **Cognitive Services** scope:: 90 91 from azure.identity import DefaultAzureCredential, get_bearer_token_provider 92 token_provider = get_bearer_token_provider( 93 DefaultAzureCredential(), 94 "https://cognitiveservices.azure.com/.default" 95 ) 96 97 Note: this scope is different from Azure AI Search 98 (``https://search.azure.com/.default``) — each service 99 requires its own token_provider. 100 model: Model identifier (default: text-embedding-3-large) 101 api_version: Azure OpenAI API version (default: 2024-02-01) 102 max_retries: Maximum number of retry attempts (default: 3) 103 retry_delay: Initial delay between retries in seconds (default: 1.0) 104 ssl_cert_path: Optional path to SSL certificate bundle for corporate networks. 105 If None, uses default SSL verification. 106 logger: Optional BasicLogger instance for observability 107 """ 108 if not api_key and not token_provider: 109 raise ValueError( 110 "Either api_key or token_provider must be supplied to AzureOpenAIEmbeddings." 111 ) 112 self.endpoint = endpoint 113 self.deployment_name = deployment_name 114 self.model = model 115 self.max_retries = max_retries 116 self.retry_delay = retry_delay 117 self.logger = logger 118 self.ssl_cert_path = ssl_cert_path 119 120 # Log initialization details 121 if self.logger: 122 self.logger.info( 123 "Initializing AzureOpenAIEmbeddings", 124 endpoint=endpoint, 125 deployment=deployment_name, 126 model=model, 127 ssl_cert_provided=ssl_cert_path is not None 128 ) 129 130 # Handle SSL certificate if provided (matches shared-core pattern) 131 http_client = None 132 if ssl_cert_path and os.path.exists(ssl_cert_path): 133 try: 134 # Use httpx Client with verify parameter (same as shared-core AzureOpenAIProvider) 135 http_client = httpx.Client(verify=str(ssl_cert_path), timeout=30.0) 136 137 if self.logger: 138 self.logger.info(f"✓ Using SSL certificate: {ssl_cert_path}") 139 except Exception as e: 140 if self.logger: 141 self.logger.warning(f"Failed to configure SSL certificate: {e}") 142 elif ssl_cert_path: 143 if self.logger: 144 self.logger.warning(f"SSL certificate path provided but file not found: {ssl_cert_path}") 145 else: 146 if self.logger: 147 self.logger.debug("No SSL certificate provided, using default SSL verification") 148 149 # Initialize Azure OpenAI client 150 if token_provider: 151 self.client = AzureOpenAI( 152 azure_ad_token_provider=token_provider, 153 api_version=api_version, 154 azure_endpoint=endpoint, 155 http_client=http_client, 156 ) 157 else: 158 self.client = AzureOpenAI( 159 api_key=api_key, 160 api_version=api_version, 161 azure_endpoint=endpoint, 162 http_client=http_client, 163 ) 164 165 # Dimension is detected lazily from the first API response 166 self._dimension: Optional[int] = None 167 168 if self.logger: 169 self.logger.info( 170 f"✓ Initialized AzureOpenAIEmbeddings with model={model}, " 171 f"deployment={deployment_name}" 172 )
Initialize Azure OpenAI embeddings provider.
Exactly one of api_key or token_provider must be supplied.
Args: endpoint: Azure OpenAI endpoint URL deployment_name: Name of the deployment in Azure api_key: Azure OpenAI API key. Use for local development or when managed identity is not available. token_provider: Zero-argument callable that returns a bearer token string. Use for managed identity / workload identity scenarios. The callable must request the Cognitive Services scope::
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
token_provider = get_bearer_token_provider(
DefaultAzureCredential(),
"https://cognitiveservices.azure.com/.default"
)
Note: this scope is different from Azure AI Search
(``https://search.azure.com/.default``) — each service
requires its own token_provider.
model: Model identifier (default: text-embedding-3-large)
api_version: Azure OpenAI API version (default: 2024-02-01)
max_retries: Maximum number of retry attempts (default: 3)
retry_delay: Initial delay between retries in seconds (default: 1.0)
ssl_cert_path: Optional path to SSL certificate bundle for corporate networks.
If None, uses default SSL verification.
logger: Optional BasicLogger instance for observability
220 def embed_text(self, text: str) -> List[float]: 221 """ 222 Generate embeddings for a single text string. 223 224 Args: 225 text: The input text to embed 226 227 Returns: 228 A list of floats representing the embedding vector 229 230 Raises: 231 ValueError: If text is invalid 232 Exception: Azure OpenAI API errors 233 """ 234 self.validate_text(text) 235 236 if self.logger: 237 self.logger.debug(f"Embedding single text (length={len(text)})") 238 239 def _embed(): 240 response: CreateEmbeddingResponse = self.client.embeddings.create( 241 input=text, 242 model=self.deployment_name, 243 ) 244 return response.data[0].embedding 245 246 embedding = self._call_with_retry(_embed) 247 248 if self._dimension is None: 249 self._dimension = len(embedding) 250 251 if self.logger: 252 self.logger.debug(f"Generated embedding with dimension={len(embedding)}") 253 254 return embedding
Generate embeddings for a single text string.
Args: text: The input text to embed
Returns: A list of floats representing the embedding vector
Raises: ValueError: If text is invalid Exception: Azure OpenAI API errors
256 def embed_batch(self, texts: List[str]) -> List[List[float]]: 257 """ 258 Generate embeddings for multiple texts in a batch. 259 260 Azure OpenAI supports batch requests which are more efficient than 261 individual calls. This method automatically handles the batch API. 262 263 Args: 264 texts: List of text strings to embed 265 266 Returns: 267 List of embedding vectors, one for each input text 268 269 Raises: 270 ValueError: If texts is invalid 271 Exception: Azure OpenAI API errors 272 """ 273 self.validate_batch(texts) 274 275 if self.logger: 276 self.logger.info(f"Embedding batch of {len(texts)} texts") 277 278 def _embed(): 279 response: CreateEmbeddingResponse = self.client.embeddings.create( 280 input=texts, 281 model=self.deployment_name, 282 ) 283 # Ensure results are in the correct order 284 sorted_data = sorted(response.data, key=lambda x: x.index) 285 return [item.embedding for item in sorted_data] 286 287 embeddings = self._call_with_retry(_embed) 288 289 if self._dimension is None and embeddings: 290 self._dimension = len(embeddings[0]) 291 292 if self.logger: 293 self.logger.info( 294 f"Generated {len(embeddings)} embeddings successfully" 295 ) 296 297 return embeddings
Generate embeddings for multiple texts in a batch.
Azure OpenAI supports batch requests which are more efficient than individual calls. This method automatically handles the batch API.
Args: texts: List of text strings to embed
Returns: List of embedding vectors, one for each input text
Raises: ValueError: If texts is invalid Exception: Azure OpenAI API errors
299 def get_embedding_dimension(self) -> Optional[int]: 300 """ 301 Get the dimensionality of the embedding vectors. 302 303 Returns the actual dimension observed from the first API response. 304 Returns None if no embedding has been generated yet. 305 """ 306 return self._dimension
Get the dimensionality of the embedding vectors.
Returns the actual dimension observed from the first API response. Returns None if no embedding has been generated yet.
25class BatchEmbeddings(EmbeddingProvider): 26 """ 27 Wrapper for efficient batch processing of embeddings. 28 29 This class wraps any EmbeddingProvider and adds: 30 - Automatic batching with configurable batch size 31 - Progress tracking for large document collections 32 - Retry logic per batch 33 - Memory-efficient processing 34 35 Example: 36 >>> base_embedder = AzureOpenAIEmbeddings(...) 37 >>> batch_embedder = BatchEmbeddings( 38 ... provider=base_embedder, 39 ... batch_size=100, 40 ... show_progress=True 41 ... ) 42 >>> # Process 10,000 documents efficiently 43 >>> embeddings = batch_embedder.embed_batch(large_text_list) 44 """ 45 46 def __init__( 47 self, 48 provider: EmbeddingProvider, 49 batch_size: int = 100, 50 show_progress: bool = True, 51 progress_callback: Optional[Callable[[int, int], None]] = None, 52 logger: Optional['BasicLogger'] = None, 53 ): 54 """ 55 Initialize batch embedding wrapper. 56 57 Args: 58 provider: The underlying EmbeddingProvider to wrap 59 batch_size: Number of texts to process per batch (default: 100) 60 show_progress: Whether to print progress messages (default: True) 61 progress_callback: Optional callback function(current, total) for 62 custom progress tracking 63 logger: Optional BasicLogger instance for observability 64 """ 65 self.provider = provider 66 self.batch_size = batch_size 67 self.show_progress = show_progress 68 self.progress_callback = progress_callback 69 self.logger = logger 70 71 if self.logger: 72 self.logger.info( 73 f"Initialized BatchEmbeddings wrapper with batch_size={batch_size}" 74 ) 75 76 def embed_text(self, text: str) -> List[float]: 77 """ 78 Generate embeddings for a single text (delegates to wrapped provider). 79 80 Args: 81 text: The input text to embed 82 83 Returns: 84 A list of floats representing the embedding vector 85 """ 86 return self.provider.embed_text(text) 87 88 def embed_batch(self, texts: List[str]) -> List[List[float]]: 89 """ 90 Generate embeddings for multiple texts with automatic batching. 91 92 This method splits large text lists into smaller batches and processes 93 them sequentially to avoid hitting provider limits and manage memory. 94 95 Args: 96 texts: List of text strings to embed 97 98 Returns: 99 List of embedding vectors, one for each input text 100 101 Raises: 102 ValueError: If texts is invalid 103 Exception: Provider-specific errors 104 """ 105 self.validate_batch(texts) 106 107 total_texts = len(texts) 108 109 if self.logger: 110 self.logger.info( 111 f"Processing {total_texts} texts in batches of {self.batch_size}" 112 ) 113 114 # If batch size is larger than input, just use the provider directly 115 if total_texts <= self.batch_size: 116 if self.logger: 117 self.logger.debug("Batch size larger than input, processing in one call") 118 return self.provider.embed_batch(texts) 119 120 # Process in batches 121 all_embeddings = [] 122 num_batches = (total_texts + self.batch_size - 1) // self.batch_size 123 124 for batch_idx in range(num_batches): 125 start_idx = batch_idx * self.batch_size 126 end_idx = min(start_idx + self.batch_size, total_texts) 127 batch_texts = texts[start_idx:end_idx] 128 129 if self.show_progress and self.logger: 130 progress = (batch_idx + 1) / num_batches * 100 131 self.logger.info( 132 f"Processing batch {batch_idx + 1}/{num_batches}", 133 progress_pct=round(progress, 1), 134 texts_done=end_idx, 135 texts_total=total_texts, 136 ) 137 138 if self.progress_callback: 139 self.progress_callback(end_idx, total_texts) 140 141 if self.logger: 142 self.logger.debug( 143 f"Processing batch {batch_idx + 1}/{num_batches}: " 144 f"texts[{start_idx}:{end_idx}]" 145 ) 146 147 # Embed the batch 148 batch_embeddings = self.provider.embed_batch(batch_texts) 149 all_embeddings.extend(batch_embeddings) 150 151 if self.show_progress and self.logger: 152 self.logger.info(f"Completed embedding {total_texts} texts") 153 154 if self.logger: 155 self.logger.info( 156 f"Successfully processed all {total_texts} texts in {num_batches} batches" 157 ) 158 159 return all_embeddings 160 161 def get_embedding_dimension(self) -> int: 162 """ 163 Get the dimensionality of the embedding vectors. 164 165 Returns: 166 The dimension from the wrapped provider 167 """ 168 return self.provider.get_embedding_dimension() 169 170 def get_model_name(self) -> str: 171 """ 172 Get the name of the embedding model. 173 174 Returns: 175 The model name from the wrapped provider 176 """ 177 return self.provider.get_model_name() 178 179 def set_batch_size(self, batch_size: int) -> None: 180 """ 181 Update the batch size for processing. 182 183 Args: 184 batch_size: New batch size to use 185 186 Raises: 187 ValueError: If batch_size is not positive 188 """ 189 if batch_size <= 0: 190 raise ValueError(f"Batch size must be positive, got {batch_size}") 191 192 self.batch_size = batch_size 193 194 if self.logger: 195 self.logger.info(f"Updated batch size to {batch_size}") 196 197 def get_batch_size(self) -> int: 198 """ 199 Get the current batch size. 200 201 Returns: 202 Current batch size setting 203 """ 204 return self.batch_size
Wrapper for efficient batch processing of embeddings.
This class wraps any EmbeddingProvider and adds:
- Automatic batching with configurable batch size
- Progress tracking for large document collections
- Retry logic per batch
- Memory-efficient processing
Example:
base_embedder = AzureOpenAIEmbeddings(...) batch_embedder = BatchEmbeddings( ... provider=base_embedder, ... batch_size=100, ... show_progress=True ... )
Process 10,000 documents efficiently
embeddings = batch_embedder.embed_batch(large_text_list)
46 def __init__( 47 self, 48 provider: EmbeddingProvider, 49 batch_size: int = 100, 50 show_progress: bool = True, 51 progress_callback: Optional[Callable[[int, int], None]] = None, 52 logger: Optional['BasicLogger'] = None, 53 ): 54 """ 55 Initialize batch embedding wrapper. 56 57 Args: 58 provider: The underlying EmbeddingProvider to wrap 59 batch_size: Number of texts to process per batch (default: 100) 60 show_progress: Whether to print progress messages (default: True) 61 progress_callback: Optional callback function(current, total) for 62 custom progress tracking 63 logger: Optional BasicLogger instance for observability 64 """ 65 self.provider = provider 66 self.batch_size = batch_size 67 self.show_progress = show_progress 68 self.progress_callback = progress_callback 69 self.logger = logger 70 71 if self.logger: 72 self.logger.info( 73 f"Initialized BatchEmbeddings wrapper with batch_size={batch_size}" 74 )
Initialize batch embedding wrapper.
Args: provider: The underlying EmbeddingProvider to wrap batch_size: Number of texts to process per batch (default: 100) show_progress: Whether to print progress messages (default: True) progress_callback: Optional callback function(current, total) for custom progress tracking logger: Optional BasicLogger instance for observability
76 def embed_text(self, text: str) -> List[float]: 77 """ 78 Generate embeddings for a single text (delegates to wrapped provider). 79 80 Args: 81 text: The input text to embed 82 83 Returns: 84 A list of floats representing the embedding vector 85 """ 86 return self.provider.embed_text(text)
Generate embeddings for a single text (delegates to wrapped provider).
Args: text: The input text to embed
Returns: A list of floats representing the embedding vector
88 def embed_batch(self, texts: List[str]) -> List[List[float]]: 89 """ 90 Generate embeddings for multiple texts with automatic batching. 91 92 This method splits large text lists into smaller batches and processes 93 them sequentially to avoid hitting provider limits and manage memory. 94 95 Args: 96 texts: List of text strings to embed 97 98 Returns: 99 List of embedding vectors, one for each input text 100 101 Raises: 102 ValueError: If texts is invalid 103 Exception: Provider-specific errors 104 """ 105 self.validate_batch(texts) 106 107 total_texts = len(texts) 108 109 if self.logger: 110 self.logger.info( 111 f"Processing {total_texts} texts in batches of {self.batch_size}" 112 ) 113 114 # If batch size is larger than input, just use the provider directly 115 if total_texts <= self.batch_size: 116 if self.logger: 117 self.logger.debug("Batch size larger than input, processing in one call") 118 return self.provider.embed_batch(texts) 119 120 # Process in batches 121 all_embeddings = [] 122 num_batches = (total_texts + self.batch_size - 1) // self.batch_size 123 124 for batch_idx in range(num_batches): 125 start_idx = batch_idx * self.batch_size 126 end_idx = min(start_idx + self.batch_size, total_texts) 127 batch_texts = texts[start_idx:end_idx] 128 129 if self.show_progress and self.logger: 130 progress = (batch_idx + 1) / num_batches * 100 131 self.logger.info( 132 f"Processing batch {batch_idx + 1}/{num_batches}", 133 progress_pct=round(progress, 1), 134 texts_done=end_idx, 135 texts_total=total_texts, 136 ) 137 138 if self.progress_callback: 139 self.progress_callback(end_idx, total_texts) 140 141 if self.logger: 142 self.logger.debug( 143 f"Processing batch {batch_idx + 1}/{num_batches}: " 144 f"texts[{start_idx}:{end_idx}]" 145 ) 146 147 # Embed the batch 148 batch_embeddings = self.provider.embed_batch(batch_texts) 149 all_embeddings.extend(batch_embeddings) 150 151 if self.show_progress and self.logger: 152 self.logger.info(f"Completed embedding {total_texts} texts") 153 154 if self.logger: 155 self.logger.info( 156 f"Successfully processed all {total_texts} texts in {num_batches} batches" 157 ) 158 159 return all_embeddings
Generate embeddings for multiple texts with automatic batching.
This method splits large text lists into smaller batches and processes them sequentially to avoid hitting provider limits and manage memory.
Args: texts: List of text strings to embed
Returns: List of embedding vectors, one for each input text
Raises: ValueError: If texts is invalid Exception: Provider-specific errors
161 def get_embedding_dimension(self) -> int: 162 """ 163 Get the dimensionality of the embedding vectors. 164 165 Returns: 166 The dimension from the wrapped provider 167 """ 168 return self.provider.get_embedding_dimension()
Get the dimensionality of the embedding vectors.
Returns: The dimension from the wrapped provider
170 def get_model_name(self) -> str: 171 """ 172 Get the name of the embedding model. 173 174 Returns: 175 The model name from the wrapped provider 176 """ 177 return self.provider.get_model_name()
Get the name of the embedding model.
Returns: The model name from the wrapped provider
179 def set_batch_size(self, batch_size: int) -> None: 180 """ 181 Update the batch size for processing. 182 183 Args: 184 batch_size: New batch size to use 185 186 Raises: 187 ValueError: If batch_size is not positive 188 """ 189 if batch_size <= 0: 190 raise ValueError(f"Batch size must be positive, got {batch_size}") 191 192 self.batch_size = batch_size 193 194 if self.logger: 195 self.logger.info(f"Updated batch size to {batch_size}")
Update the batch size for processing.
Args: batch_size: New batch size to use
Raises: ValueError: If batch_size is not positive
14@dataclass 15class Chunk: 16 """ 17 Represents a chunk of text with associated metadata. 18 19 Attributes: 20 text: The actual text content of the chunk 21 metadata: Dictionary of metadata (source, page number, etc.) 22 start_pos: Character position where chunk starts in original text 23 end_pos: Character position where chunk ends in original text 24 chunk_id: Unique identifier for the chunk 25 """ 26 text: str 27 metadata: Dict[str, Any] = field(default_factory=dict) 28 start_pos: int = 0 29 end_pos: int = 0 30 chunk_id: str = "" 31 32 def __post_init__(self): 33 """Validate chunk after initialization.""" 34 if not self.text: 35 raise ValueError("Chunk text cannot be empty") 36 if self.end_pos < self.start_pos: 37 raise ValueError(f"end_pos ({self.end_pos}) cannot be less than start_pos ({self.start_pos})") 38 39 def __len__(self) -> int: 40 """Return the length of the chunk text.""" 41 return len(self.text) 42 43 def __str__(self) -> str: 44 """Return a string representation of the chunk.""" 45 preview = self.text[:50] + "..." if len(self.text) > 50 else self.text 46 return f"Chunk(id={self.chunk_id}, len={len(self.text)}, text='{preview}')"
Represents a chunk of text with associated metadata.
Attributes: text: The actual text content of the chunk metadata: Dictionary of metadata (source, page number, etc.) start_pos: Character position where chunk starts in original text end_pos: Character position where chunk ends in original text chunk_id: Unique identifier for the chunk
49class BaseChunker(ABC): 50 """ 51 Abstract base class for all text chunking strategies. 52 53 All chunking implementations must inherit from this class and implement 54 the chunk() method. 55 """ 56 57 def __init__(self, logger=None): 58 """ 59 Initialize the base chunker. 60 61 Args: 62 logger: Optional BasicLogger instance for structured logging 63 """ 64 self.logger = logger 65 66 @abstractmethod 67 def chunk(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Chunk]: 68 """ 69 Split text into chunks according to the chunking strategy. 70 71 Args: 72 text: The input text to chunk 73 metadata: Optional metadata to attach to each chunk 74 75 Returns: 76 List of Chunk objects 77 78 Raises: 79 ValueError: If text is empty or None 80 """ 81 pass 82 83 def validate_text(self, text: str) -> None: 84 """ 85 Validate input text before chunking. 86 87 Args: 88 text: The text to validate 89 90 Raises: 91 ValueError: If text is None, empty, or not a string 92 """ 93 if text is None: 94 raise ValueError("Text cannot be None") 95 if not isinstance(text, str): 96 raise ValueError(f"Text must be a string, got {type(text)}") 97 if not text.strip(): 98 raise ValueError("Text cannot be empty or whitespace only") 99 100 def _generate_chunk_id(self, index: int, metadata: Optional[Dict[str, Any]] = None) -> str: 101 """ 102 Generate a unique chunk ID. 103 104 Args: 105 index: The index of the chunk in the sequence 106 metadata: Optional metadata that may contain source information 107 108 Returns: 109 A unique chunk identifier 110 """ 111 if metadata and "source" in metadata: 112 return f"{metadata['source']}_chunk_{index}" 113 return f"chunk_{index}"
Abstract base class for all text chunking strategies.
All chunking implementations must inherit from this class and implement the chunk() method.
57 def __init__(self, logger=None): 58 """ 59 Initialize the base chunker. 60 61 Args: 62 logger: Optional BasicLogger instance for structured logging 63 """ 64 self.logger = logger
Initialize the base chunker.
Args: logger: Optional BasicLogger instance for structured logging
66 @abstractmethod 67 def chunk(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Chunk]: 68 """ 69 Split text into chunks according to the chunking strategy. 70 71 Args: 72 text: The input text to chunk 73 metadata: Optional metadata to attach to each chunk 74 75 Returns: 76 List of Chunk objects 77 78 Raises: 79 ValueError: If text is empty or None 80 """ 81 pass
Split text into chunks according to the chunking strategy.
Args: text: The input text to chunk metadata: Optional metadata to attach to each chunk
Returns: List of Chunk objects
Raises: ValueError: If text is empty or None
83 def validate_text(self, text: str) -> None: 84 """ 85 Validate input text before chunking. 86 87 Args: 88 text: The text to validate 89 90 Raises: 91 ValueError: If text is None, empty, or not a string 92 """ 93 if text is None: 94 raise ValueError("Text cannot be None") 95 if not isinstance(text, str): 96 raise ValueError(f"Text must be a string, got {type(text)}") 97 if not text.strip(): 98 raise ValueError("Text cannot be empty or whitespace only")
Validate input text before chunking.
Args: text: The text to validate
Raises: ValueError: If text is None, empty, or not a string
15class FixedSizeChunker(BaseChunker): 16 """ 17 Chunks text into fixed-size token-based segments with optional overlap. 18 19 This is one of the most common chunking strategies for LLM applications, 20 ensuring chunks don't exceed model context windows or embedding limits. 21 """ 22 23 def __init__( 24 self, 25 chunk_size: int = 512, 26 chunk_overlap: int = 50, 27 encoding_name: str = "cl100k_base", 28 logger=None 29 ): 30 """ 31 Initialize the fixed-size chunker. 32 33 Args: 34 chunk_size: Maximum number of tokens per chunk (default: 512) 35 chunk_overlap: Number of tokens to overlap between chunks (default: 50) 36 encoding_name: tiktoken encoding name (default: "cl100k_base" for GPT-3.5/4) 37 Options: "cl100k_base" (GPT-3.5, GPT-4) 38 "p50k_base" (CodeX, GPT-3) 39 "r50k_base" (GPT-2) 40 logger: Optional BasicLogger instance 41 42 Raises: 43 ValueError: If chunk_size <= 0 or chunk_overlap >= chunk_size 44 """ 45 super().__init__(logger) 46 47 if chunk_size <= 0: 48 raise ValueError("chunk_size must be greater than 0") 49 if chunk_overlap < 0: 50 raise ValueError("chunk_overlap must be non-negative") 51 if chunk_overlap >= chunk_size: 52 raise ValueError("chunk_overlap must be less than chunk_size") 53 54 self.chunk_size = chunk_size 55 self.chunk_overlap = chunk_overlap 56 self.encoding_name = encoding_name 57 58 try: 59 self.encoding = tiktoken.get_encoding(encoding_name) 60 except Exception as e: 61 raise ValueError(f"Failed to load tiktoken encoding '{encoding_name}': {e}") 62 63 if self.logger: 64 self.logger.info( 65 "Initialized FixedSizeChunker", 66 chunk_size=chunk_size, 67 chunk_overlap=chunk_overlap, 68 encoding=encoding_name 69 ) 70 71 def chunk(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Chunk]: 72 """ 73 Split text into fixed-size token-based chunks with overlap. 74 75 Args: 76 text: The input text to chunk 77 metadata: Optional metadata to attach to each chunk 78 79 Returns: 80 List of Chunk objects 81 82 Raises: 83 ValueError: If text is invalid 84 """ 85 self.validate_text(text) 86 87 if metadata is None: 88 metadata = {} 89 90 if self.logger: 91 self.logger.debug(f"Chunking text of length {len(text)} characters") 92 93 # Encode text to tokens 94 tokens = self.encoding.encode(text) 95 total_tokens = len(tokens) 96 97 if self.logger: 98 self.logger.debug(f"Text tokenized to {total_tokens} tokens") 99 100 chunks = [] 101 start_idx = 0 102 chunk_index = 0 103 104 while start_idx < total_tokens: 105 # Calculate end index for this chunk 106 end_idx = min(start_idx + self.chunk_size, total_tokens) 107 108 # Extract token slice 109 chunk_tokens = tokens[start_idx:end_idx] 110 111 # Decode tokens back to text 112 chunk_text = self.encoding.decode(chunk_tokens) 113 114 # Find character positions in original text 115 # This is approximate since token boundaries don't always align with char boundaries 116 char_start = len(self.encoding.decode(tokens[:start_idx])) 117 char_end = char_start + len(chunk_text) 118 119 # Create chunk 120 chunk = Chunk( 121 text=chunk_text, 122 metadata=metadata.copy(), 123 start_pos=char_start, 124 end_pos=char_end, 125 chunk_id=self._generate_chunk_id(chunk_index, metadata) 126 ) 127 128 # Add token count to metadata 129 chunk.metadata["token_count"] = len(chunk_tokens) 130 chunk.metadata["chunking_strategy"] = "fixed_size" 131 132 chunks.append(chunk) 133 chunk_index += 1 134 135 # Move start index forward, accounting for overlap 136 start_idx += self.chunk_size - self.chunk_overlap 137 138 # Break if we've reached the end 139 if end_idx >= total_tokens: 140 break 141 142 if self.logger: 143 self.logger.info( 144 f"Created {len(chunks)} chunks", 145 total_tokens=total_tokens, 146 avg_tokens_per_chunk=total_tokens / len(chunks) if chunks else 0 147 ) 148 149 return chunks 150 151 def count_tokens(self, text: str) -> int: 152 """ 153 Count the number of tokens in a text string. 154 155 Args: 156 text: The text to count tokens for 157 158 Returns: 159 Number of tokens 160 """ 161 return len(self.encoding.encode(text))
Chunks text into fixed-size token-based segments with optional overlap.
This is one of the most common chunking strategies for LLM applications, ensuring chunks don't exceed model context windows or embedding limits.
23 def __init__( 24 self, 25 chunk_size: int = 512, 26 chunk_overlap: int = 50, 27 encoding_name: str = "cl100k_base", 28 logger=None 29 ): 30 """ 31 Initialize the fixed-size chunker. 32 33 Args: 34 chunk_size: Maximum number of tokens per chunk (default: 512) 35 chunk_overlap: Number of tokens to overlap between chunks (default: 50) 36 encoding_name: tiktoken encoding name (default: "cl100k_base" for GPT-3.5/4) 37 Options: "cl100k_base" (GPT-3.5, GPT-4) 38 "p50k_base" (CodeX, GPT-3) 39 "r50k_base" (GPT-2) 40 logger: Optional BasicLogger instance 41 42 Raises: 43 ValueError: If chunk_size <= 0 or chunk_overlap >= chunk_size 44 """ 45 super().__init__(logger) 46 47 if chunk_size <= 0: 48 raise ValueError("chunk_size must be greater than 0") 49 if chunk_overlap < 0: 50 raise ValueError("chunk_overlap must be non-negative") 51 if chunk_overlap >= chunk_size: 52 raise ValueError("chunk_overlap must be less than chunk_size") 53 54 self.chunk_size = chunk_size 55 self.chunk_overlap = chunk_overlap 56 self.encoding_name = encoding_name 57 58 try: 59 self.encoding = tiktoken.get_encoding(encoding_name) 60 except Exception as e: 61 raise ValueError(f"Failed to load tiktoken encoding '{encoding_name}': {e}") 62 63 if self.logger: 64 self.logger.info( 65 "Initialized FixedSizeChunker", 66 chunk_size=chunk_size, 67 chunk_overlap=chunk_overlap, 68 encoding=encoding_name 69 )
Initialize the fixed-size chunker.
Args: chunk_size: Maximum number of tokens per chunk (default: 512) chunk_overlap: Number of tokens to overlap between chunks (default: 50) encoding_name: tiktoken encoding name (default: "cl100k_base" for GPT-3.5/4) Options: "cl100k_base" (GPT-3.5, GPT-4) "p50k_base" (CodeX, GPT-3) "r50k_base" (GPT-2) logger: Optional BasicLogger instance
Raises: ValueError: If chunk_size <= 0 or chunk_overlap >= chunk_size
71 def chunk(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Chunk]: 72 """ 73 Split text into fixed-size token-based chunks with overlap. 74 75 Args: 76 text: The input text to chunk 77 metadata: Optional metadata to attach to each chunk 78 79 Returns: 80 List of Chunk objects 81 82 Raises: 83 ValueError: If text is invalid 84 """ 85 self.validate_text(text) 86 87 if metadata is None: 88 metadata = {} 89 90 if self.logger: 91 self.logger.debug(f"Chunking text of length {len(text)} characters") 92 93 # Encode text to tokens 94 tokens = self.encoding.encode(text) 95 total_tokens = len(tokens) 96 97 if self.logger: 98 self.logger.debug(f"Text tokenized to {total_tokens} tokens") 99 100 chunks = [] 101 start_idx = 0 102 chunk_index = 0 103 104 while start_idx < total_tokens: 105 # Calculate end index for this chunk 106 end_idx = min(start_idx + self.chunk_size, total_tokens) 107 108 # Extract token slice 109 chunk_tokens = tokens[start_idx:end_idx] 110 111 # Decode tokens back to text 112 chunk_text = self.encoding.decode(chunk_tokens) 113 114 # Find character positions in original text 115 # This is approximate since token boundaries don't always align with char boundaries 116 char_start = len(self.encoding.decode(tokens[:start_idx])) 117 char_end = char_start + len(chunk_text) 118 119 # Create chunk 120 chunk = Chunk( 121 text=chunk_text, 122 metadata=metadata.copy(), 123 start_pos=char_start, 124 end_pos=char_end, 125 chunk_id=self._generate_chunk_id(chunk_index, metadata) 126 ) 127 128 # Add token count to metadata 129 chunk.metadata["token_count"] = len(chunk_tokens) 130 chunk.metadata["chunking_strategy"] = "fixed_size" 131 132 chunks.append(chunk) 133 chunk_index += 1 134 135 # Move start index forward, accounting for overlap 136 start_idx += self.chunk_size - self.chunk_overlap 137 138 # Break if we've reached the end 139 if end_idx >= total_tokens: 140 break 141 142 if self.logger: 143 self.logger.info( 144 f"Created {len(chunks)} chunks", 145 total_tokens=total_tokens, 146 avg_tokens_per_chunk=total_tokens / len(chunks) if chunks else 0 147 ) 148 149 return chunks
Split text into fixed-size token-based chunks with overlap.
Args: text: The input text to chunk metadata: Optional metadata to attach to each chunk
Returns: List of Chunk objects
Raises: ValueError: If text is invalid
151 def count_tokens(self, text: str) -> int: 152 """ 153 Count the number of tokens in a text string. 154 155 Args: 156 text: The text to count tokens for 157 158 Returns: 159 Number of tokens 160 """ 161 return len(self.encoding.encode(text))
Count the number of tokens in a text string.
Args: text: The text to count tokens for
Returns: Number of tokens
15class SemanticChunker(BaseChunker): 16 """ 17 Chunks text based on semantic similarity and sentence boundaries. 18 19 This chunker respects sentence boundaries and can optionally group 20 semantically related sentences together using embeddings. 21 22 By default uses simple regex for sentence detection. For better accuracy, 23 pass a custom sentence_tokenizer (e.g., nltk.sent_tokenize). 24 """ 25 26 def __init__( 27 self, 28 sentence_tokenizer: Callable[[str], List[str]], 29 max_chunk_size: int = 1000, 30 min_chunk_size: int = 100, 31 similarity_threshold: float = 0.5, 32 logger=None 33 ): 34 """ 35 Initialize the semantic chunker. 36 37 Args: 38 sentence_tokenizer: REQUIRED callable that splits text into sentences. 39 Use nltk.sent_tokenize (recommended), spacy, or custom function. 40 Example: nltk.sent_tokenize 41 Must return List[str] of sentences. 42 max_chunk_size: Maximum characters per chunk (default: 1000) 43 min_chunk_size: Minimum characters per chunk (default: 100) 44 similarity_threshold: Threshold for semantic similarity (0.0-1.0, default: 0.5) 45 Not used in basic implementation 46 logger: Optional BasicLogger instance 47 48 Raises: 49 ValueError: If sentence_tokenizer is not provided 50 """ 51 super().__init__(logger) 52 53 if sentence_tokenizer is None: 54 raise ValueError( 55 "sentence_tokenizer is required. Please provide a sentence tokenization function.\n" 56 "Recommended: import nltk; nltk.download('punkt'); use nltk.sent_tokenize\n" 57 "Example: SemanticChunker(sentence_tokenizer=nltk.sent_tokenize)" 58 ) 59 60 if max_chunk_size <= 0: 61 raise ValueError("max_chunk_size must be greater than 0") 62 if min_chunk_size < 0: 63 raise ValueError("min_chunk_size must be non-negative") 64 if min_chunk_size >= max_chunk_size: 65 raise ValueError("min_chunk_size must be less than max_chunk_size") 66 if not (0.0 <= similarity_threshold <= 1.0): 67 raise ValueError("similarity_threshold must be between 0.0 and 1.0") 68 69 self.sentence_tokenizer = sentence_tokenizer 70 self.max_chunk_size = max_chunk_size 71 self.min_chunk_size = min_chunk_size 72 self.similarity_threshold = similarity_threshold 73 74 if self.logger: 75 self.logger.info( 76 "Initialized SemanticChunker", 77 max_chunk_size=max_chunk_size, 78 min_chunk_size=min_chunk_size, 79 tokenizer=sentence_tokenizer.__name__ if hasattr(sentence_tokenizer, '__name__') else 'custom' 80 ) 81 82 def chunk(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Chunk]: 83 """ 84 Split text into chunks based on sentence boundaries and semantic similarity. 85 86 Args: 87 text: The input text to chunk 88 metadata: Optional metadata to attach to each chunk 89 90 Returns: 91 List of Chunk objects 92 93 Raises: 94 ValueError: If text is invalid 95 """ 96 self.validate_text(text) 97 98 if metadata is None: 99 metadata = {} 100 101 if self.logger: 102 self.logger.debug(f"Chunking text of length {len(text)} characters") 103 104 # Split into sentences 105 sentences = self._split_into_sentences(text) 106 107 if self.logger: 108 self.logger.debug(f"Split text into {len(sentences)} sentences") 109 110 # Group sentences into chunks 111 chunks = [] 112 current_chunk_sentences = [] 113 current_chunk_size = 0 114 chunk_index = 0 115 char_position = 0 116 117 for sentence_text, start, end in sentences: 118 sentence_len = len(sentence_text) 119 120 # Check if adding this sentence would exceed max size 121 if current_chunk_size + sentence_len > self.max_chunk_size and current_chunk_sentences: 122 # Create chunk from accumulated sentences 123 chunk_text = " ".join(current_chunk_sentences) 124 chunk_start = char_position - current_chunk_size 125 126 chunk = Chunk( 127 text=chunk_text, 128 metadata=metadata.copy(), 129 start_pos=chunk_start, 130 end_pos=char_position, 131 chunk_id=self._generate_chunk_id(chunk_index, metadata) 132 ) 133 chunk.metadata["sentence_count"] = len(current_chunk_sentences) 134 chunk.metadata["chunking_strategy"] = "semantic" 135 136 chunks.append(chunk) 137 chunk_index += 1 138 139 # Start new chunk 140 current_chunk_sentences = [] 141 current_chunk_size = 0 142 143 # Add sentence to current chunk 144 current_chunk_sentences.append(sentence_text) 145 current_chunk_size += sentence_len + 1 # +1 for space 146 char_position = end 147 148 # Create final chunk if there are remaining sentences 149 if current_chunk_sentences: 150 chunk_text = " ".join(current_chunk_sentences) 151 chunk_start = char_position - current_chunk_size 152 153 chunk = Chunk( 154 text=chunk_text, 155 metadata=metadata.copy(), 156 start_pos=chunk_start, 157 end_pos=char_position, 158 chunk_id=self._generate_chunk_id(chunk_index, metadata) 159 ) 160 chunk.metadata["sentence_count"] = len(current_chunk_sentences) 161 chunk.metadata["chunking_strategy"] = "semantic" 162 163 chunks.append(chunk) 164 165 if self.logger: 166 self.logger.info( 167 f"Created {len(chunks)} semantic chunks", 168 total_sentences=len(sentences), 169 avg_sentences_per_chunk=len(sentences) / len(chunks) if chunks else 0 170 ) 171 172 return chunks 173 174 def _split_into_sentences(self, text: str) -> List[tuple]: 175 """ 176 Split text into sentences using provided tokenizer. 177 178 Args: 179 text: The text to split 180 181 Returns: 182 List of tuples (sentence_text, start_pos, end_pos) 183 184 Raises: 185 RuntimeError: If sentence tokenizer fails 186 """ 187 try: 188 sentence_texts = self.sentence_tokenizer(text) 189 190 # Calculate positions for each sentence 191 sentences = [] 192 pos = 0 193 for sent_text in sentence_texts: 194 # Find the sentence in the original text 195 idx = text.find(sent_text, pos) 196 if idx != -1: 197 start = idx 198 end = idx + len(sent_text) 199 sentences.append((sent_text, start, end)) 200 pos = end 201 else: 202 # Fallback: estimate position 203 sentences.append((sent_text, pos, pos + len(sent_text))) 204 pos += len(sent_text) 205 206 return sentences if sentences else [(text.strip(), 0, len(text))] 207 208 except Exception as e: 209 raise RuntimeError( 210 f"Sentence tokenizer failed: {e}\n" 211 f"Please ensure your tokenizer function is working correctly." 212 ) from e
Chunks text based on semantic similarity and sentence boundaries.
This chunker respects sentence boundaries and can optionally group semantically related sentences together using embeddings.
By default uses simple regex for sentence detection. For better accuracy, pass a custom sentence_tokenizer (e.g., nltk.sent_tokenize).
26 def __init__( 27 self, 28 sentence_tokenizer: Callable[[str], List[str]], 29 max_chunk_size: int = 1000, 30 min_chunk_size: int = 100, 31 similarity_threshold: float = 0.5, 32 logger=None 33 ): 34 """ 35 Initialize the semantic chunker. 36 37 Args: 38 sentence_tokenizer: REQUIRED callable that splits text into sentences. 39 Use nltk.sent_tokenize (recommended), spacy, or custom function. 40 Example: nltk.sent_tokenize 41 Must return List[str] of sentences. 42 max_chunk_size: Maximum characters per chunk (default: 1000) 43 min_chunk_size: Minimum characters per chunk (default: 100) 44 similarity_threshold: Threshold for semantic similarity (0.0-1.0, default: 0.5) 45 Not used in basic implementation 46 logger: Optional BasicLogger instance 47 48 Raises: 49 ValueError: If sentence_tokenizer is not provided 50 """ 51 super().__init__(logger) 52 53 if sentence_tokenizer is None: 54 raise ValueError( 55 "sentence_tokenizer is required. Please provide a sentence tokenization function.\n" 56 "Recommended: import nltk; nltk.download('punkt'); use nltk.sent_tokenize\n" 57 "Example: SemanticChunker(sentence_tokenizer=nltk.sent_tokenize)" 58 ) 59 60 if max_chunk_size <= 0: 61 raise ValueError("max_chunk_size must be greater than 0") 62 if min_chunk_size < 0: 63 raise ValueError("min_chunk_size must be non-negative") 64 if min_chunk_size >= max_chunk_size: 65 raise ValueError("min_chunk_size must be less than max_chunk_size") 66 if not (0.0 <= similarity_threshold <= 1.0): 67 raise ValueError("similarity_threshold must be between 0.0 and 1.0") 68 69 self.sentence_tokenizer = sentence_tokenizer 70 self.max_chunk_size = max_chunk_size 71 self.min_chunk_size = min_chunk_size 72 self.similarity_threshold = similarity_threshold 73 74 if self.logger: 75 self.logger.info( 76 "Initialized SemanticChunker", 77 max_chunk_size=max_chunk_size, 78 min_chunk_size=min_chunk_size, 79 tokenizer=sentence_tokenizer.__name__ if hasattr(sentence_tokenizer, '__name__') else 'custom' 80 )
Initialize the semantic chunker.
Args: sentence_tokenizer: REQUIRED callable that splits text into sentences. Use nltk.sent_tokenize (recommended), spacy, or custom function. Example: nltk.sent_tokenize Must return List[str] of sentences. max_chunk_size: Maximum characters per chunk (default: 1000) min_chunk_size: Minimum characters per chunk (default: 100) similarity_threshold: Threshold for semantic similarity (0.0-1.0, default: 0.5) Not used in basic implementation logger: Optional BasicLogger instance
Raises: ValueError: If sentence_tokenizer is not provided
82 def chunk(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Chunk]: 83 """ 84 Split text into chunks based on sentence boundaries and semantic similarity. 85 86 Args: 87 text: The input text to chunk 88 metadata: Optional metadata to attach to each chunk 89 90 Returns: 91 List of Chunk objects 92 93 Raises: 94 ValueError: If text is invalid 95 """ 96 self.validate_text(text) 97 98 if metadata is None: 99 metadata = {} 100 101 if self.logger: 102 self.logger.debug(f"Chunking text of length {len(text)} characters") 103 104 # Split into sentences 105 sentences = self._split_into_sentences(text) 106 107 if self.logger: 108 self.logger.debug(f"Split text into {len(sentences)} sentences") 109 110 # Group sentences into chunks 111 chunks = [] 112 current_chunk_sentences = [] 113 current_chunk_size = 0 114 chunk_index = 0 115 char_position = 0 116 117 for sentence_text, start, end in sentences: 118 sentence_len = len(sentence_text) 119 120 # Check if adding this sentence would exceed max size 121 if current_chunk_size + sentence_len > self.max_chunk_size and current_chunk_sentences: 122 # Create chunk from accumulated sentences 123 chunk_text = " ".join(current_chunk_sentences) 124 chunk_start = char_position - current_chunk_size 125 126 chunk = Chunk( 127 text=chunk_text, 128 metadata=metadata.copy(), 129 start_pos=chunk_start, 130 end_pos=char_position, 131 chunk_id=self._generate_chunk_id(chunk_index, metadata) 132 ) 133 chunk.metadata["sentence_count"] = len(current_chunk_sentences) 134 chunk.metadata["chunking_strategy"] = "semantic" 135 136 chunks.append(chunk) 137 chunk_index += 1 138 139 # Start new chunk 140 current_chunk_sentences = [] 141 current_chunk_size = 0 142 143 # Add sentence to current chunk 144 current_chunk_sentences.append(sentence_text) 145 current_chunk_size += sentence_len + 1 # +1 for space 146 char_position = end 147 148 # Create final chunk if there are remaining sentences 149 if current_chunk_sentences: 150 chunk_text = " ".join(current_chunk_sentences) 151 chunk_start = char_position - current_chunk_size 152 153 chunk = Chunk( 154 text=chunk_text, 155 metadata=metadata.copy(), 156 start_pos=chunk_start, 157 end_pos=char_position, 158 chunk_id=self._generate_chunk_id(chunk_index, metadata) 159 ) 160 chunk.metadata["sentence_count"] = len(current_chunk_sentences) 161 chunk.metadata["chunking_strategy"] = "semantic" 162 163 chunks.append(chunk) 164 165 if self.logger: 166 self.logger.info( 167 f"Created {len(chunks)} semantic chunks", 168 total_sentences=len(sentences), 169 avg_sentences_per_chunk=len(sentences) / len(chunks) if chunks else 0 170 ) 171 172 return chunks
Split text into chunks based on sentence boundaries and semantic similarity.
Args: text: The input text to chunk metadata: Optional metadata to attach to each chunk
Returns: List of Chunk objects
Raises: ValueError: If text is invalid
16class RecursiveChunker(BaseChunker): 17 """ 18 Recursively chunks text using a hierarchy of separators. 19 20 This chunker attempts to split at natural boundaries in order of preference: 21 1. Double newlines (paragraphs) 22 2. Single newlines (lines) 23 3. Sentence boundaries 24 4. Word boundaries 25 5. Character boundaries (last resort) 26 27 This preserves document structure while ensuring chunks don't exceed the maximum size. 28 """ 29 30 def __init__( 31 self, 32 chunk_size: int = 1000, 33 chunk_overlap: int = 100, 34 separators: Optional[List[str]] = None, 35 logger=None 36 ): 37 """ 38 Initialize the recursive chunker. 39 40 Args: 41 chunk_size: Target maximum characters per chunk (default: 1000) 42 chunk_overlap: Characters to overlap between chunks (default: 100) 43 separators: List of separators in priority order (default: standard hierarchy) 44 logger: Optional BasicLogger instance 45 """ 46 super().__init__(logger) 47 48 if chunk_size <= 0: 49 raise ValueError("chunk_size must be greater than 0") 50 if chunk_overlap < 0: 51 raise ValueError("chunk_overlap must be non-negative") 52 if chunk_overlap >= chunk_size: 53 raise ValueError("chunk_overlap must be less than chunk_size") 54 55 self.chunk_size = chunk_size 56 self.chunk_overlap = chunk_overlap 57 58 # Default separator hierarchy if not provided 59 if separators is None: 60 self.separators = [ 61 "\n\n", # Paragraphs 62 "\n", # Lines 63 ". ", # Sentences 64 " ", # Words 65 "" # Characters (no separator) 66 ] 67 else: 68 self.separators = separators 69 70 if self.logger: 71 self.logger.info( 72 "Initialized RecursiveChunker", 73 chunk_size=chunk_size, 74 chunk_overlap=chunk_overlap, 75 num_separators=len(self.separators) 76 ) 77 78 def chunk(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Chunk]: 79 """ 80 Recursively split text into chunks using hierarchical separators. 81 82 Args: 83 text: The input text to chunk 84 metadata: Optional metadata to attach to each chunk 85 86 Returns: 87 List of Chunk objects 88 89 Raises: 90 ValueError: If text is invalid 91 """ 92 self.validate_text(text) 93 94 if metadata is None: 95 metadata = {} 96 97 if self.logger: 98 self.logger.debug(f"Recursively chunking text of length {len(text)} characters") 99 100 # Perform recursive splitting 101 text_chunks = self._split_text_recursively(text) 102 103 # Convert text chunks to Chunk objects with metadata 104 chunks = [] 105 char_position = 0 106 107 for i, chunk_text in enumerate(text_chunks): 108 chunk = Chunk( 109 text=chunk_text, 110 metadata=metadata.copy(), 111 start_pos=char_position, 112 end_pos=char_position + len(chunk_text), 113 chunk_id=self._generate_chunk_id(i, metadata) 114 ) 115 chunk.metadata["chunking_strategy"] = "recursive" 116 chunks.append(chunk) 117 118 # Update position accounting for overlap 119 char_position += len(chunk_text) - self.chunk_overlap 120 121 if self.logger: 122 self.logger.info( 123 f"Created {len(chunks)} recursive chunks", 124 avg_chunk_size=sum(len(c.text) for c in chunks) / len(chunks) if chunks else 0 125 ) 126 127 return chunks 128 129 def _split_text_recursively(self, text: str) -> List[str]: 130 """ 131 Recursively split text using the separator hierarchy. 132 133 Args: 134 text: Text to split 135 136 Returns: 137 List of text chunks 138 """ 139 return self._split_text(text, self.separators) 140 141 def _split_text(self, text: str, separators: List[str]) -> List[str]: 142 """ 143 Split text using the given separators recursively. 144 145 Args: 146 text: Text to split 147 separators: Remaining separators to try 148 149 Returns: 150 List of text chunks 151 """ 152 final_chunks = [] 153 154 # Choose separator (last one if list is exhausted) 155 separator = separators[-1] if separators else "" 156 157 # Split by current separator 158 if separator: 159 splits = text.split(separator) 160 else: 161 # No separator: split by characters 162 splits = list(text) 163 164 # Process each split segment 165 current_chunk = [] 166 for split in splits: 167 # Add separator back (except for character-level splitting) 168 if separator and current_chunk: 169 split = separator + split 170 171 # If this split is small enough, accumulate it 172 current_size = sum(len(s) for s in current_chunk) 173 174 if current_size + len(split) <= self.chunk_size: 175 current_chunk.append(split) 176 else: 177 # Current chunk is ready 178 if current_chunk: 179 merged = "".join(current_chunk) if not separator else separator.join(current_chunk) 180 if separator == " ": 181 merged = " ".join(c.strip() for c in current_chunk if c.strip()) 182 183 if merged.strip(): 184 final_chunks.append(merged) 185 current_chunk = [] 186 187 # Check if this split itself needs to be broken down 188 if len(split) > self.chunk_size: 189 if len(separators) > 1: 190 # Try next separator in hierarchy 191 sub_chunks = self._split_text(split, separators[1:]) 192 final_chunks.extend(sub_chunks) 193 else: 194 # Force split at chunk_size boundaries 195 for i in range(0, len(split), self.chunk_size): 196 sub_chunk = split[i:i + self.chunk_size] 197 if sub_chunk.strip(): 198 final_chunks.append(sub_chunk) 199 else: 200 current_chunk.append(split) 201 202 # Add remaining chunk 203 if current_chunk: 204 merged = "".join(current_chunk) if not separator else separator.join(current_chunk) 205 if separator == " ": 206 merged = " ".join(c.strip() for c in current_chunk if c.strip()) 207 208 if merged.strip(): 209 final_chunks.append(merged) 210 211 return final_chunks
Recursively chunks text using a hierarchy of separators.
This chunker attempts to split at natural boundaries in order of preference:
- Double newlines (paragraphs)
- Single newlines (lines)
- Sentence boundaries
- Word boundaries
- Character boundaries (last resort)
This preserves document structure while ensuring chunks don't exceed the maximum size.
30 def __init__( 31 self, 32 chunk_size: int = 1000, 33 chunk_overlap: int = 100, 34 separators: Optional[List[str]] = None, 35 logger=None 36 ): 37 """ 38 Initialize the recursive chunker. 39 40 Args: 41 chunk_size: Target maximum characters per chunk (default: 1000) 42 chunk_overlap: Characters to overlap between chunks (default: 100) 43 separators: List of separators in priority order (default: standard hierarchy) 44 logger: Optional BasicLogger instance 45 """ 46 super().__init__(logger) 47 48 if chunk_size <= 0: 49 raise ValueError("chunk_size must be greater than 0") 50 if chunk_overlap < 0: 51 raise ValueError("chunk_overlap must be non-negative") 52 if chunk_overlap >= chunk_size: 53 raise ValueError("chunk_overlap must be less than chunk_size") 54 55 self.chunk_size = chunk_size 56 self.chunk_overlap = chunk_overlap 57 58 # Default separator hierarchy if not provided 59 if separators is None: 60 self.separators = [ 61 "\n\n", # Paragraphs 62 "\n", # Lines 63 ". ", # Sentences 64 " ", # Words 65 "" # Characters (no separator) 66 ] 67 else: 68 self.separators = separators 69 70 if self.logger: 71 self.logger.info( 72 "Initialized RecursiveChunker", 73 chunk_size=chunk_size, 74 chunk_overlap=chunk_overlap, 75 num_separators=len(self.separators) 76 )
Initialize the recursive chunker.
Args: chunk_size: Target maximum characters per chunk (default: 1000) chunk_overlap: Characters to overlap between chunks (default: 100) separators: List of separators in priority order (default: standard hierarchy) logger: Optional BasicLogger instance
78 def chunk(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Chunk]: 79 """ 80 Recursively split text into chunks using hierarchical separators. 81 82 Args: 83 text: The input text to chunk 84 metadata: Optional metadata to attach to each chunk 85 86 Returns: 87 List of Chunk objects 88 89 Raises: 90 ValueError: If text is invalid 91 """ 92 self.validate_text(text) 93 94 if metadata is None: 95 metadata = {} 96 97 if self.logger: 98 self.logger.debug(f"Recursively chunking text of length {len(text)} characters") 99 100 # Perform recursive splitting 101 text_chunks = self._split_text_recursively(text) 102 103 # Convert text chunks to Chunk objects with metadata 104 chunks = [] 105 char_position = 0 106 107 for i, chunk_text in enumerate(text_chunks): 108 chunk = Chunk( 109 text=chunk_text, 110 metadata=metadata.copy(), 111 start_pos=char_position, 112 end_pos=char_position + len(chunk_text), 113 chunk_id=self._generate_chunk_id(i, metadata) 114 ) 115 chunk.metadata["chunking_strategy"] = "recursive" 116 chunks.append(chunk) 117 118 # Update position accounting for overlap 119 char_position += len(chunk_text) - self.chunk_overlap 120 121 if self.logger: 122 self.logger.info( 123 f"Created {len(chunks)} recursive chunks", 124 avg_chunk_size=sum(len(c.text) for c in chunks) / len(chunks) if chunks else 0 125 ) 126 127 return chunks
Recursively split text into chunks using hierarchical separators.
Args: text: The input text to chunk metadata: Optional metadata to attach to each chunk
Returns: List of Chunk objects
Raises: ValueError: If text is invalid
15class SentenceChunker(BaseChunker): 16 """ 17 Chunks text by grouping sentences together. 18 19 This chunker preserves sentence boundaries while grouping multiple 20 sentences into chunks based on character count or sentence count limits. 21 22 By default uses simple regex for sentence detection. For better accuracy, 23 pass a custom sentence_tokenizer (e.g., nltk.sent_tokenize). 24 """ 25 26 def __init__( 27 self, 28 sentence_tokenizer: Callable[[str], List[str]], 29 max_chunk_size: int = 1000, 30 sentences_per_chunk: Optional[int] = None, 31 logger=None 32 ): 33 """ 34 Initialize the sentence chunker. 35 36 Args: 37 sentence_tokenizer: REQUIRED callable that splits text into sentences. 38 Use nltk.sent_tokenize (recommended), spacy, or custom function. 39 Example: nltk.sent_tokenize 40 Must return List[str] of sentences. 41 max_chunk_size: Maximum characters per chunk (default: 1000) 42 sentences_per_chunk: Optional fixed number of sentences per chunk 43 If set, this takes priority over max_chunk_size 44 logger: Optional BasicLogger instance 45 46 Raises: 47 ValueError: If sentence_tokenizer is not provided 48 """ 49 super().__init__(logger) 50 51 if sentence_tokenizer is None: 52 raise ValueError( 53 "sentence_tokenizer is required. Please provide a sentence tokenization function.\n" 54 "Recommended: import nltk; nltk.download('punkt'); use nltk.sent_tokenize\n" 55 "Example: SentenceChunker(sentence_tokenizer=nltk.sent_tokenize)" 56 ) 57 58 if max_chunk_size <= 0: 59 raise ValueError("max_chunk_size must be greater than 0") 60 if sentences_per_chunk is not None and sentences_per_chunk <= 0: 61 raise ValueError("sentences_per_chunk must be greater than 0") 62 63 self.sentence_tokenizer = sentence_tokenizer 64 self.max_chunk_size = max_chunk_size 65 self.sentences_per_chunk = sentences_per_chunk 66 67 if self.logger: 68 self.logger.info( 69 "Initialized SentenceChunker", 70 max_chunk_size=max_chunk_size, 71 sentences_per_chunk=sentences_per_chunk, 72 tokenizer=sentence_tokenizer.__name__ if hasattr(sentence_tokenizer, '__name__') else 'custom' 73 ) 74 75 def chunk(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Chunk]: 76 """ 77 Split text into chunks based on sentence boundaries. 78 79 Args: 80 text: The input text to chunk 81 metadata: Optional metadata to attach to each chunk 82 83 Returns: 84 List of Chunk objects 85 86 Raises: 87 ValueError: If text is invalid 88 """ 89 self.validate_text(text) 90 91 if metadata is None: 92 metadata = {} 93 94 if self.logger: 95 self.logger.debug(f"Chunking text of length {len(text)} characters by sentences") 96 97 # Split text into sentences 98 sentences = self._split_sentences(text) 99 100 if self.logger: 101 self.logger.debug(f"Found {len(sentences)} sentences") 102 103 chunks = [] 104 current_sentences = [] 105 current_size = 0 106 chunk_index = 0 107 char_position = 0 108 109 for sentence, start, end in sentences: 110 sentence_len = len(sentence) 111 112 # Check if we should create a new chunk 113 should_chunk = False 114 115 if self.sentences_per_chunk: 116 # Fixed sentence count mode 117 should_chunk = len(current_sentences) >= self.sentences_per_chunk 118 else: 119 # Size-based mode 120 should_chunk = ( 121 current_size + sentence_len > self.max_chunk_size 122 and current_sentences # Don't create empty chunks 123 ) 124 125 if should_chunk: 126 # Create chunk from accumulated sentences 127 chunk_text = " ".join(current_sentences) 128 chunk_start = char_position - current_size 129 130 chunk = Chunk( 131 text=chunk_text, 132 metadata=metadata.copy(), 133 start_pos=chunk_start, 134 end_pos=char_position, 135 chunk_id=self._generate_chunk_id(chunk_index, metadata) 136 ) 137 chunk.metadata["sentence_count"] = len(current_sentences) 138 chunk.metadata["chunking_strategy"] = "sentence" 139 140 chunks.append(chunk) 141 chunk_index += 1 142 143 # Reset for new chunk 144 current_sentences = [] 145 current_size = 0 146 147 # Add sentence to current chunk 148 current_sentences.append(sentence) 149 current_size += sentence_len + 1 # +1 for space 150 char_position = end 151 152 # Create final chunk if there are remaining sentences 153 if current_sentences: 154 chunk_text = " ".join(current_sentences) 155 chunk_start = char_position - current_size 156 157 chunk = Chunk( 158 text=chunk_text, 159 metadata=metadata.copy(), 160 start_pos=chunk_start, 161 end_pos=char_position, 162 chunk_id=self._generate_chunk_id(chunk_index, metadata) 163 ) 164 chunk.metadata["sentence_count"] = len(current_sentences) 165 chunk.metadata["chunking_strategy"] = "sentence" 166 167 chunks.append(chunk) 168 169 if self.logger: 170 self.logger.info( 171 f"Created {len(chunks)} sentence-based chunks", 172 total_sentences=len(sentences), 173 avg_sentences_per_chunk=len(sentences) / len(chunks) if chunks else 0 174 ) 175 176 return chunks 177 178 def _split_sentences(self, text: str) -> List[tuple]: 179 """ 180 Split text into sentences with position tracking using provided tokenizer. 181 182 Args: 183 text: Text to split into sentences 184 185 Returns: 186 List of tuples (sentence_text, start_pos, end_pos) 187 188 Raises: 189 RuntimeError: If sentence tokenizer fails 190 """ 191 try: 192 sentence_texts = self.sentence_tokenizer(text) 193 194 # Calculate positions for each sentence 195 sentences = [] 196 pos = 0 197 for sent_text in sentence_texts: 198 # Find the sentence in the original text 199 idx = text.find(sent_text, pos) 200 if idx != -1: 201 start = idx 202 end = idx + len(sent_text) 203 sentences.append((sent_text, start, end)) 204 pos = end 205 else: 206 # Fallback: estimate position 207 sentences.append((sent_text, pos, pos + len(sent_text))) 208 pos += len(sent_text) 209 210 return sentences if sentences else [(text.strip(), 0, len(text))] 211 212 except Exception as e: 213 raise RuntimeError( 214 f"Sentence tokenizer failed: {e}\n" 215 f"Please ensure your tokenizer function is working correctly." 216 ) from e
Chunks text by grouping sentences together.
This chunker preserves sentence boundaries while grouping multiple sentences into chunks based on character count or sentence count limits.
By default uses simple regex for sentence detection. For better accuracy, pass a custom sentence_tokenizer (e.g., nltk.sent_tokenize).
26 def __init__( 27 self, 28 sentence_tokenizer: Callable[[str], List[str]], 29 max_chunk_size: int = 1000, 30 sentences_per_chunk: Optional[int] = None, 31 logger=None 32 ): 33 """ 34 Initialize the sentence chunker. 35 36 Args: 37 sentence_tokenizer: REQUIRED callable that splits text into sentences. 38 Use nltk.sent_tokenize (recommended), spacy, or custom function. 39 Example: nltk.sent_tokenize 40 Must return List[str] of sentences. 41 max_chunk_size: Maximum characters per chunk (default: 1000) 42 sentences_per_chunk: Optional fixed number of sentences per chunk 43 If set, this takes priority over max_chunk_size 44 logger: Optional BasicLogger instance 45 46 Raises: 47 ValueError: If sentence_tokenizer is not provided 48 """ 49 super().__init__(logger) 50 51 if sentence_tokenizer is None: 52 raise ValueError( 53 "sentence_tokenizer is required. Please provide a sentence tokenization function.\n" 54 "Recommended: import nltk; nltk.download('punkt'); use nltk.sent_tokenize\n" 55 "Example: SentenceChunker(sentence_tokenizer=nltk.sent_tokenize)" 56 ) 57 58 if max_chunk_size <= 0: 59 raise ValueError("max_chunk_size must be greater than 0") 60 if sentences_per_chunk is not None and sentences_per_chunk <= 0: 61 raise ValueError("sentences_per_chunk must be greater than 0") 62 63 self.sentence_tokenizer = sentence_tokenizer 64 self.max_chunk_size = max_chunk_size 65 self.sentences_per_chunk = sentences_per_chunk 66 67 if self.logger: 68 self.logger.info( 69 "Initialized SentenceChunker", 70 max_chunk_size=max_chunk_size, 71 sentences_per_chunk=sentences_per_chunk, 72 tokenizer=sentence_tokenizer.__name__ if hasattr(sentence_tokenizer, '__name__') else 'custom' 73 )
Initialize the sentence chunker.
Args: sentence_tokenizer: REQUIRED callable that splits text into sentences. Use nltk.sent_tokenize (recommended), spacy, or custom function. Example: nltk.sent_tokenize Must return List[str] of sentences. max_chunk_size: Maximum characters per chunk (default: 1000) sentences_per_chunk: Optional fixed number of sentences per chunk If set, this takes priority over max_chunk_size logger: Optional BasicLogger instance
Raises: ValueError: If sentence_tokenizer is not provided
75 def chunk(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Chunk]: 76 """ 77 Split text into chunks based on sentence boundaries. 78 79 Args: 80 text: The input text to chunk 81 metadata: Optional metadata to attach to each chunk 82 83 Returns: 84 List of Chunk objects 85 86 Raises: 87 ValueError: If text is invalid 88 """ 89 self.validate_text(text) 90 91 if metadata is None: 92 metadata = {} 93 94 if self.logger: 95 self.logger.debug(f"Chunking text of length {len(text)} characters by sentences") 96 97 # Split text into sentences 98 sentences = self._split_sentences(text) 99 100 if self.logger: 101 self.logger.debug(f"Found {len(sentences)} sentences") 102 103 chunks = [] 104 current_sentences = [] 105 current_size = 0 106 chunk_index = 0 107 char_position = 0 108 109 for sentence, start, end in sentences: 110 sentence_len = len(sentence) 111 112 # Check if we should create a new chunk 113 should_chunk = False 114 115 if self.sentences_per_chunk: 116 # Fixed sentence count mode 117 should_chunk = len(current_sentences) >= self.sentences_per_chunk 118 else: 119 # Size-based mode 120 should_chunk = ( 121 current_size + sentence_len > self.max_chunk_size 122 and current_sentences # Don't create empty chunks 123 ) 124 125 if should_chunk: 126 # Create chunk from accumulated sentences 127 chunk_text = " ".join(current_sentences) 128 chunk_start = char_position - current_size 129 130 chunk = Chunk( 131 text=chunk_text, 132 metadata=metadata.copy(), 133 start_pos=chunk_start, 134 end_pos=char_position, 135 chunk_id=self._generate_chunk_id(chunk_index, metadata) 136 ) 137 chunk.metadata["sentence_count"] = len(current_sentences) 138 chunk.metadata["chunking_strategy"] = "sentence" 139 140 chunks.append(chunk) 141 chunk_index += 1 142 143 # Reset for new chunk 144 current_sentences = [] 145 current_size = 0 146 147 # Add sentence to current chunk 148 current_sentences.append(sentence) 149 current_size += sentence_len + 1 # +1 for space 150 char_position = end 151 152 # Create final chunk if there are remaining sentences 153 if current_sentences: 154 chunk_text = " ".join(current_sentences) 155 chunk_start = char_position - current_size 156 157 chunk = Chunk( 158 text=chunk_text, 159 metadata=metadata.copy(), 160 start_pos=chunk_start, 161 end_pos=char_position, 162 chunk_id=self._generate_chunk_id(chunk_index, metadata) 163 ) 164 chunk.metadata["sentence_count"] = len(current_sentences) 165 chunk.metadata["chunking_strategy"] = "sentence" 166 167 chunks.append(chunk) 168 169 if self.logger: 170 self.logger.info( 171 f"Created {len(chunks)} sentence-based chunks", 172 total_sentences=len(sentences), 173 avg_sentences_per_chunk=len(sentences) / len(chunks) if chunks else 0 174 ) 175 176 return chunks
Split text into chunks based on sentence boundaries.
Args: text: The input text to chunk metadata: Optional metadata to attach to each chunk
Returns: List of Chunk objects
Raises: ValueError: If text is invalid
15class MarkdownChunker(BaseChunker): 16 """ 17 Chunks markdown text while respecting document structure. 18 19 This chunker identifies markdown headers (# ## ###, etc.) and uses them 20 as natural boundaries for chunking, preserving the document hierarchy. 21 22 Uses regex-based parsing which works well for standard markdown. For complex 23 markdown with extensions, consider pre-processing with mistune or markdown-it-py. 24 """ 25 26 def __init__( 27 self, 28 max_chunk_size: int = 1500, 29 combine_headers: bool = True, 30 min_header_level: int = 1, 31 logger=None 32 ): 33 """ 34 Initialize the markdown chunker. 35 36 Args: 37 max_chunk_size: Maximum characters per chunk (default: 1500) 38 combine_headers: Whether to combine small sections under headers (default: True) 39 min_header_level: Minimum header level to split at (1-6, default: 1) 40 logger: Optional BasicLogger instance 41 """ 42 super().__init__(logger) 43 44 if max_chunk_size <= 0: 45 raise ValueError("max_chunk_size must be greater than 0") 46 if not (1 <= min_header_level <= 6): 47 raise ValueError("min_header_level must be between 1 and 6") 48 49 self.max_chunk_size = max_chunk_size 50 self.combine_headers = combine_headers 51 self.min_header_level = min_header_level 52 53 # Regex pattern for markdown headers (both ATX and Setext styles) 54 self.header_pattern = re.compile(r'^(#{1,6})\s+(.+)$', re.MULTILINE) 55 56 if self.logger: 57 self.logger.info( 58 "Initialized MarkdownChunker", 59 max_chunk_size=max_chunk_size, 60 combine_headers=combine_headers, 61 min_header_level=min_header_level 62 ) 63 64 def chunk(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Chunk]: 65 """ 66 Split markdown text into chunks respecting header boundaries. 67 68 Args: 69 text: The markdown text to chunk 70 metadata: Optional metadata to attach to each chunk 71 72 Returns: 73 List of Chunk objects 74 75 Raises: 76 ValueError: If text is invalid 77 """ 78 self.validate_text(text) 79 80 if metadata is None: 81 metadata = {} 82 83 if self.logger: 84 self.logger.debug(f"Chunking markdown text of length {len(text)} characters") 85 86 # Find all headers and their positions 87 sections = self._parse_sections(text) 88 89 if self.logger: 90 self.logger.debug(f"Found {len(sections)} markdown sections") 91 92 # Create chunks from sections 93 chunks = [] 94 current_chunk_parts = [] 95 current_size = 0 96 chunk_index = 0 97 current_headers = [] 98 99 for level, header_text, content, start, end in sections: 100 section_text = content 101 section_size = len(section_text) 102 103 # Check if we should start a new chunk 104 if (current_size + section_size > self.max_chunk_size 105 and current_chunk_parts 106 and level <= self.min_header_level): 107 108 # Create chunk from accumulated sections 109 chunk_text = "\n\n".join(current_chunk_parts) 110 111 chunk = Chunk( 112 text=chunk_text, 113 metadata=metadata.copy(), 114 start_pos=start - current_size, 115 end_pos=start, 116 chunk_id=self._generate_chunk_id(chunk_index, metadata) 117 ) 118 chunk.metadata["headers"] = current_headers.copy() 119 chunk.metadata["chunking_strategy"] = "markdown" 120 121 chunks.append(chunk) 122 chunk_index += 1 123 124 # Reset for new chunk 125 current_chunk_parts = [] 126 current_size = 0 127 current_headers = [] 128 129 # Add section to current chunk 130 current_chunk_parts.append(section_text) 131 current_size += section_size + 2 # +2 for \n\n 132 133 # Track header hierarchy 134 if header_text: 135 current_headers.append({ 136 "level": level, 137 "text": header_text 138 }) 139 140 # Create final chunk 141 if current_chunk_parts: 142 chunk_text = "\n\n".join(current_chunk_parts) 143 144 chunk = Chunk( 145 text=chunk_text, 146 metadata=metadata.copy(), 147 start_pos=len(text) - current_size, 148 end_pos=len(text), 149 chunk_id=self._generate_chunk_id(chunk_index, metadata) 150 ) 151 chunk.metadata["headers"] = current_headers 152 chunk.metadata["chunking_strategy"] = "markdown" 153 154 chunks.append(chunk) 155 156 if self.logger: 157 self.logger.info( 158 f"Created {len(chunks)} markdown chunks", 159 total_sections=len(sections), 160 avg_sections_per_chunk=len(sections) / len(chunks) if chunks else 0 161 ) 162 163 return chunks 164 165 def _parse_sections(self, text: str) -> List[Tuple[int, str, str, int, int]]: 166 """ 167 Parse markdown text into sections based on headers. 168 169 Args: 170 text: Markdown text to parse 171 172 Returns: 173 List of tuples (header_level, header_text, content, start_pos, end_pos) 174 """ 175 sections = [] 176 lines = text.split('\n') 177 current_content = [] 178 current_header_level = 0 179 current_header_text = "" 180 section_start = 0 181 char_position = 0 182 183 for i, line in enumerate(lines): 184 # Check if line is a header 185 header_match = self.header_pattern.match(line) 186 187 if header_match: 188 # Save previous section if it exists 189 if current_content or current_header_text: 190 content = '\n'.join(current_content) 191 sections.append(( 192 current_header_level, 193 current_header_text, 194 content, 195 section_start, 196 char_position 197 )) 198 199 # Start new section 200 current_header_level = len(header_match.group(1)) 201 current_header_text = header_match.group(2).strip() 202 current_content = [line] # Include header in content 203 section_start = char_position 204 else: 205 current_content.append(line) 206 207 char_position += len(line) + 1 # +1 for newline 208 209 # Add final section 210 if current_content: 211 content = '\n'.join(current_content) 212 sections.append(( 213 current_header_level, 214 current_header_text, 215 content, 216 section_start, 217 char_position 218 )) 219 220 # If no sections found, treat entire text as one section 221 if not sections: 222 sections = [(0, "", text, 0, len(text))] 223 224 return sections
Chunks markdown text while respecting document structure.
This chunker identifies markdown headers (# ## ###, etc.) and uses them as natural boundaries for chunking, preserving the document hierarchy.
Uses regex-based parsing which works well for standard markdown. For complex markdown with extensions, consider pre-processing with mistune or markdown-it-py.
26 def __init__( 27 self, 28 max_chunk_size: int = 1500, 29 combine_headers: bool = True, 30 min_header_level: int = 1, 31 logger=None 32 ): 33 """ 34 Initialize the markdown chunker. 35 36 Args: 37 max_chunk_size: Maximum characters per chunk (default: 1500) 38 combine_headers: Whether to combine small sections under headers (default: True) 39 min_header_level: Minimum header level to split at (1-6, default: 1) 40 logger: Optional BasicLogger instance 41 """ 42 super().__init__(logger) 43 44 if max_chunk_size <= 0: 45 raise ValueError("max_chunk_size must be greater than 0") 46 if not (1 <= min_header_level <= 6): 47 raise ValueError("min_header_level must be between 1 and 6") 48 49 self.max_chunk_size = max_chunk_size 50 self.combine_headers = combine_headers 51 self.min_header_level = min_header_level 52 53 # Regex pattern for markdown headers (both ATX and Setext styles) 54 self.header_pattern = re.compile(r'^(#{1,6})\s+(.+)$', re.MULTILINE) 55 56 if self.logger: 57 self.logger.info( 58 "Initialized MarkdownChunker", 59 max_chunk_size=max_chunk_size, 60 combine_headers=combine_headers, 61 min_header_level=min_header_level 62 )
Initialize the markdown chunker.
Args: max_chunk_size: Maximum characters per chunk (default: 1500) combine_headers: Whether to combine small sections under headers (default: True) min_header_level: Minimum header level to split at (1-6, default: 1) logger: Optional BasicLogger instance
64 def chunk(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Chunk]: 65 """ 66 Split markdown text into chunks respecting header boundaries. 67 68 Args: 69 text: The markdown text to chunk 70 metadata: Optional metadata to attach to each chunk 71 72 Returns: 73 List of Chunk objects 74 75 Raises: 76 ValueError: If text is invalid 77 """ 78 self.validate_text(text) 79 80 if metadata is None: 81 metadata = {} 82 83 if self.logger: 84 self.logger.debug(f"Chunking markdown text of length {len(text)} characters") 85 86 # Find all headers and their positions 87 sections = self._parse_sections(text) 88 89 if self.logger: 90 self.logger.debug(f"Found {len(sections)} markdown sections") 91 92 # Create chunks from sections 93 chunks = [] 94 current_chunk_parts = [] 95 current_size = 0 96 chunk_index = 0 97 current_headers = [] 98 99 for level, header_text, content, start, end in sections: 100 section_text = content 101 section_size = len(section_text) 102 103 # Check if we should start a new chunk 104 if (current_size + section_size > self.max_chunk_size 105 and current_chunk_parts 106 and level <= self.min_header_level): 107 108 # Create chunk from accumulated sections 109 chunk_text = "\n\n".join(current_chunk_parts) 110 111 chunk = Chunk( 112 text=chunk_text, 113 metadata=metadata.copy(), 114 start_pos=start - current_size, 115 end_pos=start, 116 chunk_id=self._generate_chunk_id(chunk_index, metadata) 117 ) 118 chunk.metadata["headers"] = current_headers.copy() 119 chunk.metadata["chunking_strategy"] = "markdown" 120 121 chunks.append(chunk) 122 chunk_index += 1 123 124 # Reset for new chunk 125 current_chunk_parts = [] 126 current_size = 0 127 current_headers = [] 128 129 # Add section to current chunk 130 current_chunk_parts.append(section_text) 131 current_size += section_size + 2 # +2 for \n\n 132 133 # Track header hierarchy 134 if header_text: 135 current_headers.append({ 136 "level": level, 137 "text": header_text 138 }) 139 140 # Create final chunk 141 if current_chunk_parts: 142 chunk_text = "\n\n".join(current_chunk_parts) 143 144 chunk = Chunk( 145 text=chunk_text, 146 metadata=metadata.copy(), 147 start_pos=len(text) - current_size, 148 end_pos=len(text), 149 chunk_id=self._generate_chunk_id(chunk_index, metadata) 150 ) 151 chunk.metadata["headers"] = current_headers 152 chunk.metadata["chunking_strategy"] = "markdown" 153 154 chunks.append(chunk) 155 156 if self.logger: 157 self.logger.info( 158 f"Created {len(chunks)} markdown chunks", 159 total_sections=len(sections), 160 avg_sections_per_chunk=len(sections) / len(chunks) if chunks else 0 161 ) 162 163 return chunks
Split markdown text into chunks respecting header boundaries.
Args: text: The markdown text to chunk metadata: Optional metadata to attach to each chunk
Returns: List of Chunk objects
Raises: ValueError: If text is invalid
15class CodeChunker(BaseChunker): 16 """ 17 Chunks code while respecting function and class boundaries. 18 19 This chunker identifies code structures (functions, classes, methods) 20 and uses them as natural boundaries for chunking, preserving code context. 21 Supports Python, JavaScript, TypeScript, Java, C#, and similar languages. 22 """ 23 24 def __init__( 25 self, 26 max_chunk_size: int = 2000, 27 language: str = "python", 28 include_imports: bool = True, 29 logger=None 30 ): 31 """ 32 Initialize the code chunker. 33 34 Args: 35 max_chunk_size: Maximum characters per chunk (default: 2000) 36 language: Programming language ("python", "javascript", "java", etc.) 37 include_imports: Whether to include imports/using statements in chunks 38 logger: Optional BasicLogger instance 39 """ 40 super().__init__(logger) 41 42 if max_chunk_size <= 0: 43 raise ValueError("max_chunk_size must be greater than 0") 44 45 self.max_chunk_size = max_chunk_size 46 self.language = language.lower() 47 self.include_imports = include_imports 48 49 # Define patterns for different languages 50 self._setup_patterns() 51 52 if self.logger: 53 self.logger.info( 54 "Initialized CodeChunker", 55 max_chunk_size=max_chunk_size, 56 language=language 57 ) 58 59 def _setup_patterns(self): 60 """Setup regex patterns based on language.""" 61 if self.language == "python": 62 self.function_pattern = re.compile( 63 r'^(async\s+)?def\s+\w+\s*\([^)]*\)\s*(->\s*[^:]+)?:', 64 re.MULTILINE 65 ) 66 self.class_pattern = re.compile( 67 r'^class\s+\w+(\([^)]*\))?:\s*$', 68 re.MULTILINE 69 ) 70 self.import_pattern = re.compile( 71 r'^(import\s+\S+|from\s+\S+\s+import\s+.+)$', 72 re.MULTILINE 73 ) 74 elif self.language in ["javascript", "typescript", "java", "csharp", "c++"]: 75 self.function_pattern = re.compile( 76 r'(public|private|protected|static|async)?\s*(function|void|int|string|bool|var|let|const)?\s+\w+\s*\([^)]*\)\s*\{', 77 re.MULTILINE 78 ) 79 self.class_pattern = re.compile( 80 r'(export\s+)?(public\s+)?class\s+\w+(\s+extends\s+\w+)?(\s+implements\s+[\w,\s]+)?\s*\{', 81 re.MULTILINE 82 ) 83 self.import_pattern = re.compile( 84 r'^(import\s+.*from\s+["\'].*["\'];?|using\s+\S+;|#include\s+<.*>)$', 85 re.MULTILINE 86 ) 87 else: 88 # Generic patterns for unknown languages 89 self.function_pattern = re.compile(r'^\s*(def|function|func|fun)\s+\w+', re.MULTILINE) 90 self.class_pattern = re.compile(r'^\s*class\s+\w+', re.MULTILINE) 91 self.import_pattern = re.compile(r'^(import|using|include)\s+', re.MULTILINE) 92 93 def chunk(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Chunk]: 94 """ 95 Split code into chunks respecting function and class boundaries. 96 97 Args: 98 text: The code text to chunk 99 metadata: Optional metadata to attach to each chunk 100 101 Returns: 102 List of Chunk objects 103 104 Raises: 105 ValueError: If text is invalid 106 """ 107 self.validate_text(text) 108 109 if metadata is None: 110 metadata = {} 111 112 if self.logger: 113 self.logger.debug(f"Chunking code of length {len(text)} characters") 114 115 # Extract imports if needed 116 imports_text = "" 117 if self.include_imports: 118 imports_text = self._extract_imports(text) 119 120 # Find all code blocks (functions and classes) 121 code_blocks = self._parse_code_blocks(text) 122 123 if self.logger: 124 self.logger.debug(f"Found {len(code_blocks)} code blocks") 125 126 # Create chunks from code blocks 127 chunks = [] 128 current_chunk_parts = [] 129 current_size = len(imports_text) 130 chunk_index = 0 131 132 for block_type, block_name, block_content, start, end in code_blocks: 133 block_size = len(block_content) 134 135 # Check if we should start a new chunk 136 if current_size + block_size > self.max_chunk_size and current_chunk_parts: 137 # Create chunk from accumulated blocks 138 chunk_text = imports_text + "\n\n" + "\n\n".join(current_chunk_parts) 139 140 chunk = Chunk( 141 text=chunk_text.strip(), 142 metadata=metadata.copy(), 143 start_pos=start - current_size, 144 end_pos=start, 145 chunk_id=self._generate_chunk_id(chunk_index, metadata) 146 ) 147 chunk.metadata["code_blocks"] = len(current_chunk_parts) 148 chunk.metadata["chunking_strategy"] = "code" 149 chunk.metadata["language"] = self.language 150 151 chunks.append(chunk) 152 chunk_index += 1 153 154 # Reset for new chunk 155 current_chunk_parts = [] 156 current_size = len(imports_text) 157 158 # Add block to current chunk 159 current_chunk_parts.append(block_content) 160 current_size += block_size + 2 # +2 for \n\n 161 162 # Create final chunk 163 if current_chunk_parts: 164 chunk_text = imports_text + "\n\n" + "\n\n".join(current_chunk_parts) 165 166 chunk = Chunk( 167 text=chunk_text.strip(), 168 metadata=metadata.copy(), 169 start_pos=len(text) - current_size, 170 end_pos=len(text), 171 chunk_id=self._generate_chunk_id(chunk_index, metadata) 172 ) 173 chunk.metadata["code_blocks"] = len(current_chunk_parts) 174 chunk.metadata["chunking_strategy"] = "code" 175 chunk.metadata["language"] = self.language 176 177 chunks.append(chunk) 178 179 if self.logger: 180 self.logger.info( 181 f"Created {len(chunks)} code chunks", 182 total_blocks=len(code_blocks), 183 avg_blocks_per_chunk=len(code_blocks) / len(chunks) if chunks else 0 184 ) 185 186 return chunks 187 188 def _extract_imports(self, text: str) -> str: 189 """Extract import/using statements from the code.""" 190 imports = [] 191 for match in self.import_pattern.finditer(text): 192 imports.append(match.group().strip()) 193 194 return "\n".join(imports) if imports else "" 195 196 def _parse_code_blocks(self, text: str) -> List[Tuple[str, str, str, int, int]]: 197 """ 198 Parse code into blocks (functions, classes, etc.). 199 200 Args: 201 text: Code text to parse 202 203 Returns: 204 List of tuples (block_type, block_name, content, start_pos, end_pos) 205 """ 206 blocks = [] 207 lines = text.split('\n') 208 209 # Find all function and class definitions 210 all_matches = [] 211 212 for match in self.function_pattern.finditer(text): 213 all_matches.append(('function', match.start(), match.group())) 214 215 for match in self.class_pattern.finditer(text): 216 all_matches.append(('class', match.start(), match.group())) 217 218 # Sort by position 219 all_matches.sort(key=lambda x: x[1]) 220 221 # Extract code blocks with their content 222 for i, (block_type, start, match_text) in enumerate(all_matches): 223 # Extract block name 224 block_name = self._extract_name(match_text) 225 226 # Find end of block (next block start or end of file) 227 if i + 1 < len(all_matches): 228 end = all_matches[i + 1][1] 229 else: 230 end = len(text) 231 232 # Extract block content 233 block_content = text[start:end].strip() 234 235 blocks.append((block_type, block_name, block_content, start, end)) 236 237 # If no blocks found, treat entire text as one block 238 if not blocks: 239 blocks = [('code', 'main', text, 0, len(text))] 240 241 return blocks 242 243 def _extract_name(self, definition: str) -> str: 244 """Extract function or class name from definition.""" 245 # Try to find the name after 'def', 'function', 'class', etc. 246 name_match = re.search(r'\b(def|function|class|func|fun)\s+(\w+)', definition) 247 if name_match: 248 return name_match.group(2) 249 250 # Fallback: try to find any word after space 251 words = definition.split() 252 for word in words: 253 if re.match(r'^\w+$', word) and word not in ['def', 'function', 'class', 'public', 'private', 'static', 'async']: 254 return word 255 256 return 'unknown'
Chunks code while respecting function and class boundaries.
This chunker identifies code structures (functions, classes, methods) and uses them as natural boundaries for chunking, preserving code context. Supports Python, JavaScript, TypeScript, Java, C#, and similar languages.
24 def __init__( 25 self, 26 max_chunk_size: int = 2000, 27 language: str = "python", 28 include_imports: bool = True, 29 logger=None 30 ): 31 """ 32 Initialize the code chunker. 33 34 Args: 35 max_chunk_size: Maximum characters per chunk (default: 2000) 36 language: Programming language ("python", "javascript", "java", etc.) 37 include_imports: Whether to include imports/using statements in chunks 38 logger: Optional BasicLogger instance 39 """ 40 super().__init__(logger) 41 42 if max_chunk_size <= 0: 43 raise ValueError("max_chunk_size must be greater than 0") 44 45 self.max_chunk_size = max_chunk_size 46 self.language = language.lower() 47 self.include_imports = include_imports 48 49 # Define patterns for different languages 50 self._setup_patterns() 51 52 if self.logger: 53 self.logger.info( 54 "Initialized CodeChunker", 55 max_chunk_size=max_chunk_size, 56 language=language 57 )
Initialize the code chunker.
Args: max_chunk_size: Maximum characters per chunk (default: 2000) language: Programming language ("python", "javascript", "java", etc.) include_imports: Whether to include imports/using statements in chunks logger: Optional BasicLogger instance
93 def chunk(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Chunk]: 94 """ 95 Split code into chunks respecting function and class boundaries. 96 97 Args: 98 text: The code text to chunk 99 metadata: Optional metadata to attach to each chunk 100 101 Returns: 102 List of Chunk objects 103 104 Raises: 105 ValueError: If text is invalid 106 """ 107 self.validate_text(text) 108 109 if metadata is None: 110 metadata = {} 111 112 if self.logger: 113 self.logger.debug(f"Chunking code of length {len(text)} characters") 114 115 # Extract imports if needed 116 imports_text = "" 117 if self.include_imports: 118 imports_text = self._extract_imports(text) 119 120 # Find all code blocks (functions and classes) 121 code_blocks = self._parse_code_blocks(text) 122 123 if self.logger: 124 self.logger.debug(f"Found {len(code_blocks)} code blocks") 125 126 # Create chunks from code blocks 127 chunks = [] 128 current_chunk_parts = [] 129 current_size = len(imports_text) 130 chunk_index = 0 131 132 for block_type, block_name, block_content, start, end in code_blocks: 133 block_size = len(block_content) 134 135 # Check if we should start a new chunk 136 if current_size + block_size > self.max_chunk_size and current_chunk_parts: 137 # Create chunk from accumulated blocks 138 chunk_text = imports_text + "\n\n" + "\n\n".join(current_chunk_parts) 139 140 chunk = Chunk( 141 text=chunk_text.strip(), 142 metadata=metadata.copy(), 143 start_pos=start - current_size, 144 end_pos=start, 145 chunk_id=self._generate_chunk_id(chunk_index, metadata) 146 ) 147 chunk.metadata["code_blocks"] = len(current_chunk_parts) 148 chunk.metadata["chunking_strategy"] = "code" 149 chunk.metadata["language"] = self.language 150 151 chunks.append(chunk) 152 chunk_index += 1 153 154 # Reset for new chunk 155 current_chunk_parts = [] 156 current_size = len(imports_text) 157 158 # Add block to current chunk 159 current_chunk_parts.append(block_content) 160 current_size += block_size + 2 # +2 for \n\n 161 162 # Create final chunk 163 if current_chunk_parts: 164 chunk_text = imports_text + "\n\n" + "\n\n".join(current_chunk_parts) 165 166 chunk = Chunk( 167 text=chunk_text.strip(), 168 metadata=metadata.copy(), 169 start_pos=len(text) - current_size, 170 end_pos=len(text), 171 chunk_id=self._generate_chunk_id(chunk_index, metadata) 172 ) 173 chunk.metadata["code_blocks"] = len(current_chunk_parts) 174 chunk.metadata["chunking_strategy"] = "code" 175 chunk.metadata["language"] = self.language 176 177 chunks.append(chunk) 178 179 if self.logger: 180 self.logger.info( 181 f"Created {len(chunks)} code chunks", 182 total_blocks=len(code_blocks), 183 avg_blocks_per_chunk=len(code_blocks) / len(chunks) if chunks else 0 184 ) 185 186 return chunks
Split code into chunks respecting function and class boundaries.
Args: text: The code text to chunk metadata: Optional metadata to attach to each chunk
Returns: List of Chunk objects
Raises: ValueError: If text is invalid
181class BaseVectorStore(ABC): 182 """ 183 Abstract base class for vector store implementations. 184 185 Vector stores provide storage and retrieval of document embeddings, 186 supporting various search strategies (vector similarity, keyword, hybrid). 187 188 Implementations must provide: 189 - Document storage with embeddings 190 - Vector similarity search 191 - Optional keyword/hybrid search 192 - Document lifecycle management (add, update, delete) 193 """ 194 195 @abstractmethod 196 def add_documents( 197 self, 198 documents: List[Document], 199 generate_embeddings: bool = True 200 ) -> List[str]: 201 """ 202 Add documents to the vector store. 203 204 Args: 205 documents: List of Document objects to add 206 generate_embeddings: If True and documents lack embeddings, 207 generate them automatically 208 209 Returns: 210 List of document IDs that were added 211 212 Raises: 213 ValueError: If documents are invalid or embeddings cannot be generated 214 """ 215 pass 216 217 @abstractmethod 218 def search( 219 self, 220 query: Optional[str] = None, 221 query_embedding: Optional[List[float]] = None, 222 top_k: int = 5, 223 filters: Optional[Dict[str, Any]] = None, 224 search_type: str = "vector" 225 ) -> List[SearchResult]: 226 """ 227 Search the vector store. 228 229 Args: 230 query: Text query (required for keyword/hybrid search) 231 query_embedding: Vector embedding of the query (required for vector search) 232 top_k: Number of results to return 233 filters: Metadata filters to apply (e.g., {"source": "docs.pdf"}) 234 search_type: Type of search - "vector", "keyword", or "hybrid" 235 236 Returns: 237 List of SearchResult objects, ordered by relevance 238 239 Raises: 240 ValueError: If required parameters are missing for the search type 241 """ 242 pass 243 244 @abstractmethod 245 def delete_documents(self, document_ids: List[str]) -> int: 246 """ 247 Delete documents from the vector store. 248 249 Args: 250 document_ids: List of document IDs to delete 251 252 Returns: 253 Number of documents successfully deleted 254 """ 255 pass 256 257 @abstractmethod 258 def update_document( 259 self, 260 document_id: str, 261 content: Optional[str] = None, 262 embedding: Optional[List[float]] = None, 263 metadata: Optional[Dict[str, Any]] = None 264 ) -> bool: 265 """ 266 Update an existing document. 267 268 Args: 269 document_id: ID of the document to update 270 content: New content (if provided) 271 embedding: New embedding (if provided) 272 metadata: New metadata (merged with existing) 273 274 Returns: 275 True if update was successful, False if document not found 276 """ 277 pass 278 279 @abstractmethod 280 def get_document(self, document_id: str) -> Optional[Document]: 281 """ 282 Retrieve a document by ID. 283 284 Args: 285 document_id: ID of the document to retrieve 286 287 Returns: 288 Document if found, None otherwise 289 """ 290 pass 291 292 @abstractmethod 293 def count(self) -> int: 294 """ 295 Get the total number of documents in the vector store. 296 297 Returns: 298 Total document count 299 """ 300 pass 301 302 @abstractmethod 303 def clear(self) -> None: 304 """ 305 Remove all documents from the vector store. 306 307 Warning: This operation is irreversible. 308 """ 309 pass
Abstract base class for vector store implementations.
Vector stores provide storage and retrieval of document embeddings, supporting various search strategies (vector similarity, keyword, hybrid).
Implementations must provide:
- Document storage with embeddings
- Vector similarity search
- Optional keyword/hybrid search
- Document lifecycle management (add, update, delete)
195 @abstractmethod 196 def add_documents( 197 self, 198 documents: List[Document], 199 generate_embeddings: bool = True 200 ) -> List[str]: 201 """ 202 Add documents to the vector store. 203 204 Args: 205 documents: List of Document objects to add 206 generate_embeddings: If True and documents lack embeddings, 207 generate them automatically 208 209 Returns: 210 List of document IDs that were added 211 212 Raises: 213 ValueError: If documents are invalid or embeddings cannot be generated 214 """ 215 pass
Add documents to the vector store.
Args: documents: List of Document objects to add generate_embeddings: If True and documents lack embeddings, generate them automatically
Returns: List of document IDs that were added
Raises: ValueError: If documents are invalid or embeddings cannot be generated
217 @abstractmethod 218 def search( 219 self, 220 query: Optional[str] = None, 221 query_embedding: Optional[List[float]] = None, 222 top_k: int = 5, 223 filters: Optional[Dict[str, Any]] = None, 224 search_type: str = "vector" 225 ) -> List[SearchResult]: 226 """ 227 Search the vector store. 228 229 Args: 230 query: Text query (required for keyword/hybrid search) 231 query_embedding: Vector embedding of the query (required for vector search) 232 top_k: Number of results to return 233 filters: Metadata filters to apply (e.g., {"source": "docs.pdf"}) 234 search_type: Type of search - "vector", "keyword", or "hybrid" 235 236 Returns: 237 List of SearchResult objects, ordered by relevance 238 239 Raises: 240 ValueError: If required parameters are missing for the search type 241 """ 242 pass
Search the vector store.
Args: query: Text query (required for keyword/hybrid search) query_embedding: Vector embedding of the query (required for vector search) top_k: Number of results to return filters: Metadata filters to apply (e.g., {"source": "docs.pdf"}) search_type: Type of search - "vector", "keyword", or "hybrid"
Returns: List of SearchResult objects, ordered by relevance
Raises: ValueError: If required parameters are missing for the search type
244 @abstractmethod 245 def delete_documents(self, document_ids: List[str]) -> int: 246 """ 247 Delete documents from the vector store. 248 249 Args: 250 document_ids: List of document IDs to delete 251 252 Returns: 253 Number of documents successfully deleted 254 """ 255 pass
Delete documents from the vector store.
Args: document_ids: List of document IDs to delete
Returns: Number of documents successfully deleted
257 @abstractmethod 258 def update_document( 259 self, 260 document_id: str, 261 content: Optional[str] = None, 262 embedding: Optional[List[float]] = None, 263 metadata: Optional[Dict[str, Any]] = None 264 ) -> bool: 265 """ 266 Update an existing document. 267 268 Args: 269 document_id: ID of the document to update 270 content: New content (if provided) 271 embedding: New embedding (if provided) 272 metadata: New metadata (merged with existing) 273 274 Returns: 275 True if update was successful, False if document not found 276 """ 277 pass
Update an existing document.
Args: document_id: ID of the document to update content: New content (if provided) embedding: New embedding (if provided) metadata: New metadata (merged with existing)
Returns: True if update was successful, False if document not found
279 @abstractmethod 280 def get_document(self, document_id: str) -> Optional[Document]: 281 """ 282 Retrieve a document by ID. 283 284 Args: 285 document_id: ID of the document to retrieve 286 287 Returns: 288 Document if found, None otherwise 289 """ 290 pass
Retrieve a document by ID.
Args: document_id: ID of the document to retrieve
Returns: Document if found, None otherwise
19@dataclass 20class Document: 21 """ 22 Represents a document or chunk with its content and metadata. 23 24 This class is designed to be extensible - you can subclass it to add 25 custom fields specific to your application domain. 26 27 Base Attributes: 28 id: Unique identifier for the document 29 content: The text content of the document 30 embedding: Optional vector embedding (list of floats) 31 timestamp: Optional timestamp for time-weighted retrieval 32 metadata: Flexible dictionary for custom properties 33 34 Design Philosophy: 35 The base Document keeps only truly universal fields. Domain-specific 36 fields (source, page_number, category, etc.) should be added via 37 inheritance to match your application's data model. 38 39 Example - Using base Document with metadata: 40 ```python 41 doc = Document( 42 id="doc_1", 43 content="AI is transforming software development...", 44 timestamp=datetime.now(), 45 metadata={ 46 "source": "blog.pdf", 47 "page_number": 5, 48 "author": "John Doe", 49 "category": "AI" 50 } 51 ) 52 ``` 53 54 Example - Custom Document schema via inheritance: 55 ```python 56 from dataclasses import dataclass 57 from gmf_forge_ai_data.vector_stores import Document 58 59 @dataclass 60 class LegalDocument(Document): 61 case_number: str 62 court: str 63 decision_date: datetime 64 jurisdiction: str 65 source: str = "" # Add source as typed field 66 page_number: Optional[int] = None 67 68 legal_doc = LegalDocument( 69 id="case_123", 70 content="In the matter of...", 71 case_number="2024-CV-12345", 72 court="Supreme Court", 73 decision_date=datetime(2024, 1, 15), 74 jurisdiction="Federal", 75 source="court_records.pdf", 76 page_number=42 77 ) 78 ``` 79 80 Example - E-commerce Product Document: 81 ```python 82 @dataclass 83 class ProductDocument(Document): 84 sku: str 85 category: str 86 price: float 87 in_stock: bool 88 brand: str 89 # No source/page_number - not applicable to products 90 91 product = ProductDocument( 92 id="prod_456", 93 content="High-performance laptop with...", 94 sku="LAP-001", 95 category="Electronics", 96 price=1299.99, 97 in_stock=True, 98 brand="TechCorp" 99 ) 100 ``` 101 """ 102 id: str 103 content: str 104 embedding: Optional[List[float]] = None 105 timestamp: Optional[datetime] = None 106 metadata: Dict[str, Any] = field(default_factory=dict) 107 108 def to_dict(self) -> Dict[str, Any]: 109 """ 110 Convert Document to dictionary for serialization. 111 112 Returns: 113 Dictionary representation of the document 114 """ 115 doc_dict = asdict(self) 116 # Convert datetime to ISO string for serialization 117 if self.timestamp: 118 doc_dict['timestamp'] = self.timestamp.isoformat() 119 return doc_dict 120 121 @classmethod 122 def from_dict(cls: Type[T], data: Dict[str, Any]) -> T: 123 """ 124 Create Document from dictionary. 125 126 Args: 127 data: Dictionary containing document fields 128 129 Returns: 130 Document instance (or subclass instance) 131 """ 132 # Handle datetime deserialization 133 if 'timestamp' in data and isinstance(data['timestamp'], str): 134 data['timestamp'] = datetime.fromisoformat(data['timestamp']) 135 136 # Filter to only fields that exist in the dataclass 137 import inspect 138 valid_fields = set(inspect.signature(cls).parameters.keys()) 139 filtered_data = {k: v for k, v in data.items() if k in valid_fields} 140 141 return cls(**filtered_data) 142 143 def update_metadata(self, **kwargs) -> None: 144 """ 145 Update metadata fields. 146 147 Args: 148 **kwargs: Key-value pairs to add/update in metadata 149 """ 150 self.metadata.update(kwargs) 151 152 def get_metadata(self, key: str, default: Any = None) -> Any: 153 """ 154 Get a metadata value. 155 156 Args: 157 key: Metadata key to retrieve 158 default: Default value if key not found 159 160 Returns: 161 Metadata value or default 162 """ 163 return self.metadata.get(key, default)
Represents a document or chunk with its content and metadata.
This class is designed to be extensible - you can subclass it to add custom fields specific to your application domain.
Base Attributes: id: Unique identifier for the document content: The text content of the document embedding: Optional vector embedding (list of floats) timestamp: Optional timestamp for time-weighted retrieval metadata: Flexible dictionary for custom properties
Design Philosophy: The base Document keeps only truly universal fields. Domain-specific fields (source, page_number, category, etc.) should be added via inheritance to match your application's data model.
Example - Using base Document with metadata:
doc = Document(
id="doc_1",
content="AI is transforming software development...",
timestamp=datetime.now(),
metadata={
"source": "blog.pdf",
"page_number": 5,
"author": "John Doe",
"category": "AI"
}
)
Example - Custom Document schema via inheritance:
from dataclasses import dataclass
from gmf_forge_ai_data.vector_stores import Document
@dataclass
class LegalDocument(Document):
case_number: str
court: str
decision_date: datetime
jurisdiction: str
source: str = "" # Add source as typed field
page_number: Optional[int] = None
legal_doc = LegalDocument(
id="case_123",
content="In the matter of...",
case_number="2024-CV-12345",
court="Supreme Court",
decision_date=datetime(2024, 1, 15),
jurisdiction="Federal",
source="court_records.pdf",
page_number=42
)
Example - E-commerce Product Document:
@dataclass
class ProductDocument(Document):
sku: str
category: str
price: float
in_stock: bool
brand: str
# No source/page_number - not applicable to products
product = ProductDocument(
id="prod_456",
content="High-performance laptop with...",
sku="LAP-001",
category="Electronics",
price=1299.99,
in_stock=True,
brand="TechCorp"
)
108 def to_dict(self) -> Dict[str, Any]: 109 """ 110 Convert Document to dictionary for serialization. 111 112 Returns: 113 Dictionary representation of the document 114 """ 115 doc_dict = asdict(self) 116 # Convert datetime to ISO string for serialization 117 if self.timestamp: 118 doc_dict['timestamp'] = self.timestamp.isoformat() 119 return doc_dict
Convert Document to dictionary for serialization.
Returns: Dictionary representation of the document
121 @classmethod 122 def from_dict(cls: Type[T], data: Dict[str, Any]) -> T: 123 """ 124 Create Document from dictionary. 125 126 Args: 127 data: Dictionary containing document fields 128 129 Returns: 130 Document instance (or subclass instance) 131 """ 132 # Handle datetime deserialization 133 if 'timestamp' in data and isinstance(data['timestamp'], str): 134 data['timestamp'] = datetime.fromisoformat(data['timestamp']) 135 136 # Filter to only fields that exist in the dataclass 137 import inspect 138 valid_fields = set(inspect.signature(cls).parameters.keys()) 139 filtered_data = {k: v for k, v in data.items() if k in valid_fields} 140 141 return cls(**filtered_data)
Create Document from dictionary.
Args: data: Dictionary containing document fields
Returns: Document instance (or subclass instance)
143 def update_metadata(self, **kwargs) -> None: 144 """ 145 Update metadata fields. 146 147 Args: 148 **kwargs: Key-value pairs to add/update in metadata 149 """ 150 self.metadata.update(kwargs)
Update metadata fields.
Args: **kwargs: Key-value pairs to add/update in metadata
152 def get_metadata(self, key: str, default: Any = None) -> Any: 153 """ 154 Get a metadata value. 155 156 Args: 157 key: Metadata key to retrieve 158 default: Default value if key not found 159 160 Returns: 161 Metadata value or default 162 """ 163 return self.metadata.get(key, default)
Get a metadata value.
Args: key: Metadata key to retrieve default: Default value if key not found
Returns: Metadata value or default
166@dataclass 167class SearchResult: 168 """ 169 Represents a search result from a vector store. 170 171 Attributes: 172 document: The retrieved document 173 score: Similarity score (0.0 to 1.0, higher is better) 174 rank: Position in the result list (0-indexed) 175 """ 176 document: Document 177 score: float 178 rank: int
Represents a search result from a vector store.
Attributes: document: The retrieved document score: Similarity score (0.0 to 1.0, higher is better) rank: Position in the result list (0-indexed)
17class InMemoryVectorStore(BaseVectorStore): 18 """ 19 In-memory vector store using numpy for vector operations. 20 21 Features: 22 - Fast cosine similarity search using numpy 23 - Keyword search using simple text matching 24 - Hybrid search combining vector and keyword scores 25 - Metadata filtering 26 - No external dependencies (Azure, etc.) 27 28 Ideal for: 29 - Unit testing 30 - Development and prototyping 31 - Small datasets (< 10,000 documents) 32 - CI/CD pipelines 33 34 Note: All data is lost when the process terminates. 35 """ 36 37 def __init__(self, embedding_dimension: int = 1536): 38 """ 39 Initialize the in-memory vector store. 40 41 Args: 42 embedding_dimension: Expected dimension of embeddings (default 1536 for text-embedding-ada-002) 43 """ 44 self.embedding_dimension = embedding_dimension 45 self._documents: Dict[str, Document] = {} 46 self._embeddings: Optional[np.ndarray] = None 47 self._doc_ids: List[str] = [] 48 49 def add_documents( 50 self, 51 documents: List[Document], 52 generate_embeddings: bool = True 53 ) -> List[str]: 54 """ 55 Add documents to the in-memory store. 56 57 Args: 58 documents: List of Document objects to add 59 generate_embeddings: If True, validates embeddings are present 60 (actual generation should be done externally) 61 62 Returns: 63 List of document IDs that were added 64 """ 65 if not documents: 66 return [] 67 68 added_ids = [] 69 70 for doc in documents: 71 # Validate document 72 if not doc.id or not doc.content: 73 raise ValueError(f"Document must have id and content: {doc}") 74 75 if generate_embeddings and doc.embedding is None: 76 raise ValueError( 77 f"Document {doc.id} lacks embedding. " 78 "Generate embeddings externally before adding to store." 79 ) 80 81 if doc.embedding is not None and len(doc.embedding) != self.embedding_dimension: 82 raise ValueError( 83 f"Document {doc.id} has embedding dimension {len(doc.embedding)}, " 84 f"expected {self.embedding_dimension}" 85 ) 86 87 # Store document 88 self._documents[doc.id] = deepcopy(doc) 89 added_ids.append(doc.id) 90 91 # Rebuild embedding matrix 92 self._rebuild_embeddings() 93 94 return added_ids 95 96 def search( 97 self, 98 query: Optional[str] = None, 99 query_embedding: Optional[List[float]] = None, 100 top_k: int = 5, 101 filters: Optional[Dict[str, Any]] = None, 102 search_type: str = "vector" 103 ) -> List[SearchResult]: 104 """ 105 Search the in-memory vector store. 106 107 Args: 108 query: Text query (required for keyword/hybrid search) 109 query_embedding: Vector embedding of the query (required for vector/hybrid search) 110 top_k: Number of results to return 111 filters: Metadata filters (e.g., {"source": "doc.pdf"}) 112 search_type: "vector", "keyword", or "hybrid" 113 114 Returns: 115 List of SearchResult objects, ordered by relevance 116 """ 117 if search_type not in ["vector", "keyword", "hybrid"]: 118 raise ValueError(f"Invalid search_type: {search_type}. Must be 'vector', 'keyword', or 'hybrid'") 119 120 if search_type in ["vector", "hybrid"] and query_embedding is None: 121 raise ValueError(f"{search_type} search requires query_embedding") 122 123 if search_type in ["keyword", "hybrid"] and query is None: 124 raise ValueError(f"{search_type} search requires query text") 125 126 if not self._documents: 127 return [] 128 129 # Apply metadata filters 130 filtered_docs = self._apply_filters(filters) 131 if not filtered_docs: 132 return [] 133 134 # Calculate scores based on search type 135 if search_type == "vector": 136 scores = self._vector_search(query_embedding, filtered_docs) 137 elif search_type == "keyword": 138 scores = self._keyword_search(query, filtered_docs) 139 else: # hybrid 140 vector_scores = self._vector_search(query_embedding, filtered_docs) 141 keyword_scores = self._keyword_search(query, filtered_docs) 142 # Combine scores (50/50 weighting) 143 scores = { 144 doc_id: 0.5 * vector_scores.get(doc_id, 0.0) + 0.5 * keyword_scores.get(doc_id, 0.0) 145 for doc_id in filtered_docs 146 } 147 148 # Sort by score and take top_k 149 sorted_results = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:top_k] 150 151 # Build SearchResult objects 152 results = [] 153 for rank, (doc_id, score) in enumerate(sorted_results): 154 results.append(SearchResult( 155 document=deepcopy(self._documents[doc_id]), 156 score=float(score), 157 rank=rank 158 )) 159 160 return results 161 162 def delete_documents(self, document_ids: List[str]) -> int: 163 """ 164 Delete documents from the store. 165 166 Args: 167 document_ids: List of document IDs to delete 168 169 Returns: 170 Number of documents successfully deleted 171 """ 172 deleted_count = 0 173 for doc_id in document_ids: 174 if doc_id in self._documents: 175 del self._documents[doc_id] 176 deleted_count += 1 177 178 # Rebuild embedding matrix 179 self._rebuild_embeddings() 180 181 return deleted_count 182 183 def update_document( 184 self, 185 document_id: str, 186 content: Optional[str] = None, 187 embedding: Optional[List[float]] = None, 188 metadata: Optional[Dict[str, Any]] = None 189 ) -> bool: 190 """ 191 Update an existing document. 192 193 Args: 194 document_id: ID of the document to update 195 content: New content (if provided) 196 embedding: New embedding (if provided) 197 metadata: New metadata (merged with existing) 198 199 Returns: 200 True if update was successful, False if document not found 201 """ 202 if document_id not in self._documents: 203 return False 204 205 doc = self._documents[document_id] 206 207 if content is not None: 208 doc.content = content 209 210 if embedding is not None: 211 if len(embedding) != self.embedding_dimension: 212 raise ValueError( 213 f"Embedding dimension {len(embedding)} != expected {self.embedding_dimension}" 214 ) 215 doc.embedding = embedding 216 self._rebuild_embeddings() 217 218 if metadata is not None: 219 doc.metadata.update(metadata) 220 221 return True 222 223 def get_document(self, document_id: str) -> Optional[Document]: 224 """ 225 Retrieve a document by ID. 226 227 Args: 228 document_id: ID of the document to retrieve 229 230 Returns: 231 Document if found, None otherwise 232 """ 233 doc = self._documents.get(document_id) 234 return deepcopy(doc) if doc else None 235 236 def count(self) -> int: 237 """Get the total number of documents in the store.""" 238 return len(self._documents) 239 240 def clear(self) -> None: 241 """Remove all documents from the store.""" 242 self._documents.clear() 243 self._embeddings = None 244 self._doc_ids = [] 245 246 # Private helper methods 247 248 def _rebuild_embeddings(self) -> None: 249 """Rebuild the embedding matrix from stored documents.""" 250 docs_with_embeddings = [ 251 (doc_id, doc) for doc_id, doc in self._documents.items() 252 if doc.embedding is not None 253 ] 254 255 if not docs_with_embeddings: 256 self._embeddings = None 257 self._doc_ids = [] 258 return 259 260 self._doc_ids = [doc_id for doc_id, _ in docs_with_embeddings] 261 embeddings_list = [doc.embedding for _, doc in docs_with_embeddings] 262 self._embeddings = np.array(embeddings_list, dtype=np.float32) 263 264 def _apply_filters(self, filters: Optional[Dict[str, Any]]) -> List[str]: 265 """ 266 Apply metadata filters and return list of matching document IDs. 267 268 Supports filtering on: source, page_number, timestamp, and metadata fields 269 Special operators: use tuples for range queries 270 - (">=", value) for greater than or equal 271 - ("<=", value) for less than or equal 272 - ("range", min_val, max_val) for range queries 273 """ 274 if filters is None: 275 return list(self._documents.keys()) 276 277 matching_ids = [] 278 for doc_id, doc in self._documents.items(): 279 matches = True 280 281 for key, value in filters.items(): 282 # Check document attributes (source, page_number, timestamp) 283 doc_value = getattr(doc, key, None) 284 285 # If not a document attribute, check metadata 286 if doc_value is None and key in doc.metadata: 287 doc_value = doc.metadata[key] 288 289 # Skip if field doesn't exist 290 if doc_value is None: 291 matches = False 292 break 293 294 # Handle tuple operators for range queries 295 if isinstance(value, tuple): 296 operator = value[0] 297 if operator == "range" and len(value) == 3: 298 min_val, max_val = value[1], value[2] 299 if not (min_val <= doc_value <= max_val): 300 matches = False 301 break 302 elif operator in [">=", "gt"]: 303 if not (doc_value >= value[1]): 304 matches = False 305 break 306 elif operator in ["<=", "lt"]: 307 if not (doc_value <= value[1]): 308 matches = False 309 break 310 # Simple equality check 311 elif doc_value != value: 312 matches = False 313 break 314 315 if matches: 316 matching_ids.append(doc_id) 317 318 return matching_ids 319 320 def _vector_search( 321 self, 322 query_embedding: List[float], 323 candidate_ids: List[str] 324 ) -> Dict[str, float]: 325 """ 326 Perform vector similarity search using cosine similarity. 327 328 Returns: 329 Dict mapping document IDs to similarity scores (0.0 to 1.0) 330 """ 331 if self._embeddings is None or len(self._doc_ids) == 0: 332 return {} 333 334 query_vector = np.array(query_embedding, dtype=np.float32) 335 336 # Filter embeddings to only candidate documents 337 candidate_indices = [ 338 i for i, doc_id in enumerate(self._doc_ids) 339 if doc_id in candidate_ids 340 ] 341 342 if not candidate_indices: 343 return {} 344 345 candidate_embeddings = self._embeddings[candidate_indices] 346 candidate_doc_ids = [self._doc_ids[i] for i in candidate_indices] 347 348 # Compute cosine similarity 349 query_norm = np.linalg.norm(query_vector) 350 if query_norm == 0: 351 return {doc_id: 0.0 for doc_id in candidate_doc_ids} 352 353 doc_norms = np.linalg.norm(candidate_embeddings, axis=1) 354 dot_products = candidate_embeddings.dot(query_vector) 355 356 # Avoid division by zero 357 similarities = np.zeros(len(candidate_doc_ids)) 358 valid_mask = doc_norms > 0 359 similarities[valid_mask] = dot_products[valid_mask] / (doc_norms[valid_mask] * query_norm) 360 361 # Convert from [-1, 1] to [0, 1] 362 similarities = (similarities + 1) / 2 363 364 return {doc_id: float(sim) for doc_id, sim in zip(candidate_doc_ids, similarities)} 365 366 def _keyword_search( 367 self, 368 query: str, 369 candidate_ids: List[str] 370 ) -> Dict[str, float]: 371 """ 372 Perform simple keyword search using term matching. 373 374 Returns: 375 Dict mapping document IDs to relevance scores (0.0 to 1.0) 376 """ 377 query_lower = query.lower() 378 query_terms = set(query_lower.split()) 379 380 scores = {} 381 for doc_id in candidate_ids: 382 doc = self._documents[doc_id] 383 content_lower = doc.content.lower() 384 content_terms = set(content_lower.split()) 385 386 # Calculate Jaccard similarity as relevance score 387 if not query_terms: 388 scores[doc_id] = 0.0 389 else: 390 intersection = len(query_terms & content_terms) 391 union = len(query_terms | content_terms) 392 scores[doc_id] = intersection / union if union > 0 else 0.0 393 394 return scores
In-memory vector store using numpy for vector operations.
Features:
- Fast cosine similarity search using numpy
- Keyword search using simple text matching
- Hybrid search combining vector and keyword scores
- Metadata filtering
- No external dependencies (Azure, etc.)
Ideal for:
- Unit testing
- Development and prototyping
- Small datasets (< 10,000 documents)
- CI/CD pipelines
Note: All data is lost when the process terminates.
37 def __init__(self, embedding_dimension: int = 1536): 38 """ 39 Initialize the in-memory vector store. 40 41 Args: 42 embedding_dimension: Expected dimension of embeddings (default 1536 for text-embedding-ada-002) 43 """ 44 self.embedding_dimension = embedding_dimension 45 self._documents: Dict[str, Document] = {} 46 self._embeddings: Optional[np.ndarray] = None 47 self._doc_ids: List[str] = []
Initialize the in-memory vector store.
Args: embedding_dimension: Expected dimension of embeddings (default 1536 for text-embedding-ada-002)
49 def add_documents( 50 self, 51 documents: List[Document], 52 generate_embeddings: bool = True 53 ) -> List[str]: 54 """ 55 Add documents to the in-memory store. 56 57 Args: 58 documents: List of Document objects to add 59 generate_embeddings: If True, validates embeddings are present 60 (actual generation should be done externally) 61 62 Returns: 63 List of document IDs that were added 64 """ 65 if not documents: 66 return [] 67 68 added_ids = [] 69 70 for doc in documents: 71 # Validate document 72 if not doc.id or not doc.content: 73 raise ValueError(f"Document must have id and content: {doc}") 74 75 if generate_embeddings and doc.embedding is None: 76 raise ValueError( 77 f"Document {doc.id} lacks embedding. " 78 "Generate embeddings externally before adding to store." 79 ) 80 81 if doc.embedding is not None and len(doc.embedding) != self.embedding_dimension: 82 raise ValueError( 83 f"Document {doc.id} has embedding dimension {len(doc.embedding)}, " 84 f"expected {self.embedding_dimension}" 85 ) 86 87 # Store document 88 self._documents[doc.id] = deepcopy(doc) 89 added_ids.append(doc.id) 90 91 # Rebuild embedding matrix 92 self._rebuild_embeddings() 93 94 return added_ids
Add documents to the in-memory store.
Args: documents: List of Document objects to add generate_embeddings: If True, validates embeddings are present (actual generation should be done externally)
Returns: List of document IDs that were added
96 def search( 97 self, 98 query: Optional[str] = None, 99 query_embedding: Optional[List[float]] = None, 100 top_k: int = 5, 101 filters: Optional[Dict[str, Any]] = None, 102 search_type: str = "vector" 103 ) -> List[SearchResult]: 104 """ 105 Search the in-memory vector store. 106 107 Args: 108 query: Text query (required for keyword/hybrid search) 109 query_embedding: Vector embedding of the query (required for vector/hybrid search) 110 top_k: Number of results to return 111 filters: Metadata filters (e.g., {"source": "doc.pdf"}) 112 search_type: "vector", "keyword", or "hybrid" 113 114 Returns: 115 List of SearchResult objects, ordered by relevance 116 """ 117 if search_type not in ["vector", "keyword", "hybrid"]: 118 raise ValueError(f"Invalid search_type: {search_type}. Must be 'vector', 'keyword', or 'hybrid'") 119 120 if search_type in ["vector", "hybrid"] and query_embedding is None: 121 raise ValueError(f"{search_type} search requires query_embedding") 122 123 if search_type in ["keyword", "hybrid"] and query is None: 124 raise ValueError(f"{search_type} search requires query text") 125 126 if not self._documents: 127 return [] 128 129 # Apply metadata filters 130 filtered_docs = self._apply_filters(filters) 131 if not filtered_docs: 132 return [] 133 134 # Calculate scores based on search type 135 if search_type == "vector": 136 scores = self._vector_search(query_embedding, filtered_docs) 137 elif search_type == "keyword": 138 scores = self._keyword_search(query, filtered_docs) 139 else: # hybrid 140 vector_scores = self._vector_search(query_embedding, filtered_docs) 141 keyword_scores = self._keyword_search(query, filtered_docs) 142 # Combine scores (50/50 weighting) 143 scores = { 144 doc_id: 0.5 * vector_scores.get(doc_id, 0.0) + 0.5 * keyword_scores.get(doc_id, 0.0) 145 for doc_id in filtered_docs 146 } 147 148 # Sort by score and take top_k 149 sorted_results = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:top_k] 150 151 # Build SearchResult objects 152 results = [] 153 for rank, (doc_id, score) in enumerate(sorted_results): 154 results.append(SearchResult( 155 document=deepcopy(self._documents[doc_id]), 156 score=float(score), 157 rank=rank 158 )) 159 160 return results
Search the in-memory vector store.
Args: query: Text query (required for keyword/hybrid search) query_embedding: Vector embedding of the query (required for vector/hybrid search) top_k: Number of results to return filters: Metadata filters (e.g., {"source": "doc.pdf"}) search_type: "vector", "keyword", or "hybrid"
Returns: List of SearchResult objects, ordered by relevance
162 def delete_documents(self, document_ids: List[str]) -> int: 163 """ 164 Delete documents from the store. 165 166 Args: 167 document_ids: List of document IDs to delete 168 169 Returns: 170 Number of documents successfully deleted 171 """ 172 deleted_count = 0 173 for doc_id in document_ids: 174 if doc_id in self._documents: 175 del self._documents[doc_id] 176 deleted_count += 1 177 178 # Rebuild embedding matrix 179 self._rebuild_embeddings() 180 181 return deleted_count
Delete documents from the store.
Args: document_ids: List of document IDs to delete
Returns: Number of documents successfully deleted
183 def update_document( 184 self, 185 document_id: str, 186 content: Optional[str] = None, 187 embedding: Optional[List[float]] = None, 188 metadata: Optional[Dict[str, Any]] = None 189 ) -> bool: 190 """ 191 Update an existing document. 192 193 Args: 194 document_id: ID of the document to update 195 content: New content (if provided) 196 embedding: New embedding (if provided) 197 metadata: New metadata (merged with existing) 198 199 Returns: 200 True if update was successful, False if document not found 201 """ 202 if document_id not in self._documents: 203 return False 204 205 doc = self._documents[document_id] 206 207 if content is not None: 208 doc.content = content 209 210 if embedding is not None: 211 if len(embedding) != self.embedding_dimension: 212 raise ValueError( 213 f"Embedding dimension {len(embedding)} != expected {self.embedding_dimension}" 214 ) 215 doc.embedding = embedding 216 self._rebuild_embeddings() 217 218 if metadata is not None: 219 doc.metadata.update(metadata) 220 221 return True
Update an existing document.
Args: document_id: ID of the document to update content: New content (if provided) embedding: New embedding (if provided) metadata: New metadata (merged with existing)
Returns: True if update was successful, False if document not found
223 def get_document(self, document_id: str) -> Optional[Document]: 224 """ 225 Retrieve a document by ID. 226 227 Args: 228 document_id: ID of the document to retrieve 229 230 Returns: 231 Document if found, None otherwise 232 """ 233 doc = self._documents.get(document_id) 234 return deepcopy(doc) if doc else None
Retrieve a document by ID.
Args: document_id: ID of the document to retrieve
Returns: Document if found, None otherwise
43class AzureAISearchVectorStore(BaseVectorStore): 44 """ 45 Azure AI Search vector store for production RAG pipelines. 46 47 Features: 48 - Scalable vector search with HNSW (Hierarchical Navigable Small World) 49 - Hybrid search (vector + keyword BM25) 50 - Semantic ranking powered by Microsoft Bing 51 - Automatic indexing of all document fields (universal + custom) 52 - Efficient filtering on all indexed fields 53 - High availability and disaster recovery 54 - Managed service (no infrastructure management) 55 56 Prerequisites: 57 - Azure AI Search service (Basic tier or higher for vector search) 58 - Service endpoint and API key **or** a managed-identity token provider 59 - Index created with vector field configuration 60 61 Usage: 62 63 **Simple Mode** (base Document): 64 ```python 65 # API key auth 66 store = AzureAISearchVectorStore( 67 endpoint="https://my-search.search.windows.net", 68 index_name="documents", 69 api_key="...", 70 embedding_dimension=1536 71 ) 72 73 # Managed identity / token provider auth 74 from azure.identity import DefaultAzureCredential, get_bearer_token_provider 75 token_provider = get_bearer_token_provider( 76 DefaultAzureCredential(), 77 "https://search.azure.com/.default" # Azure AI Search scope 78 ) 79 store = AzureAISearchVectorStore( 80 endpoint="https://my-search.search.windows.net", 81 index_name="documents", 82 token_provider=token_provider, 83 embedding_dimension=1536 84 ) 85 # Uses base Document with universal fields only 86 ``` 87 88 **Custom Schema Mode** (domain-specific Document): 89 ```python 90 from dataclasses import dataclass 91 from gmf_forge_ai_data.vector_stores import Document 92 93 @dataclass 94 class LegalDocument(Document): 95 case_number: str = "" 96 court: str = "" 97 jurisdiction: str = "" 98 decision_date: Optional[datetime] = None 99 100 # Pass document_type - all fields automatically indexed. 101 # If the index was provisioned externally (portal, Bicep, etc.) the 102 # vector and content field names will likely differ from the defaults 103 # ("embedding" / "content"). Always set vector_field_name and 104 # content_field_name to match your actual index schema. 105 store = AzureAISearchVectorStore( 106 endpoint="https://my-search.search.windows.net", 107 index_name="legal_documents", 108 api_key="...", 109 embedding_dimension=1536, 110 document_type=LegalDocument, 111 vector_field_name="contentVector", # match your index schema 112 content_field_name="chunkContent", # match your index schema 113 ) 114 115 # All fields indexed - filter on any field 116 results = store.search( 117 query_embedding=vector, 118 filters={ 119 "jurisdiction": "Federal", 120 "court": "Supreme Court", 121 "decision_date": (">=", datetime(2024, 1, 1)) 122 } 123 ) 124 ``` 125 """ 126 127 def __init__( 128 self, 129 endpoint: str, 130 index_name: str, 131 api_key: Optional[str] = None, 132 token_provider: Optional[Callable[[], str]] = None, 133 embedding_dimension: int = 1536, 134 document_type: type = Document, 135 vector_field_name: str = "embedding", 136 content_field_name: str = "content", 137 ): 138 """ 139 Initialize Azure AI Search vector store. 140 141 The index must be pre-provisioned using ``AzureAISearchIndexBuilder`` 142 before first use. The constructor only establishes client connections 143 — it never creates or modifies indexes. 144 145 Exactly one of ``api_key`` or ``token_provider`` must be supplied. 146 147 Args: 148 endpoint: Azure AI Search service endpoint 149 index_name: Name of the search index 150 api_key: Azure AI Search API key. Use for local development or 151 when managed identity is not available. 152 token_provider: Zero-argument callable that returns a bearer token 153 string. Use for managed identity / workload identity scenarios. 154 The callable must request the **Azure AI Search** scope:: 155 156 from azure.identity import DefaultAzureCredential, get_bearer_token_provider 157 token_provider = get_bearer_token_provider( 158 DefaultAzureCredential(), 159 "https://search.azure.com/.default" 160 ) 161 162 Note: this scope is different from Azure OpenAI / Cognitive Services 163 (``https://cognitiveservices.azure.com/.default``) — each service 164 requires its own token_provider. 165 embedding_dimension: Dimension of vector embeddings (default 1536) 166 document_type: Document class (base or custom subclass). All fields will be indexed. 167 vector_field_name: Name of the vector field in the Azure Search index (default "embedding"). 168 Override this when the index was built externally with a different field name 169 (e.g. "contentVector", "chunkVector"). 170 content_field_name: Name of the text content field in the Azure Search index (default "content"). 171 Override when the index uses a different name (e.g. "chunkContent", "text"). 172 """ 173 if not api_key and not token_provider: 174 raise ValueError( 175 "Either api_key or token_provider must be supplied to AzureAISearchVectorStore." 176 ) 177 self.endpoint = endpoint 178 self.index_name = index_name 179 self.embedding_dimension = embedding_dimension 180 self.document_type = document_type 181 self.vector_field_name = vector_field_name 182 self.content_field_name = content_field_name 183 184 if token_provider: 185 credential = _TokenProviderCredential(token_provider) 186 else: 187 credential = AzureKeyCredential(api_key) 188 189 # Initialize clients 190 self.search_client = SearchClient( 191 endpoint=endpoint, 192 index_name=index_name, 193 credential=credential 194 ) 195 196 @staticmethod 197 def _format_datetime_for_azure(dt: datetime) -> str: 198 """ 199 Convert datetime to Azure Search DateTimeOffset format (ISO 8601 with timezone). 200 201 Args: 202 dt: datetime object (timezone-aware or naive) 203 204 Returns: 205 ISO 8601 string with timezone (e.g., '2024-01-15T10:30:00Z') 206 """ 207 if dt.tzinfo is None: 208 # Timezone-naive datetime: assume UTC and append 'Z' 209 return dt.isoformat() + 'Z' 210 else: 211 # Timezone-aware datetime: use standard ISO format 212 return dt.isoformat() 213 214 def add_documents( 215 self, 216 documents: List[Document], 217 generate_embeddings: bool = True 218 ) -> List[str]: 219 """ 220 Add documents to Azure AI Search. 221 222 Args: 223 documents: List of Document objects to add (can be base or custom subclasses) 224 generate_embeddings: If True, validates embeddings are present 225 226 Returns: 227 List of document IDs that were added 228 229 Note: 230 Custom document fields are automatically serialized and stored. 231 All fields remain accessible after retrieval. 232 """ 233 if not documents: 234 return [] 235 236 # Helper function to format datetime for Azure Search DateTimeOffset 237 format_datetime_for_azure = self._format_datetime_for_azure 238 239 # Helper function to serialize datetime objects in dictionaries 240 def serialize_for_json(obj): 241 """Recursively convert datetime objects to ISO format strings for JSON serialization.""" 242 if isinstance(obj, datetime): 243 return format_datetime_for_azure(obj) 244 elif isinstance(obj, dict): 245 return {k: serialize_for_json(v) for k, v in obj.items()} 246 elif isinstance(obj, (list, tuple)): 247 return [serialize_for_json(item) for item in obj] 248 else: 249 return obj 250 251 # Prepare documents for upload 252 search_documents = [] 253 added_ids = [] 254 255 for doc in documents: 256 if not doc.id or not doc.content: 257 raise ValueError(f"Document must have id and content: {doc}") 258 259 if generate_embeddings and doc.embedding is None: 260 raise ValueError( 261 f"Document {doc.id} lacks embedding. " 262 "Generate embeddings externally before adding." 263 ) 264 265 # Serialize entire document (including custom fields) to JSON 266 import json 267 import dataclasses 268 doc_dict = doc.to_dict() 269 270 # Convert datetime objects for JSON serialization 271 doc_dict_serializable = serialize_for_json(doc_dict) 272 273 # Convert to Azure Search document format 274 # Store base fields directly 275 search_doc = { 276 "id": doc.id, 277 "content": doc.content, 278 "embedding": doc.embedding if doc.embedding else [], 279 "timestamp": format_datetime_for_azure(doc.timestamp) if doc.timestamp else None, 280 "document_data": json.dumps(doc_dict_serializable) # Keep for metadata and backwards compatibility 281 } 282 283 # Add all custom fields as indexed fields 284 if dataclasses.is_dataclass(doc): 285 for field in dataclasses.fields(doc): 286 # Skip base Document fields (already added above) 287 if field.name in {'id', 'content', 'embedding', 'timestamp', 'metadata'}: 288 continue 289 290 field_value = getattr(doc, field.name, None) 291 292 if field_value is not None: 293 # Serialize datetime fields with Azure DateTimeOffset format 294 if isinstance(field_value, datetime): 295 search_doc[field.name] = format_datetime_for_azure(field_value) 296 # Handle lists/dicts by converting to JSON string 297 elif isinstance(field_value, (list, dict)): 298 search_doc[field.name] = json.dumps(field_value) 299 else: 300 search_doc[field.name] = field_value 301 302 search_documents.append(search_doc) 303 added_ids.append(doc.id) 304 305 # Upload to Azure Search 306 result = self.search_client.upload_documents(documents=search_documents) 307 308 # Check for failures 309 failed_count = sum(1 for r in result if not r.succeeded) 310 if failed_count > 0: 311 logger.warning(f"{failed_count} documents failed to upload") 312 313 return added_ids 314 315 def search( 316 self, 317 query: Optional[str] = None, 318 query_embedding: Optional[List[float]] = None, 319 top_k: int = 5, 320 filters: Optional[Dict[str, Any]] = None, 321 search_type: str = "vector" 322 ) -> List[SearchResult]: 323 """ 324 Search Azure AI Search index. 325 326 Args: 327 query: Text query (required for keyword/hybrid search) 328 query_embedding: Vector embedding (required for vector/hybrid search) 329 top_k: Number of results to return 330 filters: Metadata filters (e.g., {"source": "doc.pdf"}) 331 search_type: "vector", "keyword", or "hybrid" 332 333 Returns: 334 List of SearchResult objects, ordered by relevance 335 """ 336 if search_type not in ["vector", "keyword", "hybrid"]: 337 raise ValueError(f"Invalid search_type: {search_type}") 338 339 if search_type in ["vector", "hybrid"] and query_embedding is None: 340 raise ValueError(f"{search_type} search requires query_embedding") 341 342 if search_type in ["keyword", "hybrid"] and query is None: 343 raise ValueError(f"{search_type} search requires query text") 344 345 # Build filter expression 346 filter_expr = self._build_filter_expression(filters) 347 348 # Execute search based on type 349 if search_type == "vector": 350 # Pure vector search 351 vector_query = VectorizedQuery( 352 vector=query_embedding, 353 k_nearest_neighbors=top_k, 354 fields=self.vector_field_name 355 ) 356 357 results = self.search_client.search( 358 search_text=None, 359 vector_queries=[vector_query], 360 filter=filter_expr, 361 top=top_k 362 ) 363 364 elif search_type == "keyword": 365 # Pure keyword search (BM25) 366 results = self.search_client.search( 367 search_text=query, 368 filter=filter_expr, 369 top=top_k, 370 include_total_count=False 371 ) 372 373 else: # hybrid 374 # Hybrid search (vector + keyword) 375 vector_query = VectorizedQuery( 376 vector=query_embedding, 377 k_nearest_neighbors=top_k, 378 fields=self.vector_field_name 379 ) 380 381 results = self.search_client.search( 382 search_text=query, 383 vector_queries=[vector_query], 384 filter=filter_expr, 385 top=top_k 386 ) 387 388 # Convert to SearchResult objects 389 search_results = [] 390 for rank, result in enumerate(results): 391 import json 392 393 # Deserialize full document data 394 doc_data = json.loads(result.get("document_data", "{}") or "{}") 395 396 if not doc_data: 397 # Index was built externally — fields are stored directly on the result. 398 # Collect all non-Azure-internal fields, excluding the vector blob. 399 doc_data = { 400 k: v for k, v in result.items() 401 if not k.startswith("@") and k != self.vector_field_name 402 } 403 # Remap content field when the index uses a non-standard name. 404 if self.content_field_name != "content" and self.content_field_name in doc_data: 405 doc_data["content"] = doc_data.pop(self.content_field_name) 406 407 # Reconstruct document using the configured document_type 408 doc = self.document_type.from_dict(doc_data) 409 410 search_results.append(SearchResult( 411 document=doc, 412 score=result.get("@search.score", 0.0), 413 rank=rank 414 )) 415 416 return search_results 417 418 def delete_documents(self, document_ids: List[str]) -> int: 419 """ 420 Delete documents from Azure AI Search. 421 422 Args: 423 document_ids: List of document IDs to delete 424 425 Returns: 426 Number of documents successfully deleted 427 """ 428 if not document_ids: 429 return 0 430 431 # Prepare delete operations 432 delete_docs = [{"id": doc_id} for doc_id in document_ids] 433 434 # Execute delete 435 result = self.search_client.delete_documents(documents=delete_docs) 436 437 # Count successful deletions 438 deleted_count = sum(1 for r in result if r.succeeded) 439 return deleted_count 440 441 def update_document( 442 self, 443 document_id: str, 444 content: Optional[str] = None, 445 embedding: Optional[List[float]] = None, 446 metadata: Optional[Dict[str, Any]] = None, 447 timestamp: Optional[datetime] = None, 448 **custom_fields 449 ) -> bool: 450 """ 451 Update an existing document in Azure AI Search. 452 453 Args: 454 document_id: ID of the document to update 455 content: New content (if provided) 456 embedding: New embedding (if provided) 457 metadata: New metadata (merged with existing) 458 timestamp: New timestamp (if provided) 459 **custom_fields: Any custom fields to update 460 461 Returns: 462 True if update was successful, False if document not found 463 """ 464 try: 465 import json 466 467 # Retrieve existing document 468 existing_doc = self.search_client.get_document(key=document_id) 469 doc_data = json.loads(existing_doc.get("document_data", "{}")) 470 471 # Update fields in document data 472 if content is not None: 473 doc_data["content"] = content 474 existing_doc["content"] = content 475 476 if embedding is not None: 477 doc_data["embedding"] = embedding 478 existing_doc["embedding"] = embedding 479 480 if timestamp is not None: 481 doc_data["timestamp"] = self._format_datetime_for_azure(timestamp) 482 existing_doc["timestamp"] = self._format_datetime_for_azure(timestamp) 483 484 if metadata is not None: 485 if "metadata" not in doc_data: 486 doc_data["metadata"] = {} 487 doc_data["metadata"].update(metadata) 488 489 # Update custom fields 490 for key, value in custom_fields.items(): 491 doc_data[key] = value 492 493 # Update document_data 494 existing_doc["document_data"] = json.dumps(doc_data) 495 496 # Merge/upload updated document 497 result = self.search_client.merge_or_upload_documents(documents=[existing_doc]) 498 499 return result[0].succeeded 500 501 except Exception as e: 502 logger.error(f"Failed to update document {document_id}: {e}") 503 return False 504 505 def get_document(self, document_id: str) -> Optional[Document]: 506 """ 507 Retrieve a document by ID from Azure AI Search. 508 509 Args: 510 document_id: ID of the document to retrieve 511 512 Returns: 513 Document if found, None otherwise (returns configured document_type) 514 """ 515 try: 516 import json 517 result = self.search_client.get_document(key=document_id) 518 519 # Deserialize full document data using configured document_type 520 doc_data = json.loads(result.get("document_data", "{}")) 521 return self.document_type.from_dict(doc_data) 522 523 except Exception as e: 524 logger.debug(f"Document {document_id} not found: {e}") 525 return None 526 527 def get_document_by_parent(self, document_id: str) -> List[Document]: 528 """ 529 Retrieve all chunks belonging to a parent document. 530 531 Args: 532 document_id: The value of the ``documentId`` field shared across 533 all chunks of the same parent document. 534 535 Returns: 536 List of Document objects sorted by ``pageNumber``, or an empty 537 list if no matching chunks are found. 538 """ 539 results = self.search( 540 query="*", 541 search_type="keyword", 542 filters={"documentId": document_id}, 543 top_k=1000, 544 ) 545 chunks = [r.document for r in results] 546 chunks.sort(key=lambda c: getattr(c, "pageNumber", 0)) 547 return chunks 548 549 def count(self) -> int: 550 """ 551 Get the total number of documents in the index. 552 553 Returns: 554 Total document count 555 """ 556 # Execute a search to get total count 557 results = self.search_client.search( 558 search_text="*", 559 include_total_count=True, 560 top=0 561 ) 562 563 # Get count from results 564 count = getattr(results, 'get_count', lambda: 0)() 565 return count if count is not None else 0 566 567 def clear(self) -> None: 568 """ 569 Remove all documents from the index. 570 571 Warning: This operation is irreversible. 572 """ 573 # Search for all document IDs 574 results = self.search_client.search( 575 search_text="*", 576 select=["id"], 577 top=1000 # Adjust batch size as needed 578 ) 579 580 # Collect all IDs 581 doc_ids = [r["id"] for r in results] 582 583 # Delete in batches 584 if doc_ids: 585 self.delete_documents(doc_ids) 586 logger.info(f"Cleared {len(doc_ids)} documents from index '{self.index_name}'") 587 588 def _build_filter_expression(self, filters: Optional[Dict[str, Any]]) -> Optional[str]: 589 """ 590 Build OData filter expression from filter dictionary. 591 592 Supports filtering on all indexed fields (universal + custom fields from document_type). 593 594 Args: 595 filters: Dictionary of field: value pairs 596 Special operators: use tuples for range queries 597 - (">=", value) for greater than or equal 598 - ("<=", value) for less than or equal 599 - ("range", min_val, max_val) for range queries 600 601 Returns: 602 OData filter string or None 603 604 Example: 605 ```python 606 # Filter on universal field 607 filters = {"timestamp": (">=", datetime(2024, 1, 1))} 608 609 # Filter on custom fields (if document_type has them) 610 filters = { 611 "court": "Supreme Court", 612 "case_number": "2024-001", 613 "decision_date": ("range", start_date, end_date) 614 } 615 ``` 616 """ 617 if not filters: 618 return None 619 620 # Get indexed field names from document_type 621 import dataclasses 622 indexed_fields = {'id', 'content', 'timestamp'} # Universal fields 623 624 if dataclasses.is_dataclass(self.document_type): 625 for field in dataclasses.fields(self.document_type): 626 if field.name not in {'id', 'content', 'embedding', 'timestamp', 'metadata'}: 627 indexed_fields.add(field.name) 628 629 filter_parts = [] 630 631 for key, value in filters.items(): 632 # Check if field is indexed 633 if key not in indexed_fields: 634 logger.warning( 635 f"Filtering on '{key}' - field not found in document schema. " 636 f"Indexed fields: {indexed_fields}" 637 ) 638 continue 639 640 # Handle tuple operators for range queries 641 if isinstance(value, tuple): 642 operator = value[0] 643 if operator == "range" and len(value) == 3: 644 # Range query: field >= min and field <= max 645 min_val = value[1] 646 max_val = value[2] 647 if isinstance(min_val, datetime): 648 min_val = self._format_datetime_for_azure(min_val) 649 if isinstance(max_val, datetime): 650 max_val = self._format_datetime_for_azure(max_val) 651 filter_parts.append(f"{key} ge {min_val} and {key} le {max_val}") 652 elif operator in [">=", "gt", ">="]: 653 op_str = "ge" 654 val = value[1] 655 if isinstance(val, datetime): 656 val = self._format_datetime_for_azure(val) 657 filter_parts.append(f"{key} {op_str} {val}") 658 elif operator in ["<=", "lt", "<="]: 659 op_str = "le" 660 val = value[1] 661 if isinstance(val, datetime): 662 val = self._format_datetime_for_azure(val) 663 filter_parts.append(f"{key} {op_str} {val}") 664 # Handle datetime equality 665 elif isinstance(value, datetime): 666 filter_parts.append(f"{key} eq {self._format_datetime_for_azure(value)}") 667 # Handle string equality (needs quotes in OData) 668 elif isinstance(value, str): 669 # Escape single quotes in the value 670 escaped_value = value.replace("'", "''") 671 filter_parts.append(f"{key} eq '{escaped_value}'") 672 # Handle boolean 673 elif isinstance(value, bool): 674 filter_parts.append(f"{key} eq {str(value).lower()}") 675 # Handle numeric types 676 elif isinstance(value, (int, float)): 677 filter_parts.append(f"{key} eq {value}") 678 679 return " and ".join(filter_parts) if filter_parts else None
Azure AI Search vector store for production RAG pipelines.
Features:
- Scalable vector search with HNSW (Hierarchical Navigable Small World)
- Hybrid search (vector + keyword BM25)
- Semantic ranking powered by Microsoft Bing
- Automatic indexing of all document fields (universal + custom)
- Efficient filtering on all indexed fields
- High availability and disaster recovery
- Managed service (no infrastructure management)
Prerequisites:
- Azure AI Search service (Basic tier or higher for vector search)
- Service endpoint and API key or a managed-identity token provider
- Index created with vector field configuration
Usage:
Simple Mode (base Document):
# API key auth
store = AzureAISearchVectorStore(
endpoint="https://my-search.search.windows.net",
index_name="documents",
api_key="...",
embedding_dimension=1536
)
# Managed identity / token provider auth
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
token_provider = get_bearer_token_provider(
DefaultAzureCredential(),
"https://search.azure.com/.default" # Azure AI Search scope
)
store = AzureAISearchVectorStore(
endpoint="https://my-search.search.windows.net",
index_name="documents",
token_provider=token_provider,
embedding_dimension=1536
)
# Uses base Document with universal fields only
Custom Schema Mode (domain-specific Document):
from dataclasses import dataclass
from gmf_forge_ai_data.vector_stores import Document
@dataclass
class LegalDocument(Document):
case_number: str = ""
court: str = ""
jurisdiction: str = ""
decision_date: Optional[datetime] = None
# Pass document_type - all fields automatically indexed.
# If the index was provisioned externally (portal, Bicep, etc.) the
# vector and content field names will likely differ from the defaults
# ("embedding" / "content"). Always set vector_field_name and
# content_field_name to match your actual index schema.
store = AzureAISearchVectorStore(
endpoint="https://my-search.search.windows.net",
index_name="legal_documents",
api_key="...",
embedding_dimension=1536,
document_type=LegalDocument,
vector_field_name="contentVector", # match your index schema
content_field_name="chunkContent", # match your index schema
)
# All fields indexed - filter on any field
results = store.search(
query_embedding=vector,
filters={
"jurisdiction": "Federal",
"court": "Supreme Court",
"decision_date": (">=", datetime(2024, 1, 1))
}
)
127 def __init__( 128 self, 129 endpoint: str, 130 index_name: str, 131 api_key: Optional[str] = None, 132 token_provider: Optional[Callable[[], str]] = None, 133 embedding_dimension: int = 1536, 134 document_type: type = Document, 135 vector_field_name: str = "embedding", 136 content_field_name: str = "content", 137 ): 138 """ 139 Initialize Azure AI Search vector store. 140 141 The index must be pre-provisioned using ``AzureAISearchIndexBuilder`` 142 before first use. The constructor only establishes client connections 143 — it never creates or modifies indexes. 144 145 Exactly one of ``api_key`` or ``token_provider`` must be supplied. 146 147 Args: 148 endpoint: Azure AI Search service endpoint 149 index_name: Name of the search index 150 api_key: Azure AI Search API key. Use for local development or 151 when managed identity is not available. 152 token_provider: Zero-argument callable that returns a bearer token 153 string. Use for managed identity / workload identity scenarios. 154 The callable must request the **Azure AI Search** scope:: 155 156 from azure.identity import DefaultAzureCredential, get_bearer_token_provider 157 token_provider = get_bearer_token_provider( 158 DefaultAzureCredential(), 159 "https://search.azure.com/.default" 160 ) 161 162 Note: this scope is different from Azure OpenAI / Cognitive Services 163 (``https://cognitiveservices.azure.com/.default``) — each service 164 requires its own token_provider. 165 embedding_dimension: Dimension of vector embeddings (default 1536) 166 document_type: Document class (base or custom subclass). All fields will be indexed. 167 vector_field_name: Name of the vector field in the Azure Search index (default "embedding"). 168 Override this when the index was built externally with a different field name 169 (e.g. "contentVector", "chunkVector"). 170 content_field_name: Name of the text content field in the Azure Search index (default "content"). 171 Override when the index uses a different name (e.g. "chunkContent", "text"). 172 """ 173 if not api_key and not token_provider: 174 raise ValueError( 175 "Either api_key or token_provider must be supplied to AzureAISearchVectorStore." 176 ) 177 self.endpoint = endpoint 178 self.index_name = index_name 179 self.embedding_dimension = embedding_dimension 180 self.document_type = document_type 181 self.vector_field_name = vector_field_name 182 self.content_field_name = content_field_name 183 184 if token_provider: 185 credential = _TokenProviderCredential(token_provider) 186 else: 187 credential = AzureKeyCredential(api_key) 188 189 # Initialize clients 190 self.search_client = SearchClient( 191 endpoint=endpoint, 192 index_name=index_name, 193 credential=credential 194 )
Initialize Azure AI Search vector store.
The index must be pre-provisioned using AzureAISearchIndexBuilder
before first use. The constructor only establishes client connections
— it never creates or modifies indexes.
Exactly one of api_key or token_provider must be supplied.
Args: endpoint: Azure AI Search service endpoint index_name: Name of the search index api_key: Azure AI Search API key. Use for local development or when managed identity is not available. token_provider: Zero-argument callable that returns a bearer token string. Use for managed identity / workload identity scenarios. The callable must request the Azure AI Search scope::
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
token_provider = get_bearer_token_provider(
DefaultAzureCredential(),
"https://search.azure.com/.default"
)
Note: this scope is different from Azure OpenAI / Cognitive Services
(``https://cognitiveservices.azure.com/.default``) — each service
requires its own token_provider.
embedding_dimension: Dimension of vector embeddings (default 1536)
document_type: Document class (base or custom subclass). All fields will be indexed.
vector_field_name: Name of the vector field in the Azure Search index (default "embedding").
Override this when the index was built externally with a different field name
(e.g. "contentVector", "chunkVector").
content_field_name: Name of the text content field in the Azure Search index (default "content").
Override when the index uses a different name (e.g. "chunkContent", "text").
214 def add_documents( 215 self, 216 documents: List[Document], 217 generate_embeddings: bool = True 218 ) -> List[str]: 219 """ 220 Add documents to Azure AI Search. 221 222 Args: 223 documents: List of Document objects to add (can be base or custom subclasses) 224 generate_embeddings: If True, validates embeddings are present 225 226 Returns: 227 List of document IDs that were added 228 229 Note: 230 Custom document fields are automatically serialized and stored. 231 All fields remain accessible after retrieval. 232 """ 233 if not documents: 234 return [] 235 236 # Helper function to format datetime for Azure Search DateTimeOffset 237 format_datetime_for_azure = self._format_datetime_for_azure 238 239 # Helper function to serialize datetime objects in dictionaries 240 def serialize_for_json(obj): 241 """Recursively convert datetime objects to ISO format strings for JSON serialization.""" 242 if isinstance(obj, datetime): 243 return format_datetime_for_azure(obj) 244 elif isinstance(obj, dict): 245 return {k: serialize_for_json(v) for k, v in obj.items()} 246 elif isinstance(obj, (list, tuple)): 247 return [serialize_for_json(item) for item in obj] 248 else: 249 return obj 250 251 # Prepare documents for upload 252 search_documents = [] 253 added_ids = [] 254 255 for doc in documents: 256 if not doc.id or not doc.content: 257 raise ValueError(f"Document must have id and content: {doc}") 258 259 if generate_embeddings and doc.embedding is None: 260 raise ValueError( 261 f"Document {doc.id} lacks embedding. " 262 "Generate embeddings externally before adding." 263 ) 264 265 # Serialize entire document (including custom fields) to JSON 266 import json 267 import dataclasses 268 doc_dict = doc.to_dict() 269 270 # Convert datetime objects for JSON serialization 271 doc_dict_serializable = serialize_for_json(doc_dict) 272 273 # Convert to Azure Search document format 274 # Store base fields directly 275 search_doc = { 276 "id": doc.id, 277 "content": doc.content, 278 "embedding": doc.embedding if doc.embedding else [], 279 "timestamp": format_datetime_for_azure(doc.timestamp) if doc.timestamp else None, 280 "document_data": json.dumps(doc_dict_serializable) # Keep for metadata and backwards compatibility 281 } 282 283 # Add all custom fields as indexed fields 284 if dataclasses.is_dataclass(doc): 285 for field in dataclasses.fields(doc): 286 # Skip base Document fields (already added above) 287 if field.name in {'id', 'content', 'embedding', 'timestamp', 'metadata'}: 288 continue 289 290 field_value = getattr(doc, field.name, None) 291 292 if field_value is not None: 293 # Serialize datetime fields with Azure DateTimeOffset format 294 if isinstance(field_value, datetime): 295 search_doc[field.name] = format_datetime_for_azure(field_value) 296 # Handle lists/dicts by converting to JSON string 297 elif isinstance(field_value, (list, dict)): 298 search_doc[field.name] = json.dumps(field_value) 299 else: 300 search_doc[field.name] = field_value 301 302 search_documents.append(search_doc) 303 added_ids.append(doc.id) 304 305 # Upload to Azure Search 306 result = self.search_client.upload_documents(documents=search_documents) 307 308 # Check for failures 309 failed_count = sum(1 for r in result if not r.succeeded) 310 if failed_count > 0: 311 logger.warning(f"{failed_count} documents failed to upload") 312 313 return added_ids
Add documents to Azure AI Search.
Args: documents: List of Document objects to add (can be base or custom subclasses) generate_embeddings: If True, validates embeddings are present
Returns: List of document IDs that were added
Note: Custom document fields are automatically serialized and stored. All fields remain accessible after retrieval.
315 def search( 316 self, 317 query: Optional[str] = None, 318 query_embedding: Optional[List[float]] = None, 319 top_k: int = 5, 320 filters: Optional[Dict[str, Any]] = None, 321 search_type: str = "vector" 322 ) -> List[SearchResult]: 323 """ 324 Search Azure AI Search index. 325 326 Args: 327 query: Text query (required for keyword/hybrid search) 328 query_embedding: Vector embedding (required for vector/hybrid search) 329 top_k: Number of results to return 330 filters: Metadata filters (e.g., {"source": "doc.pdf"}) 331 search_type: "vector", "keyword", or "hybrid" 332 333 Returns: 334 List of SearchResult objects, ordered by relevance 335 """ 336 if search_type not in ["vector", "keyword", "hybrid"]: 337 raise ValueError(f"Invalid search_type: {search_type}") 338 339 if search_type in ["vector", "hybrid"] and query_embedding is None: 340 raise ValueError(f"{search_type} search requires query_embedding") 341 342 if search_type in ["keyword", "hybrid"] and query is None: 343 raise ValueError(f"{search_type} search requires query text") 344 345 # Build filter expression 346 filter_expr = self._build_filter_expression(filters) 347 348 # Execute search based on type 349 if search_type == "vector": 350 # Pure vector search 351 vector_query = VectorizedQuery( 352 vector=query_embedding, 353 k_nearest_neighbors=top_k, 354 fields=self.vector_field_name 355 ) 356 357 results = self.search_client.search( 358 search_text=None, 359 vector_queries=[vector_query], 360 filter=filter_expr, 361 top=top_k 362 ) 363 364 elif search_type == "keyword": 365 # Pure keyword search (BM25) 366 results = self.search_client.search( 367 search_text=query, 368 filter=filter_expr, 369 top=top_k, 370 include_total_count=False 371 ) 372 373 else: # hybrid 374 # Hybrid search (vector + keyword) 375 vector_query = VectorizedQuery( 376 vector=query_embedding, 377 k_nearest_neighbors=top_k, 378 fields=self.vector_field_name 379 ) 380 381 results = self.search_client.search( 382 search_text=query, 383 vector_queries=[vector_query], 384 filter=filter_expr, 385 top=top_k 386 ) 387 388 # Convert to SearchResult objects 389 search_results = [] 390 for rank, result in enumerate(results): 391 import json 392 393 # Deserialize full document data 394 doc_data = json.loads(result.get("document_data", "{}") or "{}") 395 396 if not doc_data: 397 # Index was built externally — fields are stored directly on the result. 398 # Collect all non-Azure-internal fields, excluding the vector blob. 399 doc_data = { 400 k: v for k, v in result.items() 401 if not k.startswith("@") and k != self.vector_field_name 402 } 403 # Remap content field when the index uses a non-standard name. 404 if self.content_field_name != "content" and self.content_field_name in doc_data: 405 doc_data["content"] = doc_data.pop(self.content_field_name) 406 407 # Reconstruct document using the configured document_type 408 doc = self.document_type.from_dict(doc_data) 409 410 search_results.append(SearchResult( 411 document=doc, 412 score=result.get("@search.score", 0.0), 413 rank=rank 414 )) 415 416 return search_results
Search Azure AI Search index.
Args: query: Text query (required for keyword/hybrid search) query_embedding: Vector embedding (required for vector/hybrid search) top_k: Number of results to return filters: Metadata filters (e.g., {"source": "doc.pdf"}) search_type: "vector", "keyword", or "hybrid"
Returns: List of SearchResult objects, ordered by relevance
418 def delete_documents(self, document_ids: List[str]) -> int: 419 """ 420 Delete documents from Azure AI Search. 421 422 Args: 423 document_ids: List of document IDs to delete 424 425 Returns: 426 Number of documents successfully deleted 427 """ 428 if not document_ids: 429 return 0 430 431 # Prepare delete operations 432 delete_docs = [{"id": doc_id} for doc_id in document_ids] 433 434 # Execute delete 435 result = self.search_client.delete_documents(documents=delete_docs) 436 437 # Count successful deletions 438 deleted_count = sum(1 for r in result if r.succeeded) 439 return deleted_count
Delete documents from Azure AI Search.
Args: document_ids: List of document IDs to delete
Returns: Number of documents successfully deleted
441 def update_document( 442 self, 443 document_id: str, 444 content: Optional[str] = None, 445 embedding: Optional[List[float]] = None, 446 metadata: Optional[Dict[str, Any]] = None, 447 timestamp: Optional[datetime] = None, 448 **custom_fields 449 ) -> bool: 450 """ 451 Update an existing document in Azure AI Search. 452 453 Args: 454 document_id: ID of the document to update 455 content: New content (if provided) 456 embedding: New embedding (if provided) 457 metadata: New metadata (merged with existing) 458 timestamp: New timestamp (if provided) 459 **custom_fields: Any custom fields to update 460 461 Returns: 462 True if update was successful, False if document not found 463 """ 464 try: 465 import json 466 467 # Retrieve existing document 468 existing_doc = self.search_client.get_document(key=document_id) 469 doc_data = json.loads(existing_doc.get("document_data", "{}")) 470 471 # Update fields in document data 472 if content is not None: 473 doc_data["content"] = content 474 existing_doc["content"] = content 475 476 if embedding is not None: 477 doc_data["embedding"] = embedding 478 existing_doc["embedding"] = embedding 479 480 if timestamp is not None: 481 doc_data["timestamp"] = self._format_datetime_for_azure(timestamp) 482 existing_doc["timestamp"] = self._format_datetime_for_azure(timestamp) 483 484 if metadata is not None: 485 if "metadata" not in doc_data: 486 doc_data["metadata"] = {} 487 doc_data["metadata"].update(metadata) 488 489 # Update custom fields 490 for key, value in custom_fields.items(): 491 doc_data[key] = value 492 493 # Update document_data 494 existing_doc["document_data"] = json.dumps(doc_data) 495 496 # Merge/upload updated document 497 result = self.search_client.merge_or_upload_documents(documents=[existing_doc]) 498 499 return result[0].succeeded 500 501 except Exception as e: 502 logger.error(f"Failed to update document {document_id}: {e}") 503 return False
Update an existing document in Azure AI Search.
Args: document_id: ID of the document to update content: New content (if provided) embedding: New embedding (if provided) metadata: New metadata (merged with existing) timestamp: New timestamp (if provided) **custom_fields: Any custom fields to update
Returns: True if update was successful, False if document not found
505 def get_document(self, document_id: str) -> Optional[Document]: 506 """ 507 Retrieve a document by ID from Azure AI Search. 508 509 Args: 510 document_id: ID of the document to retrieve 511 512 Returns: 513 Document if found, None otherwise (returns configured document_type) 514 """ 515 try: 516 import json 517 result = self.search_client.get_document(key=document_id) 518 519 # Deserialize full document data using configured document_type 520 doc_data = json.loads(result.get("document_data", "{}")) 521 return self.document_type.from_dict(doc_data) 522 523 except Exception as e: 524 logger.debug(f"Document {document_id} not found: {e}") 525 return None
Retrieve a document by ID from Azure AI Search.
Args: document_id: ID of the document to retrieve
Returns: Document if found, None otherwise (returns configured document_type)
527 def get_document_by_parent(self, document_id: str) -> List[Document]: 528 """ 529 Retrieve all chunks belonging to a parent document. 530 531 Args: 532 document_id: The value of the ``documentId`` field shared across 533 all chunks of the same parent document. 534 535 Returns: 536 List of Document objects sorted by ``pageNumber``, or an empty 537 list if no matching chunks are found. 538 """ 539 results = self.search( 540 query="*", 541 search_type="keyword", 542 filters={"documentId": document_id}, 543 top_k=1000, 544 ) 545 chunks = [r.document for r in results] 546 chunks.sort(key=lambda c: getattr(c, "pageNumber", 0)) 547 return chunks
Retrieve all chunks belonging to a parent document.
Args:
document_id: The value of the documentId field shared across
all chunks of the same parent document.
Returns:
List of Document objects sorted by pageNumber, or an empty
list if no matching chunks are found.
549 def count(self) -> int: 550 """ 551 Get the total number of documents in the index. 552 553 Returns: 554 Total document count 555 """ 556 # Execute a search to get total count 557 results = self.search_client.search( 558 search_text="*", 559 include_total_count=True, 560 top=0 561 ) 562 563 # Get count from results 564 count = getattr(results, 'get_count', lambda: 0)() 565 return count if count is not None else 0
Get the total number of documents in the index.
Returns: Total document count
567 def clear(self) -> None: 568 """ 569 Remove all documents from the index. 570 571 Warning: This operation is irreversible. 572 """ 573 # Search for all document IDs 574 results = self.search_client.search( 575 search_text="*", 576 select=["id"], 577 top=1000 # Adjust batch size as needed 578 ) 579 580 # Collect all IDs 581 doc_ids = [r["id"] for r in results] 582 583 # Delete in batches 584 if doc_ids: 585 self.delete_documents(doc_ids) 586 logger.info(f"Cleared {len(doc_ids)} documents from index '{self.index_name}'")
Remove all documents from the index.
Warning: This operation is irreversible.
32class AzureCosmosDBVectorStore(BaseVectorStore): 33 """ 34 Azure Cosmos DB NoSQL API vector store. 35 36 Stores documents in a Cosmos DB NoSQL container and queries them using 37 Cosmos DB's vector distance functions and SQL-like query language. 38 39 Features: 40 - Vector search using Cosmos DB vector indexing (cosine distance) 41 - Full-text keyword search via SQL CONTAINS / LIKE queries 42 - Hybrid search (vector + keyword, equal weighting) 43 - Metadata filtering via SQL WHERE clauses 44 - Custom ``document_type`` support (any Document subclass) 45 - Automatic container creation with vector embedding policy 46 47 Usage:: 48 49 from gmf_forge_ai_data.vector_stores import AzureCosmosDBVectorStore 50 51 store = AzureCosmosDBVectorStore( 52 endpoint="https://your-account.documents.azure.com:443/", 53 key="your-key", 54 database_name="rag_db", 55 container_name="documents", 56 embedding_dimension=1536, 57 ) 58 store.add_documents(docs_with_embeddings) 59 results = store.search(query_embedding=embedding, top_k=5) 60 """ 61 62 def __init__( 63 self, 64 endpoint: str, 65 key: str, 66 database_name: str, 67 container_name: str, 68 embedding_dimension: int = 1536, 69 document_type: type = Document, 70 ssl_cert_path: Optional[str] = None, 71 ) -> None: 72 """ 73 Initialise the Cosmos DB NoSQL vector store. 74 75 The database and container must be pre-provisioned using 76 ``CosmosDBIndexBuilder`` before first use. The constructor only 77 establishes client connections — it never creates or modifies 78 databases or containers. 79 80 Args: 81 endpoint: Cosmos DB account endpoint 82 (e.g. ``https://your-account.documents.azure.com:443/``). 83 key: Cosmos DB account key or resource token. 84 database_name: Name of the Cosmos DB database. 85 container_name: Name of the container. 86 embedding_dimension: Dimension of vector embeddings (default 1536). 87 document_type: Document class to use when deserialising results. 88 Pass a custom subclass to support additional typed fields. 89 ssl_cert_path: Path to a CA certificate bundle (PEM) for TLS 90 verification. Useful in corporate environments with custom 91 certificate authorities. 92 """ 93 try: 94 from azure.cosmos import CosmosClient # noqa: F401 95 except ImportError as exc: 96 raise ImportError( 97 "azure-cosmos is required for AzureCosmosDBVectorStore. " 98 "Install it with: pip install azure-cosmos" 99 ) from exc 100 101 self.embedding_dimension = embedding_dimension 102 self.document_type = document_type 103 self._container_name = container_name 104 105 # Build client with optional custom SSL cert 106 import ssl as _ssl 107 connection_kwargs: Dict[str, Any] = {} 108 if ssl_cert_path: 109 ssl_context = _ssl.create_default_context(cafile=ssl_cert_path) 110 connection_kwargs["connection_verify"] = ssl_cert_path 111 112 from azure.cosmos import CosmosClient 113 self._client = CosmosClient(endpoint, credential=key, **connection_kwargs) 114 115 # Database and container — pre-provisioned by developer 116 self._database = self._client.get_database_client(database_name) 117 self._container = self._database.get_container_client(container_name) 118 119 # ------------------------------------------------------------------ 120 # Serialisation helpers 121 # ------------------------------------------------------------------ 122 123 @staticmethod 124 def _serialize_value(obj: Any) -> Any: 125 """Recursively convert datetime objects to ISO strings.""" 126 if isinstance(obj, datetime): 127 return obj.isoformat() 128 if isinstance(obj, dict): 129 return {k: AzureCosmosDBVectorStore._serialize_value(v) for k, v in obj.items()} 130 if isinstance(obj, (list, tuple)): 131 return [AzureCosmosDBVectorStore._serialize_value(i) for i in obj] 132 return obj 133 134 def _to_cosmos(self, doc: Document) -> Dict[str, Any]: 135 """Convert a Document to a Cosmos DB item dict.""" 136 doc_dict = doc.to_dict() 137 doc_dict_serialisable = self._serialize_value(doc_dict) 138 item = { 139 "id": doc.id, 140 "content": doc.content, 141 "embedding": doc.embedding or [], 142 "timestamp": doc.timestamp.isoformat() if doc.timestamp else None, 143 "metadata": doc.metadata, 144 "document_data": json.dumps(doc_dict_serialisable), 145 } 146 # Promote custom fields (e.g. category, author, year) to top-level 147 # so they are filterable via SQL WHERE clauses 148 base_keys = {"id", "content", "embedding", "timestamp", "metadata"} 149 for key, value in doc_dict_serialisable.items(): 150 if key not in base_keys: 151 item[key] = value 152 return item 153 154 def _from_cosmos(self, cosmos_item: Dict[str, Any]) -> Document: 155 """Reconstruct a Document (or subclass) from a Cosmos DB item.""" 156 doc_data = json.loads(cosmos_item.get("document_data", "{}")) 157 return self.document_type.from_dict(doc_data) 158 159 # ------------------------------------------------------------------ 160 # Filter helpers 161 # ------------------------------------------------------------------ 162 163 @staticmethod 164 def _build_where_clause(filters: Optional[Dict[str, Any]]) -> str: 165 """Translate a user-facing filter dict to a SQL WHERE clause.""" 166 if not filters: 167 return "" 168 op_map = {">=": ">=", "<=": "<=", ">": ">", "<": "<", "!=": "!="} 169 conditions: List[str] = [] 170 for key, value in filters.items(): 171 if isinstance(value, tuple) and len(value) == 2: 172 op, val = value 173 sql_op = op_map.get(op, "=") 174 if isinstance(val, str): 175 conditions.append(f"c.{key} {sql_op} '{val}'") 176 elif isinstance(val, datetime): 177 conditions.append(f"c.{key} {sql_op} '{val.isoformat()}'") 178 else: 179 conditions.append(f"c.{key} {sql_op} {val}") 180 else: 181 if isinstance(value, str): 182 conditions.append(f"c.{key} = '{value}'") 183 elif isinstance(value, bool): 184 conditions.append(f"c.{key} = {'true' if value else 'false'}") 185 elif isinstance(value, datetime): 186 conditions.append(f"c.{key} = '{value.isoformat()}'") 187 else: 188 conditions.append(f"c.{key} = {value}") 189 return " AND ".join(conditions) 190 191 # ------------------------------------------------------------------ 192 # BaseVectorStore interface 193 # ------------------------------------------------------------------ 194 195 def add_documents( 196 self, 197 documents: List[Document], 198 generate_embeddings: bool = True, 199 ) -> List[str]: 200 """ 201 Upsert documents into the Cosmos DB container. 202 203 Args: 204 documents: Documents to add. Each must have ``id`` and ``content`` 205 set. If *generate_embeddings* is ``True`` (default) every 206 document must also carry a pre-computed ``embedding``. 207 generate_embeddings: When ``True`` the method validates that all 208 documents already have embeddings rather than generating them 209 itself (generation is the caller's responsibility). 210 211 Returns: 212 List of document IDs that were successfully upserted. 213 """ 214 if not documents: 215 return [] 216 217 added_ids: List[str] = [] 218 219 for doc in documents: 220 if not doc.id or not doc.content: 221 raise ValueError(f"Document must have id and content: {doc}") 222 223 if generate_embeddings and doc.embedding is None: 224 raise ValueError( 225 f"Document '{doc.id}' has no embedding. " 226 "Generate embeddings before calling add_documents()." 227 ) 228 229 cosmos_item = self._to_cosmos(doc) 230 self._container.upsert_item(cosmos_item) 231 added_ids.append(doc.id) 232 233 logger.info( 234 "Upserted %d document(s) into Cosmos DB container '%s'", 235 len(added_ids), 236 self._container_name, 237 ) 238 return added_ids 239 240 def search( 241 self, 242 query: Optional[str] = None, 243 query_embedding: Optional[List[float]] = None, 244 top_k: int = 5, 245 filters: Optional[Dict[str, Any]] = None, 246 search_type: str = "vector", 247 ) -> List[SearchResult]: 248 """ 249 Search the Cosmos DB container. 250 251 Args: 252 query: Plain-text query string. Required for ``keyword`` and 253 ``hybrid`` search types. 254 query_embedding: Pre-computed query vector. Required for 255 ``vector`` and ``hybrid`` search types. 256 top_k: Maximum number of results to return. 257 filters: Optional filter dict. Supports equality 258 (``{"key": value}``) and range (``{"key": (">=", value)}``) 259 operators. 260 search_type: One of ``"vector"``, ``"keyword"``, or ``"hybrid"``. 261 262 Returns: 263 Ranked list of :class:`SearchResult` objects. 264 """ 265 if search_type not in ("vector", "keyword", "hybrid"): 266 raise ValueError( 267 f"Invalid search_type '{search_type}'. " 268 "Must be 'vector', 'keyword', or 'hybrid'." 269 ) 270 if search_type in ("vector", "hybrid") and query_embedding is None: 271 raise ValueError(f"search_type='{search_type}' requires query_embedding") 272 if search_type in ("keyword", "hybrid") and query is None: 273 raise ValueError(f"search_type='{search_type}' requires query text") 274 275 if search_type == "vector": 276 return self._vector_search(query_embedding, top_k, filters) 277 if search_type == "keyword": 278 return self._keyword_search(query, top_k, filters) 279 return self._hybrid_search(query, query_embedding, top_k, filters) 280 281 # --- search helpers ----------------------------------------------- 282 283 def _vector_search( 284 self, 285 query_embedding: List[float], 286 top_k: int, 287 filters: Optional[Dict[str, Any]], 288 ) -> List[SearchResult]: 289 where_clause = self._build_where_clause(filters) 290 where_sql = f"WHERE {where_clause}" if where_clause else "" 291 292 query_text = ( 293 f"SELECT TOP @top_k c.id, c.document_data, " 294 f"VectorDistance(c.embedding, @embedding) AS score " 295 f"FROM c {where_sql} " 296 f"ORDER BY VectorDistance(c.embedding, @embedding)" 297 ) 298 parameters = [ 299 {"name": "@top_k", "value": top_k}, 300 {"name": "@embedding", "value": query_embedding}, 301 ] 302 303 items = list(self._container.query_items( 304 query=query_text, 305 parameters=parameters, 306 enable_cross_partition_query=True, 307 )) 308 309 results = [] 310 for rank, item in enumerate(items): 311 doc = self._from_cosmos(item) 312 # VectorDistance with cosine returns distance (0 = identical); 313 # convert to similarity score (1 = identical) 314 distance = float(item.get("score", 1.0)) 315 similarity = 1.0 - distance 316 results.append(SearchResult(document=doc, score=similarity, rank=rank)) 317 return results 318 319 def _keyword_search( 320 self, 321 query: str, 322 top_k: int, 323 filters: Optional[Dict[str, Any]], 324 ) -> List[SearchResult]: 325 # Build keyword conditions using CONTAINS for each query term 326 terms = query.split() 327 keyword_conditions = " OR ".join( 328 f"CONTAINS(LOWER(c.content), '{term.lower()}')" for term in terms 329 ) 330 where_clause = self._build_where_clause(filters) 331 filter_sql = f" AND {where_clause}" if where_clause else "" 332 333 # Fetch ALL matching documents — scoring and top_k slicing happen in Python. 334 # Applying SELECT TOP before Python scoring would miss high-scoring documents 335 # if Cosmos DB returns lower-scoring ones first (order without ORDER BY is undefined). 336 query_text = ( 337 f"SELECT c.id, c.document_data, c.content " 338 f"FROM c WHERE ({keyword_conditions}){filter_sql}" 339 ) 340 341 items = list(self._container.query_items( 342 query=query_text, 343 parameters=[], 344 enable_cross_partition_query=True, 345 )) 346 347 # Score by number of matching terms (simple BM25-like relevance) 348 results = [] 349 for item in items: 350 content_lower = item.get("content", "").lower() 351 match_count = sum(1 for t in terms if t.lower() in content_lower) 352 score = match_count / len(terms) if terms else 0.0 353 results.append(SearchResult( 354 document=self._from_cosmos(item), 355 score=score, 356 rank=0, 357 )) 358 359 results.sort(key=lambda r: r.score, reverse=True) 360 for rank, r in enumerate(results): 361 r.rank = rank 362 return results[:top_k] 363 364 def _hybrid_search( 365 self, 366 query: str, 367 query_embedding: List[float], 368 top_k: int, 369 filters: Optional[Dict[str, Any]], 370 ) -> List[SearchResult]: 371 """Combine vector and keyword results with equal (50/50) weighting.""" 372 vector_results = self._vector_search(query_embedding, top_k, filters) 373 keyword_results = self._keyword_search(query, top_k, filters) 374 375 def _max(results: List[SearchResult]) -> float: 376 return max((r.score for r in results), default=1.0) or 1.0 377 378 v_max = _max(vector_results) 379 k_max = _max(keyword_results) 380 381 combined: Dict[str, float] = {} 382 doc_map: Dict[str, Document] = {} 383 384 for r in vector_results: 385 combined[r.document.id] = 0.5 * (r.score / v_max) 386 doc_map[r.document.id] = r.document 387 388 for r in keyword_results: 389 combined[r.document.id] = ( 390 combined.get(r.document.id, 0.0) + 0.5 * (r.score / k_max) 391 ) 392 doc_map.setdefault(r.document.id, r.document) 393 394 sorted_ids = sorted(combined, key=lambda x: combined[x], reverse=True)[:top_k] 395 return [ 396 SearchResult( 397 document=doc_map[doc_id], 398 score=combined[doc_id], 399 rank=rank, 400 ) 401 for rank, doc_id in enumerate(sorted_ids) 402 ] 403 404 # ------------------------------------------------------------------ 405 # Document lifecycle 406 # ------------------------------------------------------------------ 407 408 def delete_documents(self, document_ids: List[str]) -> int: 409 """ 410 Delete documents by ID. 411 412 Args: 413 document_ids: IDs of the documents to remove. 414 415 Returns: 416 Number of documents actually deleted. 417 """ 418 deleted = 0 419 for doc_id in document_ids: 420 try: 421 self._container.delete_item(item=doc_id, partition_key=doc_id) 422 deleted += 1 423 except Exception: 424 logger.debug("Document '%s' not found for deletion", doc_id) 425 return deleted 426 427 def update_document( 428 self, 429 document_id: str, 430 content: Optional[str] = None, 431 embedding: Optional[List[float]] = None, 432 metadata: Optional[Dict[str, Any]] = None, 433 ) -> bool: 434 """ 435 Update an existing document. 436 437 Args: 438 document_id: ID of the document to update. 439 content: New text content (optional). 440 embedding: New embedding vector (optional). 441 metadata: Metadata key-value pairs to *merge* into existing metadata 442 (optional). 443 444 Returns: 445 ``True`` if the document was found and updated, ``False`` otherwise. 446 """ 447 try: 448 item = self._container.read_item(item=document_id, partition_key=document_id) 449 except Exception: 450 return False 451 452 doc = self._from_cosmos(item) 453 454 if content is not None: 455 doc.content = content 456 if embedding is not None: 457 doc.embedding = embedding 458 if metadata is not None: 459 doc.metadata.update(metadata) 460 461 self._container.upsert_item(self._to_cosmos(doc)) 462 return True 463 464 def get_document(self, document_id: str) -> Optional[Document]: 465 """ 466 Retrieve a single document by ID. 467 468 Args: 469 document_id: The document's unique identifier. 470 471 Returns: 472 The :class:`Document` (or subclass) if found, otherwise ``None``. 473 """ 474 try: 475 item = self._container.read_item(item=document_id, partition_key=document_id) 476 return self._from_cosmos(item) 477 except Exception: 478 return None 479 480 def count(self) -> int: 481 """Return the total number of documents in the container.""" 482 query_text = "SELECT VALUE COUNT(1) FROM c" 483 items = list(self._container.query_items( 484 query=query_text, 485 enable_cross_partition_query=True, 486 )) 487 return items[0] if items else 0 488 489 def clear(self) -> None: 490 """ 491 Remove **all** documents from the container. 492 493 Warning: 494 This operation is irreversible. 495 """ 496 query_text = "SELECT c.id FROM c" 497 items = list(self._container.query_items( 498 query=query_text, 499 enable_cross_partition_query=True, 500 )) 501 for item in items: 502 try: 503 self._container.delete_item(item=item["id"], partition_key=item["id"]) 504 except Exception: 505 pass 506 logger.warning( 507 "Cleared all documents from Cosmos DB container '%s'", 508 self._container_name, 509 ) 510 511 def close(self) -> None: 512 """Close the underlying Cosmos DB client.""" 513 # azure-cosmos CosmosClient doesn't require explicit close, 514 # but we provide the method for interface consistency 515 pass
Azure Cosmos DB NoSQL API vector store.
Stores documents in a Cosmos DB NoSQL container and queries them using Cosmos DB's vector distance functions and SQL-like query language.
Features:
- Vector search using Cosmos DB vector indexing (cosine distance)
- Full-text keyword search via SQL CONTAINS / LIKE queries
- Hybrid search (vector + keyword, equal weighting)
- Metadata filtering via SQL WHERE clauses
- Custom
document_typesupport (any Document subclass) - Automatic container creation with vector embedding policy
Usage::
from gmf_forge_ai_data.vector_stores import AzureCosmosDBVectorStore
store = AzureCosmosDBVectorStore(
endpoint="https://your-account.documents.azure.com:443/",
key="your-key",
database_name="rag_db",
container_name="documents",
embedding_dimension=1536,
)
store.add_documents(docs_with_embeddings)
results = store.search(query_embedding=embedding, top_k=5)
62 def __init__( 63 self, 64 endpoint: str, 65 key: str, 66 database_name: str, 67 container_name: str, 68 embedding_dimension: int = 1536, 69 document_type: type = Document, 70 ssl_cert_path: Optional[str] = None, 71 ) -> None: 72 """ 73 Initialise the Cosmos DB NoSQL vector store. 74 75 The database and container must be pre-provisioned using 76 ``CosmosDBIndexBuilder`` before first use. The constructor only 77 establishes client connections — it never creates or modifies 78 databases or containers. 79 80 Args: 81 endpoint: Cosmos DB account endpoint 82 (e.g. ``https://your-account.documents.azure.com:443/``). 83 key: Cosmos DB account key or resource token. 84 database_name: Name of the Cosmos DB database. 85 container_name: Name of the container. 86 embedding_dimension: Dimension of vector embeddings (default 1536). 87 document_type: Document class to use when deserialising results. 88 Pass a custom subclass to support additional typed fields. 89 ssl_cert_path: Path to a CA certificate bundle (PEM) for TLS 90 verification. Useful in corporate environments with custom 91 certificate authorities. 92 """ 93 try: 94 from azure.cosmos import CosmosClient # noqa: F401 95 except ImportError as exc: 96 raise ImportError( 97 "azure-cosmos is required for AzureCosmosDBVectorStore. " 98 "Install it with: pip install azure-cosmos" 99 ) from exc 100 101 self.embedding_dimension = embedding_dimension 102 self.document_type = document_type 103 self._container_name = container_name 104 105 # Build client with optional custom SSL cert 106 import ssl as _ssl 107 connection_kwargs: Dict[str, Any] = {} 108 if ssl_cert_path: 109 ssl_context = _ssl.create_default_context(cafile=ssl_cert_path) 110 connection_kwargs["connection_verify"] = ssl_cert_path 111 112 from azure.cosmos import CosmosClient 113 self._client = CosmosClient(endpoint, credential=key, **connection_kwargs) 114 115 # Database and container — pre-provisioned by developer 116 self._database = self._client.get_database_client(database_name) 117 self._container = self._database.get_container_client(container_name)
Initialise the Cosmos DB NoSQL vector store.
The database and container must be pre-provisioned using
CosmosDBIndexBuilder before first use. The constructor only
establishes client connections — it never creates or modifies
databases or containers.
Args:
endpoint: Cosmos DB account endpoint
(e.g. https://your-account.documents.azure.com:443/).
key: Cosmos DB account key or resource token.
database_name: Name of the Cosmos DB database.
container_name: Name of the container.
embedding_dimension: Dimension of vector embeddings (default 1536).
document_type: Document class to use when deserialising results.
Pass a custom subclass to support additional typed fields.
ssl_cert_path: Path to a CA certificate bundle (PEM) for TLS
verification. Useful in corporate environments with custom
certificate authorities.
195 def add_documents( 196 self, 197 documents: List[Document], 198 generate_embeddings: bool = True, 199 ) -> List[str]: 200 """ 201 Upsert documents into the Cosmos DB container. 202 203 Args: 204 documents: Documents to add. Each must have ``id`` and ``content`` 205 set. If *generate_embeddings* is ``True`` (default) every 206 document must also carry a pre-computed ``embedding``. 207 generate_embeddings: When ``True`` the method validates that all 208 documents already have embeddings rather than generating them 209 itself (generation is the caller's responsibility). 210 211 Returns: 212 List of document IDs that were successfully upserted. 213 """ 214 if not documents: 215 return [] 216 217 added_ids: List[str] = [] 218 219 for doc in documents: 220 if not doc.id or not doc.content: 221 raise ValueError(f"Document must have id and content: {doc}") 222 223 if generate_embeddings and doc.embedding is None: 224 raise ValueError( 225 f"Document '{doc.id}' has no embedding. " 226 "Generate embeddings before calling add_documents()." 227 ) 228 229 cosmos_item = self._to_cosmos(doc) 230 self._container.upsert_item(cosmos_item) 231 added_ids.append(doc.id) 232 233 logger.info( 234 "Upserted %d document(s) into Cosmos DB container '%s'", 235 len(added_ids), 236 self._container_name, 237 ) 238 return added_ids
Upsert documents into the Cosmos DB container.
Args:
documents: Documents to add. Each must have id and content
set. If generate_embeddings is True (default) every
document must also carry a pre-computed embedding.
generate_embeddings: When True the method validates that all
documents already have embeddings rather than generating them
itself (generation is the caller's responsibility).
Returns: List of document IDs that were successfully upserted.
240 def search( 241 self, 242 query: Optional[str] = None, 243 query_embedding: Optional[List[float]] = None, 244 top_k: int = 5, 245 filters: Optional[Dict[str, Any]] = None, 246 search_type: str = "vector", 247 ) -> List[SearchResult]: 248 """ 249 Search the Cosmos DB container. 250 251 Args: 252 query: Plain-text query string. Required for ``keyword`` and 253 ``hybrid`` search types. 254 query_embedding: Pre-computed query vector. Required for 255 ``vector`` and ``hybrid`` search types. 256 top_k: Maximum number of results to return. 257 filters: Optional filter dict. Supports equality 258 (``{"key": value}``) and range (``{"key": (">=", value)}``) 259 operators. 260 search_type: One of ``"vector"``, ``"keyword"``, or ``"hybrid"``. 261 262 Returns: 263 Ranked list of :class:`SearchResult` objects. 264 """ 265 if search_type not in ("vector", "keyword", "hybrid"): 266 raise ValueError( 267 f"Invalid search_type '{search_type}'. " 268 "Must be 'vector', 'keyword', or 'hybrid'." 269 ) 270 if search_type in ("vector", "hybrid") and query_embedding is None: 271 raise ValueError(f"search_type='{search_type}' requires query_embedding") 272 if search_type in ("keyword", "hybrid") and query is None: 273 raise ValueError(f"search_type='{search_type}' requires query text") 274 275 if search_type == "vector": 276 return self._vector_search(query_embedding, top_k, filters) 277 if search_type == "keyword": 278 return self._keyword_search(query, top_k, filters) 279 return self._hybrid_search(query, query_embedding, top_k, filters)
Search the Cosmos DB container.
Args:
query: Plain-text query string. Required for keyword and
hybrid search types.
query_embedding: Pre-computed query vector. Required for
vector and hybrid search types.
top_k: Maximum number of results to return.
filters: Optional filter dict. Supports equality
({"key": value}) and range ({"key": (">=", value)})
operators.
search_type: One of "vector", "keyword", or "hybrid".
Returns:
Ranked list of SearchResult objects.
408 def delete_documents(self, document_ids: List[str]) -> int: 409 """ 410 Delete documents by ID. 411 412 Args: 413 document_ids: IDs of the documents to remove. 414 415 Returns: 416 Number of documents actually deleted. 417 """ 418 deleted = 0 419 for doc_id in document_ids: 420 try: 421 self._container.delete_item(item=doc_id, partition_key=doc_id) 422 deleted += 1 423 except Exception: 424 logger.debug("Document '%s' not found for deletion", doc_id) 425 return deleted
Delete documents by ID.
Args: document_ids: IDs of the documents to remove.
Returns: Number of documents actually deleted.
427 def update_document( 428 self, 429 document_id: str, 430 content: Optional[str] = None, 431 embedding: Optional[List[float]] = None, 432 metadata: Optional[Dict[str, Any]] = None, 433 ) -> bool: 434 """ 435 Update an existing document. 436 437 Args: 438 document_id: ID of the document to update. 439 content: New text content (optional). 440 embedding: New embedding vector (optional). 441 metadata: Metadata key-value pairs to *merge* into existing metadata 442 (optional). 443 444 Returns: 445 ``True`` if the document was found and updated, ``False`` otherwise. 446 """ 447 try: 448 item = self._container.read_item(item=document_id, partition_key=document_id) 449 except Exception: 450 return False 451 452 doc = self._from_cosmos(item) 453 454 if content is not None: 455 doc.content = content 456 if embedding is not None: 457 doc.embedding = embedding 458 if metadata is not None: 459 doc.metadata.update(metadata) 460 461 self._container.upsert_item(self._to_cosmos(doc)) 462 return True
Update an existing document.
Args: document_id: ID of the document to update. content: New text content (optional). embedding: New embedding vector (optional). metadata: Metadata key-value pairs to merge into existing metadata (optional).
Returns:
True if the document was found and updated, False otherwise.
464 def get_document(self, document_id: str) -> Optional[Document]: 465 """ 466 Retrieve a single document by ID. 467 468 Args: 469 document_id: The document's unique identifier. 470 471 Returns: 472 The :class:`Document` (or subclass) if found, otherwise ``None``. 473 """ 474 try: 475 item = self._container.read_item(item=document_id, partition_key=document_id) 476 return self._from_cosmos(item) 477 except Exception: 478 return None
Retrieve a single document by ID.
Args: document_id: The document's unique identifier.
Returns:
The Document (or subclass) if found, otherwise None.
480 def count(self) -> int: 481 """Return the total number of documents in the container.""" 482 query_text = "SELECT VALUE COUNT(1) FROM c" 483 items = list(self._container.query_items( 484 query=query_text, 485 enable_cross_partition_query=True, 486 )) 487 return items[0] if items else 0
Return the total number of documents in the container.
489 def clear(self) -> None: 490 """ 491 Remove **all** documents from the container. 492 493 Warning: 494 This operation is irreversible. 495 """ 496 query_text = "SELECT c.id FROM c" 497 items = list(self._container.query_items( 498 query=query_text, 499 enable_cross_partition_query=True, 500 )) 501 for item in items: 502 try: 503 self._container.delete_item(item=item["id"], partition_key=item["id"]) 504 except Exception: 505 pass 506 logger.warning( 507 "Cleared all documents from Cosmos DB container '%s'", 508 self._container_name, 509 )
Remove all documents from the container.
Warning: This operation is irreversible.
35class MongoDBVectorStore(BaseVectorStore): 36 """ 37 MongoDB Atlas Vector Search vector store. 38 39 Stores documents in a MongoDB collection and queries them using the 40 ``$vectorSearch`` aggregation stage. Hybrid search is implemented by 41 running vector and text searches separately and combining their normalised 42 scores with a 50 / 50 weight. 43 44 Features: 45 - Atlas Vector Search (approximate KNN, configurable candidates) 46 - Full-text keyword search via ``$text`` index 47 - Hybrid search (vector + keyword, equal weighting) 48 - Metadata pre-filtering on vector search 49 - Custom ``document_type`` support (any Document subclass) 50 51 Note: 52 The Atlas Vector Search index (``$vectorSearch``) must be created 53 outside of this class — through the MongoDB Atlas UI, Atlas CLI, or 54 Atlas Admin API. The ``vector_index_name`` parameter must match the 55 name you gave that index. 56 57 Usage:: 58 59 from gmf_forge_ai_data.vector_stores import MongoDBVectorStore 60 61 store = MongoDBVectorStore( 62 connection_string="mongodb+srv://...", 63 database_name="rag_db", 64 collection_name="documents", 65 vector_index_name="vector_index", 66 embedding_dimension=1536, 67 ) 68 store.add_documents(docs_with_embeddings) 69 results = store.search(query_embedding=embedding, top_k=5) 70 """ 71 72 def __init__( 73 self, 74 connection_string: str, 75 database_name: str, 76 collection_name: str, 77 vector_index_name: str = "vector_index", 78 embedding_dimension: int = 1536, 79 document_type: type = Document, 80 ssl_cert_path: Optional[str] = None, 81 ) -> None: 82 """ 83 Initialise the MongoDB Atlas vector store. 84 85 Args: 86 connection_string: MongoDB Atlas connection string 87 (e.g. ``mongodb+srv://user:pass@cluster.mongodb.net/``). 88 database_name: Name of the MongoDB database. 89 collection_name: Name of the collection. 90 vector_index_name: Name of the Atlas Vector Search index defined 91 on this collection. Defaults to ``"vector_index"``. 92 embedding_dimension: Dimension of vector embeddings. Must match 93 the dimension configured in the Atlas Vector Search index 94 (default 1536). 95 document_type: Document class to use when deserialising results. 96 Pass a custom subclass to support additional typed fields. 97 ssl_cert_path: Path to a CA certificate bundle (PEM) for TLS 98 verification. Useful in corporate environments with custom 99 certificate authorities. 100 """ 101 try: 102 import pymongo # noqa: F401 — checked at construction time 103 except ImportError as exc: 104 raise ImportError( 105 "pymongo is required for MongoDBVectorStore. " 106 "Install it with: pip install pymongo" 107 ) from exc 108 109 self.embedding_dimension = embedding_dimension 110 self.document_type = document_type 111 self.vector_index_name = vector_index_name 112 self._collection_name = collection_name 113 114 client_kwargs: Dict[str, Any] = {} 115 if ssl_cert_path: 116 client_kwargs["tlsCAFile"] = ssl_cert_path 117 118 import pymongo 119 120 self._client: pymongo.MongoClient = pymongo.MongoClient( 121 connection_string, **client_kwargs 122 ) 123 self._db = self._client[database_name] 124 self._collection = self._db[collection_name] 125 126 # ------------------------------------------------------------------ 127 # Serialisation helpers 128 # ------------------------------------------------------------------ 129 130 @staticmethod 131 def _serialize_value(obj: Any) -> Any: 132 """Recursively convert datetime objects to ISO strings.""" 133 if isinstance(obj, datetime): 134 return obj.isoformat() 135 if isinstance(obj, dict): 136 return {k: MongoDBVectorStore._serialize_value(v) for k, v in obj.items()} 137 if isinstance(obj, (list, tuple)): 138 return [MongoDBVectorStore._serialize_value(i) for i in obj] 139 return obj 140 141 def _to_mongo(self, doc: Document) -> Dict[str, Any]: 142 """Convert a Document to a MongoDB document dict.""" 143 doc_dict = doc.to_dict() 144 doc_dict_serialisable = self._serialize_value(doc_dict) 145 mongo_doc = { 146 "_id": doc.id, 147 "id": doc.id, 148 "content": doc.content, 149 "embedding": doc.embedding or [], 150 "timestamp": doc.timestamp.isoformat() if doc.timestamp else None, 151 "metadata": doc.metadata, 152 "document_data": json.dumps(doc_dict_serialisable), 153 } 154 # Promote custom fields (e.g. category, published_year, field, institution) 155 # to top level so they are filterable via $vectorSearch filter and $match. 156 base_keys = {"id", "content", "embedding", "timestamp", "metadata"} 157 for key, value in doc_dict_serialisable.items(): 158 if key not in base_keys: 159 mongo_doc[key] = value 160 return mongo_doc 161 162 def _from_mongo(self, mongo_doc: Dict[str, Any]) -> Document: 163 """Reconstruct a Document (or subclass) from a MongoDB document.""" 164 doc_data = json.loads(mongo_doc.get("document_data", "{}")) 165 return self.document_type.from_dict(doc_data) 166 167 # ------------------------------------------------------------------ 168 # Filter helpers 169 # ------------------------------------------------------------------ 170 171 @staticmethod 172 def _build_filter(filters: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: 173 """Translate a user-facing filter dict to a MongoDB query expression.""" 174 if not filters: 175 return None 176 op_map = {">=": "$gte", "<=": "$lte", ">": "$gt", "<": "$lt", "!=": "$ne"} 177 match: Dict[str, Any] = {} 178 for key, value in filters.items(): 179 if isinstance(value, tuple) and len(value) == 2: 180 op, val = value 181 match[key] = {op_map[op]: val} if op in op_map else val 182 else: 183 match[key] = value 184 return match 185 186 # ------------------------------------------------------------------ 187 # BaseVectorStore interface 188 # ------------------------------------------------------------------ 189 190 def add_documents( 191 self, 192 documents: List[Document], 193 generate_embeddings: bool = True, 194 ) -> List[str]: 195 """ 196 Upsert documents into the MongoDB collection. 197 198 Args: 199 documents: Documents to add. Each must have ``id`` and ``content`` 200 set. If *generate_embeddings* is ``True`` (default) every 201 document must also carry a pre-computed ``embedding``. 202 generate_embeddings: When ``True`` the method validates that all 203 documents already have embeddings rather than generating them 204 itself (generation is the caller's responsibility). 205 206 Returns: 207 List of document IDs that were successfully upserted. 208 """ 209 if not documents: 210 return [] 211 212 added_ids: List[str] = [] 213 214 for doc in documents: 215 if not doc.id or not doc.content: 216 raise ValueError(f"Document must have id and content: {doc}") 217 218 if generate_embeddings and doc.embedding is None: 219 raise ValueError( 220 f"Document '{doc.id}' has no embedding. " 221 "Generate embeddings before calling add_documents()." 222 ) 223 224 mongo_doc = self._to_mongo(doc) 225 self._collection.replace_one( 226 {"_id": mongo_doc["_id"]}, mongo_doc, upsert=True 227 ) 228 added_ids.append(doc.id) 229 230 logger.info( 231 "Upserted %d document(s) into MongoDB collection '%s'", 232 len(added_ids), 233 self._collection_name, 234 ) 235 return added_ids 236 237 def search( 238 self, 239 query: Optional[str] = None, 240 query_embedding: Optional[List[float]] = None, 241 top_k: int = 5, 242 filters: Optional[Dict[str, Any]] = None, 243 search_type: str = "vector", 244 ) -> List[SearchResult]: 245 """ 246 Search the MongoDB collection. 247 248 Args: 249 query: Plain-text query string. Required for ``keyword`` and 250 ``hybrid`` search types. 251 query_embedding: Pre-computed query vector. Required for 252 ``vector`` and ``hybrid`` search types. 253 top_k: Maximum number of results to return. 254 filters: Optional metadata filter dict. Supports equality 255 (``{"key": value}``) and range (``{"key": (">=", value)}``) 256 operators. 257 search_type: One of ``"vector"``, ``"keyword"``, or ``"hybrid"``. 258 259 Returns: 260 Ranked list of :class:`SearchResult` objects. 261 """ 262 if search_type not in ("vector", "keyword", "hybrid"): 263 raise ValueError( 264 f"Invalid search_type '{search_type}'. " 265 "Must be 'vector', 'keyword', or 'hybrid'." 266 ) 267 if search_type in ("vector", "hybrid") and query_embedding is None: 268 raise ValueError(f"search_type='{search_type}' requires query_embedding") 269 if search_type in ("keyword", "hybrid") and query is None: 270 raise ValueError(f"search_type='{search_type}' requires query text") 271 272 if search_type == "vector": 273 return self._vector_search(query_embedding, top_k, filters) 274 if search_type == "keyword": 275 return self._keyword_search(query, top_k, filters) 276 return self._hybrid_search(query, query_embedding, top_k, filters) 277 278 # --- search helpers ----------------------------------------------- 279 280 def _vector_search( 281 self, 282 query_embedding: List[float], 283 top_k: int, 284 filters: Optional[Dict[str, Any]], 285 ) -> List[SearchResult]: 286 """Execute an Atlas ``$vectorSearch`` aggregation pipeline.""" 287 vector_stage: Dict[str, Any] = { 288 "index": self.vector_index_name, 289 "path": "embedding", 290 "queryVector": query_embedding, 291 # numCandidates controls recall vs. latency trade-off. 292 # A value of top_k * 10 (min 100) is a sensible default. 293 "numCandidates": max(top_k * 10, 100), 294 "limit": top_k, 295 } 296 297 pre_filter = self._build_filter(filters) 298 if pre_filter: 299 vector_stage["filter"] = pre_filter 300 301 pipeline = [ 302 {"$vectorSearch": vector_stage}, 303 { 304 "$project": { 305 "_id": 1, 306 "document_data": 1, 307 "score": {"$meta": "vectorSearchScore"}, 308 } 309 }, 310 ] 311 312 try: 313 return [ 314 SearchResult( 315 document=self._from_mongo(raw), 316 score=float(raw.get("score", 0.0)), 317 rank=rank, 318 ) 319 for rank, raw in enumerate(self._collection.aggregate(pipeline)) 320 ] 321 except Exception as exc: 322 msg = str(exc) 323 if "localhost:28000" in msg or "HostUnreachable" in msg or "PlanExecutor" in msg: 324 raise RuntimeError( 325 f"Atlas Vector Search index '{self.vector_index_name}' is not available. " 326 "Create it in the Atlas UI (collection → Search Indexes → Create Search Index → " 327 "Atlas Vector Search) with field='embedding', " 328 f"numDimensions={self.embedding_dimension}, similarity=cosine, " 329 f"indexName='{self.vector_index_name}'." 330 ) from exc 331 raise 332 333 def _keyword_search( 334 self, 335 query: str, 336 top_k: int, 337 filters: Optional[Dict[str, Any]], 338 ) -> List[SearchResult]: 339 """Execute a full-text search via the MongoDB ``$text`` operator. 340 341 Requires a text index on the ``content`` field to be created by the 342 developer before use:: 343 344 db.collection.createIndex({ "content": "text" }) 345 346 or via the Atlas UI / CLI / Admin API. 347 """ 348 mongo_filter: Dict[str, Any] = {"$text": {"$search": query}} 349 pre_filter = self._build_filter(filters) 350 if pre_filter: 351 mongo_filter.update(pre_filter) 352 353 cursor = ( 354 self._collection.find( 355 mongo_filter, 356 {"document_data": 1, "textScore": {"$meta": "textScore"}}, 357 ) 358 .sort([("textScore", {"$meta": "textScore"})]) 359 .limit(top_k) 360 ) 361 362 try: 363 return [ 364 SearchResult( 365 document=self._from_mongo(raw), 366 score=float(raw.get("textScore", 0.0)), 367 rank=rank, 368 ) 369 for rank, raw in enumerate(cursor) 370 ] 371 except Exception as exc: 372 msg = str(exc) 373 if "text index required" in msg.lower() or "$text" in msg: 374 raise RuntimeError( 375 "No text index found on the collection. " 376 "Create one before using keyword or hybrid search: " 377 f"db.{self._collection.name}.createIndex({{ 'content': 'text' }})" 378 ) from exc 379 raise 380 381 def _hybrid_search( 382 self, 383 query: str, 384 query_embedding: List[float], 385 top_k: int, 386 filters: Optional[Dict[str, Any]], 387 ) -> List[SearchResult]: 388 """Combine vector and keyword results with equal (50/50) weighting.""" 389 vector_results = self._vector_search(query_embedding, top_k, filters) 390 keyword_results = self._keyword_search(query, top_k, filters) 391 392 def _max(results: List[SearchResult]) -> float: 393 return max((r.score for r in results), default=1.0) or 1.0 394 395 v_max = _max(vector_results) 396 k_max = _max(keyword_results) 397 398 combined: Dict[str, float] = {} 399 doc_map: Dict[str, Document] = {} 400 401 for r in vector_results: 402 combined[r.document.id] = 0.5 * (r.score / v_max) 403 doc_map[r.document.id] = r.document 404 405 for r in keyword_results: 406 combined[r.document.id] = ( 407 combined.get(r.document.id, 0.0) + 0.5 * (r.score / k_max) 408 ) 409 doc_map.setdefault(r.document.id, r.document) 410 411 sorted_ids = sorted(combined, key=lambda x: combined[x], reverse=True)[:top_k] 412 return [ 413 SearchResult( 414 document=doc_map[doc_id], 415 score=combined[doc_id], 416 rank=rank, 417 ) 418 for rank, doc_id in enumerate(sorted_ids) 419 ] 420 421 # ------------------------------------------------------------------ 422 # Document lifecycle 423 # ------------------------------------------------------------------ 424 425 def delete_documents(self, document_ids: List[str]) -> int: 426 """ 427 Delete documents by ID. 428 429 Args: 430 document_ids: IDs of the documents to remove. 431 432 Returns: 433 Number of documents actually deleted. 434 """ 435 result = self._collection.delete_many({"_id": {"$in": document_ids}}) 436 return result.deleted_count 437 438 def update_document( 439 self, 440 document_id: str, 441 content: Optional[str] = None, 442 embedding: Optional[List[float]] = None, 443 metadata: Optional[Dict[str, Any]] = None, 444 ) -> bool: 445 """ 446 Update an existing document. 447 448 Args: 449 document_id: ID of the document to update. 450 content: New text content (optional). 451 embedding: New embedding vector (optional). 452 metadata: Metadata key-value pairs to *merge* into existing metadata 453 (optional). 454 455 Returns: 456 ``True`` if the document was found and updated, ``False`` otherwise. 457 """ 458 raw = self._collection.find_one({"_id": document_id}) 459 if raw is None: 460 return False 461 462 doc = self._from_mongo(raw) 463 464 if content is not None: 465 doc.content = content 466 if embedding is not None: 467 doc.embedding = embedding 468 if metadata is not None: 469 doc.metadata.update(metadata) 470 471 self._collection.replace_one({"_id": document_id}, self._to_mongo(doc)) 472 return True 473 474 def get_document(self, document_id: str) -> Optional[Document]: 475 """ 476 Retrieve a single document by ID. 477 478 Args: 479 document_id: The document's unique identifier. 480 481 Returns: 482 The :class:`Document` (or subclass) if found, otherwise ``None``. 483 """ 484 raw = self._collection.find_one({"_id": document_id}) 485 return None if raw is None else self._from_mongo(raw) 486 487 def count(self) -> int: 488 """Return the total number of documents in the collection.""" 489 return self._collection.count_documents({}) 490 491 def clear(self) -> None: 492 """ 493 Remove **all** documents from the collection. 494 495 Warning: 496 This operation is irreversible. 497 """ 498 self._collection.delete_many({}) 499 logger.warning( 500 "Cleared all documents from MongoDB collection '%s'", 501 self._collection_name, 502 ) 503 504 def close(self) -> None: 505 """Close the underlying MongoDB client connection.""" 506 self._client.close()
MongoDB Atlas Vector Search vector store.
Stores documents in a MongoDB collection and queries them using the
$vectorSearch aggregation stage. Hybrid search is implemented by
running vector and text searches separately and combining their normalised
scores with a 50 / 50 weight.
Features:
- Atlas Vector Search (approximate KNN, configurable candidates)
- Full-text keyword search via
$textindex - Hybrid search (vector + keyword, equal weighting)
- Metadata pre-filtering on vector search
- Custom
document_typesupport (any Document subclass)
Note:
The Atlas Vector Search index ($vectorSearch) must be created
outside of this class — through the MongoDB Atlas UI, Atlas CLI, or
Atlas Admin API. The vector_index_name parameter must match the
name you gave that index.
Usage::
from gmf_forge_ai_data.vector_stores import MongoDBVectorStore
store = MongoDBVectorStore(
connection_string="mongodb+srv://...",
database_name="rag_db",
collection_name="documents",
vector_index_name="vector_index",
embedding_dimension=1536,
)
store.add_documents(docs_with_embeddings)
results = store.search(query_embedding=embedding, top_k=5)
72 def __init__( 73 self, 74 connection_string: str, 75 database_name: str, 76 collection_name: str, 77 vector_index_name: str = "vector_index", 78 embedding_dimension: int = 1536, 79 document_type: type = Document, 80 ssl_cert_path: Optional[str] = None, 81 ) -> None: 82 """ 83 Initialise the MongoDB Atlas vector store. 84 85 Args: 86 connection_string: MongoDB Atlas connection string 87 (e.g. ``mongodb+srv://user:pass@cluster.mongodb.net/``). 88 database_name: Name of the MongoDB database. 89 collection_name: Name of the collection. 90 vector_index_name: Name of the Atlas Vector Search index defined 91 on this collection. Defaults to ``"vector_index"``. 92 embedding_dimension: Dimension of vector embeddings. Must match 93 the dimension configured in the Atlas Vector Search index 94 (default 1536). 95 document_type: Document class to use when deserialising results. 96 Pass a custom subclass to support additional typed fields. 97 ssl_cert_path: Path to a CA certificate bundle (PEM) for TLS 98 verification. Useful in corporate environments with custom 99 certificate authorities. 100 """ 101 try: 102 import pymongo # noqa: F401 — checked at construction time 103 except ImportError as exc: 104 raise ImportError( 105 "pymongo is required for MongoDBVectorStore. " 106 "Install it with: pip install pymongo" 107 ) from exc 108 109 self.embedding_dimension = embedding_dimension 110 self.document_type = document_type 111 self.vector_index_name = vector_index_name 112 self._collection_name = collection_name 113 114 client_kwargs: Dict[str, Any] = {} 115 if ssl_cert_path: 116 client_kwargs["tlsCAFile"] = ssl_cert_path 117 118 import pymongo 119 120 self._client: pymongo.MongoClient = pymongo.MongoClient( 121 connection_string, **client_kwargs 122 ) 123 self._db = self._client[database_name] 124 self._collection = self._db[collection_name]
Initialise the MongoDB Atlas vector store.
Args:
connection_string: MongoDB Atlas connection string
(e.g. mongodb+srv://user:pass@cluster.mongodb.net/).
database_name: Name of the MongoDB database.
collection_name: Name of the collection.
vector_index_name: Name of the Atlas Vector Search index defined
on this collection. Defaults to "vector_index".
embedding_dimension: Dimension of vector embeddings. Must match
the dimension configured in the Atlas Vector Search index
(default 1536).
document_type: Document class to use when deserialising results.
Pass a custom subclass to support additional typed fields.
ssl_cert_path: Path to a CA certificate bundle (PEM) for TLS
verification. Useful in corporate environments with custom
certificate authorities.
190 def add_documents( 191 self, 192 documents: List[Document], 193 generate_embeddings: bool = True, 194 ) -> List[str]: 195 """ 196 Upsert documents into the MongoDB collection. 197 198 Args: 199 documents: Documents to add. Each must have ``id`` and ``content`` 200 set. If *generate_embeddings* is ``True`` (default) every 201 document must also carry a pre-computed ``embedding``. 202 generate_embeddings: When ``True`` the method validates that all 203 documents already have embeddings rather than generating them 204 itself (generation is the caller's responsibility). 205 206 Returns: 207 List of document IDs that were successfully upserted. 208 """ 209 if not documents: 210 return [] 211 212 added_ids: List[str] = [] 213 214 for doc in documents: 215 if not doc.id or not doc.content: 216 raise ValueError(f"Document must have id and content: {doc}") 217 218 if generate_embeddings and doc.embedding is None: 219 raise ValueError( 220 f"Document '{doc.id}' has no embedding. " 221 "Generate embeddings before calling add_documents()." 222 ) 223 224 mongo_doc = self._to_mongo(doc) 225 self._collection.replace_one( 226 {"_id": mongo_doc["_id"]}, mongo_doc, upsert=True 227 ) 228 added_ids.append(doc.id) 229 230 logger.info( 231 "Upserted %d document(s) into MongoDB collection '%s'", 232 len(added_ids), 233 self._collection_name, 234 ) 235 return added_ids
Upsert documents into the MongoDB collection.
Args:
documents: Documents to add. Each must have id and content
set. If generate_embeddings is True (default) every
document must also carry a pre-computed embedding.
generate_embeddings: When True the method validates that all
documents already have embeddings rather than generating them
itself (generation is the caller's responsibility).
Returns: List of document IDs that were successfully upserted.
237 def search( 238 self, 239 query: Optional[str] = None, 240 query_embedding: Optional[List[float]] = None, 241 top_k: int = 5, 242 filters: Optional[Dict[str, Any]] = None, 243 search_type: str = "vector", 244 ) -> List[SearchResult]: 245 """ 246 Search the MongoDB collection. 247 248 Args: 249 query: Plain-text query string. Required for ``keyword`` and 250 ``hybrid`` search types. 251 query_embedding: Pre-computed query vector. Required for 252 ``vector`` and ``hybrid`` search types. 253 top_k: Maximum number of results to return. 254 filters: Optional metadata filter dict. Supports equality 255 (``{"key": value}``) and range (``{"key": (">=", value)}``) 256 operators. 257 search_type: One of ``"vector"``, ``"keyword"``, or ``"hybrid"``. 258 259 Returns: 260 Ranked list of :class:`SearchResult` objects. 261 """ 262 if search_type not in ("vector", "keyword", "hybrid"): 263 raise ValueError( 264 f"Invalid search_type '{search_type}'. " 265 "Must be 'vector', 'keyword', or 'hybrid'." 266 ) 267 if search_type in ("vector", "hybrid") and query_embedding is None: 268 raise ValueError(f"search_type='{search_type}' requires query_embedding") 269 if search_type in ("keyword", "hybrid") and query is None: 270 raise ValueError(f"search_type='{search_type}' requires query text") 271 272 if search_type == "vector": 273 return self._vector_search(query_embedding, top_k, filters) 274 if search_type == "keyword": 275 return self._keyword_search(query, top_k, filters) 276 return self._hybrid_search(query, query_embedding, top_k, filters)
Search the MongoDB collection.
Args:
query: Plain-text query string. Required for keyword and
hybrid search types.
query_embedding: Pre-computed query vector. Required for
vector and hybrid search types.
top_k: Maximum number of results to return.
filters: Optional metadata filter dict. Supports equality
({"key": value}) and range ({"key": (">=", value)})
operators.
search_type: One of "vector", "keyword", or "hybrid".
Returns:
Ranked list of SearchResult objects.
425 def delete_documents(self, document_ids: List[str]) -> int: 426 """ 427 Delete documents by ID. 428 429 Args: 430 document_ids: IDs of the documents to remove. 431 432 Returns: 433 Number of documents actually deleted. 434 """ 435 result = self._collection.delete_many({"_id": {"$in": document_ids}}) 436 return result.deleted_count
Delete documents by ID.
Args: document_ids: IDs of the documents to remove.
Returns: Number of documents actually deleted.
438 def update_document( 439 self, 440 document_id: str, 441 content: Optional[str] = None, 442 embedding: Optional[List[float]] = None, 443 metadata: Optional[Dict[str, Any]] = None, 444 ) -> bool: 445 """ 446 Update an existing document. 447 448 Args: 449 document_id: ID of the document to update. 450 content: New text content (optional). 451 embedding: New embedding vector (optional). 452 metadata: Metadata key-value pairs to *merge* into existing metadata 453 (optional). 454 455 Returns: 456 ``True`` if the document was found and updated, ``False`` otherwise. 457 """ 458 raw = self._collection.find_one({"_id": document_id}) 459 if raw is None: 460 return False 461 462 doc = self._from_mongo(raw) 463 464 if content is not None: 465 doc.content = content 466 if embedding is not None: 467 doc.embedding = embedding 468 if metadata is not None: 469 doc.metadata.update(metadata) 470 471 self._collection.replace_one({"_id": document_id}, self._to_mongo(doc)) 472 return True
Update an existing document.
Args: document_id: ID of the document to update. content: New text content (optional). embedding: New embedding vector (optional). metadata: Metadata key-value pairs to merge into existing metadata (optional).
Returns:
True if the document was found and updated, False otherwise.
474 def get_document(self, document_id: str) -> Optional[Document]: 475 """ 476 Retrieve a single document by ID. 477 478 Args: 479 document_id: The document's unique identifier. 480 481 Returns: 482 The :class:`Document` (or subclass) if found, otherwise ``None``. 483 """ 484 raw = self._collection.find_one({"_id": document_id}) 485 return None if raw is None else self._from_mongo(raw)
Retrieve a single document by ID.
Args: document_id: The document's unique identifier.
Returns:
The Document (or subclass) if found, otherwise None.
487 def count(self) -> int: 488 """Return the total number of documents in the collection.""" 489 return self._collection.count_documents({})
Return the total number of documents in the collection.
491 def clear(self) -> None: 492 """ 493 Remove **all** documents from the collection. 494 495 Warning: 496 This operation is irreversible. 497 """ 498 self._collection.delete_many({}) 499 logger.warning( 500 "Cleared all documents from MongoDB collection '%s'", 501 self._collection_name, 502 )
Remove all documents from the collection.
Warning: This operation is irreversible.
35class BaseRetriever(ABC): 36 """ 37 Abstract base class for all retriever implementations. 38 39 Retrievers provide various strategies for finding relevant documents 40 from a vector store or other data source. 41 42 All implementations must provide: 43 - retrieve(): Main retrieval method 44 """ 45 46 @abstractmethod 47 def retrieve( 48 self, 49 query: RetrievalQuery 50 ) -> List[SearchResult]: 51 """ 52 Retrieve relevant documents based on the query. 53 54 Args: 55 query: RetrievalQuery containing query parameters 56 57 Returns: 58 List of SearchResult objects, ordered by relevance 59 60 Raises: 61 ValueError: If required query parameters are missing 62 """ 63 pass 64 65 def retrieve_text( 66 self, 67 text: str, 68 top_k: int = 5, 69 filters: Optional[Dict[str, Any]] = None 70 ) -> List[SearchResult]: 71 """ 72 Convenience method for text-based retrieval. 73 74 Args: 75 text: Query text 76 top_k: Number of results to retrieve 77 filters: Optional metadata filters 78 79 Returns: 80 List of SearchResult objects 81 """ 82 query = RetrievalQuery( 83 text=text, 84 top_k=top_k, 85 filters=filters 86 ) 87 return self.retrieve(query) 88 89 def retrieve_embedding( 90 self, 91 embedding: List[float], 92 top_k: int = 5, 93 filters: Optional[Dict[str, Any]] = None 94 ) -> List[SearchResult]: 95 """ 96 Convenience method for embedding-based retrieval. 97 98 Args: 99 embedding: Query embedding vector 100 top_k: Number of results to retrieve 101 filters: Optional metadata filters 102 103 Returns: 104 List of SearchResult objects 105 """ 106 query = RetrievalQuery( 107 embedding=embedding, 108 top_k=top_k, 109 filters=filters 110 ) 111 return self.retrieve(query)
Abstract base class for all retriever implementations.
Retrievers provide various strategies for finding relevant documents from a vector store or other data source.
All implementations must provide:
- retrieve(): Main retrieval method
46 @abstractmethod 47 def retrieve( 48 self, 49 query: RetrievalQuery 50 ) -> List[SearchResult]: 51 """ 52 Retrieve relevant documents based on the query. 53 54 Args: 55 query: RetrievalQuery containing query parameters 56 57 Returns: 58 List of SearchResult objects, ordered by relevance 59 60 Raises: 61 ValueError: If required query parameters are missing 62 """ 63 pass
Retrieve relevant documents based on the query.
Args: query: RetrievalQuery containing query parameters
Returns: List of SearchResult objects, ordered by relevance
Raises: ValueError: If required query parameters are missing
65 def retrieve_text( 66 self, 67 text: str, 68 top_k: int = 5, 69 filters: Optional[Dict[str, Any]] = None 70 ) -> List[SearchResult]: 71 """ 72 Convenience method for text-based retrieval. 73 74 Args: 75 text: Query text 76 top_k: Number of results to retrieve 77 filters: Optional metadata filters 78 79 Returns: 80 List of SearchResult objects 81 """ 82 query = RetrievalQuery( 83 text=text, 84 top_k=top_k, 85 filters=filters 86 ) 87 return self.retrieve(query)
Convenience method for text-based retrieval.
Args: text: Query text top_k: Number of results to retrieve filters: Optional metadata filters
Returns: List of SearchResult objects
89 def retrieve_embedding( 90 self, 91 embedding: List[float], 92 top_k: int = 5, 93 filters: Optional[Dict[str, Any]] = None 94 ) -> List[SearchResult]: 95 """ 96 Convenience method for embedding-based retrieval. 97 98 Args: 99 embedding: Query embedding vector 100 top_k: Number of results to retrieve 101 filters: Optional metadata filters 102 103 Returns: 104 List of SearchResult objects 105 """ 106 query = RetrievalQuery( 107 embedding=embedding, 108 top_k=top_k, 109 filters=filters 110 ) 111 return self.retrieve(query)
Convenience method for embedding-based retrieval.
Args: embedding: Query embedding vector top_k: Number of results to retrieve filters: Optional metadata filters
Returns: List of SearchResult objects
16@dataclass 17class RetrievalQuery: 18 """ 19 Represents a retrieval query with optional query text and embedding. 20 21 Attributes: 22 text: Query text (required for keyword/hybrid search) 23 embedding: Query embedding vector (required for vector search) 24 top_k: Number of results to retrieve 25 filters: Metadata filters to apply 26 metadata: Additional query metadata 27 """ 28 text: Optional[str] = None 29 embedding: Optional[List[float]] = None 30 top_k: int = 5 31 filters: Optional[Dict[str, Any]] = None 32 metadata: Optional[Dict[str, Any]] = None
Represents a retrieval query with optional query text and embedding.
Attributes: text: Query text (required for keyword/hybrid search) embedding: Query embedding vector (required for vector search) top_k: Number of results to retrieve filters: Metadata filters to apply metadata: Additional query metadata
17class VectorRetriever(BaseRetriever): 18 """ 19 Vector similarity retriever using cosine similarity. 20 21 Performs pure vector search using embeddings. Requires query embeddings to be 22 provided externally (e.g., using an embeddings provider). 23 24 Example: 25 ```python 26 from gmf_forge_ai_data.retrieval import VectorRetriever, RetrievalQuery 27 from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings 28 29 # Setup 30 embedder = AzureOpenAIEmbeddings(...) 31 retriever = VectorRetriever(vector_store) 32 33 # Retrieve 34 query_text = "What is machine learning?" 35 query_embedding = embedder.embed_text(query_text) 36 37 results = retriever.retrieve_embedding( 38 embedding=query_embedding, 39 top_k=5 40 ) 41 ``` 42 """ 43 44 def __init__(self, vector_store: BaseVectorStore): 45 """ 46 Initialize vector retriever. 47 48 Args: 49 vector_store: Vector store to search 50 """ 51 self.vector_store = vector_store 52 53 def retrieve(self, query: RetrievalQuery) -> List[SearchResult]: 54 """ 55 Retrieve documents using vector similarity search. 56 57 Args: 58 query: RetrievalQuery with embedding and parameters 59 60 Returns: 61 List of SearchResult objects ordered by similarity 62 63 Raises: 64 ValueError: If query.embedding is None 65 """ 66 if query.embedding is None: 67 raise ValueError("VectorRetriever requires query.embedding") 68 69 return self.vector_store.search( 70 query_embedding=query.embedding, 71 top_k=query.top_k, 72 filters=query.filters, 73 search_type="vector" 74 )
Vector similarity retriever using cosine similarity.
Performs pure vector search using embeddings. Requires query embeddings to be provided externally (e.g., using an embeddings provider).
Example:
from gmf_forge_ai_data.retrieval import VectorRetriever, RetrievalQuery
from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings
# Setup
embedder = AzureOpenAIEmbeddings(...)
retriever = VectorRetriever(vector_store)
# Retrieve
query_text = "What is machine learning?"
query_embedding = embedder.embed_text(query_text)
results = retriever.retrieve_embedding(
embedding=query_embedding,
top_k=5
)
44 def __init__(self, vector_store: BaseVectorStore): 45 """ 46 Initialize vector retriever. 47 48 Args: 49 vector_store: Vector store to search 50 """ 51 self.vector_store = vector_store
Initialize vector retriever.
Args: vector_store: Vector store to search
53 def retrieve(self, query: RetrievalQuery) -> List[SearchResult]: 54 """ 55 Retrieve documents using vector similarity search. 56 57 Args: 58 query: RetrievalQuery with embedding and parameters 59 60 Returns: 61 List of SearchResult objects ordered by similarity 62 63 Raises: 64 ValueError: If query.embedding is None 65 """ 66 if query.embedding is None: 67 raise ValueError("VectorRetriever requires query.embedding") 68 69 return self.vector_store.search( 70 query_embedding=query.embedding, 71 top_k=query.top_k, 72 filters=query.filters, 73 search_type="vector" 74 )
Retrieve documents using vector similarity search.
Args: query: RetrievalQuery with embedding and parameters
Returns: List of SearchResult objects ordered by similarity
Raises: ValueError: If query.embedding is None
77class KeywordRetriever(BaseRetriever): 78 """ 79 Keyword/BM25 retriever using text matching. 80 81 Performs traditional keyword-based search without using embeddings. 82 Uses BM25 for Azure AI Search, Jaccard similarity for in-memory. 83 84 Example: 85 ```python 86 from gmf_forge_ai_data.retrieval import KeywordRetriever, RetrievalQuery 87 88 retriever = KeywordRetriever(vector_store) 89 90 results = retriever.retrieve_text( 91 text="machine learning algorithms", 92 top_k=5 93 ) 94 ``` 95 """ 96 97 def __init__(self, vector_store: BaseVectorStore): 98 """ 99 Initialize keyword retriever. 100 101 Args: 102 vector_store: Vector store to search 103 """ 104 self.vector_store = vector_store 105 106 def retrieve(self, query: RetrievalQuery) -> List[SearchResult]: 107 """ 108 Retrieve documents using keyword search. 109 110 Args: 111 query: RetrievalQuery with text and parameters 112 113 Returns: 114 List of SearchResult objects ordered by keyword relevance 115 116 Raises: 117 ValueError: If query.text is None 118 """ 119 if query.text is None: 120 raise ValueError("KeywordRetriever requires query.text") 121 122 return self.vector_store.search( 123 query=query.text, 124 top_k=query.top_k, 125 filters=query.filters, 126 search_type="keyword" 127 )
Keyword/BM25 retriever using text matching.
Performs traditional keyword-based search without using embeddings. Uses BM25 for Azure AI Search, Jaccard similarity for in-memory.
Example:
from gmf_forge_ai_data.retrieval import KeywordRetriever, RetrievalQuery
retriever = KeywordRetriever(vector_store)
results = retriever.retrieve_text(
text="machine learning algorithms",
top_k=5
)
97 def __init__(self, vector_store: BaseVectorStore): 98 """ 99 Initialize keyword retriever. 100 101 Args: 102 vector_store: Vector store to search 103 """ 104 self.vector_store = vector_store
Initialize keyword retriever.
Args: vector_store: Vector store to search
106 def retrieve(self, query: RetrievalQuery) -> List[SearchResult]: 107 """ 108 Retrieve documents using keyword search. 109 110 Args: 111 query: RetrievalQuery with text and parameters 112 113 Returns: 114 List of SearchResult objects ordered by keyword relevance 115 116 Raises: 117 ValueError: If query.text is None 118 """ 119 if query.text is None: 120 raise ValueError("KeywordRetriever requires query.text") 121 122 return self.vector_store.search( 123 query=query.text, 124 top_k=query.top_k, 125 filters=query.filters, 126 search_type="keyword" 127 )
Retrieve documents using keyword search.
Args: query: RetrievalQuery with text and parameters
Returns: List of SearchResult objects ordered by keyword relevance
Raises: ValueError: If query.text is None
130class HybridRetriever(BaseRetriever): 131 """ 132 Hybrid retriever combining vector and keyword search. 133 134 Combines vector similarity with keyword matching for comprehensive retrieval. 135 Scores are normalized and combined (typically 50/50 weighting). 136 137 Example: 138 ```python 139 from gmf_forge_ai_data.retrieval import HybridRetriever, RetrievalQuery 140 from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings 141 142 embedder = AzureOpenAIEmbeddings(...) 143 retriever = HybridRetriever(vector_store) 144 145 query_text = "machine learning" 146 query_embedding = embedder.embed_text(query_text) 147 148 # Hybrid search using both text and embedding 149 query = RetrievalQuery( 150 text=query_text, 151 embedding=query_embedding, 152 top_k=5 153 ) 154 results = retriever.retrieve(query) 155 ``` 156 """ 157 158 def __init__(self, vector_store: BaseVectorStore): 159 """ 160 Initialize hybrid retriever. 161 162 Args: 163 vector_store: Vector store to search 164 """ 165 self.vector_store = vector_store 166 167 def retrieve(self, query: RetrievalQuery) -> List[SearchResult]: 168 """ 169 Retrieve documents using hybrid search (vector + keyword). 170 171 Args: 172 query: RetrievalQuery with both text and embedding 173 174 Returns: 175 List of SearchResult objects ordered by combined relevance 176 177 Raises: 178 ValueError: If query.text or query.embedding is None 179 """ 180 if query.text is None: 181 raise ValueError("HybridRetriever requires query.text") 182 if query.embedding is None: 183 raise ValueError("HybridRetriever requires query.embedding") 184 185 return self.vector_store.search( 186 query=query.text, 187 query_embedding=query.embedding, 188 top_k=query.top_k, 189 filters=query.filters, 190 search_type="hybrid" 191 )
Hybrid retriever combining vector and keyword search.
Combines vector similarity with keyword matching for comprehensive retrieval. Scores are normalized and combined (typically 50/50 weighting).
Example:
from gmf_forge_ai_data.retrieval import HybridRetriever, RetrievalQuery
from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings
embedder = AzureOpenAIEmbeddings(...)
retriever = HybridRetriever(vector_store)
query_text = "machine learning"
query_embedding = embedder.embed_text(query_text)
# Hybrid search using both text and embedding
query = RetrievalQuery(
text=query_text,
embedding=query_embedding,
top_k=5
)
results = retriever.retrieve(query)
158 def __init__(self, vector_store: BaseVectorStore): 159 """ 160 Initialize hybrid retriever. 161 162 Args: 163 vector_store: Vector store to search 164 """ 165 self.vector_store = vector_store
Initialize hybrid retriever.
Args: vector_store: Vector store to search
167 def retrieve(self, query: RetrievalQuery) -> List[SearchResult]: 168 """ 169 Retrieve documents using hybrid search (vector + keyword). 170 171 Args: 172 query: RetrievalQuery with both text and embedding 173 174 Returns: 175 List of SearchResult objects ordered by combined relevance 176 177 Raises: 178 ValueError: If query.text or query.embedding is None 179 """ 180 if query.text is None: 181 raise ValueError("HybridRetriever requires query.text") 182 if query.embedding is None: 183 raise ValueError("HybridRetriever requires query.embedding") 184 185 return self.vector_store.search( 186 query=query.text, 187 query_embedding=query.embedding, 188 top_k=query.top_k, 189 filters=query.filters, 190 search_type="hybrid" 191 )
Retrieve documents using hybrid search (vector + keyword).
Args: query: RetrievalQuery with both text and embedding
Returns: List of SearchResult objects ordered by combined relevance
Raises: ValueError: If query.text or query.embedding is None
16class MMRRetriever(BaseRetriever): 17 """ 18 Maximal Marginal Relevance (MMR) retriever for diverse results. 19 20 MMR reranks initial retrieval results to balance relevance with diversity: 21 - High lambda (λ → 1.0): Prioritize relevance 22 - Low lambda (λ → 0.0): Prioritize diversity 23 - λ = 0.5: Balanced (default) 24 25 Algorithm: 26 1. Retrieve initial candidate documents (fetch_k results) 27 2. Select most relevant document first 28 3. For each remaining selection: 29 - Score = λ * relevance - (1-λ) * max_similarity_to_selected 30 - Choose document with highest score 31 4. Return top_k diverse results 32 33 Benefits: 34 - Reduces redundancy in results 35 - Improves coverage of different aspects 36 - Useful for exploration and summarization 37 38 Example: 39 ```python 40 from gmf_forge_ai_data.retrieval import MMRRetriever, RetrievalQuery 41 from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings 42 43 embedder = AzureOpenAIEmbeddings(...) 44 retriever = MMRRetriever( 45 vector_store=vector_store, 46 lambda_param=0.5, # Balanced relevance/diversity 47 fetch_k=20 # Retrieve 20 candidates, return top_k diverse 48 ) 49 50 query_text = "machine learning algorithms" 51 query_embedding = embedder.embed_text(query_text) 52 53 # Returns 5 diverse results from 20 candidates 54 results = retriever.retrieve_embedding( 55 embedding=query_embedding, 56 top_k=5 57 ) 58 ``` 59 60 References: 61 Carbonell, J., & Goldstein, J. (1998). The use of MMR, diversity-based 62 reranking for reordering documents and producing summaries. 63 """ 64 65 def __init__( 66 self, 67 vector_store: BaseVectorStore, 68 lambda_param: float = 0.5, 69 fetch_k: int = 20 70 ): 71 """ 72 Initialize MMR retriever. 73 74 Args: 75 vector_store: Vector store to search 76 lambda_param: Balance between relevance (1.0) and diversity (0.0) 77 fetch_k: Number of initial candidates to fetch (should be >= top_k) 78 79 Raises: 80 ValueError: If lambda_param not in [0, 1] or fetch_k < 1 81 """ 82 if not 0 <= lambda_param <= 1: 83 raise ValueError(f"lambda_param must be in [0, 1], got {lambda_param}") 84 if fetch_k < 1: 85 raise ValueError(f"fetch_k must be >= 1, got {fetch_k}") 86 87 self.vector_store = vector_store 88 self.lambda_param = lambda_param 89 self.fetch_k = fetch_k 90 91 def retrieve(self, query: RetrievalQuery) -> List[SearchResult]: 92 """ 93 Retrieve diverse documents using MMR reranking. 94 95 Args: 96 query: RetrievalQuery with embedding (required) 97 98 Returns: 99 List of SearchResult objects with diverse, relevant results 100 101 Raises: 102 ValueError: If query.embedding is None 103 """ 104 if query.embedding is None: 105 raise ValueError("MMRRetriever requires query.embedding for diversity calculation") 106 107 # Fetch initial candidates (more than top_k) 108 fetch_k = max(self.fetch_k, query.top_k) 109 candidates = self.vector_store.search( 110 query_embedding=query.embedding, 111 top_k=fetch_k, 112 filters=query.filters, 113 search_type="vector" 114 ) 115 116 if len(candidates) == 0: 117 return [] 118 119 if len(candidates) <= query.top_k: 120 # Not enough candidates for MMR, return as-is 121 return candidates[:query.top_k] 122 123 # Extract embeddings and relevance scores 124 query_embedding = np.array(query.embedding, dtype=np.float32) 125 candidate_embeddings = [] 126 candidate_scores = [] 127 128 for result in candidates: 129 if result.document.embedding is None: 130 raise ValueError( 131 f"Document {result.document.id} has no embedding. " 132 "MMR requires all documents to have embeddings." 133 ) 134 candidate_embeddings.append(result.document.embedding) 135 candidate_scores.append(result.score) 136 137 candidate_embeddings = np.array(candidate_embeddings, dtype=np.float32) 138 candidate_scores = np.array(candidate_scores, dtype=np.float32) 139 140 # Normalize embeddings for cosine similarity 141 query_norm = query_embedding / (np.linalg.norm(query_embedding) + 1e-10) 142 candidate_norms = candidate_embeddings / ( 143 np.linalg.norm(candidate_embeddings, axis=1, keepdims=True) + 1e-10 144 ) 145 146 # MMR selection 147 selected_indices = [] 148 selected_embeddings = [] 149 150 for _ in range(min(query.top_k, len(candidates))): 151 if len(selected_indices) == 0: 152 # First selection: most relevant 153 best_idx = int(np.argmax(candidate_scores)) 154 else: 155 # Subsequent selections: balance relevance and diversity 156 mmr_scores = [] 157 158 for idx in range(len(candidates)): 159 if idx in selected_indices: 160 mmr_scores.append(-np.inf) 161 continue 162 163 # Relevance score (already normalized 0-1 from vector store) 164 relevance = candidate_scores[idx] 165 166 # Diversity score: max similarity to selected documents 167 doc_embedding = candidate_norms[idx] 168 similarities = [ 169 np.dot(doc_embedding, selected_embeddings[i]) 170 for i in range(len(selected_embeddings)) 171 ] 172 max_similarity = max(similarities) 173 174 # MMR score: λ * relevance - (1-λ) * max_similarity 175 mmr_score = ( 176 self.lambda_param * relevance - 177 (1 - self.lambda_param) * max_similarity 178 ) 179 mmr_scores.append(mmr_score) 180 181 best_idx = int(np.argmax(mmr_scores)) 182 183 selected_indices.append(best_idx) 184 selected_embeddings.append(candidate_norms[best_idx]) 185 186 # Build results with updated ranks 187 mmr_results = [] 188 for rank, idx in enumerate(selected_indices): 189 result = candidates[idx] 190 mmr_results.append(SearchResult( 191 document=result.document, 192 score=result.score, # Keep original relevance score 193 rank=rank 194 )) 195 196 return mmr_results
Maximal Marginal Relevance (MMR) retriever for diverse results.
MMR reranks initial retrieval results to balance relevance with diversity:
- High lambda (λ → 1.0): Prioritize relevance
- Low lambda (λ → 0.0): Prioritize diversity
- λ = 0.5: Balanced (default)
Algorithm:
- Retrieve initial candidate documents (fetch_k results)
- Select most relevant document first
- For each remaining selection:
- Score = λ * relevance - (1-λ) * max_similarity_to_selected
- Choose document with highest score
- Return top_k diverse results
Benefits:
- Reduces redundancy in results
- Improves coverage of different aspects
- Useful for exploration and summarization
Example:
from gmf_forge_ai_data.retrieval import MMRRetriever, RetrievalQuery
from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings
embedder = AzureOpenAIEmbeddings(...)
retriever = MMRRetriever(
vector_store=vector_store,
lambda_param=0.5, # Balanced relevance/diversity
fetch_k=20 # Retrieve 20 candidates, return top_k diverse
)
query_text = "machine learning algorithms"
query_embedding = embedder.embed_text(query_text)
# Returns 5 diverse results from 20 candidates
results = retriever.retrieve_embedding(
embedding=query_embedding,
top_k=5
)
References: Carbonell, J., & Goldstein, J. (1998). The use of MMR, diversity-based reranking for reordering documents and producing summaries.
65 def __init__( 66 self, 67 vector_store: BaseVectorStore, 68 lambda_param: float = 0.5, 69 fetch_k: int = 20 70 ): 71 """ 72 Initialize MMR retriever. 73 74 Args: 75 vector_store: Vector store to search 76 lambda_param: Balance between relevance (1.0) and diversity (0.0) 77 fetch_k: Number of initial candidates to fetch (should be >= top_k) 78 79 Raises: 80 ValueError: If lambda_param not in [0, 1] or fetch_k < 1 81 """ 82 if not 0 <= lambda_param <= 1: 83 raise ValueError(f"lambda_param must be in [0, 1], got {lambda_param}") 84 if fetch_k < 1: 85 raise ValueError(f"fetch_k must be >= 1, got {fetch_k}") 86 87 self.vector_store = vector_store 88 self.lambda_param = lambda_param 89 self.fetch_k = fetch_k
Initialize MMR retriever.
Args: vector_store: Vector store to search lambda_param: Balance between relevance (1.0) and diversity (0.0) fetch_k: Number of initial candidates to fetch (should be >= top_k)
Raises: ValueError: If lambda_param not in [0, 1] or fetch_k < 1
91 def retrieve(self, query: RetrievalQuery) -> List[SearchResult]: 92 """ 93 Retrieve diverse documents using MMR reranking. 94 95 Args: 96 query: RetrievalQuery with embedding (required) 97 98 Returns: 99 List of SearchResult objects with diverse, relevant results 100 101 Raises: 102 ValueError: If query.embedding is None 103 """ 104 if query.embedding is None: 105 raise ValueError("MMRRetriever requires query.embedding for diversity calculation") 106 107 # Fetch initial candidates (more than top_k) 108 fetch_k = max(self.fetch_k, query.top_k) 109 candidates = self.vector_store.search( 110 query_embedding=query.embedding, 111 top_k=fetch_k, 112 filters=query.filters, 113 search_type="vector" 114 ) 115 116 if len(candidates) == 0: 117 return [] 118 119 if len(candidates) <= query.top_k: 120 # Not enough candidates for MMR, return as-is 121 return candidates[:query.top_k] 122 123 # Extract embeddings and relevance scores 124 query_embedding = np.array(query.embedding, dtype=np.float32) 125 candidate_embeddings = [] 126 candidate_scores = [] 127 128 for result in candidates: 129 if result.document.embedding is None: 130 raise ValueError( 131 f"Document {result.document.id} has no embedding. " 132 "MMR requires all documents to have embeddings." 133 ) 134 candidate_embeddings.append(result.document.embedding) 135 candidate_scores.append(result.score) 136 137 candidate_embeddings = np.array(candidate_embeddings, dtype=np.float32) 138 candidate_scores = np.array(candidate_scores, dtype=np.float32) 139 140 # Normalize embeddings for cosine similarity 141 query_norm = query_embedding / (np.linalg.norm(query_embedding) + 1e-10) 142 candidate_norms = candidate_embeddings / ( 143 np.linalg.norm(candidate_embeddings, axis=1, keepdims=True) + 1e-10 144 ) 145 146 # MMR selection 147 selected_indices = [] 148 selected_embeddings = [] 149 150 for _ in range(min(query.top_k, len(candidates))): 151 if len(selected_indices) == 0: 152 # First selection: most relevant 153 best_idx = int(np.argmax(candidate_scores)) 154 else: 155 # Subsequent selections: balance relevance and diversity 156 mmr_scores = [] 157 158 for idx in range(len(candidates)): 159 if idx in selected_indices: 160 mmr_scores.append(-np.inf) 161 continue 162 163 # Relevance score (already normalized 0-1 from vector store) 164 relevance = candidate_scores[idx] 165 166 # Diversity score: max similarity to selected documents 167 doc_embedding = candidate_norms[idx] 168 similarities = [ 169 np.dot(doc_embedding, selected_embeddings[i]) 170 for i in range(len(selected_embeddings)) 171 ] 172 max_similarity = max(similarities) 173 174 # MMR score: λ * relevance - (1-λ) * max_similarity 175 mmr_score = ( 176 self.lambda_param * relevance - 177 (1 - self.lambda_param) * max_similarity 178 ) 179 mmr_scores.append(mmr_score) 180 181 best_idx = int(np.argmax(mmr_scores)) 182 183 selected_indices.append(best_idx) 184 selected_embeddings.append(candidate_norms[best_idx]) 185 186 # Build results with updated ranks 187 mmr_results = [] 188 for rank, idx in enumerate(selected_indices): 189 result = candidates[idx] 190 mmr_results.append(SearchResult( 191 document=result.document, 192 score=result.score, # Keep original relevance score 193 rank=rank 194 )) 195 196 return mmr_results
Retrieve diverse documents using MMR reranking.
Args: query: RetrievalQuery with embedding (required)
Returns: List of SearchResult objects with diverse, relevant results
Raises: ValueError: If query.embedding is None
16class ParentDocumentRetriever(BaseRetriever): 17 """ 18 Retriever that searches child chunks but returns parent documents. 19 20 This pattern is useful when: 21 - Documents are chunked into small pieces for precise embedding 22 - But you want to return full parent documents for better context 23 - Multiple chunks from the same parent might match the query 24 25 Architecture: 26 - Child store: Contains small chunks with embeddings (searchable) 27 - Parent store: Contains full parent documents (for retrieval) 28 - Mapping: Children have metadata["parent_id"] pointing to parent 29 30 Workflow: 31 1. Search child store for relevant chunks 32 2. Extract parent_ids from child metadata 33 3. Retrieve parent documents from parent store 34 4. Deduplicate and rank by best child score 35 36 Example: 37 ```python 38 from gmf_forge_ai_data.retrieval import ParentDocumentRetriever 39 from gmf_forge_ai_data.vector_stores import InMemoryVectorStore, Document 40 from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings 41 42 embedder = AzureOpenAIEmbeddings(...) 43 44 # Setup stores 45 child_store = InMemoryVectorStore() 46 parent_store = InMemoryVectorStore() 47 48 # Create parent document 49 parent = Document( 50 id="doc_1", 51 content="Full document with multiple sections...", 52 embedding=embedder.embed_text("Full document...") 53 ) 54 parent_store.add_documents([parent]) 55 56 # Create child chunks 57 child1 = Document( 58 id="doc_1_chunk_0", 59 content="Introduction section...", 60 embedding=embedder.embed_text("Introduction..."), 61 metadata={"parent_id": "doc_1"} # Link to parent 62 ) 63 child2 = Document( 64 id="doc_1_chunk_1", 65 content="Methods section...", 66 embedding=embedder.embed_text("Methods..."), 67 metadata={"parent_id": "doc_1"} 68 ) 69 child_store.add_documents([child1, child2]) 70 71 # Setup retriever 72 retriever = ParentDocumentRetriever( 73 child_store=child_store, 74 parent_store=parent_store, 75 parent_id_key="parent_id" 76 ) 77 78 # Search chunks, return parent 79 query_embedding = embedder.embed_text("introduction") 80 results = retriever.retrieve_embedding( 81 embedding=query_embedding, 82 top_k=5 83 ) 84 # Returns full parent documents, not chunks 85 ``` 86 87 Benefits: 88 - Precise search (small chunks) 89 - Rich context (full parents) 90 - Automatic deduplication 91 """ 92 93 def __init__( 94 self, 95 child_store: BaseVectorStore, 96 parent_store: BaseVectorStore, 97 parent_id_key: str = "parent_id", 98 search_type: str = "vector" 99 ): 100 """ 101 Initialize parent document retriever. 102 103 Args: 104 child_store: Vector store containing child chunks with embeddings 105 parent_store: Vector store containing parent documents 106 parent_id_key: Metadata key in children pointing to parent ID 107 search_type: Search type for child store ("vector", "keyword", "hybrid") 108 109 Raises: 110 ValueError: If search_type is invalid 111 """ 112 if search_type not in ["vector", "keyword", "hybrid"]: 113 raise ValueError(f"Invalid search_type: {search_type}") 114 115 self.child_store = child_store 116 self.parent_store = parent_store 117 self.parent_id_key = parent_id_key 118 self.search_type = search_type 119 120 def retrieve(self, query: RetrievalQuery) -> List[SearchResult]: 121 """ 122 Retrieve parent documents by searching child chunks. 123 124 Args: 125 query: RetrievalQuery with appropriate parameters for search_type 126 127 Returns: 128 List of SearchResult with parent documents, deduplicated and ranked 129 130 Raises: 131 ValueError: If required query parameters are missing 132 """ 133 # Validate query based on search type 134 if self.search_type in ["vector", "hybrid"] and query.embedding is None: 135 raise ValueError(f"{self.search_type} search requires query.embedding") 136 if self.search_type in ["keyword", "hybrid"] and query.text is None: 137 raise ValueError(f"{self.search_type} search requires query.text") 138 139 # Search child store (fetch more candidates to account for deduplication) 140 fetch_k = query.top_k * 3 # Heuristic: fetch 3x to handle duplicates 141 child_results = self.child_store.search( 142 query=query.text, 143 query_embedding=query.embedding, 144 top_k=fetch_k, 145 filters=query.filters, 146 search_type=self.search_type 147 ) 148 149 if len(child_results) == 0: 150 return [] 151 152 # Extract parent IDs and track best scores 153 parent_scores: Dict[str, float] = {} # parent_id -> best_score 154 parent_order: OrderedDict[str, None] = OrderedDict() # Track first occurrence 155 156 for child_result in child_results: 157 parent_id = child_result.document.metadata.get(self.parent_id_key) 158 159 if parent_id is None: 160 # Skip children without parent link 161 continue 162 163 # Track best score for each parent (in case multiple children match) 164 if parent_id not in parent_scores: 165 parent_scores[parent_id] = child_result.score 166 parent_order[parent_id] = None # Track insertion order 167 else: 168 # Update if this child has better score 169 parent_scores[parent_id] = max( 170 parent_scores[parent_id], 171 child_result.score 172 ) 173 174 if len(parent_scores) == 0: 175 return [] 176 177 # Retrieve parent documents 178 parent_results = [] 179 for rank, parent_id in enumerate(parent_order.keys()): 180 if rank >= query.top_k: 181 break 182 183 parent_doc = self.parent_store.get_document(parent_id) 184 if parent_doc is None: 185 # Parent not found in store, skip 186 continue 187 188 parent_results.append(SearchResult( 189 document=parent_doc, 190 score=parent_scores[parent_id], # Best child score 191 rank=rank 192 )) 193 194 return parent_results
Retriever that searches child chunks but returns parent documents.
This pattern is useful when:
- Documents are chunked into small pieces for precise embedding
- But you want to return full parent documents for better context
- Multiple chunks from the same parent might match the query
Architecture:
- Child store: Contains small chunks with embeddings (searchable)
- Parent store: Contains full parent documents (for retrieval)
- Mapping: Children have metadata["parent_id"] pointing to parent
Workflow:
- Search child store for relevant chunks
- Extract parent_ids from child metadata
- Retrieve parent documents from parent store
- Deduplicate and rank by best child score
Example:
from gmf_forge_ai_data.retrieval import ParentDocumentRetriever
from gmf_forge_ai_data.vector_stores import InMemoryVectorStore, Document
from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings
embedder = AzureOpenAIEmbeddings(...)
# Setup stores
child_store = InMemoryVectorStore()
parent_store = InMemoryVectorStore()
# Create parent document
parent = Document(
id="doc_1",
content="Full document with multiple sections...",
embedding=embedder.embed_text("Full document...")
)
parent_store.add_documents([parent])
# Create child chunks
child1 = Document(
id="doc_1_chunk_0",
content="Introduction section...",
embedding=embedder.embed_text("Introduction..."),
metadata={"parent_id": "doc_1"} # Link to parent
)
child2 = Document(
id="doc_1_chunk_1",
content="Methods section...",
embedding=embedder.embed_text("Methods..."),
metadata={"parent_id": "doc_1"}
)
child_store.add_documents([child1, child2])
# Setup retriever
retriever = ParentDocumentRetriever(
child_store=child_store,
parent_store=parent_store,
parent_id_key="parent_id"
)
# Search chunks, return parent
query_embedding = embedder.embed_text("introduction")
results = retriever.retrieve_embedding(
embedding=query_embedding,
top_k=5
)
# Returns full parent documents, not chunks
Benefits:
- Precise search (small chunks)
- Rich context (full parents)
- Automatic deduplication
93 def __init__( 94 self, 95 child_store: BaseVectorStore, 96 parent_store: BaseVectorStore, 97 parent_id_key: str = "parent_id", 98 search_type: str = "vector" 99 ): 100 """ 101 Initialize parent document retriever. 102 103 Args: 104 child_store: Vector store containing child chunks with embeddings 105 parent_store: Vector store containing parent documents 106 parent_id_key: Metadata key in children pointing to parent ID 107 search_type: Search type for child store ("vector", "keyword", "hybrid") 108 109 Raises: 110 ValueError: If search_type is invalid 111 """ 112 if search_type not in ["vector", "keyword", "hybrid"]: 113 raise ValueError(f"Invalid search_type: {search_type}") 114 115 self.child_store = child_store 116 self.parent_store = parent_store 117 self.parent_id_key = parent_id_key 118 self.search_type = search_type
Initialize parent document retriever.
Args: child_store: Vector store containing child chunks with embeddings parent_store: Vector store containing parent documents parent_id_key: Metadata key in children pointing to parent ID search_type: Search type for child store ("vector", "keyword", "hybrid")
Raises: ValueError: If search_type is invalid
120 def retrieve(self, query: RetrievalQuery) -> List[SearchResult]: 121 """ 122 Retrieve parent documents by searching child chunks. 123 124 Args: 125 query: RetrievalQuery with appropriate parameters for search_type 126 127 Returns: 128 List of SearchResult with parent documents, deduplicated and ranked 129 130 Raises: 131 ValueError: If required query parameters are missing 132 """ 133 # Validate query based on search type 134 if self.search_type in ["vector", "hybrid"] and query.embedding is None: 135 raise ValueError(f"{self.search_type} search requires query.embedding") 136 if self.search_type in ["keyword", "hybrid"] and query.text is None: 137 raise ValueError(f"{self.search_type} search requires query.text") 138 139 # Search child store (fetch more candidates to account for deduplication) 140 fetch_k = query.top_k * 3 # Heuristic: fetch 3x to handle duplicates 141 child_results = self.child_store.search( 142 query=query.text, 143 query_embedding=query.embedding, 144 top_k=fetch_k, 145 filters=query.filters, 146 search_type=self.search_type 147 ) 148 149 if len(child_results) == 0: 150 return [] 151 152 # Extract parent IDs and track best scores 153 parent_scores: Dict[str, float] = {} # parent_id -> best_score 154 parent_order: OrderedDict[str, None] = OrderedDict() # Track first occurrence 155 156 for child_result in child_results: 157 parent_id = child_result.document.metadata.get(self.parent_id_key) 158 159 if parent_id is None: 160 # Skip children without parent link 161 continue 162 163 # Track best score for each parent (in case multiple children match) 164 if parent_id not in parent_scores: 165 parent_scores[parent_id] = child_result.score 166 parent_order[parent_id] = None # Track insertion order 167 else: 168 # Update if this child has better score 169 parent_scores[parent_id] = max( 170 parent_scores[parent_id], 171 child_result.score 172 ) 173 174 if len(parent_scores) == 0: 175 return [] 176 177 # Retrieve parent documents 178 parent_results = [] 179 for rank, parent_id in enumerate(parent_order.keys()): 180 if rank >= query.top_k: 181 break 182 183 parent_doc = self.parent_store.get_document(parent_id) 184 if parent_doc is None: 185 # Parent not found in store, skip 186 continue 187 188 parent_results.append(SearchResult( 189 document=parent_doc, 190 score=parent_scores[parent_id], # Best child score 191 rank=rank 192 )) 193 194 return parent_results
Retrieve parent documents by searching child chunks.
Args: query: RetrievalQuery with appropriate parameters for search_type
Returns: List of SearchResult with parent documents, deduplicated and ranked
Raises: ValueError: If required query parameters are missing
17class EnsembleRetriever(BaseRetriever): 18 """ 19 Ensemble retriever combining multiple retrieval strategies. 20 21 Combines results from multiple retrievers using score fusion techniques: 22 - Reciprocal Rank Fusion (RRF): Robust, rank-based fusion 23 - Weighted Average: Score-based fusion with configurable weights 24 - Max Score: Conservative, consensus-based fusion 25 26 Benefits: 27 - Improved recall (combine different strategies) 28 - Robustness (less sensitive to individual retriever failures) 29 - Flexibility (combine vector, keyword, hybrid, MMR, etc.) 30 31 Example: 32 ```python 33 from gmf_forge_ai_data.retrieval import ( 34 EnsembleRetriever, 35 VectorRetriever, 36 KeywordRetriever, 37 MMRRetriever 38 ) 39 from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings 40 41 embedder = AzureOpenAIEmbeddings(...) 42 43 # Create multiple retrievers 44 vector_retriever = VectorRetriever(vector_store) 45 keyword_retriever = KeywordRetriever(vector_store) 46 mmr_retriever = MMRRetriever(vector_store, lambda_param=0.7) 47 48 # Combine with RRF fusion 49 ensemble = EnsembleRetriever( 50 retrievers=[vector_retriever, keyword_retriever, mmr_retriever], 51 weights=[0.5, 0.3, 0.2], # Optional weights 52 fusion_strategy="rrf" # Reciprocal Rank Fusion 53 ) 54 55 # Query requires parameters for all retrievers 56 query_text = "machine learning" 57 query_embedding = embedder.embed_text(query_text) 58 59 query = RetrievalQuery( 60 text=query_text, # For keyword retriever 61 embedding=query_embedding, # For vector/MMR retrievers 62 top_k=5 63 ) 64 65 results = ensemble.retrieve(query) 66 # Returns fused results from all retrievers 67 ``` 68 69 Fusion Strategies: 70 71 1. **rrf** (Reciprocal Rank Fusion): 72 - score(doc) = Σ weights[i] / (k + rank_i) 73 - k = 60 (default) 74 - Robust to different score scales 75 - Rank-based, not score-based 76 77 2. **weighted_avg** (Weighted Average): 78 - Normalize scores to [0, 1] 79 - score(doc) = Σ weights[i] * normalized_score_i 80 - Score-based fusion 81 82 3. **max_score** (Maximum Score): 83 - score(doc) = max(normalized_scores) 84 - Conservative, requires consensus 85 86 References: 87 Cormack, G. V., Clarke, C. L., & Buettcher, S. (2009). 88 Reciprocal rank fusion outperforms condorcet and individual 89 ranklists. ACM SIGIR. 90 """ 91 92 def __init__( 93 self, 94 retrievers: List[BaseRetriever], 95 weights: List[float] = None, 96 fusion_strategy: str = "rrf", 97 rrf_k: int = 60 98 ): 99 """ 100 Initialize ensemble retriever. 101 102 Args: 103 retrievers: List of retrievers to combine 104 weights: Optional weights for each retriever (default: equal weights) 105 fusion_strategy: Fusion method - "rrf", "weighted_avg", or "max_score" 106 rrf_k: k parameter for RRF (default: 60) 107 108 Raises: 109 ValueError: If retrievers list is empty or params are invalid 110 """ 111 if not retrievers: 112 raise ValueError("Must provide at least one retriever") 113 114 if fusion_strategy not in ["rrf", "weighted_avg", "max_score"]: 115 raise ValueError( 116 f"Invalid fusion_strategy: {fusion_strategy}. " 117 "Must be 'rrf', 'weighted_avg', or 'max_score'" 118 ) 119 120 if weights is None: 121 # Equal weights 122 weights = [1.0 / len(retrievers)] * len(retrievers) 123 else: 124 if len(weights) != len(retrievers): 125 raise ValueError( 126 f"Number of weights ({len(weights)}) must match " 127 f"number of retrievers ({len(retrievers)})" 128 ) 129 # Normalize weights to sum to 1.0 130 total = sum(weights) 131 weights = [w / total for w in weights] 132 133 self.retrievers = retrievers 134 self.weights = weights 135 self.fusion_strategy = fusion_strategy 136 self.rrf_k = rrf_k 137 138 def retrieve(self, query: RetrievalQuery) -> List[SearchResult]: 139 """ 140 Retrieve documents using ensemble fusion. 141 142 Args: 143 query: RetrievalQuery with parameters for all retrievers 144 145 Returns: 146 List of SearchResult with fused scores and ranks 147 148 Note: 149 All retrievers are called with the same query. Ensure the query 150 contains all required parameters (text, embedding) for your retrievers. 151 """ 152 # Retrieve from all retrievers 153 all_results: List[List[SearchResult]] = [] 154 155 for retriever in self.retrievers: 156 try: 157 results = retriever.retrieve(query) 158 all_results.append(results) 159 except ValueError as e: 160 # Retriever might not support this query type, skip it 161 all_results.append([]) 162 163 if not any(all_results): 164 return [] 165 166 # Apply fusion strategy 167 if self.fusion_strategy == "rrf": 168 fused_results = self._reciprocal_rank_fusion(all_results) 169 elif self.fusion_strategy == "weighted_avg": 170 fused_results = self._weighted_average_fusion(all_results) 171 else: # max_score 172 fused_results = self._max_score_fusion(all_results) 173 174 # Sort by fused score and take top_k 175 fused_results.sort(key=lambda x: x.score, reverse=True) 176 top_results = fused_results[:query.top_k] 177 178 # Update ranks 179 for rank, result in enumerate(top_results): 180 result.rank = rank 181 182 return top_results 183 184 def _reciprocal_rank_fusion( 185 self, 186 all_results: List[List[SearchResult]] 187 ) -> List[SearchResult]: 188 """ 189 Fuse results using Reciprocal Rank Fusion (RRF). 190 191 RRF score for document d: 192 score(d) = Σ weight_i / (k + rank_i(d)) 193 194 Where rank_i(d) is the rank of d in retriever i's results. 195 """ 196 doc_scores: Dict[str, float] = defaultdict(float) 197 doc_objects: Dict[str, Document] = {} 198 199 for retriever_idx, results in enumerate(all_results): 200 weight = self.weights[retriever_idx] 201 202 for result in results: 203 doc_id = result.document.id 204 205 # RRF score contribution 206 rrf_score = weight / (self.rrf_k + result.rank) 207 doc_scores[doc_id] += rrf_score 208 209 # Keep document object (use first occurrence) 210 if doc_id not in doc_objects: 211 doc_objects[doc_id] = result.document 212 213 # Build SearchResult objects 214 fused_results = [ 215 SearchResult( 216 document=doc_objects[doc_id], 217 score=score, 218 rank=0 # Will be set later 219 ) 220 for doc_id, score in doc_scores.items() 221 ] 222 223 return fused_results 224 225 def _weighted_average_fusion( 226 self, 227 all_results: List[List[SearchResult]] 228 ) -> List[SearchResult]: 229 """ 230 Fuse results using weighted average of normalized scores. 231 232 Scores are normalized to [0, 1] within each retriever, then averaged. 233 """ 234 doc_scores: Dict[str, float] = defaultdict(float) 235 doc_objects: Dict[str, Document] = {} 236 237 for retriever_idx, results in enumerate(all_results): 238 if not results: 239 continue 240 241 weight = self.weights[retriever_idx] 242 243 # Normalize scores to [0, 1] 244 scores = np.array([r.score for r in results]) 245 min_score = scores.min() 246 max_score = scores.max() 247 248 if max_score > min_score: 249 normalized_scores = (scores - min_score) / (max_score - min_score) 250 else: 251 normalized_scores = np.ones_like(scores) 252 253 for result, norm_score in zip(results, normalized_scores): 254 doc_id = result.document.id 255 doc_scores[doc_id] += weight * norm_score 256 257 if doc_id not in doc_objects: 258 doc_objects[doc_id] = result.document 259 260 # Build SearchResult objects 261 fused_results = [ 262 SearchResult( 263 document=doc_objects[doc_id], 264 score=score, 265 rank=0 266 ) 267 for doc_id, score in doc_scores.items() 268 ] 269 270 return fused_results 271 272 def _max_score_fusion( 273 self, 274 all_results: List[List[SearchResult]] 275 ) -> List[SearchResult]: 276 """ 277 Fuse results using maximum normalized score across retrievers. 278 279 Conservative strategy: document needs high score from at least one retriever. 280 """ 281 doc_scores: Dict[str, float] = defaultdict(float) 282 doc_objects: Dict[str, Document] = {} 283 284 for retriever_idx, results in enumerate(all_results): 285 if not results: 286 continue 287 288 # Normalize scores to [0, 1] 289 scores = np.array([r.score for r in results]) 290 min_score = scores.min() 291 max_score = scores.max() 292 293 if max_score > min_score: 294 normalized_scores = (scores - min_score) / (max_score - min_score) 295 else: 296 normalized_scores = np.ones_like(scores) 297 298 for result, norm_score in zip(results, normalized_scores): 299 doc_id = result.document.id 300 301 # Take maximum normalized score 302 doc_scores[doc_id] = max(doc_scores[doc_id], norm_score) 303 304 if doc_id not in doc_objects: 305 doc_objects[doc_id] = result.document 306 307 # Build SearchResult objects 308 fused_results = [ 309 SearchResult( 310 document=doc_objects[doc_id], 311 score=score, 312 rank=0 313 ) 314 for doc_id, score in doc_scores.items() 315 ] 316 317 return fused_results
Ensemble retriever combining multiple retrieval strategies.
Combines results from multiple retrievers using score fusion techniques:
- Reciprocal Rank Fusion (RRF): Robust, rank-based fusion
- Weighted Average: Score-based fusion with configurable weights
- Max Score: Conservative, consensus-based fusion
Benefits:
- Improved recall (combine different strategies)
- Robustness (less sensitive to individual retriever failures)
- Flexibility (combine vector, keyword, hybrid, MMR, etc.)
Example:
from gmf_forge_ai_data.retrieval import (
EnsembleRetriever,
VectorRetriever,
KeywordRetriever,
MMRRetriever
)
from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings
embedder = AzureOpenAIEmbeddings(...)
# Create multiple retrievers
vector_retriever = VectorRetriever(vector_store)
keyword_retriever = KeywordRetriever(vector_store)
mmr_retriever = MMRRetriever(vector_store, lambda_param=0.7)
# Combine with RRF fusion
ensemble = EnsembleRetriever(
retrievers=[vector_retriever, keyword_retriever, mmr_retriever],
weights=[0.5, 0.3, 0.2], # Optional weights
fusion_strategy="rrf" # Reciprocal Rank Fusion
)
# Query requires parameters for all retrievers
query_text = "machine learning"
query_embedding = embedder.embed_text(query_text)
query = RetrievalQuery(
text=query_text, # For keyword retriever
embedding=query_embedding, # For vector/MMR retrievers
top_k=5
)
results = ensemble.retrieve(query)
# Returns fused results from all retrievers
Fusion Strategies:
rrf (Reciprocal Rank Fusion):
- score(doc) = Σ weights[i] / (k + rank_i)
- k = 60 (default)
- Robust to different score scales
- Rank-based, not score-based
weighted_avg (Weighted Average):
- Normalize scores to [0, 1]
- score(doc) = Σ weights[i] * normalized_score_i
- Score-based fusion
max_score (Maximum Score):
- score(doc) = max(normalized_scores)
- Conservative, requires consensus
References: Cormack, G. V., Clarke, C. L., & Buettcher, S. (2009). Reciprocal rank fusion outperforms condorcet and individual ranklists. ACM SIGIR.
92 def __init__( 93 self, 94 retrievers: List[BaseRetriever], 95 weights: List[float] = None, 96 fusion_strategy: str = "rrf", 97 rrf_k: int = 60 98 ): 99 """ 100 Initialize ensemble retriever. 101 102 Args: 103 retrievers: List of retrievers to combine 104 weights: Optional weights for each retriever (default: equal weights) 105 fusion_strategy: Fusion method - "rrf", "weighted_avg", or "max_score" 106 rrf_k: k parameter for RRF (default: 60) 107 108 Raises: 109 ValueError: If retrievers list is empty or params are invalid 110 """ 111 if not retrievers: 112 raise ValueError("Must provide at least one retriever") 113 114 if fusion_strategy not in ["rrf", "weighted_avg", "max_score"]: 115 raise ValueError( 116 f"Invalid fusion_strategy: {fusion_strategy}. " 117 "Must be 'rrf', 'weighted_avg', or 'max_score'" 118 ) 119 120 if weights is None: 121 # Equal weights 122 weights = [1.0 / len(retrievers)] * len(retrievers) 123 else: 124 if len(weights) != len(retrievers): 125 raise ValueError( 126 f"Number of weights ({len(weights)}) must match " 127 f"number of retrievers ({len(retrievers)})" 128 ) 129 # Normalize weights to sum to 1.0 130 total = sum(weights) 131 weights = [w / total for w in weights] 132 133 self.retrievers = retrievers 134 self.weights = weights 135 self.fusion_strategy = fusion_strategy 136 self.rrf_k = rrf_k
Initialize ensemble retriever.
Args: retrievers: List of retrievers to combine weights: Optional weights for each retriever (default: equal weights) fusion_strategy: Fusion method - "rrf", "weighted_avg", or "max_score" rrf_k: k parameter for RRF (default: 60)
Raises: ValueError: If retrievers list is empty or params are invalid
138 def retrieve(self, query: RetrievalQuery) -> List[SearchResult]: 139 """ 140 Retrieve documents using ensemble fusion. 141 142 Args: 143 query: RetrievalQuery with parameters for all retrievers 144 145 Returns: 146 List of SearchResult with fused scores and ranks 147 148 Note: 149 All retrievers are called with the same query. Ensure the query 150 contains all required parameters (text, embedding) for your retrievers. 151 """ 152 # Retrieve from all retrievers 153 all_results: List[List[SearchResult]] = [] 154 155 for retriever in self.retrievers: 156 try: 157 results = retriever.retrieve(query) 158 all_results.append(results) 159 except ValueError as e: 160 # Retriever might not support this query type, skip it 161 all_results.append([]) 162 163 if not any(all_results): 164 return [] 165 166 # Apply fusion strategy 167 if self.fusion_strategy == "rrf": 168 fused_results = self._reciprocal_rank_fusion(all_results) 169 elif self.fusion_strategy == "weighted_avg": 170 fused_results = self._weighted_average_fusion(all_results) 171 else: # max_score 172 fused_results = self._max_score_fusion(all_results) 173 174 # Sort by fused score and take top_k 175 fused_results.sort(key=lambda x: x.score, reverse=True) 176 top_results = fused_results[:query.top_k] 177 178 # Update ranks 179 for rank, result in enumerate(top_results): 180 result.rank = rank 181 182 return top_results
Retrieve documents using ensemble fusion.
Args: query: RetrievalQuery with parameters for all retrievers
Returns: List of SearchResult with fused scores and ranks
Note: All retrievers are called with the same query. Ensure the query contains all required parameters (text, embedding) for your retrievers.
24class HierarchicalRetriever(BaseRetriever): 25 """ 26 Two-stage hierarchical retrieval for efficient search in large collections. 27 28 Stage 1: Retrieve document summaries to identify relevant documents 29 Stage 2: Retrieve detailed chunks from top documents only 30 31 This reduces computational cost by focusing detailed retrieval on relevant documents. 32 33 Example: 34 ```python 35 # Create summary and chunk retrievers 36 summary_retriever = VectorRetriever(summary_store, embedder) 37 chunk_retriever = VectorRetriever(chunk_store, embedder) 38 39 # Create hierarchical retriever 40 retriever = HierarchicalRetriever( 41 summary_retriever=summary_retriever, 42 chunk_retriever=chunk_retriever, 43 stage1_top_k=10, # Get top 10 documents 44 stage2_top_k=5, # Get 5 chunks per document 45 document_id_field="document_id" 46 ) 47 48 # Retrieve 49 query = RetrievalQuery(text="machine learning", top_k=20) 50 results = retriever.retrieve(query) 51 ``` 52 """ 53 54 def __init__( 55 self, 56 summary_retriever: BaseRetriever, 57 chunk_retriever: BaseRetriever, 58 stage1_top_k: int = 10, 59 stage2_top_k: int = 5, 60 document_id_field: str = "document_id", 61 combine_scores: bool = True, 62 stage1_weight: float = 0.3, 63 stage2_weight: float = 0.7 64 ): 65 """ 66 Initialize hierarchical retriever. 67 68 Args: 69 summary_retriever: Retriever for document summaries (stage 1) 70 chunk_retriever: Retriever for detailed chunks (stage 2) 71 stage1_top_k: Number of documents to retrieve in stage 1 72 stage2_top_k: Number of chunks to retrieve per document in stage 2 73 document_id_field: Metadata field linking chunks to documents 74 combine_scores: If True, combine stage1 and stage2 scores 75 stage1_weight: Weight for stage 1 score (document relevance) 76 stage2_weight: Weight for stage 2 score (chunk relevance) 77 """ 78 self.summary_retriever = summary_retriever 79 self.chunk_retriever = chunk_retriever 80 self.stage1_top_k = stage1_top_k 81 self.stage2_top_k = stage2_top_k 82 self.document_id_field = document_id_field 83 self.combine_scores = combine_scores 84 self.stage1_weight = stage1_weight 85 self.stage2_weight = stage2_weight 86 87 # Normalize weights 88 total_weight = stage1_weight + stage2_weight 89 if total_weight > 0: 90 self.stage1_weight = stage1_weight / total_weight 91 self.stage2_weight = stage2_weight / total_weight 92 93 def retrieve( 94 self, 95 query: RetrievalQuery, 96 **kwargs 97 ) -> List[SearchResult]: 98 """ 99 Perform two-stage hierarchical retrieval. 100 101 Args: 102 query: The retrieval query 103 **kwargs: Additional arguments 104 105 Returns: 106 List of SearchResult objects from stage 2, optionally with combined scores 107 """ 108 # Stage 1: Retrieve document summaries 109 logger.info(f"Stage 1: Retrieving top {self.stage1_top_k} document summaries") 110 111 stage1_query = RetrievalQuery( 112 text=query.text, 113 embedding=query.embedding, 114 top_k=self.stage1_top_k, 115 filters=query.filters, 116 metadata=query.metadata 117 ) 118 119 summary_results = self.summary_retriever.retrieve(stage1_query, **kwargs) 120 121 if not summary_results: 122 logger.warning("Stage 1 returned no results") 123 return [] 124 125 # Extract document IDs from summaries 126 document_ids = [] 127 document_scores = {} 128 129 for result in summary_results: 130 doc_id = result.document.metadata.get(self.document_id_field) 131 if doc_id: 132 document_ids.append(doc_id) 133 document_scores[doc_id] = result.score 134 else: 135 # If no document_id field, use the document id itself 136 document_ids.append(result.document.id) 137 document_scores[result.document.id] = result.score 138 139 logger.info(f"Stage 1 identified {len(document_ids)} relevant documents") 140 141 # Stage 2: Retrieve detailed chunks from top documents 142 logger.info(f"Stage 2: Retrieving up to {self.stage2_top_k} chunks per document") 143 144 all_chunks: List[SearchResult] = [] 145 146 for doc_id in document_ids: 147 # Create filter for this document 148 doc_filter = {self.document_id_field: doc_id} 149 150 # Merge with existing filters if any 151 if query.filters: 152 doc_filter.update(query.filters) 153 154 # Retrieve chunks for this document 155 stage2_query = RetrievalQuery( 156 text=query.text, 157 embedding=query.embedding, 158 top_k=self.stage2_top_k, 159 filters=doc_filter, 160 metadata=query.metadata 161 ) 162 163 try: 164 chunk_results = self.chunk_retriever.retrieve(stage2_query, **kwargs) 165 166 if self.combine_scores and doc_id in document_scores: 167 # Combine stage 1 and stage 2 scores 168 for result in chunk_results: 169 combined_score = ( 170 self.stage1_weight * document_scores[doc_id] + 171 self.stage2_weight * result.score 172 ) 173 # Create new result with combined score 174 result.score = combined_score 175 176 all_chunks.extend(chunk_results) 177 178 except ValueError as e: 179 logger.warning(f"Failed to retrieve chunks for document {doc_id}: {e}") 180 continue 181 182 if not all_chunks: 183 logger.warning("Stage 2 returned no chunks") 184 return [] 185 186 # Sort by score and limit to top_k 187 all_chunks.sort(key=lambda x: x.score, reverse=True) 188 final_results = all_chunks[:query.top_k] 189 190 # Update ranks 191 for i, result in enumerate(final_results): 192 result.rank = i + 1 193 194 logger.info( 195 f"Hierarchical retrieval complete: {len(final_results)} chunks from " 196 f"{len(document_ids)} documents" 197 ) 198 199 return final_results
Two-stage hierarchical retrieval for efficient search in large collections.
Stage 1: Retrieve document summaries to identify relevant documents Stage 2: Retrieve detailed chunks from top documents only
This reduces computational cost by focusing detailed retrieval on relevant documents.
Example:
# Create summary and chunk retrievers
summary_retriever = VectorRetriever(summary_store, embedder)
chunk_retriever = VectorRetriever(chunk_store, embedder)
# Create hierarchical retriever
retriever = HierarchicalRetriever(
summary_retriever=summary_retriever,
chunk_retriever=chunk_retriever,
stage1_top_k=10, # Get top 10 documents
stage2_top_k=5, # Get 5 chunks per document
document_id_field="document_id"
)
# Retrieve
query = RetrievalQuery(text="machine learning", top_k=20)
results = retriever.retrieve(query)
54 def __init__( 55 self, 56 summary_retriever: BaseRetriever, 57 chunk_retriever: BaseRetriever, 58 stage1_top_k: int = 10, 59 stage2_top_k: int = 5, 60 document_id_field: str = "document_id", 61 combine_scores: bool = True, 62 stage1_weight: float = 0.3, 63 stage2_weight: float = 0.7 64 ): 65 """ 66 Initialize hierarchical retriever. 67 68 Args: 69 summary_retriever: Retriever for document summaries (stage 1) 70 chunk_retriever: Retriever for detailed chunks (stage 2) 71 stage1_top_k: Number of documents to retrieve in stage 1 72 stage2_top_k: Number of chunks to retrieve per document in stage 2 73 document_id_field: Metadata field linking chunks to documents 74 combine_scores: If True, combine stage1 and stage2 scores 75 stage1_weight: Weight for stage 1 score (document relevance) 76 stage2_weight: Weight for stage 2 score (chunk relevance) 77 """ 78 self.summary_retriever = summary_retriever 79 self.chunk_retriever = chunk_retriever 80 self.stage1_top_k = stage1_top_k 81 self.stage2_top_k = stage2_top_k 82 self.document_id_field = document_id_field 83 self.combine_scores = combine_scores 84 self.stage1_weight = stage1_weight 85 self.stage2_weight = stage2_weight 86 87 # Normalize weights 88 total_weight = stage1_weight + stage2_weight 89 if total_weight > 0: 90 self.stage1_weight = stage1_weight / total_weight 91 self.stage2_weight = stage2_weight / total_weight
Initialize hierarchical retriever.
Args: summary_retriever: Retriever for document summaries (stage 1) chunk_retriever: Retriever for detailed chunks (stage 2) stage1_top_k: Number of documents to retrieve in stage 1 stage2_top_k: Number of chunks to retrieve per document in stage 2 document_id_field: Metadata field linking chunks to documents combine_scores: If True, combine stage1 and stage2 scores stage1_weight: Weight for stage 1 score (document relevance) stage2_weight: Weight for stage 2 score (chunk relevance)
93 def retrieve( 94 self, 95 query: RetrievalQuery, 96 **kwargs 97 ) -> List[SearchResult]: 98 """ 99 Perform two-stage hierarchical retrieval. 100 101 Args: 102 query: The retrieval query 103 **kwargs: Additional arguments 104 105 Returns: 106 List of SearchResult objects from stage 2, optionally with combined scores 107 """ 108 # Stage 1: Retrieve document summaries 109 logger.info(f"Stage 1: Retrieving top {self.stage1_top_k} document summaries") 110 111 stage1_query = RetrievalQuery( 112 text=query.text, 113 embedding=query.embedding, 114 top_k=self.stage1_top_k, 115 filters=query.filters, 116 metadata=query.metadata 117 ) 118 119 summary_results = self.summary_retriever.retrieve(stage1_query, **kwargs) 120 121 if not summary_results: 122 logger.warning("Stage 1 returned no results") 123 return [] 124 125 # Extract document IDs from summaries 126 document_ids = [] 127 document_scores = {} 128 129 for result in summary_results: 130 doc_id = result.document.metadata.get(self.document_id_field) 131 if doc_id: 132 document_ids.append(doc_id) 133 document_scores[doc_id] = result.score 134 else: 135 # If no document_id field, use the document id itself 136 document_ids.append(result.document.id) 137 document_scores[result.document.id] = result.score 138 139 logger.info(f"Stage 1 identified {len(document_ids)} relevant documents") 140 141 # Stage 2: Retrieve detailed chunks from top documents 142 logger.info(f"Stage 2: Retrieving up to {self.stage2_top_k} chunks per document") 143 144 all_chunks: List[SearchResult] = [] 145 146 for doc_id in document_ids: 147 # Create filter for this document 148 doc_filter = {self.document_id_field: doc_id} 149 150 # Merge with existing filters if any 151 if query.filters: 152 doc_filter.update(query.filters) 153 154 # Retrieve chunks for this document 155 stage2_query = RetrievalQuery( 156 text=query.text, 157 embedding=query.embedding, 158 top_k=self.stage2_top_k, 159 filters=doc_filter, 160 metadata=query.metadata 161 ) 162 163 try: 164 chunk_results = self.chunk_retriever.retrieve(stage2_query, **kwargs) 165 166 if self.combine_scores and doc_id in document_scores: 167 # Combine stage 1 and stage 2 scores 168 for result in chunk_results: 169 combined_score = ( 170 self.stage1_weight * document_scores[doc_id] + 171 self.stage2_weight * result.score 172 ) 173 # Create new result with combined score 174 result.score = combined_score 175 176 all_chunks.extend(chunk_results) 177 178 except ValueError as e: 179 logger.warning(f"Failed to retrieve chunks for document {doc_id}: {e}") 180 continue 181 182 if not all_chunks: 183 logger.warning("Stage 2 returned no chunks") 184 return [] 185 186 # Sort by score and limit to top_k 187 all_chunks.sort(key=lambda x: x.score, reverse=True) 188 final_results = all_chunks[:query.top_k] 189 190 # Update ranks 191 for i, result in enumerate(final_results): 192 result.rank = i + 1 193 194 logger.info( 195 f"Hierarchical retrieval complete: {len(final_results)} chunks from " 196 f"{len(document_ids)} documents" 197 ) 198 199 return final_results
Perform two-stage hierarchical retrieval.
Args: query: The retrieval query **kwargs: Additional arguments
Returns: List of SearchResult objects from stage 2, optionally with combined scores
45class GraphRetriever(BaseRetriever): 46 """ 47 Graph-based retrieval using entity relationships. 48 49 This retriever: 50 1. Extracts entities from the query 51 2. Finds matching entities in the knowledge graph 52 3. Traverses the graph to find related entities 53 4. Retrieves documents associated with relevant entities 54 55 Example: 56 ```python 57 # Create knowledge graph 58 graph = nx.DiGraph() 59 graph.add_edge("Python", "Machine Learning", relation="used_in", weight=0.9) 60 graph.add_edge("Machine Learning", "Neural Networks", relation="includes", weight=0.8) 61 62 # Create entity-document mapping 63 entity_docs = { 64 "Python": ["doc1", "doc2"], 65 "Machine Learning": ["doc3", "doc4"], 66 "Neural Networks": ["doc5"] 67 } 68 69 # Create retriever 70 retriever = GraphRetriever( 71 vector_store=store, 72 knowledge_graph=graph, 73 entity_document_mapping=entity_docs, 74 embedder=embedder, 75 max_hops=2 76 ) 77 78 # Retrieve 79 query = RetrievalQuery(text="Python for deep learning", top_k=5) 80 results = retriever.retrieve(query) 81 ``` 82 """ 83 84 def __init__( 85 self, 86 vector_store: BaseVectorStore, 87 knowledge_graph: Optional['nx.DiGraph'] = None, 88 entity_document_mapping: Optional[Dict[str, List[str]]] = None, 89 embedder: Optional[Any] = None, 90 max_hops: int = 2, 91 min_relation_weight: float = 0.5, 92 combine_vector_scores: bool = True, 93 graph_weight: float = 0.4, 94 vector_weight: float = 0.6 95 ): 96 """ 97 Initialize graph retriever. 98 99 Args: 100 vector_store: Vector store for document retrieval 101 knowledge_graph: NetworkX DiGraph with entity relationships 102 entity_document_mapping: Dict mapping entity IDs to document IDs 103 embedder: Embedding provider for query encoding 104 max_hops: Maximum number of hops in graph traversal 105 min_relation_weight: Minimum weight for relationship edges 106 combine_vector_scores: If True, combine graph and vector scores 107 graph_weight: Weight for graph-based relevance 108 vector_weight: Weight for vector similarity 109 """ 110 if not NETWORKX_AVAILABLE: 111 raise ImportError( 112 "NetworkX is required for GraphRetriever. " 113 "Install with: pip install networkx" 114 ) 115 116 self.vector_store = vector_store 117 self.knowledge_graph = knowledge_graph or nx.DiGraph() 118 self.entity_document_mapping = entity_document_mapping or {} 119 self.embedder = embedder 120 self.max_hops = max_hops 121 self.min_relation_weight = min_relation_weight 122 self.combine_vector_scores = combine_vector_scores 123 self.graph_weight = graph_weight 124 self.vector_weight = vector_weight 125 126 # Normalize weights 127 total_weight = graph_weight + vector_weight 128 if total_weight > 0: 129 self.graph_weight = graph_weight / total_weight 130 self.vector_weight = vector_weight / total_weight 131 132 def extract_entities(self, query_text: str) -> List[str]: 133 """ 134 Extract entities from query text. 135 136 Simple implementation: looks for entity names in the knowledge graph. 137 For production, use NER models (spaCy, Hugging Face, etc.). 138 139 Args: 140 query_text: Query text 141 142 Returns: 143 List of entity IDs found in the query 144 """ 145 query_lower = query_text.lower() 146 entities = [] 147 148 # Simple matching: check if entity names appear in query 149 for entity_id in self.knowledge_graph.nodes(): 150 entity_name = str(entity_id).lower() 151 if entity_name in query_lower: 152 entities.append(entity_id) 153 154 logger.info(f"Extracted {len(entities)} entities from query: {entities}") 155 return entities 156 157 def traverse_graph( 158 self, 159 seed_entities: List[str], 160 max_hops: int 161 ) -> Dict[str, float]: 162 """ 163 Traverse knowledge graph from seed entities. 164 165 Args: 166 seed_entities: Starting entities 167 max_hops: Maximum number of hops 168 169 Returns: 170 Dict mapping entity IDs to relevance scores (0-1) 171 """ 172 entity_scores: Dict[str, float] = {} 173 174 # Initialize seed entities with score 1.0 175 for entity in seed_entities: 176 if entity in self.knowledge_graph: 177 entity_scores[entity] = 1.0 178 179 if not entity_scores: 180 return entity_scores 181 182 # BFS traversal 183 visited: Set[str] = set() 184 current_level = [(e, 1.0, 0) for e in seed_entities] # (entity, score, hop) 185 186 while current_level: 187 next_level = [] 188 189 for entity_id, score, hop in current_level: 190 if entity_id in visited or hop >= max_hops: 191 continue 192 193 visited.add(entity_id) 194 195 # Get neighbors 196 if entity_id not in self.knowledge_graph: 197 continue 198 199 for neighbor in self.knowledge_graph.neighbors(entity_id): 200 # Get edge weight 201 edge_data = self.knowledge_graph.get_edge_data(entity_id, neighbor) 202 edge_weight = edge_data.get('weight', 1.0) if edge_data else 1.0 203 204 # Skip weak relationships 205 if edge_weight < self.min_relation_weight: 206 continue 207 208 # Calculate neighbor score (decay with hops) 209 decay_factor = 0.7 ** (hop + 1) 210 neighbor_score = score * edge_weight * decay_factor 211 212 # Update score if better 213 if neighbor not in entity_scores or neighbor_score > entity_scores[neighbor]: 214 entity_scores[neighbor] = neighbor_score 215 next_level.append((neighbor, neighbor_score, hop + 1)) 216 217 current_level = next_level 218 219 logger.info(f"Graph traversal found {len(entity_scores)} relevant entities") 220 return entity_scores 221 222 def retrieve( 223 self, 224 query: RetrievalQuery, 225 **kwargs 226 ) -> List[SearchResult]: 227 """ 228 Perform graph-based retrieval. 229 230 Args: 231 query: The retrieval query 232 **kwargs: Additional arguments 233 234 Returns: 235 List of SearchResult objects ranked by combined graph + vector scores 236 """ 237 if not query.text: 238 raise ValueError("GraphRetriever requires query.text") 239 240 # Step 1: Extract entities from query 241 seed_entities = self.extract_entities(query.text) 242 243 if not seed_entities: 244 logger.warning("No entities found in query, falling back to vector search") 245 # Fallback to pure vector search 246 if self.embedder and not query.embedding: 247 query.embedding = self.embedder.embed_text(query.text) 248 249 return self.vector_store.search( 250 query_embedding=query.embedding, 251 top_k=query.top_k, 252 filters=query.filters 253 ) 254 255 # Step 2: Traverse graph to find related entities 256 entity_scores = self.traverse_graph(seed_entities, self.max_hops) 257 258 # Step 3: Collect documents associated with relevant entities 259 doc_graph_scores: Dict[str, float] = {} 260 261 for entity_id, entity_score in entity_scores.items(): 262 doc_ids = self.entity_document_mapping.get(entity_id, []) 263 for doc_id in doc_ids: 264 if doc_id not in doc_graph_scores or entity_score > doc_graph_scores[doc_id]: 265 doc_graph_scores[doc_id] = entity_score 266 267 logger.info(f"Found {len(doc_graph_scores)} documents via graph traversal") 268 269 if not doc_graph_scores: 270 logger.warning("No documents found via graph, falling back to vector search") 271 if self.embedder and not query.embedding: 272 query.embedding = self.embedder.embed_text(query.text) 273 274 return self.vector_store.search( 275 query_embedding=query.embedding, 276 top_k=query.top_k, 277 filters=query.filters 278 ) 279 280 # Step 4: Get vector scores if combining 281 if self.combine_vector_scores: 282 # Get embeddings for vector search 283 if self.embedder and not query.embedding: 284 query.embedding = self.embedder.embed_text(query.text) 285 286 # Retrieve more documents for scoring 287 vector_results = self.vector_store.search( 288 query_embedding=query.embedding, 289 top_k=query.top_k * 3, # Get more for better coverage 290 filters=query.filters 291 ) 292 293 # Create document ID to vector score mapping 294 doc_vector_scores: Dict[str, float] = {} 295 for result in vector_results: 296 doc_vector_scores[result.document.id] = result.score 297 298 # Combine scores 299 combined_results: List[SearchResult] = [] 300 301 for doc_id, graph_score in doc_graph_scores.items(): 302 # Get document from vector store 303 document = self.vector_store.get_document(doc_id) 304 if not document: 305 continue 306 307 # Get vector score (0 if not found) 308 vector_score = doc_vector_scores.get(doc_id, 0.0) 309 310 # Combine scores 311 combined_score = ( 312 self.graph_weight * graph_score + 313 self.vector_weight * vector_score 314 ) 315 316 combined_results.append( 317 SearchResult( 318 document=document, 319 score=combined_score, 320 rank=0 # Will be set later 321 ) 322 ) 323 else: 324 # Use only graph scores 325 combined_results: List[SearchResult] = [] 326 327 for doc_id, graph_score in doc_graph_scores.items(): 328 document = self.vector_store.get_document(doc_id) 329 if not document: 330 continue 331 332 combined_results.append( 333 SearchResult( 334 document=document, 335 score=graph_score, 336 rank=0 337 ) 338 ) 339 340 # Sort by score and limit to top_k 341 combined_results.sort(key=lambda x: x.score, reverse=True) 342 final_results = combined_results[:query.top_k] 343 344 # Update ranks 345 for i, result in enumerate(final_results): 346 result.rank = i + 1 347 348 logger.info(f"Graph retrieval complete: {len(final_results)} results") 349 350 return final_results
Graph-based retrieval using entity relationships.
This retriever:
- Extracts entities from the query
- Finds matching entities in the knowledge graph
- Traverses the graph to find related entities
- Retrieves documents associated with relevant entities
Example:
# Create knowledge graph
graph = nx.DiGraph()
graph.add_edge("Python", "Machine Learning", relation="used_in", weight=0.9)
graph.add_edge("Machine Learning", "Neural Networks", relation="includes", weight=0.8)
# Create entity-document mapping
entity_docs = {
"Python": ["doc1", "doc2"],
"Machine Learning": ["doc3", "doc4"],
"Neural Networks": ["doc5"]
}
# Create retriever
retriever = GraphRetriever(
vector_store=store,
knowledge_graph=graph,
entity_document_mapping=entity_docs,
embedder=embedder,
max_hops=2
)
# Retrieve
query = RetrievalQuery(text="Python for deep learning", top_k=5)
results = retriever.retrieve(query)
84 def __init__( 85 self, 86 vector_store: BaseVectorStore, 87 knowledge_graph: Optional['nx.DiGraph'] = None, 88 entity_document_mapping: Optional[Dict[str, List[str]]] = None, 89 embedder: Optional[Any] = None, 90 max_hops: int = 2, 91 min_relation_weight: float = 0.5, 92 combine_vector_scores: bool = True, 93 graph_weight: float = 0.4, 94 vector_weight: float = 0.6 95 ): 96 """ 97 Initialize graph retriever. 98 99 Args: 100 vector_store: Vector store for document retrieval 101 knowledge_graph: NetworkX DiGraph with entity relationships 102 entity_document_mapping: Dict mapping entity IDs to document IDs 103 embedder: Embedding provider for query encoding 104 max_hops: Maximum number of hops in graph traversal 105 min_relation_weight: Minimum weight for relationship edges 106 combine_vector_scores: If True, combine graph and vector scores 107 graph_weight: Weight for graph-based relevance 108 vector_weight: Weight for vector similarity 109 """ 110 if not NETWORKX_AVAILABLE: 111 raise ImportError( 112 "NetworkX is required for GraphRetriever. " 113 "Install with: pip install networkx" 114 ) 115 116 self.vector_store = vector_store 117 self.knowledge_graph = knowledge_graph or nx.DiGraph() 118 self.entity_document_mapping = entity_document_mapping or {} 119 self.embedder = embedder 120 self.max_hops = max_hops 121 self.min_relation_weight = min_relation_weight 122 self.combine_vector_scores = combine_vector_scores 123 self.graph_weight = graph_weight 124 self.vector_weight = vector_weight 125 126 # Normalize weights 127 total_weight = graph_weight + vector_weight 128 if total_weight > 0: 129 self.graph_weight = graph_weight / total_weight 130 self.vector_weight = vector_weight / total_weight
Initialize graph retriever.
Args: vector_store: Vector store for document retrieval knowledge_graph: NetworkX DiGraph with entity relationships entity_document_mapping: Dict mapping entity IDs to document IDs embedder: Embedding provider for query encoding max_hops: Maximum number of hops in graph traversal min_relation_weight: Minimum weight for relationship edges combine_vector_scores: If True, combine graph and vector scores graph_weight: Weight for graph-based relevance vector_weight: Weight for vector similarity
132 def extract_entities(self, query_text: str) -> List[str]: 133 """ 134 Extract entities from query text. 135 136 Simple implementation: looks for entity names in the knowledge graph. 137 For production, use NER models (spaCy, Hugging Face, etc.). 138 139 Args: 140 query_text: Query text 141 142 Returns: 143 List of entity IDs found in the query 144 """ 145 query_lower = query_text.lower() 146 entities = [] 147 148 # Simple matching: check if entity names appear in query 149 for entity_id in self.knowledge_graph.nodes(): 150 entity_name = str(entity_id).lower() 151 if entity_name in query_lower: 152 entities.append(entity_id) 153 154 logger.info(f"Extracted {len(entities)} entities from query: {entities}") 155 return entities
Extract entities from query text.
Simple implementation: looks for entity names in the knowledge graph. For production, use NER models (spaCy, Hugging Face, etc.).
Args: query_text: Query text
Returns: List of entity IDs found in the query
157 def traverse_graph( 158 self, 159 seed_entities: List[str], 160 max_hops: int 161 ) -> Dict[str, float]: 162 """ 163 Traverse knowledge graph from seed entities. 164 165 Args: 166 seed_entities: Starting entities 167 max_hops: Maximum number of hops 168 169 Returns: 170 Dict mapping entity IDs to relevance scores (0-1) 171 """ 172 entity_scores: Dict[str, float] = {} 173 174 # Initialize seed entities with score 1.0 175 for entity in seed_entities: 176 if entity in self.knowledge_graph: 177 entity_scores[entity] = 1.0 178 179 if not entity_scores: 180 return entity_scores 181 182 # BFS traversal 183 visited: Set[str] = set() 184 current_level = [(e, 1.0, 0) for e in seed_entities] # (entity, score, hop) 185 186 while current_level: 187 next_level = [] 188 189 for entity_id, score, hop in current_level: 190 if entity_id in visited or hop >= max_hops: 191 continue 192 193 visited.add(entity_id) 194 195 # Get neighbors 196 if entity_id not in self.knowledge_graph: 197 continue 198 199 for neighbor in self.knowledge_graph.neighbors(entity_id): 200 # Get edge weight 201 edge_data = self.knowledge_graph.get_edge_data(entity_id, neighbor) 202 edge_weight = edge_data.get('weight', 1.0) if edge_data else 1.0 203 204 # Skip weak relationships 205 if edge_weight < self.min_relation_weight: 206 continue 207 208 # Calculate neighbor score (decay with hops) 209 decay_factor = 0.7 ** (hop + 1) 210 neighbor_score = score * edge_weight * decay_factor 211 212 # Update score if better 213 if neighbor not in entity_scores or neighbor_score > entity_scores[neighbor]: 214 entity_scores[neighbor] = neighbor_score 215 next_level.append((neighbor, neighbor_score, hop + 1)) 216 217 current_level = next_level 218 219 logger.info(f"Graph traversal found {len(entity_scores)} relevant entities") 220 return entity_scores
Traverse knowledge graph from seed entities.
Args: seed_entities: Starting entities max_hops: Maximum number of hops
Returns: Dict mapping entity IDs to relevance scores (0-1)
222 def retrieve( 223 self, 224 query: RetrievalQuery, 225 **kwargs 226 ) -> List[SearchResult]: 227 """ 228 Perform graph-based retrieval. 229 230 Args: 231 query: The retrieval query 232 **kwargs: Additional arguments 233 234 Returns: 235 List of SearchResult objects ranked by combined graph + vector scores 236 """ 237 if not query.text: 238 raise ValueError("GraphRetriever requires query.text") 239 240 # Step 1: Extract entities from query 241 seed_entities = self.extract_entities(query.text) 242 243 if not seed_entities: 244 logger.warning("No entities found in query, falling back to vector search") 245 # Fallback to pure vector search 246 if self.embedder and not query.embedding: 247 query.embedding = self.embedder.embed_text(query.text) 248 249 return self.vector_store.search( 250 query_embedding=query.embedding, 251 top_k=query.top_k, 252 filters=query.filters 253 ) 254 255 # Step 2: Traverse graph to find related entities 256 entity_scores = self.traverse_graph(seed_entities, self.max_hops) 257 258 # Step 3: Collect documents associated with relevant entities 259 doc_graph_scores: Dict[str, float] = {} 260 261 for entity_id, entity_score in entity_scores.items(): 262 doc_ids = self.entity_document_mapping.get(entity_id, []) 263 for doc_id in doc_ids: 264 if doc_id not in doc_graph_scores or entity_score > doc_graph_scores[doc_id]: 265 doc_graph_scores[doc_id] = entity_score 266 267 logger.info(f"Found {len(doc_graph_scores)} documents via graph traversal") 268 269 if not doc_graph_scores: 270 logger.warning("No documents found via graph, falling back to vector search") 271 if self.embedder and not query.embedding: 272 query.embedding = self.embedder.embed_text(query.text) 273 274 return self.vector_store.search( 275 query_embedding=query.embedding, 276 top_k=query.top_k, 277 filters=query.filters 278 ) 279 280 # Step 4: Get vector scores if combining 281 if self.combine_vector_scores: 282 # Get embeddings for vector search 283 if self.embedder and not query.embedding: 284 query.embedding = self.embedder.embed_text(query.text) 285 286 # Retrieve more documents for scoring 287 vector_results = self.vector_store.search( 288 query_embedding=query.embedding, 289 top_k=query.top_k * 3, # Get more for better coverage 290 filters=query.filters 291 ) 292 293 # Create document ID to vector score mapping 294 doc_vector_scores: Dict[str, float] = {} 295 for result in vector_results: 296 doc_vector_scores[result.document.id] = result.score 297 298 # Combine scores 299 combined_results: List[SearchResult] = [] 300 301 for doc_id, graph_score in doc_graph_scores.items(): 302 # Get document from vector store 303 document = self.vector_store.get_document(doc_id) 304 if not document: 305 continue 306 307 # Get vector score (0 if not found) 308 vector_score = doc_vector_scores.get(doc_id, 0.0) 309 310 # Combine scores 311 combined_score = ( 312 self.graph_weight * graph_score + 313 self.vector_weight * vector_score 314 ) 315 316 combined_results.append( 317 SearchResult( 318 document=document, 319 score=combined_score, 320 rank=0 # Will be set later 321 ) 322 ) 323 else: 324 # Use only graph scores 325 combined_results: List[SearchResult] = [] 326 327 for doc_id, graph_score in doc_graph_scores.items(): 328 document = self.vector_store.get_document(doc_id) 329 if not document: 330 continue 331 332 combined_results.append( 333 SearchResult( 334 document=document, 335 score=graph_score, 336 rank=0 337 ) 338 ) 339 340 # Sort by score and limit to top_k 341 combined_results.sort(key=lambda x: x.score, reverse=True) 342 final_results = combined_results[:query.top_k] 343 344 # Update ranks 345 for i, result in enumerate(final_results): 346 result.rank = i + 1 347 348 logger.info(f"Graph retrieval complete: {len(final_results)} results") 349 350 return final_results
Perform graph-based retrieval.
Args: query: The retrieval query **kwargs: Additional arguments
Returns: List of SearchResult objects ranked by combined graph + vector scores
38class SQLRetriever(BaseRetriever): 39 """ 40 Retrieval from structured databases using SQL queries. 41 42 This retriever: 43 1. Converts natural language queries to SQL (via LLM or rule-based) 44 2. Executes SQL against a database 45 3. Converts results to Document objects 46 4. Returns SearchResult objects with relevance scores 47 48 Example: 49 ```python 50 import sqlite3 51 52 # Create database connection 53 conn = sqlite3.connect("products.db") 54 55 # Define schema 56 schema = SQLSchema( 57 table_name="products", 58 columns=[ 59 {"name": "id", "type": "INTEGER", "description": "Product ID"}, 60 {"name": "name", "type": "TEXT", "description": "Product name"}, 61 {"name": "price", "type": "REAL", "description": "Price in USD"}, 62 {"name": "category", "type": "TEXT", "description": "Product category"} 63 ], 64 primary_key="id", 65 description="E-commerce product catalog" 66 ) 67 68 # Create retriever 69 retriever = SQLRetriever( 70 db_connection=conn, 71 schema=schema, 72 text_to_sql_fn=my_text_to_sql_function, 73 db_type="sqlite" 74 ) 75 76 # Retrieve 77 query = RetrievalQuery(text="products under $100 in electronics", top_k=10) 78 results = retriever.retrieve(query) 79 ``` 80 """ 81 82 def __init__( 83 self, 84 db_connection: Any, 85 schema: SQLSchema, 86 text_to_sql_fn: Optional[Callable[[str, SQLSchema], SQLQuery]] = None, 87 db_type: str = "sqlite", 88 content_columns: Optional[List[str]] = None, 89 score_column: Optional[str] = None, 90 default_score: float = 1.0 91 ): 92 """ 93 Initialize SQL retriever. 94 95 Args: 96 db_connection: Database connection object (e.g., sqlite3.Connection) 97 schema: Database schema information 98 text_to_sql_fn: Function to convert text to SQL (None = use simple rules) 99 db_type: Database type ("sqlite", "postgresql", "mysql") 100 content_columns: Columns to use for document content (None = all columns) 101 score_column: Column to use for relevance score (None = use default_score) 102 default_score: Default score when no score_column specified 103 """ 104 self.db_connection = db_connection 105 self.schema = schema 106 self.text_to_sql_fn = text_to_sql_fn or self._simple_text_to_sql 107 self.db_type = db_type 108 self.content_columns = content_columns 109 self.score_column = score_column 110 self.default_score = default_score 111 112 def _simple_text_to_sql(self, query_text: str, schema: SQLSchema) -> SQLQuery: 113 """ 114 Simple rule-based text-to-SQL conversion. 115 116 For production, use LLM-based conversion or specialized libraries. 117 118 Args: 119 query_text: Natural language query 120 schema: Database schema 121 122 Returns: 123 SQLQuery object 124 """ 125 # Extract keywords for filtering 126 query_lower = query_text.lower() 127 128 # Build SELECT clause 129 if self.content_columns: 130 columns = ", ".join(self.content_columns) 131 else: 132 columns = "*" 133 134 # Add score column if specified 135 if self.score_column: 136 columns = f"{columns}, {self.score_column}" 137 138 # Build basic SELECT 139 sql = f"SELECT {columns} FROM {schema.table_name}" 140 141 # Simple WHERE clause based on keywords 142 conditions = [] 143 144 # Look for numeric comparisons 145 if "under" in query_lower or "less than" in query_lower or "<" in query_lower: 146 # Try to find price column 147 price_cols = [c["name"] for c in schema.columns if "price" in c["name"].lower() or "cost" in c["name"].lower()] 148 if price_cols: 149 # Extract number 150 import re 151 numbers = re.findall(r'\d+', query_text) 152 if numbers: 153 conditions.append(f"{price_cols[0]} < {numbers[0]}") 154 155 if "over" in query_lower or "more than" in query_lower or ">" in query_lower: 156 price_cols = [c["name"] for c in schema.columns if "price" in c["name"].lower() or "cost" in c["name"].lower()] 157 if price_cols: 158 import re 159 numbers = re.findall(r'\d+', query_text) 160 if numbers: 161 conditions.append(f"{price_cols[0]} > {numbers[0]}") 162 163 # Look for text columns to search 164 text_cols = [c["name"] for c in schema.columns if c["type"].upper() in ["TEXT", "VARCHAR", "STRING"]] 165 166 # Build LIKE conditions for text search 167 search_terms = [] 168 # Remove common words 169 stop_words = {"in", "the", "a", "an", "under", "over", "less", "more", "than", "products", "items"} 170 words = [w for w in query_lower.split() if w not in stop_words and not w.isdigit()] 171 172 for word in words: 173 if len(word) > 2: # Skip very short words 174 search_terms.append(word) 175 176 if search_terms and text_cols: 177 like_conditions = [] 178 for col in text_cols: 179 for term in search_terms: 180 like_conditions.append(f"{col} LIKE '%{term}%'") 181 182 if like_conditions: 183 conditions.append(f"({' OR '.join(like_conditions)})") 184 185 # Add WHERE clause if conditions exist 186 if conditions: 187 sql += " WHERE " + " AND ".join(conditions) 188 189 # Add LIMIT 190 sql += " LIMIT 100" # Safety limit 191 192 return SQLQuery( 193 sql=sql, 194 explanation=f"Simple keyword-based SQL generation" 195 ) 196 197 def _execute_sql(self, sql_query: SQLQuery) -> List[Dict[str, Any]]: 198 """ 199 Execute SQL query and return results. 200 201 Args: 202 sql_query: SQL query to execute 203 204 Returns: 205 List of result rows as dictionaries 206 """ 207 cursor = self.db_connection.cursor() 208 209 try: 210 logger.info(f"Executing SQL: {sql_query.sql}") 211 212 if sql_query.parameters: 213 cursor.execute(sql_query.sql, sql_query.parameters) 214 else: 215 cursor.execute(sql_query.sql) 216 217 # Get column names 218 if cursor.description: 219 columns = [desc[0] for desc in cursor.description] 220 221 # Fetch results 222 rows = cursor.fetchall() 223 224 # Convert to dictionaries 225 results = [] 226 for row in rows: 227 result_dict = dict(zip(columns, row)) 228 results.append(result_dict) 229 230 logger.info(f"SQL returned {len(results)} rows") 231 return results 232 else: 233 return [] 234 235 except Exception as e: 236 logger.error(f"SQL execution failed: {e}") 237 raise 238 finally: 239 cursor.close() 240 241 def _row_to_document(self, row: Dict[str, Any], index: int) -> Document: 242 """ 243 Convert database row to Document object. 244 245 Args: 246 row: Database row as dictionary 247 index: Row index for ID generation 248 249 Returns: 250 Document object 251 """ 252 # Extract score if available 253 score = self.default_score 254 if self.score_column and self.score_column in row: 255 score = float(row[self.score_column]) 256 257 # Build content from specified columns or all columns 258 if self.content_columns: 259 content_parts = [] 260 for col in self.content_columns: 261 if col in row: 262 content_parts.append(f"{col}: {row[col]}") 263 content = "\n".join(content_parts) 264 else: 265 # Use JSON representation 266 content = json.dumps(row, indent=2, default=str) 267 268 # Use primary key if available, otherwise use index 269 doc_id = str(row.get(self.schema.primary_key, f"sql_result_{index}")) 270 271 # Store all row data in metadata 272 metadata = { 273 "source": "sql_database", 274 "table": self.schema.table_name, 275 "row_data": row, 276 "sql_score": score 277 } 278 279 return Document( 280 id=doc_id, 281 content=content, 282 metadata=metadata 283 ) 284 285 def retrieve( 286 self, 287 query: RetrievalQuery, 288 **kwargs 289 ) -> List[SearchResult]: 290 """ 291 Perform SQL-based retrieval. 292 293 Args: 294 query: The retrieval query 295 **kwargs: Additional arguments 296 297 Returns: 298 List of SearchResult objects from database query 299 """ 300 if not query.text: 301 raise ValueError("SQLRetriever requires query.text") 302 303 # Convert text to SQL 304 sql_query = self.text_to_sql_fn(query.text, self.schema) 305 306 logger.info(f"Generated SQL: {sql_query.sql}") 307 if sql_query.explanation: 308 logger.info(f"Explanation: {sql_query.explanation}") 309 310 # Execute SQL 311 rows = self._execute_sql(sql_query) 312 313 if not rows: 314 logger.warning("SQL query returned no results") 315 return [] 316 317 # Convert rows to Documents 318 documents = [ 319 self._row_to_document(row, i) 320 for i, row in enumerate(rows) 321 ] 322 323 # Create SearchResult objects 324 results = [ 325 SearchResult( 326 document=doc, 327 score=doc.metadata.get("sql_score", self.default_score), 328 rank=i + 1 329 ) 330 for i, doc in enumerate(documents) 331 ] 332 333 # Limit to top_k 334 results = results[:query.top_k] 335 336 # Update ranks 337 for i, result in enumerate(results): 338 result.rank = i + 1 339 340 logger.info(f"SQL retrieval complete: {len(results)} results") 341 342 return results
Retrieval from structured databases using SQL queries.
This retriever:
- Converts natural language queries to SQL (via LLM or rule-based)
- Executes SQL against a database
- Converts results to Document objects
- Returns SearchResult objects with relevance scores
Example:
import sqlite3
# Create database connection
conn = sqlite3.connect("products.db")
# Define schema
schema = SQLSchema(
table_name="products",
columns=[
{"name": "id", "type": "INTEGER", "description": "Product ID"},
{"name": "name", "type": "TEXT", "description": "Product name"},
{"name": "price", "type": "REAL", "description": "Price in USD"},
{"name": "category", "type": "TEXT", "description": "Product category"}
],
primary_key="id",
description="E-commerce product catalog"
)
# Create retriever
retriever = SQLRetriever(
db_connection=conn,
schema=schema,
text_to_sql_fn=my_text_to_sql_function,
db_type="sqlite"
)
# Retrieve
query = RetrievalQuery(text="products under $100 in electronics", top_k=10)
results = retriever.retrieve(query)
82 def __init__( 83 self, 84 db_connection: Any, 85 schema: SQLSchema, 86 text_to_sql_fn: Optional[Callable[[str, SQLSchema], SQLQuery]] = None, 87 db_type: str = "sqlite", 88 content_columns: Optional[List[str]] = None, 89 score_column: Optional[str] = None, 90 default_score: float = 1.0 91 ): 92 """ 93 Initialize SQL retriever. 94 95 Args: 96 db_connection: Database connection object (e.g., sqlite3.Connection) 97 schema: Database schema information 98 text_to_sql_fn: Function to convert text to SQL (None = use simple rules) 99 db_type: Database type ("sqlite", "postgresql", "mysql") 100 content_columns: Columns to use for document content (None = all columns) 101 score_column: Column to use for relevance score (None = use default_score) 102 default_score: Default score when no score_column specified 103 """ 104 self.db_connection = db_connection 105 self.schema = schema 106 self.text_to_sql_fn = text_to_sql_fn or self._simple_text_to_sql 107 self.db_type = db_type 108 self.content_columns = content_columns 109 self.score_column = score_column 110 self.default_score = default_score
Initialize SQL retriever.
Args: db_connection: Database connection object (e.g., sqlite3.Connection) schema: Database schema information text_to_sql_fn: Function to convert text to SQL (None = use simple rules) db_type: Database type ("sqlite", "postgresql", "mysql") content_columns: Columns to use for document content (None = all columns) score_column: Column to use for relevance score (None = use default_score) default_score: Default score when no score_column specified
285 def retrieve( 286 self, 287 query: RetrievalQuery, 288 **kwargs 289 ) -> List[SearchResult]: 290 """ 291 Perform SQL-based retrieval. 292 293 Args: 294 query: The retrieval query 295 **kwargs: Additional arguments 296 297 Returns: 298 List of SearchResult objects from database query 299 """ 300 if not query.text: 301 raise ValueError("SQLRetriever requires query.text") 302 303 # Convert text to SQL 304 sql_query = self.text_to_sql_fn(query.text, self.schema) 305 306 logger.info(f"Generated SQL: {sql_query.sql}") 307 if sql_query.explanation: 308 logger.info(f"Explanation: {sql_query.explanation}") 309 310 # Execute SQL 311 rows = self._execute_sql(sql_query) 312 313 if not rows: 314 logger.warning("SQL query returned no results") 315 return [] 316 317 # Convert rows to Documents 318 documents = [ 319 self._row_to_document(row, i) 320 for i, row in enumerate(rows) 321 ] 322 323 # Create SearchResult objects 324 results = [ 325 SearchResult( 326 document=doc, 327 score=doc.metadata.get("sql_score", self.default_score), 328 rank=i + 1 329 ) 330 for i, doc in enumerate(documents) 331 ] 332 333 # Limit to top_k 334 results = results[:query.top_k] 335 336 # Update ranks 337 for i, result in enumerate(results): 338 result.rank = i + 1 339 340 logger.info(f"SQL retrieval complete: {len(results)} results") 341 342 return results
Perform SQL-based retrieval.
Args: query: The retrieval query **kwargs: Additional arguments
Returns: List of SearchResult objects from database query
21@dataclass 22class SQLSchema: 23 """Represents a database schema.""" 24 table_name: str 25 columns: List[Dict[str, str]] # [{"name": "col1", "type": "int", "description": "..."}] 26 primary_key: Optional[str] = None 27 description: Optional[str] = None
Represents a database schema.
30@dataclass 31class SQLQuery: 32 """Represents a generated SQL query.""" 33 sql: str 34 parameters: Dict[str, Any] = field(default_factory=dict) 35 explanation: Optional[str] = None
Represents a generated SQL query.
31class MultiIndexRetriever(BaseRetriever): 32 """ 33 Retrieve from multiple indices/sources and merge results. 34 35 This retriever: 36 1. Queries multiple independent retrievers in parallel (or sequentially) 37 2. Merges results from all sources 38 3. Applies source-specific weights and boosts 39 4. Re-ranks using RRF or weighted scoring 40 41 Use cases: 42 - Multi-domain search (HR + Finance + IT) 43 - Federated search across multiple databases 44 - Combining different data sources (documents + SQL + graph) 45 46 Example: 47 ```python 48 # Create retrievers for different sources 49 hr_retriever = VectorRetriever(hr_store, embedder) 50 finance_retriever = VectorRetriever(finance_store, embedder) 51 it_retriever = VectorRetriever(it_store, embedder) 52 53 # Create multi-index retriever 54 retriever = MultiIndexRetriever( 55 sources=[ 56 SourceConfig("HR", hr_retriever, weight=1.0), 57 SourceConfig("Finance", finance_retriever, weight=1.5), 58 SourceConfig("IT", it_retriever, weight=1.0) 59 ], 60 fusion_strategy="rrf" 61 ) 62 63 # Retrieve across all sources 64 query = RetrievalQuery(text="employee benefits policy", top_k=10) 65 results = retriever.retrieve(query) 66 ``` 67 """ 68 69 def __init__( 70 self, 71 sources: List[SourceConfig], 72 fusion_strategy: str = "rrf", 73 rrf_k: int = 60, 74 normalize_scores: bool = True 75 ): 76 """ 77 Initialize multi-index retriever. 78 79 Args: 80 sources: List of SourceConfig objects defining retrievers 81 fusion_strategy: Strategy for merging results: 82 - "rrf": Reciprocal Rank Fusion 83 - "weighted_average": Weighted average of normalized scores 84 - "max_score": Take maximum score across sources 85 rrf_k: Constant for RRF formula (default 60) 86 normalize_scores: Normalize scores to [0,1] before fusion 87 """ 88 self.sources = sources 89 self.fusion_strategy = fusion_strategy 90 self.rrf_k = rrf_k 91 self.normalize_scores = normalize_scores 92 93 # Validate sources 94 if not sources: 95 raise ValueError("At least one source must be provided") 96 97 # Normalize weights 98 total_weight = sum(s.weight for s in sources if s.enabled) 99 if total_weight > 0: 100 for source in sources: 101 source.weight = source.weight / total_weight 102 103 def _normalize_scores_for_source( 104 self, 105 results: List[SearchResult] 106 ) -> List[SearchResult]: 107 """ 108 Normalize scores to [0, 1] range. 109 110 Args: 111 results: Search results 112 113 Returns: 114 Results with normalized scores 115 """ 116 if not results: 117 return results 118 119 scores = [r.score for r in results] 120 min_score = min(scores) 121 max_score = max(scores) 122 123 if max_score - min_score == 0: 124 # All scores are the same 125 for result in results: 126 result.score = 1.0 127 else: 128 for result in results: 129 result.score = (result.score - min_score) / (max_score - min_score) 130 131 return results 132 133 def _rrf_fusion( 134 self, 135 source_results: Dict[str, List[SearchResult]] 136 ) -> List[SearchResult]: 137 """ 138 Merge results using Reciprocal Rank Fusion. 139 140 Args: 141 source_results: Dict mapping source names to their results 142 143 Returns: 144 Merged and re-ranked results 145 """ 146 # Track all unique documents and their RRF scores 147 doc_rrf_scores: Dict[str, float] = {} 148 doc_objects: Dict[str, SearchResult] = {} 149 150 # Calculate RRF scores 151 for source_name, results in source_results.items(): 152 # Get source config 153 source_config = next((s for s in self.sources if s.name == source_name), None) 154 if not source_config or not source_config.enabled: 155 continue 156 157 weight = source_config.weight 158 boost = source_config.boost_factor 159 160 for result in results: 161 doc_id = result.document.id 162 163 # RRF score: weight / (k + rank) 164 rrf_score = (weight * boost) / (self.rrf_k + result.rank) 165 166 if doc_id not in doc_rrf_scores: 167 doc_rrf_scores[doc_id] = rrf_score 168 doc_objects[doc_id] = result 169 # Add source metadata 170 result.document.metadata["retrieval_source"] = source_name 171 else: 172 doc_rrf_scores[doc_id] += rrf_score 173 # If document appears in multiple sources, mark it 174 existing_source = doc_objects[doc_id].document.metadata.get("retrieval_source", "") 175 if source_name not in existing_source: 176 doc_objects[doc_id].document.metadata["retrieval_source"] = f"{existing_source}, {source_name}" 177 178 # Create final results 179 merged_results = [] 180 for doc_id, rrf_score in doc_rrf_scores.items(): 181 result = doc_objects[doc_id] 182 result.score = rrf_score 183 merged_results.append(result) 184 185 # Sort by RRF score 186 merged_results.sort(key=lambda x: x.score, reverse=True) 187 188 return merged_results 189 190 def _weighted_average_fusion( 191 self, 192 source_results: Dict[str, List[SearchResult]] 193 ) -> List[SearchResult]: 194 """ 195 Merge results using weighted average of scores. 196 197 Args: 198 source_results: Dict mapping source names to their results 199 200 Returns: 201 Merged and re-ranked results 202 """ 203 # Track all unique documents and their weighted scores 204 doc_weighted_scores: Dict[str, float] = {} 205 doc_score_counts: Dict[str, int] = {} 206 doc_objects: Dict[str, SearchResult] = {} 207 208 # Calculate weighted scores 209 for source_name, results in source_results.items(): 210 # Get source config 211 source_config = next((s for s in self.sources if s.name == source_name), None) 212 if not source_config or not source_config.enabled: 213 continue 214 215 # Normalize scores for this source 216 if self.normalize_scores: 217 results = self._normalize_scores_for_source(results) 218 219 weight = source_config.weight 220 boost = source_config.boost_factor 221 222 for result in results: 223 doc_id = result.document.id 224 weighted_score = result.score * weight * boost 225 226 if doc_id not in doc_weighted_scores: 227 doc_weighted_scores[doc_id] = weighted_score 228 doc_score_counts[doc_id] = 1 229 doc_objects[doc_id] = result 230 result.document.metadata["retrieval_source"] = source_name 231 else: 232 doc_weighted_scores[doc_id] += weighted_score 233 doc_score_counts[doc_id] += 1 234 existing_source = doc_objects[doc_id].document.metadata.get("retrieval_source", "") 235 if source_name not in existing_source: 236 doc_objects[doc_id].document.metadata["retrieval_source"] = f"{existing_source}, {source_name}" 237 238 # Create final results with averaged scores 239 merged_results = [] 240 for doc_id, total_score in doc_weighted_scores.items(): 241 result = doc_objects[doc_id] 242 # Average the weighted scores 243 result.score = total_score / doc_score_counts[doc_id] 244 merged_results.append(result) 245 246 # Sort by score 247 merged_results.sort(key=lambda x: x.score, reverse=True) 248 249 return merged_results 250 251 def _max_score_fusion( 252 self, 253 source_results: Dict[str, List[SearchResult]] 254 ) -> List[SearchResult]: 255 """ 256 Merge results using maximum score across sources. 257 258 Args: 259 source_results: Dict mapping source names to their results 260 261 Returns: 262 Merged and re-ranked results 263 """ 264 # Track all unique documents and their max scores 265 doc_max_scores: Dict[str, float] = {} 266 doc_objects: Dict[str, SearchResult] = {} 267 268 # Find max scores 269 for source_name, results in source_results.items(): 270 # Get source config 271 source_config = next((s for s in self.sources if s.name == source_name), None) 272 if not source_config or not source_config.enabled: 273 continue 274 275 # Normalize scores for this source 276 if self.normalize_scores: 277 results = self._normalize_scores_for_source(results) 278 279 boost = source_config.boost_factor 280 281 for result in results: 282 doc_id = result.document.id 283 boosted_score = result.score * boost 284 285 if doc_id not in doc_max_scores or boosted_score > doc_max_scores[doc_id]: 286 doc_max_scores[doc_id] = boosted_score 287 doc_objects[doc_id] = result 288 result.document.metadata["retrieval_source"] = source_name 289 else: 290 # Update source metadata if from multiple sources 291 existing_source = doc_objects[doc_id].document.metadata.get("retrieval_source", "") 292 if source_name not in existing_source: 293 doc_objects[doc_id].document.metadata["retrieval_source"] = f"{existing_source}, {source_name}" 294 295 # Create final results 296 merged_results = [] 297 for doc_id, max_score in doc_max_scores.items(): 298 result = doc_objects[doc_id] 299 result.score = max_score 300 merged_results.append(result) 301 302 # Sort by score 303 merged_results.sort(key=lambda x: x.score, reverse=True) 304 305 return merged_results 306 307 def retrieve( 308 self, 309 query: RetrievalQuery, 310 **kwargs 311 ) -> List[SearchResult]: 312 """ 313 Perform multi-source retrieval. 314 315 Args: 316 query: The retrieval query 317 **kwargs: Additional arguments passed to each retriever 318 319 Returns: 320 List of SearchResult objects merged from all sources 321 """ 322 # Retrieve from each source 323 source_results: Dict[str, List[SearchResult]] = {} 324 325 for source in self.sources: 326 if not source.enabled: 327 logger.info(f"Skipping disabled source: {source.name}") 328 continue 329 330 try: 331 logger.info(f"Retrieving from source: {source.name}") 332 results = source.retriever.retrieve(query, **kwargs) 333 source_results[source.name] = results 334 logger.info(f"Source {source.name} returned {len(results)} results") 335 336 except Exception as e: 337 logger.warning(f"Retrieval from source {source.name} failed: {e}") 338 source_results[source.name] = [] 339 340 # Check if we got any results 341 total_results = sum(len(results) for results in source_results.values()) 342 if total_results == 0: 343 logger.warning("No results from any source") 344 return [] 345 346 # Merge results based on fusion strategy 347 logger.info(f"Merging results using {self.fusion_strategy} strategy") 348 349 if self.fusion_strategy == "rrf": 350 merged_results = self._rrf_fusion(source_results) 351 elif self.fusion_strategy == "weighted_average": 352 merged_results = self._weighted_average_fusion(source_results) 353 elif self.fusion_strategy == "max_score": 354 merged_results = self._max_score_fusion(source_results) 355 else: 356 raise ValueError(f"Unknown fusion strategy: {self.fusion_strategy}") 357 358 # Limit to top_k 359 final_results = merged_results[:query.top_k] 360 361 # Update ranks 362 for i, result in enumerate(final_results): 363 result.rank = i + 1 364 365 logger.info( 366 f"Multi-index retrieval complete: {len(final_results)} results " 367 f"from {len(source_results)} sources" 368 ) 369 370 return final_results
Retrieve from multiple indices/sources and merge results.
This retriever:
- Queries multiple independent retrievers in parallel (or sequentially)
- Merges results from all sources
- Applies source-specific weights and boosts
- Re-ranks using RRF or weighted scoring
Use cases:
- Multi-domain search (HR + Finance + IT)
- Federated search across multiple databases
- Combining different data sources (documents + SQL + graph)
Example:
# Create retrievers for different sources
hr_retriever = VectorRetriever(hr_store, embedder)
finance_retriever = VectorRetriever(finance_store, embedder)
it_retriever = VectorRetriever(it_store, embedder)
# Create multi-index retriever
retriever = MultiIndexRetriever(
sources=[
SourceConfig("HR", hr_retriever, weight=1.0),
SourceConfig("Finance", finance_retriever, weight=1.5),
SourceConfig("IT", it_retriever, weight=1.0)
],
fusion_strategy="rrf"
)
# Retrieve across all sources
query = RetrievalQuery(text="employee benefits policy", top_k=10)
results = retriever.retrieve(query)
69 def __init__( 70 self, 71 sources: List[SourceConfig], 72 fusion_strategy: str = "rrf", 73 rrf_k: int = 60, 74 normalize_scores: bool = True 75 ): 76 """ 77 Initialize multi-index retriever. 78 79 Args: 80 sources: List of SourceConfig objects defining retrievers 81 fusion_strategy: Strategy for merging results: 82 - "rrf": Reciprocal Rank Fusion 83 - "weighted_average": Weighted average of normalized scores 84 - "max_score": Take maximum score across sources 85 rrf_k: Constant for RRF formula (default 60) 86 normalize_scores: Normalize scores to [0,1] before fusion 87 """ 88 self.sources = sources 89 self.fusion_strategy = fusion_strategy 90 self.rrf_k = rrf_k 91 self.normalize_scores = normalize_scores 92 93 # Validate sources 94 if not sources: 95 raise ValueError("At least one source must be provided") 96 97 # Normalize weights 98 total_weight = sum(s.weight for s in sources if s.enabled) 99 if total_weight > 0: 100 for source in sources: 101 source.weight = source.weight / total_weight
Initialize multi-index retriever.
Args: sources: List of SourceConfig objects defining retrievers fusion_strategy: Strategy for merging results: - "rrf": Reciprocal Rank Fusion - "weighted_average": Weighted average of normalized scores - "max_score": Take maximum score across sources rrf_k: Constant for RRF formula (default 60) normalize_scores: Normalize scores to [0,1] before fusion
307 def retrieve( 308 self, 309 query: RetrievalQuery, 310 **kwargs 311 ) -> List[SearchResult]: 312 """ 313 Perform multi-source retrieval. 314 315 Args: 316 query: The retrieval query 317 **kwargs: Additional arguments passed to each retriever 318 319 Returns: 320 List of SearchResult objects merged from all sources 321 """ 322 # Retrieve from each source 323 source_results: Dict[str, List[SearchResult]] = {} 324 325 for source in self.sources: 326 if not source.enabled: 327 logger.info(f"Skipping disabled source: {source.name}") 328 continue 329 330 try: 331 logger.info(f"Retrieving from source: {source.name}") 332 results = source.retriever.retrieve(query, **kwargs) 333 source_results[source.name] = results 334 logger.info(f"Source {source.name} returned {len(results)} results") 335 336 except Exception as e: 337 logger.warning(f"Retrieval from source {source.name} failed: {e}") 338 source_results[source.name] = [] 339 340 # Check if we got any results 341 total_results = sum(len(results) for results in source_results.values()) 342 if total_results == 0: 343 logger.warning("No results from any source") 344 return [] 345 346 # Merge results based on fusion strategy 347 logger.info(f"Merging results using {self.fusion_strategy} strategy") 348 349 if self.fusion_strategy == "rrf": 350 merged_results = self._rrf_fusion(source_results) 351 elif self.fusion_strategy == "weighted_average": 352 merged_results = self._weighted_average_fusion(source_results) 353 elif self.fusion_strategy == "max_score": 354 merged_results = self._max_score_fusion(source_results) 355 else: 356 raise ValueError(f"Unknown fusion strategy: {self.fusion_strategy}") 357 358 # Limit to top_k 359 final_results = merged_results[:query.top_k] 360 361 # Update ranks 362 for i, result in enumerate(final_results): 363 result.rank = i + 1 364 365 logger.info( 366 f"Multi-index retrieval complete: {len(final_results)} results " 367 f"from {len(source_results)} sources" 368 ) 369 370 return final_results
Perform multi-source retrieval.
Args: query: The retrieval query **kwargs: Additional arguments passed to each retriever
Returns: List of SearchResult objects merged from all sources
21@dataclass 22class SourceConfig: 23 """Configuration for a retrieval source.""" 24 name: str 25 retriever: BaseRetriever 26 weight: float = 1.0 27 boost_factor: float = 1.0 28 enabled: bool = True
Configuration for a retrieval source.
20class BaseConnector(ABC): 21 """ 22 Abstract base class for all data source connectors. 23 24 Connectors handle one concern only: sourcing raw content and converting 25 it to Documents. They do NOT chunk, embed, or index — those steps follow. 26 27 Contract for all implementations: 28 - Every returned Document must have a non-empty ``id`` and ``content``. 29 - ``embedding`` is always ``None`` — the caller is responsible for embedding. 30 - Source-specific metadata (path, URL, container, etc.) must be stored in 31 ``document.metadata`` under consistent, documented keys. 32 - Files that cannot be read should be skipped with a printed warning rather 33 than crashing the entire load. 34 """ 35 36 @abstractmethod 37 def load(self) -> List[Document]: 38 """ 39 Load documents from the data source. 40 41 Returns: 42 List of Document objects with ``id``, ``content``, ``timestamp``, 43 and ``metadata`` populated. ``embedding`` is always ``None``. 44 """
Abstract base class for all data source connectors.
Connectors handle one concern only: sourcing raw content and converting it to Documents. They do NOT chunk, embed, or index — those steps follow.
Contract for all implementations:
- Every returned Document must have a non-empty
idandcontent. embeddingis alwaysNone— the caller is responsible for embedding.- Source-specific metadata (path, URL, container, etc.) must be stored in
document.metadataunder consistent, documented keys. - Files that cannot be read should be skipped with a printed warning rather than crashing the entire load.
36 @abstractmethod 37 def load(self) -> List[Document]: 38 """ 39 Load documents from the data source. 40 41 Returns: 42 List of Document objects with ``id``, ``content``, ``timestamp``, 43 and ``metadata`` populated. ``embedding`` is always ``None``. 44 """
Load documents from the data source.
Returns:
List of Document objects with id, content, timestamp,
and metadata populated. embedding is always None.
59class FilesystemConnector(BaseConnector): 60 """ 61 Loads documents from a local directory into Document objects. 62 63 Scans a root directory (optionally recursively) and converts each matching 64 file into a Document. The document ``id`` is a stable MD5 hash of the 65 absolute file path, so re-runs produce consistent IDs for upsert workflows. 66 67 Supported formats: 68 - Native text: .txt, .md, .rst, .csv, .json, .yaml, .yml, .py, .js, .ts, 69 .java, .cpp, .c, .h, .cs, .html, .htm, .xml, .toml, .ini, .cfg, .env 70 - PDF (.pdf) — requires ``pip install pypdf`` 71 - Word Document (.docx) — requires ``pip install python-docx`` 72 73 Metadata keys set on every returned Document: 74 ``source`` Absolute file path as a string 75 ``file_name`` File name including extension 76 ``extension`` Lowercase extension including dot (e.g. ``".md"``) 77 ``size_bytes`` File size in bytes 78 ``modified_at`` Last modification time in ISO 8601 format 79 80 Example: 81 ```python 82 from gmf_forge_ai_data.connectors import FilesystemConnector 83 84 connector = FilesystemConnector( 85 root_path="/data/docs", 86 extensions=[".txt", ".md", ".pdf"], 87 recursive=True, 88 ) 89 docs = connector.load() 90 # docs is List[Document] — pass to a chunker next 91 ``` 92 """ 93 94 def __init__( 95 self, 96 root_path: Union[str, Path], 97 extensions: Optional[List[str]] = None, 98 recursive: bool = True, 99 encoding: str = "utf-8-sig", 100 skip_empty: bool = True, 101 ): 102 """ 103 Args: 104 root_path: Root directory to scan. 105 extensions: Explicit list of extensions to include (e.g. 106 ``[".txt", ".md"]``). Include the leading dot. 107 If ``None``, all supported formats are loaded. 108 recursive: If ``True`` (default), scan subdirectories recursively. 109 encoding: Text encoding for native text files (default ``"utf-8"``). 110 skip_empty: If ``True`` (default), skip files that produce no 111 text content after stripping whitespace. 112 """ 113 self.root_path = Path(root_path).resolve() 114 self.extensions: Optional[Set[str]] = ( 115 {ext.lower() for ext in extensions} 116 if extensions is not None 117 else None # None = all supported formats 118 ) 119 self.recursive = recursive 120 self.encoding = encoding 121 self.skip_empty = skip_empty 122 self._logger = BasicLogger(__name__) 123 124 def load(self) -> List[Document]: 125 """ 126 Scan the root directory and return one Document per matching file. 127 128 Files that raise an error during reading are skipped with a warning 129 printed to stdout so the rest of the load continues uninterrupted. 130 131 Returns: 132 List of Document objects, one per successfully loaded file, 133 sorted by file path for deterministic ordering. 134 135 Raises: 136 FileNotFoundError: If ``root_path`` does not exist. 137 NotADirectoryError: If ``root_path`` is not a directory. 138 """ 139 if not self.root_path.exists(): 140 raise FileNotFoundError( 141 f"Root path does not exist: {self.root_path}" 142 ) 143 if not self.root_path.is_dir(): 144 raise NotADirectoryError( 145 f"Root path is not a directory: {self.root_path}" 146 ) 147 148 pattern = "**/*" if self.recursive else "*" 149 all_files = sorted(p for p in self.root_path.glob(pattern) if p.is_file()) 150 151 documents: List[Document] = [] 152 for file_path in all_files: 153 ext = file_path.suffix.lower() 154 if not self._is_accepted(ext): 155 continue 156 157 try: 158 content = self._read_file(file_path, ext) 159 except Exception as e: 160 self._logger.warning("Skipping file", file=file_path.name, error=str(e)) 161 continue 162 163 if self.skip_empty and not content.strip(): 164 continue 165 166 stat = file_path.stat() 167 doc_id = "fs_" + hashlib.md5(str(file_path).encode()).hexdigest()[:12] 168 documents.append(Document( 169 id=doc_id, 170 content=content.strip(), 171 timestamp=datetime.fromtimestamp(stat.st_mtime), 172 metadata={ 173 "source": str(file_path), 174 "file_name": file_path.name, 175 "extension": ext, 176 "size_bytes": stat.st_size, 177 "modified_at": datetime.fromtimestamp(stat.st_mtime).isoformat(), 178 }, 179 )) 180 181 return documents 182 183 # ── Private helpers ────────────────────────────────────────────────────── 184 185 def _is_accepted(self, ext: str) -> bool: 186 """Return True if this extension should be loaded.""" 187 if self.extensions is not None: 188 return ext in self.extensions 189 # No filter specified: accept all natively supported + optional formats 190 return ext in _NATIVE_TEXT_EXTENSIONS or ext in {".pdf", ".docx"} 191 192 def _read_file(self, path: Path, ext: str) -> str: 193 if ext == ".pdf": 194 return _read_pdf(path) 195 if ext == ".docx": 196 return _read_docx(path) 197 return path.read_text(encoding=self.encoding, errors="replace")
Loads documents from a local directory into Document objects.
Scans a root directory (optionally recursively) and converts each matching
file into a Document. The document id is a stable MD5 hash of the
absolute file path, so re-runs produce consistent IDs for upsert workflows.
Supported formats:
- Native text: .txt, .md, .rst, .csv, .json, .yaml, .yml, .py, .js, .ts, .java, .cpp, .c, .h, .cs, .html, .htm, .xml, .toml, .ini, .cfg, .env
- PDF (.pdf) — requires
pip install pypdf - Word Document (.docx) — requires
pip install python-docx
Metadata keys set on every returned Document:
source Absolute file path as a string
file_name File name including extension
extension Lowercase extension including dot (e.g. ".md")
size_bytes File size in bytes
modified_at Last modification time in ISO 8601 format
Example:
from gmf_forge_ai_data.connectors import FilesystemConnector
connector = FilesystemConnector(
root_path="/data/docs",
extensions=[".txt", ".md", ".pdf"],
recursive=True,
)
docs = connector.load()
# docs is List[Document] — pass to a chunker next
94 def __init__( 95 self, 96 root_path: Union[str, Path], 97 extensions: Optional[List[str]] = None, 98 recursive: bool = True, 99 encoding: str = "utf-8-sig", 100 skip_empty: bool = True, 101 ): 102 """ 103 Args: 104 root_path: Root directory to scan. 105 extensions: Explicit list of extensions to include (e.g. 106 ``[".txt", ".md"]``). Include the leading dot. 107 If ``None``, all supported formats are loaded. 108 recursive: If ``True`` (default), scan subdirectories recursively. 109 encoding: Text encoding for native text files (default ``"utf-8"``). 110 skip_empty: If ``True`` (default), skip files that produce no 111 text content after stripping whitespace. 112 """ 113 self.root_path = Path(root_path).resolve() 114 self.extensions: Optional[Set[str]] = ( 115 {ext.lower() for ext in extensions} 116 if extensions is not None 117 else None # None = all supported formats 118 ) 119 self.recursive = recursive 120 self.encoding = encoding 121 self.skip_empty = skip_empty 122 self._logger = BasicLogger(__name__)
Args:
root_path: Root directory to scan.
extensions: Explicit list of extensions to include (e.g.
[".txt", ".md"]). Include the leading dot.
If None, all supported formats are loaded.
recursive: If True (default), scan subdirectories recursively.
encoding: Text encoding for native text files (default "utf-8").
skip_empty: If True (default), skip files that produce no
text content after stripping whitespace.
124 def load(self) -> List[Document]: 125 """ 126 Scan the root directory and return one Document per matching file. 127 128 Files that raise an error during reading are skipped with a warning 129 printed to stdout so the rest of the load continues uninterrupted. 130 131 Returns: 132 List of Document objects, one per successfully loaded file, 133 sorted by file path for deterministic ordering. 134 135 Raises: 136 FileNotFoundError: If ``root_path`` does not exist. 137 NotADirectoryError: If ``root_path`` is not a directory. 138 """ 139 if not self.root_path.exists(): 140 raise FileNotFoundError( 141 f"Root path does not exist: {self.root_path}" 142 ) 143 if not self.root_path.is_dir(): 144 raise NotADirectoryError( 145 f"Root path is not a directory: {self.root_path}" 146 ) 147 148 pattern = "**/*" if self.recursive else "*" 149 all_files = sorted(p for p in self.root_path.glob(pattern) if p.is_file()) 150 151 documents: List[Document] = [] 152 for file_path in all_files: 153 ext = file_path.suffix.lower() 154 if not self._is_accepted(ext): 155 continue 156 157 try: 158 content = self._read_file(file_path, ext) 159 except Exception as e: 160 self._logger.warning("Skipping file", file=file_path.name, error=str(e)) 161 continue 162 163 if self.skip_empty and not content.strip(): 164 continue 165 166 stat = file_path.stat() 167 doc_id = "fs_" + hashlib.md5(str(file_path).encode()).hexdigest()[:12] 168 documents.append(Document( 169 id=doc_id, 170 content=content.strip(), 171 timestamp=datetime.fromtimestamp(stat.st_mtime), 172 metadata={ 173 "source": str(file_path), 174 "file_name": file_path.name, 175 "extension": ext, 176 "size_bytes": stat.st_size, 177 "modified_at": datetime.fromtimestamp(stat.st_mtime).isoformat(), 178 }, 179 )) 180 181 return documents
Scan the root directory and return one Document per matching file.
Files that raise an error during reading are skipped with a warning printed to stdout so the rest of the load continues uninterrupted.
Returns: List of Document objects, one per successfully loaded file, sorted by file path for deterministic ordering.
Raises:
FileNotFoundError: If root_path does not exist.
NotADirectoryError: If root_path is not a directory.
35class BlobStorageConnector(BaseConnector): 36 """ 37 Loads blobs from an Azure Blob Storage container into Document objects. 38 39 Lists all blobs in the container (optionally filtered by a path prefix), 40 downloads their content, and extracts text. Blobs whose extension is not 41 in the supported set are skipped automatically. 42 43 Supported file types: 44 - Native text: .txt, .md, .rst, .csv, .json, .yaml, .yml, .py, .html, .xml 45 - PDF (.pdf) — requires ``pip install pypdf`` 46 - Word Document (.docx) — requires ``pip install python-docx`` 47 48 Metadata keys set on every returned Document: 49 ``source`` Full blob URL 50 ``file_name`` Last path segment of the blob name 51 ``blob_name`` Full blob path inside the container 52 ``extension`` Lowercase extension including dot 53 ``size_bytes`` Blob content length in bytes 54 ``modified_at`` Last modified time in ISO 8601 format 55 ``container`` Container name 56 57 Example:: 58 59 from gmf_forge_ai_data.connectors import BlobStorageConnector 60 61 connector = BlobStorageConnector( 62 account_name="myaccount", 63 access_key="your-storage-account-access-key", 64 container_name="documents", 65 prefix="knowledge-base/", 66 ) 67 docs = connector.load() 68 """ 69 70 def __init__( 71 self, 72 account_name: str, 73 access_key: str, 74 container_name: str, 75 prefix: str = "", 76 ssl_cert_path: Optional[str] = None, 77 ): 78 """ 79 Args: 80 account_name: Storage account name. 81 access_key: Storage account access key. 82 container_name: Name of the blob container to read from. 83 prefix: Optional blob name prefix for filtering, acting 84 like a folder path (e.g. ``"knowledge-base/"``). 85 Pass ``""`` (default) to list the entire container. 86 ssl_cert_path: Optional path to a CA bundle PEM file for 87 environments with corporate SSL inspection. 88 """ 89 self.account_name = account_name 90 self.access_key = access_key 91 self.container_name = container_name 92 self.connection_string = ( 93 f"DefaultEndpointsProtocol=https;" 94 f"AccountName={account_name};" 95 f"AccountKey={access_key};" 96 f"EndpointSuffix=core.windows.net" 97 ) 98 self.account_url = f"https://{account_name}.blob.core.windows.net" 99 self.prefix = prefix 100 self.ssl_cert_path = ssl_cert_path 101 self._logger = BasicLogger(__name__) 102 103 def load(self) -> List[Document]: 104 """ 105 List and download all supported blobs in the container under ``prefix``. 106 107 Returns: 108 List of Document objects, one per successfully loaded blob. 109 110 Raises: 111 ImportError: If the ``azure-storage-blob`` package is not installed. 112 """ 113 try: 114 from azure.storage.blob import BlobServiceClient # type: ignore 115 except ImportError: 116 raise ImportError( 117 "azure-storage-blob is required for BlobStorageConnector. " 118 "Install it with: pip install azure-storage-blob" 119 ) 120 121 client = BlobServiceClient.from_connection_string(self.connection_string) 122 container_client = client.get_container_client(self.container_name) 123 124 documents: List[Document] = [] 125 for blob in container_client.list_blobs(name_starts_with=self.prefix or None): 126 name: str = blob.name 127 ext = ("." + name.rsplit(".", 1)[-1]).lower() if "." in name else "" 128 if ext not in _SUPPORTED_EXTENSIONS: 129 continue 130 131 try: 132 content = self._download_blob(container_client, name, ext) 133 except Exception as e: 134 self._logger.warning("Skipping blob", blob=name, error=str(e)) 135 continue 136 137 if not content.strip(): 138 continue 139 140 modified = blob.last_modified 141 size = blob.size or 0 142 blob_url = f"{self.account_url}/{self.container_name}/{name}" 143 doc_id = "blob_" + hashlib.md5(blob_url.encode()).hexdigest()[:12] 144 file_name = name.split("/")[-1] 145 146 documents.append(Document( 147 id=doc_id, 148 content=content.strip(), 149 timestamp=( 150 modified if isinstance(modified, datetime) else datetime.now() 151 ), 152 metadata={ 153 "source": blob_url, 154 "file_name": file_name, 155 "blob_name": name, 156 "extension": ext, 157 "size_bytes": size, 158 "modified_at": modified.isoformat() if modified else "", 159 "container": self.container_name, 160 }, 161 )) 162 163 return documents 164 165 # ── Private helpers ────────────────────────────────────────────────────── 166 167 def _download_blob(self, container_client, name: str, ext: str) -> str: 168 raw: bytes = container_client.download_blob(name).readall() 169 if ext == ".pdf": 170 return self._extract_pdf(raw) 171 if ext == ".docx": 172 return self._extract_docx(raw) 173 return raw.decode("utf-8-sig", errors="replace") 174 175 @staticmethod 176 def _extract_pdf(raw: bytes) -> str: 177 try: 178 import pypdf # type: ignore 179 except ImportError: 180 raise ImportError( 181 "pypdf required for PDF support: pip install pypdf" 182 ) 183 import logging 184 logging.getLogger("pypdf").setLevel(logging.ERROR) 185 reader = pypdf.PdfReader(io.BytesIO(raw)) 186 return "\n".join(page.extract_text() or "" for page in reader.pages) 187 188 @staticmethod 189 def _extract_docx(raw: bytes) -> str: 190 try: 191 import docx # type: ignore 192 except ImportError: 193 raise ImportError( 194 "python-docx required for DOCX support: pip install python-docx" 195 ) 196 doc = docx.Document(io.BytesIO(raw)) 197 return "\n".join(p.text for p in doc.paragraphs if p.text.strip())
Loads blobs from an Azure Blob Storage container into Document objects.
Lists all blobs in the container (optionally filtered by a path prefix), downloads their content, and extracts text. Blobs whose extension is not in the supported set are skipped automatically.
Supported file types:
- Native text: .txt, .md, .rst, .csv, .json, .yaml, .yml, .py, .html, .xml
- PDF (.pdf) — requires
pip install pypdf - Word Document (.docx) — requires
pip install python-docx
Metadata keys set on every returned Document:
source Full blob URL
file_name Last path segment of the blob name
blob_name Full blob path inside the container
extension Lowercase extension including dot
size_bytes Blob content length in bytes
modified_at Last modified time in ISO 8601 format
container Container name
Example::
from gmf_forge_ai_data.connectors import BlobStorageConnector
connector = BlobStorageConnector(
account_name="myaccount",
access_key="your-storage-account-access-key",
container_name="documents",
prefix="knowledge-base/",
)
docs = connector.load()
70 def __init__( 71 self, 72 account_name: str, 73 access_key: str, 74 container_name: str, 75 prefix: str = "", 76 ssl_cert_path: Optional[str] = None, 77 ): 78 """ 79 Args: 80 account_name: Storage account name. 81 access_key: Storage account access key. 82 container_name: Name of the blob container to read from. 83 prefix: Optional blob name prefix for filtering, acting 84 like a folder path (e.g. ``"knowledge-base/"``). 85 Pass ``""`` (default) to list the entire container. 86 ssl_cert_path: Optional path to a CA bundle PEM file for 87 environments with corporate SSL inspection. 88 """ 89 self.account_name = account_name 90 self.access_key = access_key 91 self.container_name = container_name 92 self.connection_string = ( 93 f"DefaultEndpointsProtocol=https;" 94 f"AccountName={account_name};" 95 f"AccountKey={access_key};" 96 f"EndpointSuffix=core.windows.net" 97 ) 98 self.account_url = f"https://{account_name}.blob.core.windows.net" 99 self.prefix = prefix 100 self.ssl_cert_path = ssl_cert_path 101 self._logger = BasicLogger(__name__)
Args:
account_name: Storage account name.
access_key: Storage account access key.
container_name: Name of the blob container to read from.
prefix: Optional blob name prefix for filtering, acting
like a folder path (e.g. "knowledge-base/").
Pass "" (default) to list the entire container.
ssl_cert_path: Optional path to a CA bundle PEM file for
environments with corporate SSL inspection.
103 def load(self) -> List[Document]: 104 """ 105 List and download all supported blobs in the container under ``prefix``. 106 107 Returns: 108 List of Document objects, one per successfully loaded blob. 109 110 Raises: 111 ImportError: If the ``azure-storage-blob`` package is not installed. 112 """ 113 try: 114 from azure.storage.blob import BlobServiceClient # type: ignore 115 except ImportError: 116 raise ImportError( 117 "azure-storage-blob is required for BlobStorageConnector. " 118 "Install it with: pip install azure-storage-blob" 119 ) 120 121 client = BlobServiceClient.from_connection_string(self.connection_string) 122 container_client = client.get_container_client(self.container_name) 123 124 documents: List[Document] = [] 125 for blob in container_client.list_blobs(name_starts_with=self.prefix or None): 126 name: str = blob.name 127 ext = ("." + name.rsplit(".", 1)[-1]).lower() if "." in name else "" 128 if ext not in _SUPPORTED_EXTENSIONS: 129 continue 130 131 try: 132 content = self._download_blob(container_client, name, ext) 133 except Exception as e: 134 self._logger.warning("Skipping blob", blob=name, error=str(e)) 135 continue 136 137 if not content.strip(): 138 continue 139 140 modified = blob.last_modified 141 size = blob.size or 0 142 blob_url = f"{self.account_url}/{self.container_name}/{name}" 143 doc_id = "blob_" + hashlib.md5(blob_url.encode()).hexdigest()[:12] 144 file_name = name.split("/")[-1] 145 146 documents.append(Document( 147 id=doc_id, 148 content=content.strip(), 149 timestamp=( 150 modified if isinstance(modified, datetime) else datetime.now() 151 ), 152 metadata={ 153 "source": blob_url, 154 "file_name": file_name, 155 "blob_name": name, 156 "extension": ext, 157 "size_bytes": size, 158 "modified_at": modified.isoformat() if modified else "", 159 "container": self.container_name, 160 }, 161 )) 162 163 return documents
List and download all supported blobs in the container under prefix.
Returns: List of Document objects, one per successfully loaded blob.
Raises:
ImportError: If the azure-storage-blob package is not installed.
47class BaseIndexBuilder(ABC): 48 """ 49 Abstract base class for index builders. 50 51 Each backend (Azure AI Search, Cosmos DB, MongoDB) provides a concrete 52 subclass that exposes backend-specific tuning parameters while sharing 53 the same management interface. 54 """ 55 56 # ------------------------------------------------------------------ # 57 # Core lifecycle # 58 # ------------------------------------------------------------------ # 59 60 @abstractmethod 61 def create_index(self) -> None: 62 """Create the index if it does not already exist. 63 64 Safe to call multiple times — must be a no-op when the index exists. 65 Use this for idempotent provisioning (CI/CD pipelines, first-run 66 setup scripts). 67 """ 68 69 @abstractmethod 70 def create_or_replace_index(self) -> None: 71 """Delete the index if it exists, then create it fresh. 72 73 Use this when you need to apply schema changes that cannot be done 74 via an in-place update (e.g. changing HNSW parameters or adding a 75 new vector field). 76 77 Warning: All documents are lost. Only use in dev/staging or after a 78 full re-ingestion has been planned. 79 """ 80 81 @abstractmethod 82 def delete_index(self) -> None: 83 """Permanently delete the index and all its documents. 84 85 Raises: 86 RuntimeError: If the index does not exist. 87 """ 88 89 @abstractmethod 90 def index_exists(self) -> bool: 91 """Return True if the index currently exists, False otherwise.""" 92 93 @abstractmethod 94 def list_indexes(self) -> List[str]: 95 """Return the names of all indexes on this backend/service."""
Abstract base class for index builders.
Each backend (Azure AI Search, Cosmos DB, MongoDB) provides a concrete subclass that exposes backend-specific tuning parameters while sharing the same management interface.
60 @abstractmethod 61 def create_index(self) -> None: 62 """Create the index if it does not already exist. 63 64 Safe to call multiple times — must be a no-op when the index exists. 65 Use this for idempotent provisioning (CI/CD pipelines, first-run 66 setup scripts). 67 """
Create the index if it does not already exist.
Safe to call multiple times — must be a no-op when the index exists. Use this for idempotent provisioning (CI/CD pipelines, first-run setup scripts).
69 @abstractmethod 70 def create_or_replace_index(self) -> None: 71 """Delete the index if it exists, then create it fresh. 72 73 Use this when you need to apply schema changes that cannot be done 74 via an in-place update (e.g. changing HNSW parameters or adding a 75 new vector field). 76 77 Warning: All documents are lost. Only use in dev/staging or after a 78 full re-ingestion has been planned. 79 """
Delete the index if it exists, then create it fresh.
Use this when you need to apply schema changes that cannot be done via an in-place update (e.g. changing HNSW parameters or adding a new vector field).
Warning: All documents are lost. Only use in dev/staging or after a full re-ingestion has been planned.
81 @abstractmethod 82 def delete_index(self) -> None: 83 """Permanently delete the index and all its documents. 84 85 Raises: 86 RuntimeError: If the index does not exist. 87 """
Permanently delete the index and all its documents.
Raises: RuntimeError: If the index does not exist.
91class AzureAISearchIndexBuilder(BaseIndexBuilder): 92 """ 93 Builds and manages Azure AI Search indexes with full developer control. 94 95 The builder owns *schema* concerns only. Document operations (add, 96 search, delete) belong to ``AzureAISearchVectorStore``. 97 98 Parameters 99 ---------- 100 endpoint: 101 Azure AI Search service endpoint URL. 102 api_key: 103 Azure AI Search admin API key. Use for local development or 104 when managed identity is not available. 105 token_provider: 106 Zero-argument callable that returns a bearer token string. 107 Use for managed identity / workload identity scenarios. 108 The callable must request the **Azure AI Search** scope:: 109 110 from azure.identity import DefaultAzureCredential, get_bearer_token_provider 111 token_provider = get_bearer_token_provider( 112 DefaultAzureCredential(), 113 "https://search.azure.com/.default" 114 ) 115 116 Note: this scope is different from Azure OpenAI / Cognitive Services 117 (``https://cognitiveservices.azure.com/.default``) — each service 118 requires its own token_provider. 119 index_name: 120 Name of the index to create / manage. 121 embedding_dimension: 122 Number of dimensions in the embedding vectors (must match the 123 embedding model — e.g. 1536 for text-embedding-ada-002, 3072 for 124 text-embedding-3-large). 125 document_type: 126 Optional Document subclass. When provided, all dataclass fields 127 not in the base Document are automatically added as indexed fields 128 (filterable, sortable, facetable where appropriate). 129 hnsw_m: 130 Number of bi-directional links created per node. Higher = better 131 recall but more memory. Typical range 4–16. Default: 4. 132 hnsw_ef_construction: 133 Size of the candidate list during index construction. Higher = 134 better recall, slower build time. Typical range 100–800. 135 Default: 400. 136 hnsw_ef_search: 137 Size of the candidate list during search. Higher = better recall, 138 slower queries. Typical range 100–1000. Default: 500. 139 metric: 140 Similarity metric. One of ``"cosine"``, ``"euclidean"``, 141 ``"dotProduct"``. Default: ``"cosine"``. 142 ssl_cert_path: 143 Optional path to a PEM certificate bundle for corporate SSL 144 inspection proxies. Sets ``REQUESTS_CA_BUNDLE`` and 145 ``SSL_CERT_FILE`` environment variables before building the client. 146 semantic_config: 147 Optional semantic search configuration. When provided the index is 148 provisioned with a ``SemanticSearch`` configuration that enables 149 Azure AI semantic reranking (``BoostedRerankerScore``). 150 151 Expected keys: 152 153 - ``name`` (str) — semantic config name (default 154 ``"default-semantic-config"``) 155 - ``title_field`` (str, optional) — field used as the document title 156 - ``content_fields`` (list[str]) — primary body content fields 157 - ``keyword_fields`` (list[str], optional) — keyword/facet fields 158 159 Example:: 160 161 { 162 "name": "policyhub-semantic-config", 163 "title_field": "document_name", 164 "content_fields": ["content"], 165 "keyword_fields": ["language", "locale", "source"], 166 } 167 """ 168 169 def __init__( 170 self, 171 endpoint: str, 172 index_name: str, 173 api_key: Optional[str] = None, 174 token_provider: Optional[Callable[[], str]] = None, 175 embedding_dimension: int = 1536, 176 document_type: Type[Document] = Document, 177 hnsw_m: int = 4, 178 hnsw_ef_construction: int = 400, 179 hnsw_ef_search: int = 500, 180 metric: str = "cosine", 181 ssl_cert_path: Optional[str] = None, 182 semantic_config: Optional[dict] = None, 183 ) -> None: 184 self.index_name = index_name 185 self.embedding_dimension = embedding_dimension 186 self.document_type = document_type 187 self.hnsw_m = hnsw_m 188 self.hnsw_ef_construction = hnsw_ef_construction 189 self.hnsw_ef_search = hnsw_ef_search 190 self.metric = metric 191 self.semantic_config = semantic_config 192 193 if not api_key and not token_provider: 194 raise ValueError( 195 "Either api_key or token_provider must be supplied to AzureAISearchIndexBuilder." 196 ) 197 198 if ssl_cert_path: 199 import os as _os 200 _os.environ.setdefault("REQUESTS_CA_BUNDLE", ssl_cert_path) 201 _os.environ.setdefault("SSL_CERT_FILE", ssl_cert_path) 202 203 if token_provider: 204 credential = _TokenProviderCredential(token_provider) 205 else: 206 credential = AzureKeyCredential(api_key) 207 self._index_client = SearchIndexClient( 208 endpoint=endpoint, 209 credential=credential, 210 ) 211 212 # ------------------------------------------------------------------ # 213 # BaseIndexBuilder interface # 214 # ------------------------------------------------------------------ # 215 216 def create_index(self) -> None: 217 """Create the index if it does not already exist (idempotent).""" 218 if self.index_exists(): 219 logger.info("Index already exists — skipping creation", index=self.index_name) 220 return 221 self._create(self.index_name) 222 logger.info("Index created successfully", index=self.index_name) 223 224 def create_or_replace_index(self) -> None: 225 """Delete the existing index (if any) then create it fresh. 226 227 Warning: All documents are permanently lost. 228 """ 229 if self.index_exists(): 230 self._index_client.delete_index(self.index_name) 231 logger.info("Index deleted for replacement", index=self.index_name) 232 self._create(self.index_name) 233 logger.info("Index created (replaced)", index=self.index_name) 234 235 def delete_index(self) -> None: 236 """Permanently delete the index and all its documents. 237 238 Raises: 239 RuntimeError: If the index does not exist. 240 """ 241 if not self.index_exists(): 242 raise RuntimeError( 243 f"Cannot delete index '{self.index_name}': it does not exist." 244 ) 245 self._index_client.delete_index(self.index_name) 246 logger.info("Index deleted", index=self.index_name) 247 248 def index_exists(self) -> bool: 249 """Return True if the index currently exists.""" 250 try: 251 self._index_client.get_index(self.index_name) 252 return True 253 except ResourceNotFoundError: 254 return False 255 except Exception: 256 # Treat any other error as non-existence to keep callers safe 257 return False 258 259 def list_indexes(self) -> List[str]: 260 """Return the names of all indexes on this Azure AI Search service.""" 261 return [idx.name for idx in self._index_client.list_indexes()] 262 263 # ------------------------------------------------------------------ # 264 # Internal helpers # 265 # ------------------------------------------------------------------ # 266 267 def _build_fields(self) -> list: 268 """Build the Azure Search field list from base + document_type fields.""" 269 fields = [ 270 SimpleField( 271 name="id", 272 type=SearchFieldDataType.String, 273 key=True, 274 filterable=True, 275 ), 276 SearchableField( 277 name="content", 278 type=SearchFieldDataType.String, 279 searchable=True, 280 ), 281 SearchField( 282 name="embedding", 283 type=SearchFieldDataType.Collection(SearchFieldDataType.Single), 284 searchable=True, 285 vector_search_dimensions=self.embedding_dimension, 286 vector_search_profile_name="default-vector-profile", 287 ), 288 SimpleField( 289 name="timestamp", 290 type=SearchFieldDataType.DateTimeOffset, 291 filterable=True, 292 sortable=True, 293 ), 294 # Stores serialised metadata dict and any non-indexed custom fields 295 SimpleField( 296 name="document_data", 297 type=SearchFieldDataType.String, 298 filterable=False, 299 ), 300 ] 301 302 # Infer custom fields from the document_type dataclass 303 if dataclasses.is_dataclass(self.document_type): 304 base_field_names = {"id", "content", "embedding", "timestamp", "metadata"} 305 for field in dataclasses.fields(self.document_type): 306 if field.name in base_field_names: 307 continue 308 azure_type = self._map_python_type(field.type) 309 if azure_type is None: 310 continue 311 scalar_types = { 312 SearchFieldDataType.String, 313 SearchFieldDataType.Int32, 314 SearchFieldDataType.Int64, 315 SearchFieldDataType.Double, 316 SearchFieldDataType.DateTimeOffset, 317 SearchFieldDataType.Boolean, 318 } 319 # Per-field overrides from dataclass field metadata. 320 # Falls back to the original defaults when not specified, so 321 # existing document types without metadata are unaffected. 322 meta = field.metadata 323 is_searchable = meta.get("searchable", False) 324 is_filterable = meta.get("filterable", True) 325 is_sortable = meta.get("sortable", azure_type in scalar_types) 326 is_facetable = meta.get( 327 "facetable", 328 azure_type in {SearchFieldDataType.String, SearchFieldDataType.Boolean}, 329 ) 330 if is_searchable: 331 fields.append( 332 SearchableField( 333 name=field.name, 334 filterable=is_filterable, 335 sortable=is_sortable, 336 facetable=is_facetable, 337 ) 338 ) 339 else: 340 fields.append( 341 SimpleField( 342 name=field.name, 343 type=azure_type, 344 filterable=is_filterable, 345 sortable=is_sortable, 346 facetable=is_facetable, 347 ) 348 ) 349 logger.info( 350 "Added indexed field", 351 field=field.name, 352 azure_type=str(azure_type), 353 searchable=is_searchable, 354 ) 355 356 return fields 357 358 def _create(self, index_name: str) -> None: 359 """Internal: build and submit the index definition to Azure.""" 360 fields = self._build_fields() 361 362 vector_search = VectorSearch( 363 algorithms=[ 364 HnswAlgorithmConfiguration( 365 name="default-hnsw", 366 parameters={ 367 "m": self.hnsw_m, 368 "efConstruction": self.hnsw_ef_construction, 369 "efSearch": self.hnsw_ef_search, 370 "metric": self.metric, 371 }, 372 ) 373 ], 374 profiles=[ 375 VectorSearchProfile( 376 name="default-vector-profile", 377 algorithm_configuration_name="default-hnsw", 378 ) 379 ], 380 ) 381 382 semantic_search = None 383 if self.semantic_config: 384 sc = self.semantic_config 385 title_field = ( 386 SemanticField(field_name=sc["title_field"]) 387 if sc.get("title_field") else None 388 ) 389 semantic_search = SemanticSearch( 390 configurations=[ 391 SemanticConfiguration( 392 name=sc.get("name", "default-semantic-config"), 393 prioritized_fields=SemanticPrioritizedFields( 394 title_field=title_field, 395 content_fields=[ 396 SemanticField(field_name=f) 397 for f in sc.get("content_fields", []) 398 ], 399 keywords_fields=[ 400 SemanticField(field_name=f) 401 for f in sc.get("keyword_fields", []) 402 ], 403 ), 404 ) 405 ] 406 ) 407 408 index = SearchIndex( 409 name=index_name, 410 fields=fields, 411 vector_search=vector_search, 412 semantic_search=semantic_search, 413 ) 414 415 self._index_client.create_index(index) 416 logger.info( 417 "Azure AI Search index provisioned", 418 index=index_name, 419 dim=self.embedding_dimension, 420 metric=self.metric, 421 hnsw_m=self.hnsw_m, 422 ef_construction=self.hnsw_ef_construction, 423 ef_search=self.hnsw_ef_search, 424 fields=len(fields), 425 ) 426 427 @staticmethod 428 def _map_python_type(python_type) -> Optional[SearchFieldDataType]: 429 """Map a Python / dataclass field type to an Azure Search field type.""" 430 _map = { 431 str: SearchFieldDataType.String, 432 int: SearchFieldDataType.Int64, 433 float: SearchFieldDataType.Double, 434 bool: SearchFieldDataType.Boolean, 435 datetime: SearchFieldDataType.DateTimeOffset, 436 } 437 438 # Handle Optional[X] → extract X 439 if hasattr(python_type, "__origin__"): 440 args = getattr(python_type, "__args__", ()) 441 for arg in args: 442 if arg is type(None): 443 continue 444 return _map.get(arg, SearchFieldDataType.String) 445 446 # Handle string annotations 447 if isinstance(python_type, str): 448 s = python_type.lower() 449 if "datetime" in s: 450 return SearchFieldDataType.DateTimeOffset 451 if "int" in s: 452 return SearchFieldDataType.Int64 453 if "float" in s or "double" in s: 454 return SearchFieldDataType.Double 455 if "bool" in s: 456 return SearchFieldDataType.Boolean 457 return SearchFieldDataType.String 458 459 return _map.get(python_type, SearchFieldDataType.String)
Builds and manages Azure AI Search indexes with full developer control.
The builder owns schema concerns only. Document operations (add,
search, delete) belong to AzureAISearchVectorStore.
Parameters
endpoint: Azure AI Search service endpoint URL. api_key: Azure AI Search admin API key. Use for local development or when managed identity is not available. token_provider: Zero-argument callable that returns a bearer token string. Use for managed identity / workload identity scenarios. The callable must request the Azure AI Search scope::
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
token_provider = get_bearer_token_provider(
DefaultAzureCredential(),
"https://search.azure.com/.default"
)
Note: this scope is different from Azure OpenAI / Cognitive Services
(``https://cognitiveservices.azure.com/.default``) — each service
requires its own token_provider.
index_name:
Name of the index to create / manage.
embedding_dimension:
Number of dimensions in the embedding vectors (must match the
embedding model — e.g. 1536 for text-embedding-ada-002, 3072 for
text-embedding-3-large).
document_type:
Optional Document subclass. When provided, all dataclass fields
not in the base Document are automatically added as indexed fields
(filterable, sortable, facetable where appropriate).
hnsw_m:
Number of bi-directional links created per node. Higher = better
recall but more memory. Typical range 4–16. Default: 4.
hnsw_ef_construction:
Size of the candidate list during index construction. Higher =
better recall, slower build time. Typical range 100–800.
Default: 400.
hnsw_ef_search:
Size of the candidate list during search. Higher = better recall,
slower queries. Typical range 100–1000. Default: 500.
metric:
Similarity metric. One of "cosine", "euclidean",
"dotProduct". Default: "cosine".
ssl_cert_path:
Optional path to a PEM certificate bundle for corporate SSL
inspection proxies. Sets REQUESTS_CA_BUNDLE and
SSL_CERT_FILE environment variables before building the client.
semantic_config:
Optional semantic search configuration. When provided the index is
provisioned with a SemanticSearch configuration that enables
Azure AI semantic reranking (BoostedRerankerScore).
Expected keys:
- ``name`` (str) — semantic config name (default
``"default-semantic-config"``)
- ``title_field`` (str, optional) — field used as the document title
- ``content_fields`` (list[str]) — primary body content fields
- ``keyword_fields`` (list[str], optional) — keyword/facet fields
Example::
{
"name": "policyhub-semantic-config",
"title_field": "document_name",
"content_fields": ["content"],
"keyword_fields": ["language", "locale", "source"],
}
169 def __init__( 170 self, 171 endpoint: str, 172 index_name: str, 173 api_key: Optional[str] = None, 174 token_provider: Optional[Callable[[], str]] = None, 175 embedding_dimension: int = 1536, 176 document_type: Type[Document] = Document, 177 hnsw_m: int = 4, 178 hnsw_ef_construction: int = 400, 179 hnsw_ef_search: int = 500, 180 metric: str = "cosine", 181 ssl_cert_path: Optional[str] = None, 182 semantic_config: Optional[dict] = None, 183 ) -> None: 184 self.index_name = index_name 185 self.embedding_dimension = embedding_dimension 186 self.document_type = document_type 187 self.hnsw_m = hnsw_m 188 self.hnsw_ef_construction = hnsw_ef_construction 189 self.hnsw_ef_search = hnsw_ef_search 190 self.metric = metric 191 self.semantic_config = semantic_config 192 193 if not api_key and not token_provider: 194 raise ValueError( 195 "Either api_key or token_provider must be supplied to AzureAISearchIndexBuilder." 196 ) 197 198 if ssl_cert_path: 199 import os as _os 200 _os.environ.setdefault("REQUESTS_CA_BUNDLE", ssl_cert_path) 201 _os.environ.setdefault("SSL_CERT_FILE", ssl_cert_path) 202 203 if token_provider: 204 credential = _TokenProviderCredential(token_provider) 205 else: 206 credential = AzureKeyCredential(api_key) 207 self._index_client = SearchIndexClient( 208 endpoint=endpoint, 209 credential=credential, 210 )
216 def create_index(self) -> None: 217 """Create the index if it does not already exist (idempotent).""" 218 if self.index_exists(): 219 logger.info("Index already exists — skipping creation", index=self.index_name) 220 return 221 self._create(self.index_name) 222 logger.info("Index created successfully", index=self.index_name)
Create the index if it does not already exist (idempotent).
224 def create_or_replace_index(self) -> None: 225 """Delete the existing index (if any) then create it fresh. 226 227 Warning: All documents are permanently lost. 228 """ 229 if self.index_exists(): 230 self._index_client.delete_index(self.index_name) 231 logger.info("Index deleted for replacement", index=self.index_name) 232 self._create(self.index_name) 233 logger.info("Index created (replaced)", index=self.index_name)
Delete the existing index (if any) then create it fresh.
Warning: All documents are permanently lost.
235 def delete_index(self) -> None: 236 """Permanently delete the index and all its documents. 237 238 Raises: 239 RuntimeError: If the index does not exist. 240 """ 241 if not self.index_exists(): 242 raise RuntimeError( 243 f"Cannot delete index '{self.index_name}': it does not exist." 244 ) 245 self._index_client.delete_index(self.index_name) 246 logger.info("Index deleted", index=self.index_name)
Permanently delete the index and all its documents.
Raises: RuntimeError: If the index does not exist.
248 def index_exists(self) -> bool: 249 """Return True if the index currently exists.""" 250 try: 251 self._index_client.get_index(self.index_name) 252 return True 253 except ResourceNotFoundError: 254 return False 255 except Exception: 256 # Treat any other error as non-existence to keep callers safe 257 return False
Return True if the index currently exists.
67class CosmosDBIndexBuilder(BaseIndexBuilder): 68 """ 69 Builds and manages Cosmos DB databases and containers for vector search. 70 71 The builder owns *schema / provisioning* concerns only. Document 72 operations (add, search, delete) belong to ``AzureCosmosDBVectorStore``. 73 74 Parameters 75 ---------- 76 endpoint: 77 Cosmos DB account endpoint URL. 78 api_key: 79 Cosmos DB account primary or secondary key. 80 database_name: 81 Name of the Cosmos DB database to create / manage. 82 container_name: 83 Name of the container to create / manage. 84 embedding_dimension: 85 Number of dimensions in the embedding vectors. 86 distance_function: 87 Vector similarity function. ``"cosine"`` (default), ``"euclidean"``, 88 or ``"dotproduct"``. 89 vector_index_type: 90 Index structure for vector search. ``"quantizedFlat"`` (default, 91 lower memory) or ``"diskANN"`` (higher recall on large datasets). 92 partition_key: 93 Cosmos DB partition key path. Default: ``"/id"``. 94 throughput: 95 Manual RU/s throughput for the container. ``None`` uses the Cosmos 96 DB account default. Ignored if the container already exists. 97 ssl_cert_path: 98 Optional path to a PEM certificate bundle for corporate SSL 99 inspection proxies. 100 """ 101 102 def __init__( 103 self, 104 endpoint: str, 105 api_key: str, 106 database_name: str, 107 container_name: str, 108 embedding_dimension: int = 1536, 109 distance_function: DistanceFunction = "cosine", 110 vector_index_type: VectorIndexType = "quantizedFlat", 111 partition_key: str = "/id", 112 throughput: Optional[int] = None, 113 ssl_cert_path: Optional[str] = None, 114 ) -> None: 115 self.database_name = database_name 116 self.container_name = container_name 117 self.embedding_dimension = embedding_dimension 118 self.distance_function = distance_function 119 self.vector_index_type = vector_index_type 120 self.partition_key = partition_key 121 self.throughput = throughput 122 self._ssl_cert_path = ssl_cert_path 123 self._endpoint = endpoint 124 self._api_key = api_key 125 126 self._client = self._build_client(endpoint, api_key, ssl_cert_path) 127 128 # ------------------------------------------------------------------ # 129 # BaseIndexBuilder interface # 130 # ------------------------------------------------------------------ # 131 132 def create_index(self) -> None: 133 """Create the Cosmos DB database and container if they don't exist. 134 135 Safe to call multiple times — no-op if both already exist. 136 """ 137 self._ensure_database() 138 if self.index_exists(): 139 logger.info( 140 "Cosmos DB container already exists — skipping creation", 141 database=self.database_name, 142 container=self.container_name, 143 ) 144 return 145 self._create_container() 146 logger.info( 147 "Cosmos DB container created", 148 database=self.database_name, 149 container=self.container_name, 150 dim=self.embedding_dimension, 151 distance_function=self.distance_function, 152 vector_index_type=self.vector_index_type, 153 ) 154 155 def create_or_replace_index(self) -> None: 156 """Delete the container if it exists then create it fresh. 157 158 Warning: All documents are permanently lost. 159 """ 160 self._ensure_database() 161 if self.index_exists(): 162 db = self._client.get_database_client(self.database_name) 163 db.delete_container(self.container_name) 164 logger.info( 165 "Cosmos DB container deleted for replacement", 166 database=self.database_name, 167 container=self.container_name, 168 ) 169 self._create_container() 170 logger.info( 171 "Cosmos DB container created (replaced)", 172 database=self.database_name, 173 container=self.container_name, 174 ) 175 176 def delete_index(self) -> None: 177 """Permanently delete the container and all its documents. 178 179 Raises: 180 RuntimeError: If the container does not exist. 181 """ 182 if not self.index_exists(): 183 raise RuntimeError( 184 f"Cannot delete container '{self.database_name}/{self.container_name}': " 185 "it does not exist." 186 ) 187 db = self._client.get_database_client(self.database_name) 188 db.delete_container(self.container_name) 189 logger.info( 190 "Cosmos DB container deleted", 191 database=self.database_name, 192 container=self.container_name, 193 ) 194 195 def index_exists(self) -> bool: 196 """Return True if the container currently exists.""" 197 try: 198 db = self._client.get_database_client(self.database_name) 199 db.get_container_client(self.container_name).read() 200 return True 201 except Exception: 202 return False 203 204 def list_indexes(self) -> List[str]: 205 """Return the names of all containers in the database.""" 206 try: 207 db = self._client.get_database_client(self.database_name) 208 return [c["id"] for c in db.list_containers()] 209 except Exception: 210 return [] 211 212 # ------------------------------------------------------------------ # 213 # Internal helpers # 214 # ------------------------------------------------------------------ # 215 216 @staticmethod 217 def _build_client(endpoint: str, api_key: str, ssl_cert_path: Optional[str]): 218 """Build a CosmosClient, optionally with a corporate SSL bundle.""" 219 from azure.cosmos import CosmosClient 220 kwargs = {"url": endpoint, "credential": api_key} 221 if ssl_cert_path: 222 import ssl, os 223 os.environ.setdefault("REQUESTS_CA_BUNDLE", ssl_cert_path) 224 os.environ.setdefault("SSL_CERT_FILE", ssl_cert_path) 225 return CosmosClient(**kwargs) 226 227 def _ensure_database(self) -> None: 228 """Create the database if it does not already exist.""" 229 self._client.create_database_if_not_exists(self.database_name) 230 231 def _create_container(self) -> None: 232 """Create the container with vector embedding and indexing policies.""" 233 from azure.cosmos import PartitionKey 234 from azure.cosmos.exceptions import CosmosHttpResponseError 235 236 db = self._client.get_database_client(self.database_name) 237 238 vector_embedding_policy = { 239 "vectorEmbeddings": [ 240 { 241 "path": "/embedding", 242 "dataType": "float32", 243 "distanceFunction": self.distance_function, 244 "dimensions": self.embedding_dimension, 245 } 246 ] 247 } 248 249 indexing_policy = { 250 "includedPaths": [{"path": "/*"}], 251 "excludedPaths": [{"path": "/embedding/*"}], 252 "vectorIndexes": [ 253 {"path": "/embedding", "type": self.vector_index_type} 254 ], 255 } 256 257 kwargs = dict( 258 id=self.container_name, 259 partition_key=PartitionKey(path=self.partition_key), 260 vector_embedding_policy=vector_embedding_policy, 261 indexing_policy=indexing_policy, 262 ) 263 if self.throughput is not None: 264 kwargs["offer_throughput"] = self.throughput 265 266 try: 267 db.create_container(**kwargs) 268 except CosmosHttpResponseError as exc: 269 if "Vector Policy" in str(exc) or "capability" in str(exc): 270 raise RuntimeError( 271 "Vector Search capability is not enabled on this Cosmos DB account. " 272 "Enable via: az cosmosdb update --resource-group <RG> " 273 "--name <ACCOUNT> --capabilities EnableNoSQLVectorSearch" 274 ) from exc 275 raise
Builds and manages Cosmos DB databases and containers for vector search.
The builder owns schema / provisioning concerns only. Document
operations (add, search, delete) belong to AzureCosmosDBVectorStore.
Parameters
endpoint:
Cosmos DB account endpoint URL.
api_key:
Cosmos DB account primary or secondary key.
database_name:
Name of the Cosmos DB database to create / manage.
container_name:
Name of the container to create / manage.
embedding_dimension:
Number of dimensions in the embedding vectors.
distance_function:
Vector similarity function. "cosine" (default), "euclidean",
or "dotproduct".
vector_index_type:
Index structure for vector search. "quantizedFlat" (default,
lower memory) or "diskANN" (higher recall on large datasets).
partition_key:
Cosmos DB partition key path. Default: "/id".
throughput:
Manual RU/s throughput for the container. None uses the Cosmos
DB account default. Ignored if the container already exists.
ssl_cert_path:
Optional path to a PEM certificate bundle for corporate SSL
inspection proxies.
102 def __init__( 103 self, 104 endpoint: str, 105 api_key: str, 106 database_name: str, 107 container_name: str, 108 embedding_dimension: int = 1536, 109 distance_function: DistanceFunction = "cosine", 110 vector_index_type: VectorIndexType = "quantizedFlat", 111 partition_key: str = "/id", 112 throughput: Optional[int] = None, 113 ssl_cert_path: Optional[str] = None, 114 ) -> None: 115 self.database_name = database_name 116 self.container_name = container_name 117 self.embedding_dimension = embedding_dimension 118 self.distance_function = distance_function 119 self.vector_index_type = vector_index_type 120 self.partition_key = partition_key 121 self.throughput = throughput 122 self._ssl_cert_path = ssl_cert_path 123 self._endpoint = endpoint 124 self._api_key = api_key 125 126 self._client = self._build_client(endpoint, api_key, ssl_cert_path)
132 def create_index(self) -> None: 133 """Create the Cosmos DB database and container if they don't exist. 134 135 Safe to call multiple times — no-op if both already exist. 136 """ 137 self._ensure_database() 138 if self.index_exists(): 139 logger.info( 140 "Cosmos DB container already exists — skipping creation", 141 database=self.database_name, 142 container=self.container_name, 143 ) 144 return 145 self._create_container() 146 logger.info( 147 "Cosmos DB container created", 148 database=self.database_name, 149 container=self.container_name, 150 dim=self.embedding_dimension, 151 distance_function=self.distance_function, 152 vector_index_type=self.vector_index_type, 153 )
Create the Cosmos DB database and container if they don't exist.
Safe to call multiple times — no-op if both already exist.
155 def create_or_replace_index(self) -> None: 156 """Delete the container if it exists then create it fresh. 157 158 Warning: All documents are permanently lost. 159 """ 160 self._ensure_database() 161 if self.index_exists(): 162 db = self._client.get_database_client(self.database_name) 163 db.delete_container(self.container_name) 164 logger.info( 165 "Cosmos DB container deleted for replacement", 166 database=self.database_name, 167 container=self.container_name, 168 ) 169 self._create_container() 170 logger.info( 171 "Cosmos DB container created (replaced)", 172 database=self.database_name, 173 container=self.container_name, 174 )
Delete the container if it exists then create it fresh.
Warning: All documents are permanently lost.
176 def delete_index(self) -> None: 177 """Permanently delete the container and all its documents. 178 179 Raises: 180 RuntimeError: If the container does not exist. 181 """ 182 if not self.index_exists(): 183 raise RuntimeError( 184 f"Cannot delete container '{self.database_name}/{self.container_name}': " 185 "it does not exist." 186 ) 187 db = self._client.get_database_client(self.database_name) 188 db.delete_container(self.container_name) 189 logger.info( 190 "Cosmos DB container deleted", 191 database=self.database_name, 192 container=self.container_name, 193 )
Permanently delete the container and all its documents.
Raises: RuntimeError: If the container does not exist.
195 def index_exists(self) -> bool: 196 """Return True if the container currently exists.""" 197 try: 198 db = self._client.get_database_client(self.database_name) 199 db.get_container_client(self.container_name).read() 200 return True 201 except Exception: 202 return False
Return True if the container currently exists.
204 def list_indexes(self) -> List[str]: 205 """Return the names of all containers in the database.""" 206 try: 207 db = self._client.get_database_client(self.database_name) 208 return [c["id"] for c in db.list_containers()] 209 except Exception: 210 return []
Return the names of all containers in the database.
63class MongoDBIndexBuilder(BaseIndexBuilder): 64 """ 65 Builds and manages Atlas Vector Search and text indexes for a MongoDB 66 collection. 67 68 The builder owns *schema / provisioning* concerns only. Document 69 operations belong to ``MongoDBVectorStore``. 70 71 Parameters 72 ---------- 73 connection_string: 74 MongoDB Atlas connection string, e.g. 75 ``"mongodb+srv://user:pass@cluster.mongodb.net/"``. 76 database_name: 77 Name of the MongoDB database. 78 collection_name: 79 Name of the collection to index. 80 embedding_dimension: 81 Number of dimensions in the embedding vectors. 82 document_type: 83 Document dataclass whose *extra* fields will be added as Atlas 84 filter fields (fields other than ``id``, ``content``, ``embedding``, 85 ``timestamp``, and ``metadata``). 86 vector_index_name: 87 Name of the Atlas Vector Search index. Must match the 88 ``vector_index_name`` used when constructing ``MongoDBVectorStore``. 89 Default: ``"vector_index"``. 90 similarity: 91 Similarity metric for the vector index. ``"cosine"`` (default), 92 ``"euclidean"``, or ``"dotProduct"``. 93 extra_filter_paths: 94 Additional document paths to register as Atlas filter fields beyond 95 those inferred from *document_type*. Useful for arbitrary metadata 96 fields stored in the ``metadata`` sub-document. 97 text_index_fields: 98 Fields to include in the MongoDB ``$text`` full-text index. 99 Default: ``["content"]``. 100 ssl_cert_path: 101 Path to a CA certificate bundle (PEM) for TLS verification in 102 corporate environments with custom certificate authorities. 103 """ 104 105 _BASE_KEYS = frozenset({"id", "content", "embedding", "timestamp", "metadata"}) 106 107 def __init__( 108 self, 109 connection_string: str, 110 database_name: str, 111 collection_name: str, 112 embedding_dimension: int = 1536, 113 document_type: Type[Document] = Document, 114 vector_index_name: str = "vector_index", 115 similarity: Similarity = "cosine", 116 extra_filter_paths: Optional[List[str]] = None, 117 text_index_fields: Optional[List[str]] = None, 118 ssl_cert_path: Optional[str] = None, 119 ) -> None: 120 try: 121 import pymongo # noqa: F401 122 except ImportError as exc: 123 raise ImportError( 124 "pymongo is required for MongoDBIndexBuilder. " 125 "Install it with: pip install pymongo" 126 ) from exc 127 128 import pymongo 129 130 self.database_name = database_name 131 self.collection_name = collection_name 132 self.embedding_dimension = embedding_dimension 133 self.document_type = document_type 134 self.vector_index_name = vector_index_name 135 self.similarity = similarity 136 self.extra_filter_paths: List[str] = extra_filter_paths or [] 137 self.text_index_fields: List[str] = text_index_fields or ["content"] 138 139 client_kwargs: Dict[str, Any] = {} 140 if ssl_cert_path: 141 client_kwargs["tlsCAFile"] = ssl_cert_path 142 143 self._client = pymongo.MongoClient(connection_string, **client_kwargs) 144 self._db = self._client[database_name] 145 self._collection = self._db[collection_name] 146 147 # ------------------------------------------------------------------ # 148 # BaseIndexBuilder interface # 149 # ------------------------------------------------------------------ # 150 151 def create_index(self) -> None: 152 """Create the Atlas Vector Search index and the text index if they 153 don't already exist. 154 155 Safe to call multiple times — each component is idempotent. 156 """ 157 self._create_vector_index(replace=False) 158 self._create_text_index() 159 160 def create_or_replace_index(self) -> None: 161 """Drop the Atlas Vector Search index if it exists, then create it 162 fresh alongside the text index. 163 164 The text index is not recreated if it already exists (text indexes 165 are schema-agnostic and need no replacement). 166 167 Warning: Existing vector index data is lost. 168 """ 169 import time 170 171 if self.index_exists(): 172 self._collection.drop_search_index(self.vector_index_name) 173 logger.info( 174 "Atlas Vector Search index dropped for replacement", 175 index=self.vector_index_name, 176 ) 177 # Atlas drops are asynchronous — poll until the index is gone 178 # before submitting the creation request with the same name. 179 deadline = time.monotonic() + 60 180 while time.monotonic() < deadline: 181 if self.vector_index_name not in [ 182 idx["name"] for idx in self._collection.list_search_indexes() 183 ]: 184 break 185 time.sleep(2) 186 else: 187 raise RuntimeError( 188 f"Timed out waiting for Atlas to finish dropping index " 189 f"'{self.vector_index_name}'. Try again in a moment." 190 ) 191 self._create_vector_index(replace=True) 192 self._create_text_index() 193 194 def delete_index(self) -> None: 195 """Drop the Atlas Vector Search index. 196 197 The MongoDB text index (``content_text``) is left in place because 198 it is independent of vector dimensionality. 199 200 Raises: 201 RuntimeError: If the vector index does not exist. 202 """ 203 if not self.index_exists(): 204 raise RuntimeError( 205 f"Cannot delete vector index '{self.vector_index_name}' on " 206 f"'{self.database_name}.{self.collection_name}': it does not exist." 207 ) 208 self._collection.drop_search_index(self.vector_index_name) 209 logger.info( 210 "Atlas Vector Search index deleted", 211 index=self.vector_index_name, 212 database=self.database_name, 213 collection=self.collection_name, 214 ) 215 216 def index_exists(self) -> bool: 217 """Return True if the Atlas Vector Search index currently exists.""" 218 existing = [idx["name"] for idx in self._collection.list_search_indexes()] 219 return self.vector_index_name in existing 220 221 def list_indexes(self) -> List[str]: 222 """Return the names of all Atlas Vector Search indexes on the collection.""" 223 return [idx["name"] for idx in self._collection.list_search_indexes()] 224 225 # ------------------------------------------------------------------ # 226 # Additional helpers # 227 # ------------------------------------------------------------------ # 228 229 def list_text_indexes(self) -> List[str]: 230 """Return the names of all standard MongoDB indexes on the collection.""" 231 return [idx["name"] for idx in self._collection.list_indexes()] 232 233 # ------------------------------------------------------------------ # 234 # Internal helpers # 235 # ------------------------------------------------------------------ # 236 237 def _build_filter_fields(self) -> List[Dict[str, str]]: 238 """Build the list of Atlas filter field definitions. 239 240 Includes custom fields inferred from *document_type* plus any 241 *extra_filter_paths* provided at construction. 242 """ 243 paths = set(self.extra_filter_paths) 244 245 if dataclasses.is_dataclass(self.document_type): 246 for field in dataclasses.fields(self.document_type): 247 if field.name not in self._BASE_KEYS: 248 paths.add(field.name) 249 250 return [{"type": "filter", "path": p} for p in sorted(paths)] 251 252 def _create_vector_index(self, replace: bool = False) -> None: 253 """Submit the Atlas Vector Search index creation request.""" 254 if not replace and self.index_exists(): 255 logger.info( 256 "Atlas Vector Search index already exists — skipping", 257 index=self.vector_index_name, 258 database=self.database_name, 259 collection=self.collection_name, 260 ) 261 return 262 263 filter_fields = self._build_filter_fields() 264 265 index_spec: Dict[str, Any] = { 266 "name": self.vector_index_name, 267 "type": "vectorSearch", 268 "definition": { 269 "fields": [ 270 { 271 "type": "vector", 272 "path": "embedding", 273 "numDimensions": self.embedding_dimension, 274 "similarity": self.similarity, 275 }, 276 *filter_fields, 277 ] 278 }, 279 } 280 281 self._collection.create_search_index(index_spec) 282 logger.info( 283 "Atlas Vector Search index created", 284 index=self.vector_index_name, 285 database=self.database_name, 286 collection=self.collection_name, 287 dim=self.embedding_dimension, 288 similarity=self.similarity, 289 filter_fields=len(filter_fields), 290 ) 291 292 def _create_text_index(self) -> None: 293 """Create a MongoDB ``$text`` index if it does not already exist.""" 294 existing = [idx["name"] for idx in self._collection.list_indexes()] 295 if "content_text" in existing: 296 logger.info( 297 "Text index already exists — skipping", 298 index="content_text", 299 database=self.database_name, 300 collection=self.collection_name, 301 ) 302 return 303 304 keys = [(field, "text") for field in self.text_index_fields] 305 self._collection.create_index(keys, name="content_text") 306 logger.info( 307 "Text index created", 308 index="content_text", 309 database=self.database_name, 310 collection=self.collection_name, 311 fields=self.text_index_fields, 312 )
Builds and manages Atlas Vector Search and text indexes for a MongoDB collection.
The builder owns schema / provisioning concerns only. Document
operations belong to MongoDBVectorStore.
Parameters
connection_string:
MongoDB Atlas connection string, e.g.
"mongodb+srv://user:pass@cluster.mongodb.net/".
database_name:
Name of the MongoDB database.
collection_name:
Name of the collection to index.
embedding_dimension:
Number of dimensions in the embedding vectors.
document_type:
Document dataclass whose extra fields will be added as Atlas
filter fields (fields other than id, content, embedding,
timestamp, and metadata).
vector_index_name:
Name of the Atlas Vector Search index. Must match the
vector_index_name used when constructing MongoDBVectorStore.
Default: "vector_index".
similarity:
Similarity metric for the vector index. "cosine" (default),
"euclidean", or "dotProduct".
extra_filter_paths:
Additional document paths to register as Atlas filter fields beyond
those inferred from document_type. Useful for arbitrary metadata
fields stored in the metadata sub-document.
text_index_fields:
Fields to include in the MongoDB $text full-text index.
Default: ["content"].
ssl_cert_path:
Path to a CA certificate bundle (PEM) for TLS verification in
corporate environments with custom certificate authorities.
107 def __init__( 108 self, 109 connection_string: str, 110 database_name: str, 111 collection_name: str, 112 embedding_dimension: int = 1536, 113 document_type: Type[Document] = Document, 114 vector_index_name: str = "vector_index", 115 similarity: Similarity = "cosine", 116 extra_filter_paths: Optional[List[str]] = None, 117 text_index_fields: Optional[List[str]] = None, 118 ssl_cert_path: Optional[str] = None, 119 ) -> None: 120 try: 121 import pymongo # noqa: F401 122 except ImportError as exc: 123 raise ImportError( 124 "pymongo is required for MongoDBIndexBuilder. " 125 "Install it with: pip install pymongo" 126 ) from exc 127 128 import pymongo 129 130 self.database_name = database_name 131 self.collection_name = collection_name 132 self.embedding_dimension = embedding_dimension 133 self.document_type = document_type 134 self.vector_index_name = vector_index_name 135 self.similarity = similarity 136 self.extra_filter_paths: List[str] = extra_filter_paths or [] 137 self.text_index_fields: List[str] = text_index_fields or ["content"] 138 139 client_kwargs: Dict[str, Any] = {} 140 if ssl_cert_path: 141 client_kwargs["tlsCAFile"] = ssl_cert_path 142 143 self._client = pymongo.MongoClient(connection_string, **client_kwargs) 144 self._db = self._client[database_name] 145 self._collection = self._db[collection_name]
151 def create_index(self) -> None: 152 """Create the Atlas Vector Search index and the text index if they 153 don't already exist. 154 155 Safe to call multiple times — each component is idempotent. 156 """ 157 self._create_vector_index(replace=False) 158 self._create_text_index()
Create the Atlas Vector Search index and the text index if they don't already exist.
Safe to call multiple times — each component is idempotent.
160 def create_or_replace_index(self) -> None: 161 """Drop the Atlas Vector Search index if it exists, then create it 162 fresh alongside the text index. 163 164 The text index is not recreated if it already exists (text indexes 165 are schema-agnostic and need no replacement). 166 167 Warning: Existing vector index data is lost. 168 """ 169 import time 170 171 if self.index_exists(): 172 self._collection.drop_search_index(self.vector_index_name) 173 logger.info( 174 "Atlas Vector Search index dropped for replacement", 175 index=self.vector_index_name, 176 ) 177 # Atlas drops are asynchronous — poll until the index is gone 178 # before submitting the creation request with the same name. 179 deadline = time.monotonic() + 60 180 while time.monotonic() < deadline: 181 if self.vector_index_name not in [ 182 idx["name"] for idx in self._collection.list_search_indexes() 183 ]: 184 break 185 time.sleep(2) 186 else: 187 raise RuntimeError( 188 f"Timed out waiting for Atlas to finish dropping index " 189 f"'{self.vector_index_name}'. Try again in a moment." 190 ) 191 self._create_vector_index(replace=True) 192 self._create_text_index()
Drop the Atlas Vector Search index if it exists, then create it fresh alongside the text index.
The text index is not recreated if it already exists (text indexes are schema-agnostic and need no replacement).
Warning: Existing vector index data is lost.
194 def delete_index(self) -> None: 195 """Drop the Atlas Vector Search index. 196 197 The MongoDB text index (``content_text``) is left in place because 198 it is independent of vector dimensionality. 199 200 Raises: 201 RuntimeError: If the vector index does not exist. 202 """ 203 if not self.index_exists(): 204 raise RuntimeError( 205 f"Cannot delete vector index '{self.vector_index_name}' on " 206 f"'{self.database_name}.{self.collection_name}': it does not exist." 207 ) 208 self._collection.drop_search_index(self.vector_index_name) 209 logger.info( 210 "Atlas Vector Search index deleted", 211 index=self.vector_index_name, 212 database=self.database_name, 213 collection=self.collection_name, 214 )
Drop the Atlas Vector Search index.
The MongoDB text index (content_text) is left in place because
it is independent of vector dimensionality.
Raises: RuntimeError: If the vector index does not exist.
216 def index_exists(self) -> bool: 217 """Return True if the Atlas Vector Search index currently exists.""" 218 existing = [idx["name"] for idx in self._collection.list_search_indexes()] 219 return self.vector_index_name in existing
Return True if the Atlas Vector Search index currently exists.
221 def list_indexes(self) -> List[str]: 222 """Return the names of all Atlas Vector Search indexes on the collection.""" 223 return [idx["name"] for idx in self._collection.list_search_indexes()]
Return the names of all Atlas Vector Search indexes on the collection.
229 def list_text_indexes(self) -> List[str]: 230 """Return the names of all standard MongoDB indexes on the collection.""" 231 return [idx["name"] for idx in self._collection.list_indexes()]
Return the names of all standard MongoDB indexes on the collection.
20class RelevanceFilter: 21 """ 22 Filters retrieved documents by minimum similarity score. 23 24 Vector stores return results ordered by score but never automatically 25 remove low-confidence matches. The filter enforces a quality floor so 26 that weakly related documents are excluded before they reach the LLM. 27 28 The threshold depends on the embedding model and index: 29 - text-embedding-ada-002 / cosine: 0.75–0.85 is typical 30 - Lower threshold (0.6): broad recall, more noise 31 - Higher threshold (0.9): tight precision, fewer results 32 33 Example: 34 ```python 35 from gmf_forge_ai_data.context import RelevanceFilter 36 37 filter_ = RelevanceFilter(min_score=0.80) 38 kept = filter_.filter(retrieved_results) 39 # Only results with score >= 0.80 are returned 40 ``` 41 """ 42 43 def __init__(self, min_score: float = 0.75): 44 """ 45 Args: 46 min_score: Minimum similarity score in [0, 1] to retain a result. 47 Results with score < min_score are discarded. 48 """ 49 if not 0.0 <= min_score <= 1.0: 50 raise ValueError(f"min_score must be in [0, 1], got {min_score}") 51 self.min_score = min_score 52 53 def filter(self, results: List[SearchResult]) -> List[SearchResult]: 54 """ 55 Remove results below the minimum score threshold. 56 57 Rank values are NOT reassigned — they preserve the original retrieval 58 rank for traceability. Use a downstream component or sort afterward 59 if contiguous ranks are required. 60 61 Args: 62 results: Retrieved SearchResult list, typically sorted by score. 63 64 Returns: 65 Subset of results where score >= min_score, in original order. 66 """ 67 return [r for r in results if r.score >= self.min_score]
Filters retrieved documents by minimum similarity score.
Vector stores return results ordered by score but never automatically remove low-confidence matches. The filter enforces a quality floor so that weakly related documents are excluded before they reach the LLM.
The threshold depends on the embedding model and index:
- text-embedding-ada-002 / cosine: 0.75–0.85 is typical
- Lower threshold (0.6): broad recall, more noise
- Higher threshold (0.9): tight precision, fewer results
Example:
from gmf_forge_ai_data.context import RelevanceFilter
filter_ = RelevanceFilter(min_score=0.80)
kept = filter_.filter(retrieved_results)
# Only results with score >= 0.80 are returned
43 def __init__(self, min_score: float = 0.75): 44 """ 45 Args: 46 min_score: Minimum similarity score in [0, 1] to retain a result. 47 Results with score < min_score are discarded. 48 """ 49 if not 0.0 <= min_score <= 1.0: 50 raise ValueError(f"min_score must be in [0, 1], got {min_score}") 51 self.min_score = min_score
Args: min_score: Minimum similarity score in [0, 1] to retain a result. Results with score < min_score are discarded.
53 def filter(self, results: List[SearchResult]) -> List[SearchResult]: 54 """ 55 Remove results below the minimum score threshold. 56 57 Rank values are NOT reassigned — they preserve the original retrieval 58 rank for traceability. Use a downstream component or sort afterward 59 if contiguous ranks are required. 60 61 Args: 62 results: Retrieved SearchResult list, typically sorted by score. 63 64 Returns: 65 Subset of results where score >= min_score, in original order. 66 """ 67 return [r for r in results if r.score >= self.min_score]
Remove results below the minimum score threshold.
Rank values are NOT reassigned — they preserve the original retrieval rank for traceability. Use a downstream component or sort afterward if contiguous ranks are required.
Args: results: Retrieved SearchResult list, typically sorted by score.
Returns: Subset of results where score >= min_score, in original order.
20class ContextDeduplicator: 21 """ 22 Removes near-duplicate passages from a retrieved result list. 23 24 Two documents are considered near-duplicates if their character trigram 25 Jaccard similarity exceeds `similarity_threshold`. The higher-ranked 26 (lower `rank` value / earlier in list) document is always kept. 27 28 Near-duplicates arise from: 29 - Overlapping chunk windows (sliding-window chunking creates 30–50% overlap) 30 - The same passage indexed in multiple documents 31 - Repeated LLM-generated content in a knowledge base 32 33 Example: 34 ```python 35 from gmf_forge_ai_data.context import ContextDeduplicator 36 37 deduper = ContextDeduplicator(similarity_threshold=0.85) 38 unique = deduper.deduplicate(retrieved_results) 39 # Near-identical passages are removed, keeping the higher-ranked one 40 ``` 41 """ 42 43 def __init__(self, similarity_threshold: float = 0.85, ngram_size: int = 3): 44 """ 45 Args: 46 similarity_threshold: Jaccard similarity above which two documents 47 are treated as duplicates. Range [0, 1]. 48 0.85 catches heavy overlaps; lower values 49 (0.6–0.7) catch paraphrase-level duplicates. 50 ngram_size: Character n-gram size for fingerprinting. 51 3 (trigrams) is a good default. 52 """ 53 if not 0.0 <= similarity_threshold <= 1.0: 54 raise ValueError( 55 f"similarity_threshold must be in [0, 1], got {similarity_threshold}" 56 ) 57 self.similarity_threshold = similarity_threshold 58 self.ngram_size = ngram_size 59 60 def deduplicate(self, results: List[SearchResult]) -> List[SearchResult]: 61 """ 62 Remove near-duplicate passages, keeping the highest-ranked copy. 63 64 Processes results in order (index 0 = highest rank). For each result, 65 checks if it is a near-duplicate of any already-kept result. If so, 66 it is discarded; otherwise it is added to the output. 67 68 Args: 69 results: List of SearchResult objects, ordered by relevance rank. 70 71 Returns: 72 Deduplicated list in the same relative order. 73 """ 74 kept: List[SearchResult] = [] 75 kept_fingerprints: List[Set[str]] = [] 76 77 for result in results: 78 fp = self._fingerprint(result.document.content) 79 if not any( 80 self._jaccard(fp, existing) >= self.similarity_threshold 81 for existing in kept_fingerprints 82 ): 83 kept.append(result) 84 kept_fingerprints.append(fp) 85 86 return kept 87 88 def _fingerprint(self, text: str) -> Set[str]: 89 """Build a character n-gram set for the text.""" 90 n = self.ngram_size 91 text = text.lower() 92 if len(text) < n: 93 return {text} 94 return {text[i:i + n] for i in range(len(text) - n + 1)} 95 96 @staticmethod 97 def _jaccard(a: Set[str], b: Set[str]) -> float: 98 if not a or not b: 99 return 0.0 100 return len(a & b) / len(a | b)
Removes near-duplicate passages from a retrieved result list.
Two documents are considered near-duplicates if their character trigram
Jaccard similarity exceeds similarity_threshold. The higher-ranked
(lower rank value / earlier in list) document is always kept.
Near-duplicates arise from:
- Overlapping chunk windows (sliding-window chunking creates 30–50% overlap)
- The same passage indexed in multiple documents
- Repeated LLM-generated content in a knowledge base
Example:
from gmf_forge_ai_data.context import ContextDeduplicator
deduper = ContextDeduplicator(similarity_threshold=0.85)
unique = deduper.deduplicate(retrieved_results)
# Near-identical passages are removed, keeping the higher-ranked one
43 def __init__(self, similarity_threshold: float = 0.85, ngram_size: int = 3): 44 """ 45 Args: 46 similarity_threshold: Jaccard similarity above which two documents 47 are treated as duplicates. Range [0, 1]. 48 0.85 catches heavy overlaps; lower values 49 (0.6–0.7) catch paraphrase-level duplicates. 50 ngram_size: Character n-gram size for fingerprinting. 51 3 (trigrams) is a good default. 52 """ 53 if not 0.0 <= similarity_threshold <= 1.0: 54 raise ValueError( 55 f"similarity_threshold must be in [0, 1], got {similarity_threshold}" 56 ) 57 self.similarity_threshold = similarity_threshold 58 self.ngram_size = ngram_size
Args: similarity_threshold: Jaccard similarity above which two documents are treated as duplicates. Range [0, 1]. 0.85 catches heavy overlaps; lower values (0.6–0.7) catch paraphrase-level duplicates. ngram_size: Character n-gram size for fingerprinting. 3 (trigrams) is a good default.
60 def deduplicate(self, results: List[SearchResult]) -> List[SearchResult]: 61 """ 62 Remove near-duplicate passages, keeping the highest-ranked copy. 63 64 Processes results in order (index 0 = highest rank). For each result, 65 checks if it is a near-duplicate of any already-kept result. If so, 66 it is discarded; otherwise it is added to the output. 67 68 Args: 69 results: List of SearchResult objects, ordered by relevance rank. 70 71 Returns: 72 Deduplicated list in the same relative order. 73 """ 74 kept: List[SearchResult] = [] 75 kept_fingerprints: List[Set[str]] = [] 76 77 for result in results: 78 fp = self._fingerprint(result.document.content) 79 if not any( 80 self._jaccard(fp, existing) >= self.similarity_threshold 81 for existing in kept_fingerprints 82 ): 83 kept.append(result) 84 kept_fingerprints.append(fp) 85 86 return kept
Remove near-duplicate passages, keeping the highest-ranked copy.
Processes results in order (index 0 = highest rank). For each result, checks if it is a near-duplicate of any already-kept result. If so, it is discarded; otherwise it is added to the output.
Args: results: List of SearchResult objects, ordered by relevance rank.
Returns: Deduplicated list in the same relative order.
22class ContextReranker: 23 """ 24 Reranks retrieved documents by LLM relevance scoring. 25 26 Sends the query and all retrieved document passages to the LLM and asks it 27 to order them from most to least relevant. Returns a new list of 28 SearchResult objects in the LLM-determined order, with scores reassigned 29 to reflect the new ranking. 30 31 The reranker is most valuable when: 32 - Vector similarity scores are clustered closely (hard to choose top-k) 33 - The query requires reasoning rather than lexical overlap 34 - Final context window is small and every slot matters 35 36 Example: 37 ```python 38 from gmf_forge_ai_data.context import ContextReranker 39 40 reranker = ContextReranker(llm_gateway) 41 reranked = await reranker.rerank( 42 query="What penalties apply for antitrust violations?", 43 results=retrieved_results, 44 top_k=3, 45 ) 46 # reranked[0] is now the most relevant doc by LLM judgement 47 ``` 48 """ 49 50 _RERANK_PROMPT = ( 51 "You are a document relevance ranking assistant for a RAG pipeline.\n\n" 52 "Query: {query}\n\n" 53 "Below are {count} retrieved documents.\n" 54 "Return ONLY a comma-separated list of document indices ordered from MOST " 55 "to LEAST relevant to the query. Include every index exactly once.\n" 56 "Example for 4 documents: 2, 0, 3, 1\n\n" 57 "Documents:\n{documents}\n\n" 58 "Ranked indices (most to least relevant):" 59 ) 60 61 def __init__(self, llm_gateway: UnifiedLLMGateway, temperature: float = 0.0): 62 """ 63 Args: 64 llm_gateway: LLM gateway for relevance assessment. 65 temperature: Sampling temperature (default 0.0 for deterministic 66 ranking). Keep low — ranking should be consistent. 67 """ 68 self.llm_gateway = llm_gateway 69 self.temperature = temperature 70 71 async def rerank( 72 self, 73 query: str, 74 results: List[SearchResult], 75 top_k: Optional[int] = None, 76 ) -> List[SearchResult]: 77 """ 78 Rerank retrieved documents by LLM relevance to the query. 79 80 If the LLM returns an unparseable response, the original order is 81 preserved so the pipeline does not break. 82 83 Args: 84 query: The original user query context for relevance judgement. 85 results: List of SearchResult objects from retrieval. 86 top_k: If provided, return only the top-k results after reranking. 87 88 Returns: 89 SearchResult list in new LLM-determined relevance order, with 90 scores reassigned as 1.0, (n-1)/n, (n-2)/n, ... reflecting rank. 91 """ 92 if not results: 93 return results 94 95 documents_text = "\n\n".join( 96 f"[{i}] {r.document.content}" 97 for i, r in enumerate(results) 98 ) 99 100 prompt = self._RERANK_PROMPT.format( 101 query=query, 102 count=len(results), 103 documents=documents_text, 104 ) 105 106 response = await self.llm_gateway.complete( 107 prompt=prompt, 108 temperature=self.temperature, 109 max_tokens=100, 110 ) 111 112 ranked_indices = self._parse_ranked_indices(response.content, len(results)) 113 114 n = len(ranked_indices) 115 reranked: List[SearchResult] = [] 116 for new_rank, idx in enumerate(ranked_indices): 117 original = results[idx] 118 score = round(1.0 - (new_rank / n), 4) if n > 1 else 1.0 119 reranked.append(SearchResult( 120 document=original.document, 121 score=score, 122 rank=new_rank, 123 )) 124 125 return reranked[:top_k] if top_k is not None else reranked 126 127 @staticmethod 128 def _parse_ranked_indices(text: str, count: int) -> List[int]: 129 """ 130 Parse a comma-separated index list from LLM output. 131 132 Falls back to the original order [0, 1, 2, ...] if the response 133 cannot be parsed or contains out-of-range indices. 134 """ 135 try: 136 parts = [p.strip() for p in text.strip().split(",")] 137 indices = [int(p) for p in parts if p.isdigit()] 138 # Validate: all indices in range, no duplicates 139 if sorted(indices) == list(range(count)): 140 return indices 141 # Partial list: include missing indices at the end 142 seen = set(indices) 143 for i in range(count): 144 if i not in seen: 145 indices.append(i) 146 return indices 147 except (ValueError, AttributeError): 148 return list(range(count))
Reranks retrieved documents by LLM relevance scoring.
Sends the query and all retrieved document passages to the LLM and asks it to order them from most to least relevant. Returns a new list of SearchResult objects in the LLM-determined order, with scores reassigned to reflect the new ranking.
The reranker is most valuable when:
- Vector similarity scores are clustered closely (hard to choose top-k)
- The query requires reasoning rather than lexical overlap
- Final context window is small and every slot matters
Example:
from gmf_forge_ai_data.context import ContextReranker
reranker = ContextReranker(llm_gateway)
reranked = await reranker.rerank(
query="What penalties apply for antitrust violations?",
results=retrieved_results,
top_k=3,
)
# reranked[0] is now the most relevant doc by LLM judgement
61 def __init__(self, llm_gateway: UnifiedLLMGateway, temperature: float = 0.0): 62 """ 63 Args: 64 llm_gateway: LLM gateway for relevance assessment. 65 temperature: Sampling temperature (default 0.0 for deterministic 66 ranking). Keep low — ranking should be consistent. 67 """ 68 self.llm_gateway = llm_gateway 69 self.temperature = temperature
Args: llm_gateway: LLM gateway for relevance assessment. temperature: Sampling temperature (default 0.0 for deterministic ranking). Keep low — ranking should be consistent.
71 async def rerank( 72 self, 73 query: str, 74 results: List[SearchResult], 75 top_k: Optional[int] = None, 76 ) -> List[SearchResult]: 77 """ 78 Rerank retrieved documents by LLM relevance to the query. 79 80 If the LLM returns an unparseable response, the original order is 81 preserved so the pipeline does not break. 82 83 Args: 84 query: The original user query context for relevance judgement. 85 results: List of SearchResult objects from retrieval. 86 top_k: If provided, return only the top-k results after reranking. 87 88 Returns: 89 SearchResult list in new LLM-determined relevance order, with 90 scores reassigned as 1.0, (n-1)/n, (n-2)/n, ... reflecting rank. 91 """ 92 if not results: 93 return results 94 95 documents_text = "\n\n".join( 96 f"[{i}] {r.document.content}" 97 for i, r in enumerate(results) 98 ) 99 100 prompt = self._RERANK_PROMPT.format( 101 query=query, 102 count=len(results), 103 documents=documents_text, 104 ) 105 106 response = await self.llm_gateway.complete( 107 prompt=prompt, 108 temperature=self.temperature, 109 max_tokens=100, 110 ) 111 112 ranked_indices = self._parse_ranked_indices(response.content, len(results)) 113 114 n = len(ranked_indices) 115 reranked: List[SearchResult] = [] 116 for new_rank, idx in enumerate(ranked_indices): 117 original = results[idx] 118 score = round(1.0 - (new_rank / n), 4) if n > 1 else 1.0 119 reranked.append(SearchResult( 120 document=original.document, 121 score=score, 122 rank=new_rank, 123 )) 124 125 return reranked[:top_k] if top_k is not None else reranked
Rerank retrieved documents by LLM relevance to the query.
If the LLM returns an unparseable response, the original order is preserved so the pipeline does not break.
Args: query: The original user query context for relevance judgement. results: List of SearchResult objects from retrieval. top_k: If provided, return only the top-k results after reranking.
Returns: SearchResult list in new LLM-determined relevance order, with scores reassigned as 1.0, (n-1)/n, (n-2)/n, ... reflecting rank.
24class ContextCompressor: 25 """ 26 Compresses retrieved passages to only the query-relevant sentences. 27 28 Calls the LLM once per retrieved document, asking it to extract only the 29 sentences that directly help answer the query. The original SearchResult 30 is preserved — only the `content` field is replaced with the compressed 31 version. Score and rank are unchanged. 32 33 When to use: 34 - Chunks are long (>500 tokens) and only partially relevant 35 - Context window budget is tight 36 - You want to reduce LLM distraction from off-topic content in chunks 37 38 Example: 39 ```python 40 from gmf_forge_ai_data.context import ContextCompressor 41 42 compressor = ContextCompressor(llm_gateway) 43 compressed = await compressor.compress( 44 query="What is self-attention in transformers?", 45 results=reranked_results, 46 ) 47 # Each result.document.content now contains only the relevant sentences 48 ``` 49 """ 50 51 _COMPRESS_PROMPT = ( 52 "You are a document compression assistant for a RAG pipeline.\n\n" 53 "Extract ONLY the sentences from the passage below that are directly " 54 "relevant to answering the query. Drop redundant, tangential, or clearly " 55 "off-topic sentences.\n" 56 "Return ONLY the extracted text — no explanation, no bullet points, no markdown.\n" 57 "If the entire passage is already focused and relevant, return it unchanged.\n\n" 58 "STRICT RULES:\n" 59 "- Copy sentences EXACTLY as they appear in the passage.\n" 60 "- Do NOT paraphrase, summarise, reword, or generate any new text.\n" 61 "- Only output sentences that exist verbatim in the passage.\n\n" 62 "Query: {query}\n\n" 63 "Passage:\n{content}\n\n" 64 "Relevant sentences from the passage (verbatim):" 65 ) 66 67 def __init__(self, llm_gateway: UnifiedLLMGateway, temperature: float = 0.0): 68 """ 69 Args: 70 llm_gateway: LLM gateway for compression. 71 temperature: Sampling temperature (default 0.0 for deterministic 72 extraction). Keep low to avoid hallucination. 73 """ 74 self.llm_gateway = llm_gateway 75 self.temperature = temperature 76 77 async def compress( 78 self, 79 query: str, 80 results: List[SearchResult], 81 min_length: int = 20, 82 ) -> List[SearchResult]: 83 """ 84 Compress each retrieved passage to query-relevant content only. 85 86 Calls the LLM once per result. If the LLM returns content shorter 87 than `min_length` characters (likely an error), the original passage 88 is kept unchanged. 89 90 Args: 91 query: The user query to guide extraction. 92 results: List of SearchResult objects to compress. 93 min_length: Minimum character length for a valid compressed result. 94 If compression produces less, the original is kept. 95 96 Returns: 97 New list of SearchResult objects with compressed document content. 98 """ 99 compressed: List[SearchResult] = [] 100 for result in results: 101 compressed_content = await self._compress_one( 102 query, result.document.content, min_length 103 ) 104 new_doc = copy.copy(result.document) 105 new_doc.content = compressed_content 106 compressed.append(SearchResult( 107 document=new_doc, 108 score=result.score, 109 rank=result.rank, 110 )) 111 return compressed 112 113 async def _compress_one( 114 self, query: str, content: str, min_length: int 115 ) -> str: 116 prompt = self._COMPRESS_PROMPT.format(query=query, content=content) 117 response = await self.llm_gateway.complete( 118 prompt=prompt, 119 temperature=self.temperature, 120 max_tokens=500, 121 ) 122 extracted = response.content.strip().strip('"').strip("'") 123 return extracted if len(extracted) >= min_length else content
Compresses retrieved passages to only the query-relevant sentences.
Calls the LLM once per retrieved document, asking it to extract only the
sentences that directly help answer the query. The original SearchResult
is preserved — only the content field is replaced with the compressed
version. Score and rank are unchanged.
When to use:
- Chunks are long (>500 tokens) and only partially relevant
- Context window budget is tight
- You want to reduce LLM distraction from off-topic content in chunks
Example:
from gmf_forge_ai_data.context import ContextCompressor
compressor = ContextCompressor(llm_gateway)
compressed = await compressor.compress(
query="What is self-attention in transformers?",
results=reranked_results,
)
# Each result.document.content now contains only the relevant sentences
67 def __init__(self, llm_gateway: UnifiedLLMGateway, temperature: float = 0.0): 68 """ 69 Args: 70 llm_gateway: LLM gateway for compression. 71 temperature: Sampling temperature (default 0.0 for deterministic 72 extraction). Keep low to avoid hallucination. 73 """ 74 self.llm_gateway = llm_gateway 75 self.temperature = temperature
Args: llm_gateway: LLM gateway for compression. temperature: Sampling temperature (default 0.0 for deterministic extraction). Keep low to avoid hallucination.
77 async def compress( 78 self, 79 query: str, 80 results: List[SearchResult], 81 min_length: int = 20, 82 ) -> List[SearchResult]: 83 """ 84 Compress each retrieved passage to query-relevant content only. 85 86 Calls the LLM once per result. If the LLM returns content shorter 87 than `min_length` characters (likely an error), the original passage 88 is kept unchanged. 89 90 Args: 91 query: The user query to guide extraction. 92 results: List of SearchResult objects to compress. 93 min_length: Minimum character length for a valid compressed result. 94 If compression produces less, the original is kept. 95 96 Returns: 97 New list of SearchResult objects with compressed document content. 98 """ 99 compressed: List[SearchResult] = [] 100 for result in results: 101 compressed_content = await self._compress_one( 102 query, result.document.content, min_length 103 ) 104 new_doc = copy.copy(result.document) 105 new_doc.content = compressed_content 106 compressed.append(SearchResult( 107 document=new_doc, 108 score=result.score, 109 rank=result.rank, 110 )) 111 return compressed
Compress each retrieved passage to query-relevant content only.
Calls the LLM once per result. If the LLM returns content shorter
than min_length characters (likely an error), the original passage
is kept unchanged.
Args: query: The user query to guide extraction. results: List of SearchResult objects to compress. min_length: Minimum character length for a valid compressed result. If compression produces less, the original is kept.
Returns: New list of SearchResult objects with compressed document content.
42class ContextWindowManager: 43 """ 44 Fits retrieved documents into a maximum token budget. 45 46 Works in two passes: 47 1. Greedily adds full documents until the budget is exhausted. 48 2. If `allow_truncation=True`, the last document that didn't fit is 49 partially included using as many characters as the remaining budget 50 allows. 51 52 Token estimation uses 4 characters ≈ 1 token (standard rule-of-thumb 53 for English text with GPT tokenizers). 54 55 Example: 56 ```python 57 from gmf_forge_ai_data.context import ContextWindowManager 58 59 manager = ContextWindowManager(max_tokens=2000) 60 window = manager.fit(reranked_and_compressed_results) 61 62 print(f"Using {window.total_tokens} / {window.budget} tokens") 63 print(f"Docs included: {len(window.results)}, truncated: {window.truncated}") 64 65 # Build the prompt context string 66 context_str = "\\n\\n".join(r.document.content for r in window.results) 67 ``` 68 """ 69 70 _CHARS_PER_TOKEN: float = 4.0 71 72 def __init__( 73 self, 74 max_tokens: int = 3000, 75 allow_truncation: bool = True, 76 ): 77 """ 78 Args: 79 max_tokens: Maximum number of tokens for all document content 80 combined. Does not include the prompt template 81 overhead — subtract your template size from the 82 model's context limit before setting this. 83 allow_truncation: If True (default), the last document that would 84 overflow is truncated to fit remaining space. 85 If False, it is dropped entirely. 86 """ 87 if max_tokens <= 0: 88 raise ValueError(f"max_tokens must be positive, got {max_tokens}") 89 self.max_tokens = max_tokens 90 self.allow_truncation = allow_truncation 91 92 def fit(self, results: List[SearchResult]) -> WindowedContext: 93 """ 94 Select and optionally truncate documents to fit within the token budget. 95 96 Args: 97 results: List of SearchResult objects, ordered by priority 98 (highest priority first — e.g. after reranking). 99 100 Returns: 101 WindowedContext with the selected results and budget metadata. 102 """ 103 kept: List[SearchResult] = [] 104 tokens_used = 0 105 truncated = False 106 dropped = 0 107 108 for result in results: 109 doc_tokens = self._estimate_tokens(result.document.content) 110 remaining = self.max_tokens - tokens_used 111 112 if doc_tokens <= remaining: 113 kept.append(result) 114 tokens_used += doc_tokens 115 elif self.allow_truncation and remaining > 0 and not truncated: 116 # Partially include this document to fill remaining budget 117 max_chars = int(remaining * self._CHARS_PER_TOKEN) 118 truncated_content = result.document.content[:max_chars] 119 new_doc = copy.copy(result.document) 120 new_doc.content = truncated_content 121 kept.append(SearchResult( 122 document=new_doc, 123 score=result.score, 124 rank=result.rank, 125 )) 126 tokens_used += self._estimate_tokens(truncated_content) 127 truncated = True 128 else: 129 dropped += 1 130 131 return WindowedContext( 132 results=kept, 133 total_tokens=tokens_used, 134 budget=self.max_tokens, 135 truncated=truncated, 136 dropped=dropped, 137 ) 138 139 def _estimate_tokens(self, text: str) -> int: 140 """Estimate token count using 4 characters-per-token heuristic.""" 141 return max(1, round(len(text) / self._CHARS_PER_TOKEN))
Fits retrieved documents into a maximum token budget.
Works in two passes:
- Greedily adds full documents until the budget is exhausted.
- If
allow_truncation=True, the last document that didn't fit is partially included using as many characters as the remaining budget allows.
Token estimation uses 4 characters ≈ 1 token (standard rule-of-thumb for English text with GPT tokenizers).
Example:
from gmf_forge_ai_data.context import ContextWindowManager
manager = ContextWindowManager(max_tokens=2000)
window = manager.fit(reranked_and_compressed_results)
print(f"Using {window.total_tokens} / {window.budget} tokens")
print(f"Docs included: {len(window.results)}, truncated: {window.truncated}")
# Build the prompt context string
context_str = "\n\n".join(r.document.content for r in window.results)
72 def __init__( 73 self, 74 max_tokens: int = 3000, 75 allow_truncation: bool = True, 76 ): 77 """ 78 Args: 79 max_tokens: Maximum number of tokens for all document content 80 combined. Does not include the prompt template 81 overhead — subtract your template size from the 82 model's context limit before setting this. 83 allow_truncation: If True (default), the last document that would 84 overflow is truncated to fit remaining space. 85 If False, it is dropped entirely. 86 """ 87 if max_tokens <= 0: 88 raise ValueError(f"max_tokens must be positive, got {max_tokens}") 89 self.max_tokens = max_tokens 90 self.allow_truncation = allow_truncation
Args: max_tokens: Maximum number of tokens for all document content combined. Does not include the prompt template overhead — subtract your template size from the model's context limit before setting this. allow_truncation: If True (default), the last document that would overflow is truncated to fit remaining space. If False, it is dropped entirely.
92 def fit(self, results: List[SearchResult]) -> WindowedContext: 93 """ 94 Select and optionally truncate documents to fit within the token budget. 95 96 Args: 97 results: List of SearchResult objects, ordered by priority 98 (highest priority first — e.g. after reranking). 99 100 Returns: 101 WindowedContext with the selected results and budget metadata. 102 """ 103 kept: List[SearchResult] = [] 104 tokens_used = 0 105 truncated = False 106 dropped = 0 107 108 for result in results: 109 doc_tokens = self._estimate_tokens(result.document.content) 110 remaining = self.max_tokens - tokens_used 111 112 if doc_tokens <= remaining: 113 kept.append(result) 114 tokens_used += doc_tokens 115 elif self.allow_truncation and remaining > 0 and not truncated: 116 # Partially include this document to fill remaining budget 117 max_chars = int(remaining * self._CHARS_PER_TOKEN) 118 truncated_content = result.document.content[:max_chars] 119 new_doc = copy.copy(result.document) 120 new_doc.content = truncated_content 121 kept.append(SearchResult( 122 document=new_doc, 123 score=result.score, 124 rank=result.rank, 125 )) 126 tokens_used += self._estimate_tokens(truncated_content) 127 truncated = True 128 else: 129 dropped += 1 130 131 return WindowedContext( 132 results=kept, 133 total_tokens=tokens_used, 134 budget=self.max_tokens, 135 truncated=truncated, 136 dropped=dropped, 137 )
Select and optionally truncate documents to fit within the token budget.
Args: results: List of SearchResult objects, ordered by priority (highest priority first — e.g. after reranking).
Returns: WindowedContext with the selected results and budget metadata.
23@dataclass 24class WindowedContext: 25 """ 26 Output of ContextWindowManager.fit(). 27 28 Attributes: 29 results: Documents that fit within the token budget, in order. 30 total_tokens: Estimated token count of all included content. 31 budget: The configured maximum token budget. 32 truncated: True if the last document's content was truncated to fit. 33 dropped: Number of documents that did not fit at all. 34 """ 35 results: List[SearchResult] 36 total_tokens: int 37 budget: int 38 truncated: bool = False 39 dropped: int = 0
Output of ContextWindowManager.fit().
Attributes: results: Documents that fit within the token budget, in order. total_tokens: Estimated token count of all included content. budget: The configured maximum token budget. truncated: True if the last document's content was truncated to fit. dropped: Number of documents that did not fit at all.
32class QueryDecomposer: 33 """ 34 Decomposes complex multi-part queries into focused sub-queries using LLM. 35 36 Multi-part queries ("What are the antitrust laws and what cases were filed in 2024?") 37 are split into individual queries for better retrieval precision per component. 38 The sub-queries can then be run in parallel with a retriever and results merged, 39 similar to query expansion but targeting distinct question atoms rather than synonyms. 40 41 Example: 42 ```python 43 from gmf_forge_ai_data.query import QueryDecomposer 44 from gmf_forge_ai_shared_core.llm_gateway import UnifiedLLMGateway 45 46 gateway = UnifiedLLMGateway(default_provider=azure_provider) 47 decomposer = QueryDecomposer(gateway) 48 49 result = await decomposer.decompose( 50 "What are the antitrust laws and what cases were filed in 2024?" 51 ) 52 # result.sub_queries = [ 53 # "What are the antitrust laws?", 54 # "What cases were filed in 2024?", 55 # ] 56 ``` 57 """ 58 59 _DECOMPOSE_PROMPT = ( 60 "You are a query decomposition assistant for a retrieval system.\n\n" 61 "Break the following complex query into {max_sub_queries} or fewer " 62 "focused sub-queries.\n" 63 "Each sub-query must be self-contained and independently answerable.\n" 64 "Return ONLY a numbered list, one sub-query per line. " 65 "Do not add explanations.\n\n" 66 "Query: {query}\n\n" 67 "Sub-queries:" 68 ) 69 70 def __init__(self, llm_gateway: UnifiedLLMGateway, temperature: float = 0.0): 71 """ 72 Args: 73 llm_gateway: LLM gateway used for intelligent decomposition. 74 temperature: Sampling temperature passed to the LLM (default 0.0 for 75 deterministic decomposition). Raise slightly (e.g. 0.2) 76 to get more varied sub-query boundaries. 77 """ 78 self.llm_gateway = llm_gateway 79 self.temperature = temperature 80 81 async def decompose( 82 self, 83 query: str, 84 max_sub_queries: int = 3, 85 ) -> DecomposedQuery: 86 """ 87 Decompose a complex query into focused sub-queries using LLM. 88 89 Args: 90 query: The complex query to break apart. 91 max_sub_queries: Maximum number of sub-queries to produce. 92 93 Returns: 94 DecomposedQuery containing the original and list of sub-queries. 95 """ 96 prompt = self._DECOMPOSE_PROMPT.format( 97 query=query, 98 max_sub_queries=max_sub_queries, 99 ) 100 101 response = await self.llm_gateway.complete( 102 prompt=prompt, 103 temperature=self.temperature, 104 max_tokens=300, 105 ) 106 107 sub_queries = self._parse_numbered_list(response.content) 108 109 if not sub_queries: 110 return DecomposedQuery( 111 original=query, 112 sub_queries=[query], 113 reasoning=response.content, 114 ) 115 116 return DecomposedQuery( 117 original=query, 118 sub_queries=sub_queries[:max_sub_queries], 119 reasoning=response.content, 120 ) 121 122 @staticmethod 123 def _parse_numbered_list(text: str) -> List[str]: 124 """Parse '1. item\\n2. item', '1) item', '- item', '* item' from LLM output.""" 125 lines = text.strip().split("\n") 126 results: List[str] = [] 127 for line in lines: 128 match = re.match(r"^\s*(?:\d+[.)]\s*|[-*]\s*)(.+)", line) 129 if match: 130 results.append(match.group(1).strip()) 131 return results
Decomposes complex multi-part queries into focused sub-queries using LLM.
Multi-part queries ("What are the antitrust laws and what cases were filed in 2024?") are split into individual queries for better retrieval precision per component. The sub-queries can then be run in parallel with a retriever and results merged, similar to query expansion but targeting distinct question atoms rather than synonyms.
Example:
from gmf_forge_ai_data.query import QueryDecomposer
from gmf_forge_ai_shared_core.llm_gateway import UnifiedLLMGateway
gateway = UnifiedLLMGateway(default_provider=azure_provider)
decomposer = QueryDecomposer(gateway)
result = await decomposer.decompose(
"What are the antitrust laws and what cases were filed in 2024?"
)
# result.sub_queries = [
# "What are the antitrust laws?",
# "What cases were filed in 2024?",
# ]
70 def __init__(self, llm_gateway: UnifiedLLMGateway, temperature: float = 0.0): 71 """ 72 Args: 73 llm_gateway: LLM gateway used for intelligent decomposition. 74 temperature: Sampling temperature passed to the LLM (default 0.0 for 75 deterministic decomposition). Raise slightly (e.g. 0.2) 76 to get more varied sub-query boundaries. 77 """ 78 self.llm_gateway = llm_gateway 79 self.temperature = temperature
Args: llm_gateway: LLM gateway used for intelligent decomposition. temperature: Sampling temperature passed to the LLM (default 0.0 for deterministic decomposition). Raise slightly (e.g. 0.2) to get more varied sub-query boundaries.
81 async def decompose( 82 self, 83 query: str, 84 max_sub_queries: int = 3, 85 ) -> DecomposedQuery: 86 """ 87 Decompose a complex query into focused sub-queries using LLM. 88 89 Args: 90 query: The complex query to break apart. 91 max_sub_queries: Maximum number of sub-queries to produce. 92 93 Returns: 94 DecomposedQuery containing the original and list of sub-queries. 95 """ 96 prompt = self._DECOMPOSE_PROMPT.format( 97 query=query, 98 max_sub_queries=max_sub_queries, 99 ) 100 101 response = await self.llm_gateway.complete( 102 prompt=prompt, 103 temperature=self.temperature, 104 max_tokens=300, 105 ) 106 107 sub_queries = self._parse_numbered_list(response.content) 108 109 if not sub_queries: 110 return DecomposedQuery( 111 original=query, 112 sub_queries=[query], 113 reasoning=response.content, 114 ) 115 116 return DecomposedQuery( 117 original=query, 118 sub_queries=sub_queries[:max_sub_queries], 119 reasoning=response.content, 120 )
Decompose a complex query into focused sub-queries using LLM.
Args: query: The complex query to break apart. max_sub_queries: Maximum number of sub-queries to produce.
Returns: DecomposedQuery containing the original and list of sub-queries.
17@dataclass 18class DecomposedQuery: 19 """ 20 Result of query decomposition. 21 22 Attributes: 23 original: The original complex query string. 24 sub_queries: List of focused sub-queries derived from the original. 25 reasoning: Raw LLM response explaining the decomposition. 26 """ 27 original: str 28 sub_queries: List[str] 29 reasoning: Optional[str] = None
Result of query decomposition.
Attributes: original: The original complex query string. sub_queries: List of focused sub-queries derived from the original. reasoning: Raw LLM response explaining the decomposition.
37class QueryRouter: 38 """ 39 Routes queries to the appropriate retriever or index using LLM. 40 41 Each route has a name and a plain-English description of its content. 42 The LLM selects the best-matching route for each incoming query. 43 44 Typical use in a multi-index RAG system: create one route per Azure AI 45 Search index and let the router automatically direct queries without 46 searching all indexes every time. 47 48 Example: 49 ```python 50 from gmf_forge_ai_data.query import QueryRouter 51 52 routes = { 53 "legal_documents": "Legal cases, court decisions, jurisdiction, antitrust, patent", 54 "products": "Products, prices, inventory, electronics, furniture, camera", 55 "financial_reports": "Earnings, revenue, fiscal year, company financials, SEC filings", 56 "ai_ml_knowledge": "Machine learning, AI, neural networks, deep learning, NLP", 57 } 58 59 router = QueryRouter(routes=routes, llm_gateway=gateway) 60 decision = await router.route("What antitrust cases were filed in 2024?") 61 # decision.target = "legal_documents" 62 # decision.confidence = 0.9 63 ``` 64 """ 65 66 _ROUTE_PROMPT = ( 67 "You are a query routing assistant for a multi-domain retrieval system.\n\n" 68 "Available indexes and what they contain:\n" 69 "{routes_description}\n\n" 70 "Given the user query below, output ONLY the name of the single best index " 71 "to search. Do not add any explanation or punctuation.\n\n" 72 "Query: {query}\n\n" 73 "Best index:" 74 ) 75 76 def __init__( 77 self, 78 routes: Dict[str, str], 79 llm_gateway: UnifiedLLMGateway, 80 temperature: float = 0.0, 81 ): 82 """ 83 Args: 84 routes: Dict mapping route name → plain-English description of content. 85 llm_gateway: LLM gateway for intelligent routing. 86 temperature: Sampling temperature passed to the LLM (default 0.0 for 87 deterministic routing). Keep low — routing should be consistent. 88 """ 89 self.routes = routes 90 self.llm_gateway = llm_gateway 91 self.temperature = temperature 92 93 async def route(self, query: str) -> RouteDecision: 94 """ 95 Route a query to the best-matching index using LLM. 96 97 Args: 98 query: The user query to route. 99 100 Returns: 101 RouteDecision with the chosen target and confidence score. 102 103 Raises: 104 ValueError: If the LLM returns an unknown route name. 105 """ 106 routes_description = "\n".join( 107 f"- {name}: {desc}" for name, desc in self.routes.items() 108 ) 109 prompt = self._ROUTE_PROMPT.format( 110 routes_description=routes_description, 111 query=query, 112 ) 113 114 response = await self.llm_gateway.complete( 115 prompt=prompt, 116 temperature=self.temperature, 117 max_tokens=50, 118 ) 119 120 target = response.content.strip().strip('"').strip("'") 121 122 if target not in self.routes: 123 raise ValueError( 124 f"LLM returned unknown route '{target}'. " 125 f"Valid routes: {list(self.routes.keys())}" 126 ) 127 128 alternatives = [(name, 0.0) for name in self.routes if name != target] 129 130 return RouteDecision( 131 query=query, 132 target=target, 133 confidence=0.9, 134 reasoning=response.content, 135 alternatives=alternatives, 136 )
Routes queries to the appropriate retriever or index using LLM.
Each route has a name and a plain-English description of its content. The LLM selects the best-matching route for each incoming query.
Typical use in a multi-index RAG system: create one route per Azure AI Search index and let the router automatically direct queries without searching all indexes every time.
Example:
from gmf_forge_ai_data.query import QueryRouter
routes = {
"legal_documents": "Legal cases, court decisions, jurisdiction, antitrust, patent",
"products": "Products, prices, inventory, electronics, furniture, camera",
"financial_reports": "Earnings, revenue, fiscal year, company financials, SEC filings",
"ai_ml_knowledge": "Machine learning, AI, neural networks, deep learning, NLP",
}
router = QueryRouter(routes=routes, llm_gateway=gateway)
decision = await router.route("What antitrust cases were filed in 2024?")
# decision.target = "legal_documents"
# decision.confidence = 0.9
76 def __init__( 77 self, 78 routes: Dict[str, str], 79 llm_gateway: UnifiedLLMGateway, 80 temperature: float = 0.0, 81 ): 82 """ 83 Args: 84 routes: Dict mapping route name → plain-English description of content. 85 llm_gateway: LLM gateway for intelligent routing. 86 temperature: Sampling temperature passed to the LLM (default 0.0 for 87 deterministic routing). Keep low — routing should be consistent. 88 """ 89 self.routes = routes 90 self.llm_gateway = llm_gateway 91 self.temperature = temperature
Args: routes: Dict mapping route name → plain-English description of content. llm_gateway: LLM gateway for intelligent routing. temperature: Sampling temperature passed to the LLM (default 0.0 for deterministic routing). Keep low — routing should be consistent.
93 async def route(self, query: str) -> RouteDecision: 94 """ 95 Route a query to the best-matching index using LLM. 96 97 Args: 98 query: The user query to route. 99 100 Returns: 101 RouteDecision with the chosen target and confidence score. 102 103 Raises: 104 ValueError: If the LLM returns an unknown route name. 105 """ 106 routes_description = "\n".join( 107 f"- {name}: {desc}" for name, desc in self.routes.items() 108 ) 109 prompt = self._ROUTE_PROMPT.format( 110 routes_description=routes_description, 111 query=query, 112 ) 113 114 response = await self.llm_gateway.complete( 115 prompt=prompt, 116 temperature=self.temperature, 117 max_tokens=50, 118 ) 119 120 target = response.content.strip().strip('"').strip("'") 121 122 if target not in self.routes: 123 raise ValueError( 124 f"LLM returned unknown route '{target}'. " 125 f"Valid routes: {list(self.routes.keys())}" 126 ) 127 128 alternatives = [(name, 0.0) for name in self.routes if name != target] 129 130 return RouteDecision( 131 query=query, 132 target=target, 133 confidence=0.9, 134 reasoning=response.content, 135 alternatives=alternatives, 136 )
Route a query to the best-matching index using LLM.
Args: query: The user query to route.
Returns: RouteDecision with the chosen target and confidence score.
Raises: ValueError: If the LLM returns an unknown route name.
18@dataclass 19class RouteDecision: 20 """ 21 Result of query routing. 22 23 Attributes: 24 query: The original query string. 25 target: Name of the chosen route (retriever or index). 26 confidence: Confidence score in [0, 1] for the chosen route. 27 reasoning: Raw LLM output. 28 alternatives: Other routes with placeholder confidence scores. 29 """ 30 query: str 31 target: str 32 confidence: float 33 reasoning: Optional[str] = None 34 alternatives: List[Tuple[str, float]] = field(default_factory=list)
Result of query routing.
Attributes: query: The original query string. target: Name of the chosen route (retriever or index). confidence: Confidence score in [0, 1] for the chosen route. reasoning: Raw LLM output. alternatives: Other routes with placeholder confidence scores.
34class QueryExpander: 35 """ 36 Generates query variations to improve retrieval recall using LLM. 37 38 Uses an LLM to produce semantically equivalent re-phrasings of the original 39 query. Expanded queries are intended to run in parallel with the original query 40 via separate retriever calls, then merged with Reciprocal Rank Fusion (RRF) 41 using EnsembleRetriever for best results. 42 43 Example: 44 ```python 45 from gmf_forge_ai_data.query import QueryExpander 46 47 expander = QueryExpander(llm_gateway) 48 result = await expander.expand("antitrust violations", num_expansions=3) 49 # result.expansions = [ 50 # "competition law breaches", 51 # "monopoly infringement cases", 52 # "anti-competitive conduct", 53 # ] 54 ``` 55 """ 56 57 _EXPAND_PROMPT = ( 58 "You are a search query expansion assistant.\n\n" 59 "Generate {num_expansions} alternative phrasings for the search query below.\n" 60 "Use synonyms, related terms, and different wording that conveys the same intent.\n" 61 "Return ONLY a numbered list, one variation per line. " 62 "Do NOT repeat the original query.\n\n" 63 "Original query: {query}\n\n" 64 "Alternative phrasings:" 65 ) 66 67 def __init__(self, llm_gateway: UnifiedLLMGateway, temperature: float = 0.3): 68 """ 69 Args: 70 llm_gateway: LLM gateway for generating query variations. 71 temperature: Sampling temperature passed to the LLM (default 0.3 for 72 creative variation). Raise toward 0.7 for more diverse 73 phrasings; lower toward 0.0 for tighter paraphrases. 74 """ 75 self.llm_gateway = llm_gateway 76 self.temperature = temperature 77 78 async def expand( 79 self, 80 query: str, 81 num_expansions: int = 3, 82 ) -> ExpandedQuery: 83 """ 84 Expand a query into multiple variations using LLM. 85 86 Args: 87 query: The original query to expand. 88 num_expansions: Number of alternative phrasings to generate. 89 90 Returns: 91 ExpandedQuery with original and list of variation strings. 92 """ 93 prompt = self._EXPAND_PROMPT.format( 94 query=query, 95 num_expansions=num_expansions, 96 ) 97 98 response = await self.llm_gateway.complete( 99 prompt=prompt, 100 temperature=self.temperature, 101 max_tokens=300, 102 ) 103 104 expansions = self._parse_numbered_list(response.content) 105 106 return ExpandedQuery( 107 original=query, 108 expansions=expansions[:num_expansions], 109 ) 110 111 @staticmethod 112 def _parse_numbered_list(text: str) -> List[str]: 113 """Parse '1. item\\n2. item', '1) item', '- item', '* item' from LLM output.""" 114 lines = text.strip().split("\n") 115 results: List[str] = [] 116 for line in lines: 117 match = re.match(r"^\s*(?:\d+[.)]\s*|[-*]\s*)(.+)", line) 118 if match: 119 results.append(match.group(1).strip()) 120 return results
Generates query variations to improve retrieval recall using LLM.
Uses an LLM to produce semantically equivalent re-phrasings of the original query. Expanded queries are intended to run in parallel with the original query via separate retriever calls, then merged with Reciprocal Rank Fusion (RRF) using EnsembleRetriever for best results.
Example:
from gmf_forge_ai_data.query import QueryExpander
expander = QueryExpander(llm_gateway)
result = await expander.expand("antitrust violations", num_expansions=3)
# result.expansions = [
# "competition law breaches",
# "monopoly infringement cases",
# "anti-competitive conduct",
# ]
67 def __init__(self, llm_gateway: UnifiedLLMGateway, temperature: float = 0.3): 68 """ 69 Args: 70 llm_gateway: LLM gateway for generating query variations. 71 temperature: Sampling temperature passed to the LLM (default 0.3 for 72 creative variation). Raise toward 0.7 for more diverse 73 phrasings; lower toward 0.0 for tighter paraphrases. 74 """ 75 self.llm_gateway = llm_gateway 76 self.temperature = temperature
Args: llm_gateway: LLM gateway for generating query variations. temperature: Sampling temperature passed to the LLM (default 0.3 for creative variation). Raise toward 0.7 for more diverse phrasings; lower toward 0.0 for tighter paraphrases.
78 async def expand( 79 self, 80 query: str, 81 num_expansions: int = 3, 82 ) -> ExpandedQuery: 83 """ 84 Expand a query into multiple variations using LLM. 85 86 Args: 87 query: The original query to expand. 88 num_expansions: Number of alternative phrasings to generate. 89 90 Returns: 91 ExpandedQuery with original and list of variation strings. 92 """ 93 prompt = self._EXPAND_PROMPT.format( 94 query=query, 95 num_expansions=num_expansions, 96 ) 97 98 response = await self.llm_gateway.complete( 99 prompt=prompt, 100 temperature=self.temperature, 101 max_tokens=300, 102 ) 103 104 expansions = self._parse_numbered_list(response.content) 105 106 return ExpandedQuery( 107 original=query, 108 expansions=expansions[:num_expansions], 109 )
Expand a query into multiple variations using LLM.
Args: query: The original query to expand. num_expansions: Number of alternative phrasings to generate.
Returns: ExpandedQuery with original and list of variation strings.
21@dataclass 22class ExpandedQuery: 23 """ 24 Result of query expansion. 25 26 Attributes: 27 original: The original query string (not included in expansions list). 28 expansions: Alternative phrasings — run alongside the original query. 29 """ 30 original: str 31 expansions: List[str]
Result of query expansion.
Attributes: original: The original query string (not included in expansions list). expansions: Alternative phrasings — run alongside the original query.
31class QueryRewriter: 32 """ 33 Improves query quality before retrieval using LLM. 34 35 Handles: 36 - Grammar and spelling fixes 37 - Replacement of vague terms with specific, domain-appropriate ones 38 - Removal of conversational filler ("tell me about", "can you find") 39 - Clarification of ambiguous intent using optional domain context 40 41 Example: 42 ```python 43 from gmf_forge_ai_data.query import QueryRewriter 44 45 rewriter = QueryRewriter(llm_gateway) 46 47 result = await rewriter.rewrite( 48 "tell me the stuff about that apple patent thing", 49 context="legal documents database" 50 ) 51 # result.rewritten = "Apple Inc. patent infringement case details" 52 # result.changes = ["LLM rewrote: '...' → '...'"] 53 ``` 54 """ 55 56 _REWRITE_PROMPT = ( 57 "You are a search query optimization assistant for a document retrieval system.\n\n" 58 "Rewrite the following query to make it more precise and effective for retrieval:\n" 59 "- Fix grammar and spelling errors\n" 60 "- Replace vague or colloquial terms with specific, domain-appropriate ones\n" 61 "- Remove conversational filler (e.g., 'tell me about', 'can you find')\n" 62 "- Preserve the original semantic intent\n" 63 "- Return ONLY the rewritten query — no explanation, no extra text\n\n" 64 "{context_line}" 65 "Query: {query}\n\n" 66 "Rewritten query:" 67 ) 68 69 def __init__(self, llm_gateway: UnifiedLLMGateway, temperature: float = 0.0): 70 """ 71 Args: 72 llm_gateway: LLM gateway for intelligent query rewriting. 73 temperature: Sampling temperature passed to the LLM (default 0.0 for 74 deterministic rewrites). Keep low — rewriting should 75 produce consistent, reproducible output. 76 """ 77 self.llm_gateway = llm_gateway 78 self.temperature = temperature 79 80 async def rewrite( 81 self, 82 query: str, 83 context: Optional[str] = None, 84 ) -> RewrittenQuery: 85 """ 86 Rewrite a query for better retrieval using LLM. 87 88 Args: 89 query: The original query to improve. 90 context: Optional domain hint passed to the LLM 91 (e.g., "legal documents", "financial filings database"). 92 93 Returns: 94 RewrittenQuery with improved text and list of changes made. 95 """ 96 context_line = f"Domain context: {context}\n\n" if context else "" 97 prompt = self._REWRITE_PROMPT.format( 98 query=query, 99 context_line=context_line, 100 ) 101 102 response = await self.llm_gateway.complete( 103 prompt=prompt, 104 temperature=self.temperature, 105 max_tokens=150, 106 ) 107 108 rewritten = response.content.strip().strip('"').strip("'") 109 110 if not rewritten or rewritten.lower() == query.lower(): 111 return RewrittenQuery( 112 original=query, 113 rewritten=query, 114 changes=["No rewrite needed"], 115 ) 116 117 return RewrittenQuery( 118 original=query, 119 rewritten=rewritten, 120 changes=[f"LLM rewrote: '{query}' → '{rewritten}'"], 121 )
Improves query quality before retrieval using LLM.
Handles:
- Grammar and spelling fixes
- Replacement of vague terms with specific, domain-appropriate ones
- Removal of conversational filler ("tell me about", "can you find")
- Clarification of ambiguous intent using optional domain context
Example:
from gmf_forge_ai_data.query import QueryRewriter
rewriter = QueryRewriter(llm_gateway)
result = await rewriter.rewrite(
"tell me the stuff about that apple patent thing",
context="legal documents database"
)
# result.rewritten = "Apple Inc. patent infringement case details"
# result.changes = ["LLM rewrote: '...' → '...'"]
69 def __init__(self, llm_gateway: UnifiedLLMGateway, temperature: float = 0.0): 70 """ 71 Args: 72 llm_gateway: LLM gateway for intelligent query rewriting. 73 temperature: Sampling temperature passed to the LLM (default 0.0 for 74 deterministic rewrites). Keep low — rewriting should 75 produce consistent, reproducible output. 76 """ 77 self.llm_gateway = llm_gateway 78 self.temperature = temperature
Args: llm_gateway: LLM gateway for intelligent query rewriting. temperature: Sampling temperature passed to the LLM (default 0.0 for deterministic rewrites). Keep low — rewriting should produce consistent, reproducible output.
80 async def rewrite( 81 self, 82 query: str, 83 context: Optional[str] = None, 84 ) -> RewrittenQuery: 85 """ 86 Rewrite a query for better retrieval using LLM. 87 88 Args: 89 query: The original query to improve. 90 context: Optional domain hint passed to the LLM 91 (e.g., "legal documents", "financial filings database"). 92 93 Returns: 94 RewrittenQuery with improved text and list of changes made. 95 """ 96 context_line = f"Domain context: {context}\n\n" if context else "" 97 prompt = self._REWRITE_PROMPT.format( 98 query=query, 99 context_line=context_line, 100 ) 101 102 response = await self.llm_gateway.complete( 103 prompt=prompt, 104 temperature=self.temperature, 105 max_tokens=150, 106 ) 107 108 rewritten = response.content.strip().strip('"').strip("'") 109 110 if not rewritten or rewritten.lower() == query.lower(): 111 return RewrittenQuery( 112 original=query, 113 rewritten=query, 114 changes=["No rewrite needed"], 115 ) 116 117 return RewrittenQuery( 118 original=query, 119 rewritten=rewritten, 120 changes=[f"LLM rewrote: '{query}' → '{rewritten}'"], 121 )
Rewrite a query for better retrieval using LLM.
Args: query: The original query to improve. context: Optional domain hint passed to the LLM (e.g., "legal documents", "financial filings database").
Returns: RewrittenQuery with improved text and list of changes made.
16@dataclass 17class RewrittenQuery: 18 """ 19 Result of query rewriting. 20 21 Attributes: 22 original: The original query string before rewriting. 23 rewritten: The improved query string after rewriting. 24 changes: Human-readable list of transformations applied. 25 """ 26 original: str 27 rewritten: str 28 changes: List[str] = field(default_factory=list)
Result of query rewriting.
Attributes: original: The original query string before rewriting. rewritten: The improved query string after rewriting. changes: Human-readable list of transformations applied.
41class HyDEGenerator: 42 """ 43 Hypothetical Document Embeddings (HyDE) generator. 44 45 Why this works: 46 --------------- 47 Short query strings ("antitrust cases 2024") and full answer passages live 48 in very different regions of an embedding space. A hypothetical passage that 49 ANSWERS the query occupies the same region as real answer documents, so 50 cosine similarity between the HyDE embedding and indexed document embeddings 51 is substantially higher than query-vs-document similarity. 52 53 Usage pattern: 54 -------------- 55 1. Call generate_and_embed(query) → HypotheticalDocument (with embedding set). 56 2. Feed the embedding into VectorRetriever via RetrievalQuery(embedding=...). 57 3. Compare results against standard VectorRetriever on the same query. 58 59 Example: 60 ```python 61 from gmf_forge_ai_data.query import HyDEGenerator 62 from gmf_forge_ai_data.retrieval import VectorRetriever, RetrievalQuery 63 64 hyde = HyDEGenerator(llm_gateway=gateway, embedder=embedder) 65 66 # Generate hypothetical doc and embed it 67 hypo = await hyde.generate_and_embed( 68 "What are the penalties for antitrust violations?", 69 domain="legal documents" 70 ) 71 72 # Use HyDE embedding for retrieval 73 query = RetrievalQuery(embedding=hypo.embedding, top_k=5) 74 results = vector_retriever.retrieve(query) 75 ``` 76 """ 77 78 _HYDE_PROMPT = ( 79 "Write a concise, authoritative passage that directly answers the question below.\n" 80 "Write it as if it were an excerpt from a reference document or knowledge base.\n" 81 "{domain_line}" 82 "Keep the passage under 150 words. " 83 "Do not include meta-commentary or mention that this is hypothetical.\n\n" 84 "Question: {query}\n\n" 85 "Passage:" 86 ) 87 88 def __init__( 89 self, 90 llm_gateway: UnifiedLLMGateway, 91 embedder: Optional[EmbeddingProvider] = None, 92 ): 93 """ 94 Initialize the HyDE generator. 95 96 Args: 97 llm_gateway: LLM gateway used to generate the hypothetical document. 98 embedder: Embedding provider used to vectorize the hypothetical doc. 99 Required only for generate_and_embed(); optional for generate(). 100 """ 101 self.llm_gateway = llm_gateway 102 self.embedder = embedder 103 104 async def generate( 105 self, 106 query: str, 107 domain: Optional[str] = None, 108 ) -> HypotheticalDocument: 109 """ 110 Generate a hypothetical document that would answer the query. 111 112 The returned HypotheticalDocument has embedding=None. Call 113 generate_and_embed() to also produce a vector in one step. 114 115 Args: 116 query: The retrieval query to generate a passage for. 117 domain: Optional domain hint to guide the LLM style 118 (e.g., "legal documents", "financial reports", "AI/ML knowledge base"). 119 120 Returns: 121 HypotheticalDocument with hypothetical_doc text, embedding=None. 122 """ 123 domain_line = f"Domain: {domain}\n" if domain else "" 124 prompt = self._HYDE_PROMPT.format(query=query, domain_line=domain_line) 125 126 response = await self.llm_gateway.complete( 127 prompt=prompt, 128 temperature=0.5, 129 max_tokens=200, 130 ) 131 132 return HypotheticalDocument( 133 query=query, 134 hypothetical_doc=response.content.strip(), 135 domain=domain, 136 ) 137 138 async def generate_and_embed( 139 self, 140 query: str, 141 domain: Optional[str] = None, 142 ) -> HypotheticalDocument: 143 """ 144 Generate a hypothetical document and embed it in a single step. 145 146 Calls generate() then uses the configured embedder to vectorize the 147 resulting passage. The embedding can be passed directly to VectorRetriever 148 via RetrievalQuery(embedding=result.embedding, ...). 149 150 Args: 151 query: The retrieval query. 152 domain: Optional domain hint for generation style. 153 154 Returns: 155 HypotheticalDocument with both hypothetical_doc and embedding populated. 156 157 Raises: 158 ValueError: If no embedder was provided at construction time. 159 """ 160 if not self.embedder: 161 raise ValueError( 162 "An EmbeddingProvider is required for generate_and_embed(). " 163 "Pass embedder= to HyDEGenerator.__init__()." 164 ) 165 166 result = await self.generate(query, domain) 167 result.embedding = self.embedder.embed_text(result.hypothetical_doc) 168 return result
Hypothetical Document Embeddings (HyDE) generator.
Why this works:
Short query strings ("antitrust cases 2024") and full answer passages live in very different regions of an embedding space. A hypothetical passage that ANSWERS the query occupies the same region as real answer documents, so cosine similarity between the HyDE embedding and indexed document embeddings is substantially higher than query-vs-document similarity.
Usage pattern:
- Call generate_and_embed(query) → HypotheticalDocument (with embedding set).
- Feed the embedding into VectorRetriever via RetrievalQuery(embedding=...).
- Compare results against standard VectorRetriever on the same query.
Example:
from gmf_forge_ai_data.query import HyDEGenerator
from gmf_forge_ai_data.retrieval import VectorRetriever, RetrievalQuery
hyde = HyDEGenerator(llm_gateway=gateway, embedder=embedder)
# Generate hypothetical doc and embed it
hypo = await hyde.generate_and_embed(
"What are the penalties for antitrust violations?",
domain="legal documents"
)
# Use HyDE embedding for retrieval
query = RetrievalQuery(embedding=hypo.embedding, top_k=5)
results = vector_retriever.retrieve(query)
88 def __init__( 89 self, 90 llm_gateway: UnifiedLLMGateway, 91 embedder: Optional[EmbeddingProvider] = None, 92 ): 93 """ 94 Initialize the HyDE generator. 95 96 Args: 97 llm_gateway: LLM gateway used to generate the hypothetical document. 98 embedder: Embedding provider used to vectorize the hypothetical doc. 99 Required only for generate_and_embed(); optional for generate(). 100 """ 101 self.llm_gateway = llm_gateway 102 self.embedder = embedder
Initialize the HyDE generator.
Args: llm_gateway: LLM gateway used to generate the hypothetical document. embedder: Embedding provider used to vectorize the hypothetical doc. Required only for generate_and_embed(); optional for generate().
104 async def generate( 105 self, 106 query: str, 107 domain: Optional[str] = None, 108 ) -> HypotheticalDocument: 109 """ 110 Generate a hypothetical document that would answer the query. 111 112 The returned HypotheticalDocument has embedding=None. Call 113 generate_and_embed() to also produce a vector in one step. 114 115 Args: 116 query: The retrieval query to generate a passage for. 117 domain: Optional domain hint to guide the LLM style 118 (e.g., "legal documents", "financial reports", "AI/ML knowledge base"). 119 120 Returns: 121 HypotheticalDocument with hypothetical_doc text, embedding=None. 122 """ 123 domain_line = f"Domain: {domain}\n" if domain else "" 124 prompt = self._HYDE_PROMPT.format(query=query, domain_line=domain_line) 125 126 response = await self.llm_gateway.complete( 127 prompt=prompt, 128 temperature=0.5, 129 max_tokens=200, 130 ) 131 132 return HypotheticalDocument( 133 query=query, 134 hypothetical_doc=response.content.strip(), 135 domain=domain, 136 )
Generate a hypothetical document that would answer the query.
The returned HypotheticalDocument has embedding=None. Call generate_and_embed() to also produce a vector in one step.
Args: query: The retrieval query to generate a passage for. domain: Optional domain hint to guide the LLM style (e.g., "legal documents", "financial reports", "AI/ML knowledge base").
Returns: HypotheticalDocument with hypothetical_doc text, embedding=None.
138 async def generate_and_embed( 139 self, 140 query: str, 141 domain: Optional[str] = None, 142 ) -> HypotheticalDocument: 143 """ 144 Generate a hypothetical document and embed it in a single step. 145 146 Calls generate() then uses the configured embedder to vectorize the 147 resulting passage. The embedding can be passed directly to VectorRetriever 148 via RetrievalQuery(embedding=result.embedding, ...). 149 150 Args: 151 query: The retrieval query. 152 domain: Optional domain hint for generation style. 153 154 Returns: 155 HypotheticalDocument with both hypothetical_doc and embedding populated. 156 157 Raises: 158 ValueError: If no embedder was provided at construction time. 159 """ 160 if not self.embedder: 161 raise ValueError( 162 "An EmbeddingProvider is required for generate_and_embed(). " 163 "Pass embedder= to HyDEGenerator.__init__()." 164 ) 165 166 result = await self.generate(query, domain) 167 result.embedding = self.embedder.embed_text(result.hypothetical_doc) 168 return result
Generate a hypothetical document and embed it in a single step.
Calls generate() then uses the configured embedder to vectorize the resulting passage. The embedding can be passed directly to VectorRetriever via RetrievalQuery(embedding=result.embedding, ...).
Args: query: The retrieval query. domain: Optional domain hint for generation style.
Returns: HypotheticalDocument with both hypothetical_doc and embedding populated.
Raises: ValueError: If no embedder was provided at construction time.
24@dataclass 25class HypotheticalDocument: 26 """ 27 Result of HyDE generation. 28 29 Attributes: 30 query: The original retrieval query. 31 hypothetical_doc: LLM-generated passage that would answer the query. 32 embedding: Vector embedding of hypothetical_doc (None until embedded). 33 domain: Optional domain hint that was passed during generation. 34 """ 35 query: str 36 hypothetical_doc: str 37 embedding: Optional[List[float]] = None 38 domain: Optional[str] = None
Result of HyDE generation.
Attributes: query: The original retrieval query. hypothetical_doc: LLM-generated passage that would answer the query. embedding: Vector embedding of hypothetical_doc (None until embedded). domain: Optional domain hint that was passed during generation.
89class DocumentIntelligenceLayout: 90 """ 91 Wraps the Azure Document Intelligence ``prebuilt-layout`` model. 92 93 Converts PDF, DOCX, PPTX, XLSX, and image files into structured markdown 94 that preserves headings, tables, lists, and page structure. The markdown 95 output is designed to be chunked with ``MarkdownChunker``. 96 97 Typical usage:: 98 99 from gmf_forge_ai_data.layout import DocumentIntelligenceLayout 100 from gmf_forge_ai_data.chunkers import MarkdownChunker 101 102 # Managed identity (omit api_key) — required for Multiservices accounts 103 layout = DocumentIntelligenceLayout( 104 endpoint="https://my-resource.cognitiveservices.azure.com", 105 ) 106 107 # API key — for standalone Document Intelligence resources 108 layout = DocumentIntelligenceLayout( 109 endpoint="https://my-resource.cognitiveservices.azure.com", 110 api_key="your-api-key", 111 ) 112 113 # From a local file 114 result = layout.analyze_file("annual_report.pdf") 115 116 # From raw bytes (e.g. downloaded from BlobStorageConnector) 117 result = layout.analyze_bytes(blob_content, filename="report.pdf") 118 119 # From a URL 120 result = layout.analyze_url("https://example.com/policy.pdf") 121 122 # Chunk the markdown directly 123 chunker = MarkdownChunker(max_chunk_size=1500, min_header_level=2) 124 chunks = chunker.chunk(result.markdown, metadata=result.metadata) 125 126 Args: 127 endpoint: Azure Document Intelligence service endpoint URL. 128 api_key: Azure Document Intelligence API key. When provided, 129 ``AzureKeyCredential`` is used. When omitted (``None``), 130 ``DefaultAzureCredential`` is used instead — required for 131 Cognitive Services Multiservices accounts that do not 132 expose API keys. 133 model_id: Model to use (default: ``prebuilt-layout``). 134 api_version: REST API version (default: ``2024-11-30``). 135 logger: Optional ``BasicLogger`` instance for structured logging. 136 137 Raises: 138 ImportError: If ``azure-ai-documentintelligence`` is not installed, or 139 if ``api_key`` is omitted and ``azure-identity`` is not 140 installed. 141 ValueError: If ``endpoint`` is empty. 142 """ 143 144 _DEFAULT_MODEL = "prebuilt-layout" 145 _DEFAULT_API_VERSION = "2024-11-30" 146 147 def __init__( 148 self, 149 endpoint: str, 150 api_key: Optional[str] = None, 151 model_id: str = _DEFAULT_MODEL, 152 api_version: str = _DEFAULT_API_VERSION, 153 logger: Optional[BasicLogger] = None, 154 ) -> None: 155 if not _SDK_AVAILABLE: 156 raise ImportError( 157 "azure-ai-documentintelligence is required: " 158 "pip install azure-ai-documentintelligence" 159 ) 160 if not endpoint or not endpoint.strip(): 161 raise ValueError("endpoint must not be empty") 162 163 self.endpoint = endpoint.rstrip("/") 164 self.model_id = model_id 165 self.api_version = api_version 166 self.logger = logger or _logger 167 168 if api_key: 169 # Standalone Document Intelligence resource with key-based auth enabled 170 credential = AzureKeyCredential(api_key) 171 auth_method = "api_key" 172 else: 173 # Cognitive Services Multiservices account — no API key exposed. 174 # DefaultAzureCredential resolves the auth chain automatically: 175 # In Azure (AKS, VM, App Service): managed identity / workload identity 176 # Locally: az login or VS Code Azure account extension 177 if not _IDENTITY_AVAILABLE: 178 raise ImportError( 179 "azure-identity is required when api_key is not provided: " 180 "pip install azure-identity" 181 ) 182 credential = DefaultAzureCredential() 183 auth_method = "managed_identity" 184 185 self._client = DocumentIntelligenceClient( 186 endpoint=self.endpoint, 187 credential=credential, 188 api_version=self.api_version, 189 ) 190 191 self.logger.info( 192 "DocumentIntelligenceLayout initialised", 193 endpoint=self.endpoint, 194 model_id=self.model_id, 195 api_version=self.api_version, 196 auth=auth_method, 197 ) 198 199 # ------------------------------------------------------------------ 200 # Public API 201 # ------------------------------------------------------------------ 202 203 def analyze_file(self, file_path: str | Path) -> LayoutResult: 204 """ 205 Analyse a local file and return its content as markdown. 206 207 Supported file types: PDF, DOCX, PPTX, XLSX, JPEG, PNG, BMP, TIFF, HEIF. 208 209 Args: 210 file_path: Path to the local document file. 211 212 Returns: 213 :class:`LayoutResult` with markdown content and metadata. 214 215 Raises: 216 FileNotFoundError: If the file does not exist. 217 ValueError: If the file is empty. 218 """ 219 path = Path(file_path) 220 if not path.exists(): 221 raise FileNotFoundError(f"File not found: {file_path}") 222 if path.stat().st_size == 0: 223 raise ValueError(f"File is empty: {file_path}") 224 225 self.logger.info("Analysing file", file=str(path), model_id=self.model_id) 226 227 with open(path, "rb") as fh: 228 content = fh.read() 229 230 result = self._analyze_bytes_content(content) 231 result.metadata.update({ 232 "source": str(path.resolve()), 233 "file_name": path.name, 234 }) 235 return result 236 237 def analyze_bytes(self, content: bytes, filename: str = "") -> LayoutResult: 238 """ 239 Analyse raw bytes and return content as markdown. 240 241 Use this when you already have the document bytes in memory — for 242 example, content downloaded via ``BlobStorageConnector`` or 243 ``SharePointConnector``. 244 245 Args: 246 content: Raw document bytes. 247 filename: Original filename (used for metadata only, e.g. "report.pdf"). 248 249 Returns: 250 :class:`LayoutResult` with markdown content and metadata. 251 252 Raises: 253 ValueError: If ``content`` is empty. 254 """ 255 if not content: 256 raise ValueError("content must not be empty") 257 258 content_hash = hashlib.sha256(content).hexdigest()[:12] 259 source = f"bytes:{content_hash}" 260 261 self.logger.info( 262 "Analysing bytes", 263 size_bytes=len(content), 264 filename=filename or "(unnamed)", 265 model_id=self.model_id, 266 ) 267 268 result = self._analyze_bytes_content(content) 269 result.metadata.update({ 270 "source": source, 271 "file_name": filename, 272 }) 273 return result 274 275 def analyze_url(self, url: str) -> LayoutResult: 276 """ 277 Analyse a document at a publicly accessible URL. 278 279 The Azure Document Intelligence service fetches the document directly 280 from the URL — the bytes never pass through your application. 281 282 Args: 283 url: Publicly accessible URL pointing to a supported document. 284 285 Returns: 286 :class:`LayoutResult` with markdown content and metadata. 287 288 Raises: 289 ValueError: If ``url`` is empty. 290 """ 291 if not url or not url.strip(): 292 raise ValueError("url must not be empty") 293 294 self.logger.info("Analysing URL", url=url, model_id=self.model_id) 295 296 poller = self._client.begin_analyze_document( 297 self.model_id, 298 AnalyzeDocumentRequest(url_source=url), 299 output_content_format=DocumentContentFormat.MARKDOWN, 300 features=[DocumentAnalysisFeature.QUERY_FIELDS], 301 ) 302 response = poller.result() 303 304 markdown = response.content or "" 305 page_count = len(response.pages) if response.pages else 0 306 307 self.logger.info( 308 "URL analysis complete", 309 url=url, 310 page_count=page_count, 311 markdown_length=len(markdown), 312 ) 313 314 return LayoutResult( 315 markdown=markdown, 316 page_count=page_count, 317 metadata={ 318 "source": url, 319 "file_name": url.split("/")[-1], 320 "model_id": self.model_id, 321 "analyzed_at": datetime.now(timezone.utc).isoformat(), 322 }, 323 ) 324 325 # ------------------------------------------------------------------ 326 # Internal helpers 327 # ------------------------------------------------------------------ 328 329 def _analyze_bytes_content(self, content: bytes) -> LayoutResult: 330 """Send raw bytes to the Document Intelligence service and return LayoutResult.""" 331 poller = self._client.begin_analyze_document( 332 self.model_id, 333 AnalyzeDocumentRequest(bytes_source=content), 334 output_content_format=DocumentContentFormat.MARKDOWN, 335 features=[DocumentAnalysisFeature.QUERY_FIELDS], 336 ) 337 response = poller.result() 338 339 markdown = response.content or "" 340 page_count = len(response.pages) if response.pages else 0 341 342 self.logger.info( 343 "Analysis complete", 344 model_id=self.model_id, 345 page_count=page_count, 346 markdown_length=len(markdown), 347 ) 348 349 return LayoutResult( 350 markdown=markdown, 351 page_count=page_count, 352 metadata={ 353 "model_id": self.model_id, 354 "analyzed_at": datetime.now(timezone.utc).isoformat(), 355 }, 356 )
Wraps the Azure Document Intelligence prebuilt-layout model.
Converts PDF, DOCX, PPTX, XLSX, and image files into structured markdown
that preserves headings, tables, lists, and page structure. The markdown
output is designed to be chunked with MarkdownChunker.
Typical usage::
from gmf_forge_ai_data.layout import DocumentIntelligenceLayout
from gmf_forge_ai_data.chunkers import MarkdownChunker
# Managed identity (omit api_key) — required for Multiservices accounts
layout = DocumentIntelligenceLayout(
endpoint="https://my-resource.cognitiveservices.azure.com",
)
# API key — for standalone Document Intelligence resources
layout = DocumentIntelligenceLayout(
endpoint="https://my-resource.cognitiveservices.azure.com",
api_key="your-api-key",
)
# From a local file
result = layout.analyze_file("annual_report.pdf")
# From raw bytes (e.g. downloaded from BlobStorageConnector)
result = layout.analyze_bytes(blob_content, filename="report.pdf")
# From a URL
result = layout.analyze_url("https://example.com/policy.pdf")
# Chunk the markdown directly
chunker = MarkdownChunker(max_chunk_size=1500, min_header_level=2)
chunks = chunker.chunk(result.markdown, metadata=result.metadata)
Args:
endpoint: Azure Document Intelligence service endpoint URL.
api_key: Azure Document Intelligence API key. When provided,
AzureKeyCredential is used. When omitted (None),
DefaultAzureCredential is used instead — required for
Cognitive Services Multiservices accounts that do not
expose API keys.
model_id: Model to use (default: prebuilt-layout).
api_version: REST API version (default: 2024-11-30).
logger: Optional BasicLogger instance for structured logging.
Raises:
ImportError: If azure-ai-documentintelligence is not installed, or
if api_key is omitted and azure-identity is not
installed.
ValueError: If endpoint is empty.
147 def __init__( 148 self, 149 endpoint: str, 150 api_key: Optional[str] = None, 151 model_id: str = _DEFAULT_MODEL, 152 api_version: str = _DEFAULT_API_VERSION, 153 logger: Optional[BasicLogger] = None, 154 ) -> None: 155 if not _SDK_AVAILABLE: 156 raise ImportError( 157 "azure-ai-documentintelligence is required: " 158 "pip install azure-ai-documentintelligence" 159 ) 160 if not endpoint or not endpoint.strip(): 161 raise ValueError("endpoint must not be empty") 162 163 self.endpoint = endpoint.rstrip("/") 164 self.model_id = model_id 165 self.api_version = api_version 166 self.logger = logger or _logger 167 168 if api_key: 169 # Standalone Document Intelligence resource with key-based auth enabled 170 credential = AzureKeyCredential(api_key) 171 auth_method = "api_key" 172 else: 173 # Cognitive Services Multiservices account — no API key exposed. 174 # DefaultAzureCredential resolves the auth chain automatically: 175 # In Azure (AKS, VM, App Service): managed identity / workload identity 176 # Locally: az login or VS Code Azure account extension 177 if not _IDENTITY_AVAILABLE: 178 raise ImportError( 179 "azure-identity is required when api_key is not provided: " 180 "pip install azure-identity" 181 ) 182 credential = DefaultAzureCredential() 183 auth_method = "managed_identity" 184 185 self._client = DocumentIntelligenceClient( 186 endpoint=self.endpoint, 187 credential=credential, 188 api_version=self.api_version, 189 ) 190 191 self.logger.info( 192 "DocumentIntelligenceLayout initialised", 193 endpoint=self.endpoint, 194 model_id=self.model_id, 195 api_version=self.api_version, 196 auth=auth_method, 197 )
203 def analyze_file(self, file_path: str | Path) -> LayoutResult: 204 """ 205 Analyse a local file and return its content as markdown. 206 207 Supported file types: PDF, DOCX, PPTX, XLSX, JPEG, PNG, BMP, TIFF, HEIF. 208 209 Args: 210 file_path: Path to the local document file. 211 212 Returns: 213 :class:`LayoutResult` with markdown content and metadata. 214 215 Raises: 216 FileNotFoundError: If the file does not exist. 217 ValueError: If the file is empty. 218 """ 219 path = Path(file_path) 220 if not path.exists(): 221 raise FileNotFoundError(f"File not found: {file_path}") 222 if path.stat().st_size == 0: 223 raise ValueError(f"File is empty: {file_path}") 224 225 self.logger.info("Analysing file", file=str(path), model_id=self.model_id) 226 227 with open(path, "rb") as fh: 228 content = fh.read() 229 230 result = self._analyze_bytes_content(content) 231 result.metadata.update({ 232 "source": str(path.resolve()), 233 "file_name": path.name, 234 }) 235 return result
Analyse a local file and return its content as markdown.
Supported file types: PDF, DOCX, PPTX, XLSX, JPEG, PNG, BMP, TIFF, HEIF.
Args: file_path: Path to the local document file.
Returns:
LayoutResult with markdown content and metadata.
Raises: FileNotFoundError: If the file does not exist. ValueError: If the file is empty.
237 def analyze_bytes(self, content: bytes, filename: str = "") -> LayoutResult: 238 """ 239 Analyse raw bytes and return content as markdown. 240 241 Use this when you already have the document bytes in memory — for 242 example, content downloaded via ``BlobStorageConnector`` or 243 ``SharePointConnector``. 244 245 Args: 246 content: Raw document bytes. 247 filename: Original filename (used for metadata only, e.g. "report.pdf"). 248 249 Returns: 250 :class:`LayoutResult` with markdown content and metadata. 251 252 Raises: 253 ValueError: If ``content`` is empty. 254 """ 255 if not content: 256 raise ValueError("content must not be empty") 257 258 content_hash = hashlib.sha256(content).hexdigest()[:12] 259 source = f"bytes:{content_hash}" 260 261 self.logger.info( 262 "Analysing bytes", 263 size_bytes=len(content), 264 filename=filename or "(unnamed)", 265 model_id=self.model_id, 266 ) 267 268 result = self._analyze_bytes_content(content) 269 result.metadata.update({ 270 "source": source, 271 "file_name": filename, 272 }) 273 return result
Analyse raw bytes and return content as markdown.
Use this when you already have the document bytes in memory — for
example, content downloaded via BlobStorageConnector or
SharePointConnector.
Args: content: Raw document bytes. filename: Original filename (used for metadata only, e.g. "report.pdf").
Returns:
LayoutResult with markdown content and metadata.
Raises:
ValueError: If content is empty.
275 def analyze_url(self, url: str) -> LayoutResult: 276 """ 277 Analyse a document at a publicly accessible URL. 278 279 The Azure Document Intelligence service fetches the document directly 280 from the URL — the bytes never pass through your application. 281 282 Args: 283 url: Publicly accessible URL pointing to a supported document. 284 285 Returns: 286 :class:`LayoutResult` with markdown content and metadata. 287 288 Raises: 289 ValueError: If ``url`` is empty. 290 """ 291 if not url or not url.strip(): 292 raise ValueError("url must not be empty") 293 294 self.logger.info("Analysing URL", url=url, model_id=self.model_id) 295 296 poller = self._client.begin_analyze_document( 297 self.model_id, 298 AnalyzeDocumentRequest(url_source=url), 299 output_content_format=DocumentContentFormat.MARKDOWN, 300 features=[DocumentAnalysisFeature.QUERY_FIELDS], 301 ) 302 response = poller.result() 303 304 markdown = response.content or "" 305 page_count = len(response.pages) if response.pages else 0 306 307 self.logger.info( 308 "URL analysis complete", 309 url=url, 310 page_count=page_count, 311 markdown_length=len(markdown), 312 ) 313 314 return LayoutResult( 315 markdown=markdown, 316 page_count=page_count, 317 metadata={ 318 "source": url, 319 "file_name": url.split("/")[-1], 320 "model_id": self.model_id, 321 "analyzed_at": datetime.now(timezone.utc).isoformat(), 322 }, 323 )
Analyse a document at a publicly accessible URL.
The Azure Document Intelligence service fetches the document directly from the URL — the bytes never pass through your application.
Args: url: Publicly accessible URL pointing to a supported document.
Returns:
LayoutResult with markdown content and metadata.
Raises:
ValueError: If url is empty.
66@dataclass 67class LayoutResult: 68 """ 69 Result of a Document Intelligence layout analysis. 70 71 Attributes: 72 markdown: Full document content as markdown. Headers, tables, lists, 73 page breaks (``<!-- PageBreak -->``) and figure captions 74 (``<!-- FigureCaption -->``) are preserved exactly as 75 produced by Azure Document Intelligence. Pass directly to 76 ``MarkdownChunker`` for header-aware chunking. 77 page_count: Number of pages in the analysed document. 78 metadata: Source information and analysis details: 79 ``source`` — file path, URL, or "bytes:<hash>" 80 ``file_name`` — basename of the source file (if known) 81 ``model_id`` — Document Intelligence model used 82 ``analyzed_at`` — ISO-8601 UTC timestamp of the analysis 83 """ 84 markdown: str 85 page_count: int 86 metadata: Dict[str, Any] = field(default_factory=dict)
Result of a Document Intelligence layout analysis.
Attributes:
markdown: Full document content as markdown. Headers, tables, lists,
page breaks (<!-- PageBreak -->) and figure captions
(<!-- FigureCaption -->) are preserved exactly as
produced by Azure Document Intelligence. Pass directly to
MarkdownChunker for header-aware chunking.
page_count: Number of pages in the analysed document.
metadata: Source information and analysis details:
source — file path, URL, or "bytes:file_name — basename of the source file (if known)
model_id — Document Intelligence model used
analyzed_at — ISO-8601 UTC timestamp of the analysis