gmf_forge_ai_data.retrieval
Retrieval strategies for GMF Forge AI Data Layer.
This module provides various retrieval strategies for finding relevant documents from vector stores. All retrievers implement the BaseRetriever interface.
Available Retrievers:
Basic Retrievers (wrappers around vector store search):
- VectorRetriever: Pure vector similarity search
- KeywordRetriever: Pure keyword/BM25 search
- HybridRetriever: Combined vector + keyword search
Advanced Retrievers (production patterns):
- MMRRetriever: Maximal Marginal Relevance for diverse results
- ParentDocumentRetriever: Search chunks, return parent documents
- EnsembleRetriever: Combine multiple retrievers with score fusion
- HierarchicalRetriever: Two-stage retrieval (coarse → fine)
- GraphRetriever: Graph-based retrieval with entity relationships
- SQLRetriever: Structured data retrieval via SQL
- MultiIndexRetriever: Retrieve from multiple indices/sources
Core Classes:
- BaseRetriever: Abstract interface for all retrievers
- RetrievalQuery: Container for query parameters
1""" 2Retrieval strategies for GMF Forge AI Data Layer. 3 4This module provides various retrieval strategies for finding relevant documents 5from vector stores. All retrievers implement the BaseRetriever interface. 6 7Available Retrievers: 8 9**Basic Retrievers** (wrappers around vector store search): 10- VectorRetriever: Pure vector similarity search 11- KeywordRetriever: Pure keyword/BM25 search 12- HybridRetriever: Combined vector + keyword search 13 14**Advanced Retrievers** (production patterns): 15- MMRRetriever: Maximal Marginal Relevance for diverse results 16- ParentDocumentRetriever: Search chunks, return parent documents 17- EnsembleRetriever: Combine multiple retrievers with score fusion 18- HierarchicalRetriever: Two-stage retrieval (coarse → fine) 19- GraphRetriever: Graph-based retrieval with entity relationships 20- SQLRetriever: Structured data retrieval via SQL 21- MultiIndexRetriever: Retrieve from multiple indices/sources 22 23Core Classes: 24- BaseRetriever: Abstract interface for all retrievers 25- RetrievalQuery: Container for query parameters 26""" 27 28from .base_retriever import BaseRetriever, RetrievalQuery 29from .basic_retrievers import VectorRetriever, KeywordRetriever, HybridRetriever 30from .mmr_retriever import MMRRetriever 31from .parent_document_retriever import ParentDocumentRetriever 32from .ensemble_retriever import EnsembleRetriever 33from .hierarchical_retriever import HierarchicalRetriever 34from .graph_retriever import GraphRetriever 35from .sql_retriever import SQLRetriever, SQLSchema, SQLQuery 36from .multi_index_retriever import MultiIndexRetriever, SourceConfig 37 38__all__ = [ 39 # Base classes 40 "BaseRetriever", 41 "RetrievalQuery", 42 43 # Basic retrievers 44 "VectorRetriever", 45 "KeywordRetriever", 46 "HybridRetriever", 47 48 # Advanced retrievers 49 "MMRRetriever", 50 "ParentDocumentRetriever", 51 "EnsembleRetriever", 52 "HierarchicalRetriever", 53 "GraphRetriever", 54 "SQLRetriever", 55 "MultiIndexRetriever", 56 57 # Helper classes 58 "SQLSchema", 59 "SQLQuery", 60 "SourceConfig", 61] 62 63__version__ = "1.0.0"
35class BaseRetriever(ABC): 36 """ 37 Abstract base class for all retriever implementations. 38 39 Retrievers provide various strategies for finding relevant documents 40 from a vector store or other data source. 41 42 All implementations must provide: 43 - retrieve(): Main retrieval method 44 """ 45 46 @abstractmethod 47 def retrieve( 48 self, 49 query: RetrievalQuery 50 ) -> List[SearchResult]: 51 """ 52 Retrieve relevant documents based on the query. 53 54 Args: 55 query: RetrievalQuery containing query parameters 56 57 Returns: 58 List of SearchResult objects, ordered by relevance 59 60 Raises: 61 ValueError: If required query parameters are missing 62 """ 63 pass 64 65 def retrieve_text( 66 self, 67 text: str, 68 top_k: int = 5, 69 filters: Optional[Dict[str, Any]] = None 70 ) -> List[SearchResult]: 71 """ 72 Convenience method for text-based retrieval. 73 74 Args: 75 text: Query text 76 top_k: Number of results to retrieve 77 filters: Optional metadata filters 78 79 Returns: 80 List of SearchResult objects 81 """ 82 query = RetrievalQuery( 83 text=text, 84 top_k=top_k, 85 filters=filters 86 ) 87 return self.retrieve(query) 88 89 def retrieve_embedding( 90 self, 91 embedding: List[float], 92 top_k: int = 5, 93 filters: Optional[Dict[str, Any]] = None 94 ) -> List[SearchResult]: 95 """ 96 Convenience method for embedding-based retrieval. 97 98 Args: 99 embedding: Query embedding vector 100 top_k: Number of results to retrieve 101 filters: Optional metadata filters 102 103 Returns: 104 List of SearchResult objects 105 """ 106 query = RetrievalQuery( 107 embedding=embedding, 108 top_k=top_k, 109 filters=filters 110 ) 111 return self.retrieve(query)
Abstract base class for all retriever implementations.
Retrievers provide various strategies for finding relevant documents from a vector store or other data source.
All implementations must provide:
- retrieve(): Main retrieval method
46 @abstractmethod 47 def retrieve( 48 self, 49 query: RetrievalQuery 50 ) -> List[SearchResult]: 51 """ 52 Retrieve relevant documents based on the query. 53 54 Args: 55 query: RetrievalQuery containing query parameters 56 57 Returns: 58 List of SearchResult objects, ordered by relevance 59 60 Raises: 61 ValueError: If required query parameters are missing 62 """ 63 pass
Retrieve relevant documents based on the query.
Args: query: RetrievalQuery containing query parameters
Returns: List of SearchResult objects, ordered by relevance
Raises: ValueError: If required query parameters are missing
65 def retrieve_text( 66 self, 67 text: str, 68 top_k: int = 5, 69 filters: Optional[Dict[str, Any]] = None 70 ) -> List[SearchResult]: 71 """ 72 Convenience method for text-based retrieval. 73 74 Args: 75 text: Query text 76 top_k: Number of results to retrieve 77 filters: Optional metadata filters 78 79 Returns: 80 List of SearchResult objects 81 """ 82 query = RetrievalQuery( 83 text=text, 84 top_k=top_k, 85 filters=filters 86 ) 87 return self.retrieve(query)
Convenience method for text-based retrieval.
Args: text: Query text top_k: Number of results to retrieve filters: Optional metadata filters
Returns: List of SearchResult objects
89 def retrieve_embedding( 90 self, 91 embedding: List[float], 92 top_k: int = 5, 93 filters: Optional[Dict[str, Any]] = None 94 ) -> List[SearchResult]: 95 """ 96 Convenience method for embedding-based retrieval. 97 98 Args: 99 embedding: Query embedding vector 100 top_k: Number of results to retrieve 101 filters: Optional metadata filters 102 103 Returns: 104 List of SearchResult objects 105 """ 106 query = RetrievalQuery( 107 embedding=embedding, 108 top_k=top_k, 109 filters=filters 110 ) 111 return self.retrieve(query)
Convenience method for embedding-based retrieval.
Args: embedding: Query embedding vector top_k: Number of results to retrieve filters: Optional metadata filters
Returns: List of SearchResult objects
16@dataclass 17class RetrievalQuery: 18 """ 19 Represents a retrieval query with optional query text and embedding. 20 21 Attributes: 22 text: Query text (required for keyword/hybrid search) 23 embedding: Query embedding vector (required for vector search) 24 top_k: Number of results to retrieve 25 filters: Metadata filters to apply 26 metadata: Additional query metadata 27 """ 28 text: Optional[str] = None 29 embedding: Optional[List[float]] = None 30 top_k: int = 5 31 filters: Optional[Dict[str, Any]] = None 32 metadata: Optional[Dict[str, Any]] = None
Represents a retrieval query with optional query text and embedding.
Attributes: text: Query text (required for keyword/hybrid search) embedding: Query embedding vector (required for vector search) top_k: Number of results to retrieve filters: Metadata filters to apply metadata: Additional query metadata
17class VectorRetriever(BaseRetriever): 18 """ 19 Vector similarity retriever using cosine similarity. 20 21 Performs pure vector search using embeddings. Requires query embeddings to be 22 provided externally (e.g., using an embeddings provider). 23 24 Example: 25 ```python 26 from gmf_forge_ai_data.retrieval import VectorRetriever, RetrievalQuery 27 from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings 28 29 # Setup 30 embedder = AzureOpenAIEmbeddings(...) 31 retriever = VectorRetriever(vector_store) 32 33 # Retrieve 34 query_text = "What is machine learning?" 35 query_embedding = embedder.embed_text(query_text) 36 37 results = retriever.retrieve_embedding( 38 embedding=query_embedding, 39 top_k=5 40 ) 41 ``` 42 """ 43 44 def __init__(self, vector_store: BaseVectorStore): 45 """ 46 Initialize vector retriever. 47 48 Args: 49 vector_store: Vector store to search 50 """ 51 self.vector_store = vector_store 52 53 def retrieve(self, query: RetrievalQuery) -> List[SearchResult]: 54 """ 55 Retrieve documents using vector similarity search. 56 57 Args: 58 query: RetrievalQuery with embedding and parameters 59 60 Returns: 61 List of SearchResult objects ordered by similarity 62 63 Raises: 64 ValueError: If query.embedding is None 65 """ 66 if query.embedding is None: 67 raise ValueError("VectorRetriever requires query.embedding") 68 69 return self.vector_store.search( 70 query_embedding=query.embedding, 71 top_k=query.top_k, 72 filters=query.filters, 73 search_type="vector" 74 )
Vector similarity retriever using cosine similarity.
Performs pure vector search using embeddings. Requires query embeddings to be provided externally (e.g., using an embeddings provider).
Example:
from gmf_forge_ai_data.retrieval import VectorRetriever, RetrievalQuery
from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings
# Setup
embedder = AzureOpenAIEmbeddings(...)
retriever = VectorRetriever(vector_store)
# Retrieve
query_text = "What is machine learning?"
query_embedding = embedder.embed_text(query_text)
results = retriever.retrieve_embedding(
embedding=query_embedding,
top_k=5
)
44 def __init__(self, vector_store: BaseVectorStore): 45 """ 46 Initialize vector retriever. 47 48 Args: 49 vector_store: Vector store to search 50 """ 51 self.vector_store = vector_store
Initialize vector retriever.
Args: vector_store: Vector store to search
53 def retrieve(self, query: RetrievalQuery) -> List[SearchResult]: 54 """ 55 Retrieve documents using vector similarity search. 56 57 Args: 58 query: RetrievalQuery with embedding and parameters 59 60 Returns: 61 List of SearchResult objects ordered by similarity 62 63 Raises: 64 ValueError: If query.embedding is None 65 """ 66 if query.embedding is None: 67 raise ValueError("VectorRetriever requires query.embedding") 68 69 return self.vector_store.search( 70 query_embedding=query.embedding, 71 top_k=query.top_k, 72 filters=query.filters, 73 search_type="vector" 74 )
Retrieve documents using vector similarity search.
Args: query: RetrievalQuery with embedding and parameters
Returns: List of SearchResult objects ordered by similarity
Raises: ValueError: If query.embedding is None
77class KeywordRetriever(BaseRetriever): 78 """ 79 Keyword/BM25 retriever using text matching. 80 81 Performs traditional keyword-based search without using embeddings. 82 Uses BM25 for Azure AI Search, Jaccard similarity for in-memory. 83 84 Example: 85 ```python 86 from gmf_forge_ai_data.retrieval import KeywordRetriever, RetrievalQuery 87 88 retriever = KeywordRetriever(vector_store) 89 90 results = retriever.retrieve_text( 91 text="machine learning algorithms", 92 top_k=5 93 ) 94 ``` 95 """ 96 97 def __init__(self, vector_store: BaseVectorStore): 98 """ 99 Initialize keyword retriever. 100 101 Args: 102 vector_store: Vector store to search 103 """ 104 self.vector_store = vector_store 105 106 def retrieve(self, query: RetrievalQuery) -> List[SearchResult]: 107 """ 108 Retrieve documents using keyword search. 109 110 Args: 111 query: RetrievalQuery with text and parameters 112 113 Returns: 114 List of SearchResult objects ordered by keyword relevance 115 116 Raises: 117 ValueError: If query.text is None 118 """ 119 if query.text is None: 120 raise ValueError("KeywordRetriever requires query.text") 121 122 return self.vector_store.search( 123 query=query.text, 124 top_k=query.top_k, 125 filters=query.filters, 126 search_type="keyword" 127 )
Keyword/BM25 retriever using text matching.
Performs traditional keyword-based search without using embeddings. Uses BM25 for Azure AI Search, Jaccard similarity for in-memory.
Example:
from gmf_forge_ai_data.retrieval import KeywordRetriever, RetrievalQuery
retriever = KeywordRetriever(vector_store)
results = retriever.retrieve_text(
text="machine learning algorithms",
top_k=5
)
97 def __init__(self, vector_store: BaseVectorStore): 98 """ 99 Initialize keyword retriever. 100 101 Args: 102 vector_store: Vector store to search 103 """ 104 self.vector_store = vector_store
Initialize keyword retriever.
Args: vector_store: Vector store to search
106 def retrieve(self, query: RetrievalQuery) -> List[SearchResult]: 107 """ 108 Retrieve documents using keyword search. 109 110 Args: 111 query: RetrievalQuery with text and parameters 112 113 Returns: 114 List of SearchResult objects ordered by keyword relevance 115 116 Raises: 117 ValueError: If query.text is None 118 """ 119 if query.text is None: 120 raise ValueError("KeywordRetriever requires query.text") 121 122 return self.vector_store.search( 123 query=query.text, 124 top_k=query.top_k, 125 filters=query.filters, 126 search_type="keyword" 127 )
Retrieve documents using keyword search.
Args: query: RetrievalQuery with text and parameters
Returns: List of SearchResult objects ordered by keyword relevance
Raises: ValueError: If query.text is None
130class HybridRetriever(BaseRetriever): 131 """ 132 Hybrid retriever combining vector and keyword search. 133 134 Combines vector similarity with keyword matching for comprehensive retrieval. 135 Scores are normalized and combined (typically 50/50 weighting). 136 137 Example: 138 ```python 139 from gmf_forge_ai_data.retrieval import HybridRetriever, RetrievalQuery 140 from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings 141 142 embedder = AzureOpenAIEmbeddings(...) 143 retriever = HybridRetriever(vector_store) 144 145 query_text = "machine learning" 146 query_embedding = embedder.embed_text(query_text) 147 148 # Hybrid search using both text and embedding 149 query = RetrievalQuery( 150 text=query_text, 151 embedding=query_embedding, 152 top_k=5 153 ) 154 results = retriever.retrieve(query) 155 ``` 156 """ 157 158 def __init__(self, vector_store: BaseVectorStore): 159 """ 160 Initialize hybrid retriever. 161 162 Args: 163 vector_store: Vector store to search 164 """ 165 self.vector_store = vector_store 166 167 def retrieve(self, query: RetrievalQuery) -> List[SearchResult]: 168 """ 169 Retrieve documents using hybrid search (vector + keyword). 170 171 Args: 172 query: RetrievalQuery with both text and embedding 173 174 Returns: 175 List of SearchResult objects ordered by combined relevance 176 177 Raises: 178 ValueError: If query.text or query.embedding is None 179 """ 180 if query.text is None: 181 raise ValueError("HybridRetriever requires query.text") 182 if query.embedding is None: 183 raise ValueError("HybridRetriever requires query.embedding") 184 185 return self.vector_store.search( 186 query=query.text, 187 query_embedding=query.embedding, 188 top_k=query.top_k, 189 filters=query.filters, 190 search_type="hybrid" 191 )
Hybrid retriever combining vector and keyword search.
Combines vector similarity with keyword matching for comprehensive retrieval. Scores are normalized and combined (typically 50/50 weighting).
Example:
from gmf_forge_ai_data.retrieval import HybridRetriever, RetrievalQuery
from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings
embedder = AzureOpenAIEmbeddings(...)
retriever = HybridRetriever(vector_store)
query_text = "machine learning"
query_embedding = embedder.embed_text(query_text)
# Hybrid search using both text and embedding
query = RetrievalQuery(
text=query_text,
embedding=query_embedding,
top_k=5
)
results = retriever.retrieve(query)
158 def __init__(self, vector_store: BaseVectorStore): 159 """ 160 Initialize hybrid retriever. 161 162 Args: 163 vector_store: Vector store to search 164 """ 165 self.vector_store = vector_store
Initialize hybrid retriever.
Args: vector_store: Vector store to search
167 def retrieve(self, query: RetrievalQuery) -> List[SearchResult]: 168 """ 169 Retrieve documents using hybrid search (vector + keyword). 170 171 Args: 172 query: RetrievalQuery with both text and embedding 173 174 Returns: 175 List of SearchResult objects ordered by combined relevance 176 177 Raises: 178 ValueError: If query.text or query.embedding is None 179 """ 180 if query.text is None: 181 raise ValueError("HybridRetriever requires query.text") 182 if query.embedding is None: 183 raise ValueError("HybridRetriever requires query.embedding") 184 185 return self.vector_store.search( 186 query=query.text, 187 query_embedding=query.embedding, 188 top_k=query.top_k, 189 filters=query.filters, 190 search_type="hybrid" 191 )
Retrieve documents using hybrid search (vector + keyword).
Args: query: RetrievalQuery with both text and embedding
Returns: List of SearchResult objects ordered by combined relevance
Raises: ValueError: If query.text or query.embedding is None
16class MMRRetriever(BaseRetriever): 17 """ 18 Maximal Marginal Relevance (MMR) retriever for diverse results. 19 20 MMR reranks initial retrieval results to balance relevance with diversity: 21 - High lambda (λ → 1.0): Prioritize relevance 22 - Low lambda (λ → 0.0): Prioritize diversity 23 - λ = 0.5: Balanced (default) 24 25 Algorithm: 26 1. Retrieve initial candidate documents (fetch_k results) 27 2. Select most relevant document first 28 3. For each remaining selection: 29 - Score = λ * relevance - (1-λ) * max_similarity_to_selected 30 - Choose document with highest score 31 4. Return top_k diverse results 32 33 Benefits: 34 - Reduces redundancy in results 35 - Improves coverage of different aspects 36 - Useful for exploration and summarization 37 38 Example: 39 ```python 40 from gmf_forge_ai_data.retrieval import MMRRetriever, RetrievalQuery 41 from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings 42 43 embedder = AzureOpenAIEmbeddings(...) 44 retriever = MMRRetriever( 45 vector_store=vector_store, 46 lambda_param=0.5, # Balanced relevance/diversity 47 fetch_k=20 # Retrieve 20 candidates, return top_k diverse 48 ) 49 50 query_text = "machine learning algorithms" 51 query_embedding = embedder.embed_text(query_text) 52 53 # Returns 5 diverse results from 20 candidates 54 results = retriever.retrieve_embedding( 55 embedding=query_embedding, 56 top_k=5 57 ) 58 ``` 59 60 References: 61 Carbonell, J., & Goldstein, J. (1998). The use of MMR, diversity-based 62 reranking for reordering documents and producing summaries. 63 """ 64 65 def __init__( 66 self, 67 vector_store: BaseVectorStore, 68 lambda_param: float = 0.5, 69 fetch_k: int = 20 70 ): 71 """ 72 Initialize MMR retriever. 73 74 Args: 75 vector_store: Vector store to search 76 lambda_param: Balance between relevance (1.0) and diversity (0.0) 77 fetch_k: Number of initial candidates to fetch (should be >= top_k) 78 79 Raises: 80 ValueError: If lambda_param not in [0, 1] or fetch_k < 1 81 """ 82 if not 0 <= lambda_param <= 1: 83 raise ValueError(f"lambda_param must be in [0, 1], got {lambda_param}") 84 if fetch_k < 1: 85 raise ValueError(f"fetch_k must be >= 1, got {fetch_k}") 86 87 self.vector_store = vector_store 88 self.lambda_param = lambda_param 89 self.fetch_k = fetch_k 90 91 def retrieve(self, query: RetrievalQuery) -> List[SearchResult]: 92 """ 93 Retrieve diverse documents using MMR reranking. 94 95 Args: 96 query: RetrievalQuery with embedding (required) 97 98 Returns: 99 List of SearchResult objects with diverse, relevant results 100 101 Raises: 102 ValueError: If query.embedding is None 103 """ 104 if query.embedding is None: 105 raise ValueError("MMRRetriever requires query.embedding for diversity calculation") 106 107 # Fetch initial candidates (more than top_k) 108 fetch_k = max(self.fetch_k, query.top_k) 109 candidates = self.vector_store.search( 110 query_embedding=query.embedding, 111 top_k=fetch_k, 112 filters=query.filters, 113 search_type="vector" 114 ) 115 116 if len(candidates) == 0: 117 return [] 118 119 if len(candidates) <= query.top_k: 120 # Not enough candidates for MMR, return as-is 121 return candidates[:query.top_k] 122 123 # Extract embeddings and relevance scores 124 query_embedding = np.array(query.embedding, dtype=np.float32) 125 candidate_embeddings = [] 126 candidate_scores = [] 127 128 for result in candidates: 129 if result.document.embedding is None: 130 raise ValueError( 131 f"Document {result.document.id} has no embedding. " 132 "MMR requires all documents to have embeddings." 133 ) 134 candidate_embeddings.append(result.document.embedding) 135 candidate_scores.append(result.score) 136 137 candidate_embeddings = np.array(candidate_embeddings, dtype=np.float32) 138 candidate_scores = np.array(candidate_scores, dtype=np.float32) 139 140 # Normalize embeddings for cosine similarity 141 query_norm = query_embedding / (np.linalg.norm(query_embedding) + 1e-10) 142 candidate_norms = candidate_embeddings / ( 143 np.linalg.norm(candidate_embeddings, axis=1, keepdims=True) + 1e-10 144 ) 145 146 # MMR selection 147 selected_indices = [] 148 selected_embeddings = [] 149 150 for _ in range(min(query.top_k, len(candidates))): 151 if len(selected_indices) == 0: 152 # First selection: most relevant 153 best_idx = int(np.argmax(candidate_scores)) 154 else: 155 # Subsequent selections: balance relevance and diversity 156 mmr_scores = [] 157 158 for idx in range(len(candidates)): 159 if idx in selected_indices: 160 mmr_scores.append(-np.inf) 161 continue 162 163 # Relevance score (already normalized 0-1 from vector store) 164 relevance = candidate_scores[idx] 165 166 # Diversity score: max similarity to selected documents 167 doc_embedding = candidate_norms[idx] 168 similarities = [ 169 np.dot(doc_embedding, selected_embeddings[i]) 170 for i in range(len(selected_embeddings)) 171 ] 172 max_similarity = max(similarities) 173 174 # MMR score: λ * relevance - (1-λ) * max_similarity 175 mmr_score = ( 176 self.lambda_param * relevance - 177 (1 - self.lambda_param) * max_similarity 178 ) 179 mmr_scores.append(mmr_score) 180 181 best_idx = int(np.argmax(mmr_scores)) 182 183 selected_indices.append(best_idx) 184 selected_embeddings.append(candidate_norms[best_idx]) 185 186 # Build results with updated ranks 187 mmr_results = [] 188 for rank, idx in enumerate(selected_indices): 189 result = candidates[idx] 190 mmr_results.append(SearchResult( 191 document=result.document, 192 score=result.score, # Keep original relevance score 193 rank=rank 194 )) 195 196 return mmr_results
Maximal Marginal Relevance (MMR) retriever for diverse results.
MMR reranks initial retrieval results to balance relevance with diversity:
- High lambda (λ → 1.0): Prioritize relevance
- Low lambda (λ → 0.0): Prioritize diversity
- λ = 0.5: Balanced (default)
Algorithm:
- Retrieve initial candidate documents (fetch_k results)
- Select most relevant document first
- For each remaining selection:
- Score = λ * relevance - (1-λ) * max_similarity_to_selected
- Choose document with highest score
- Return top_k diverse results
Benefits:
- Reduces redundancy in results
- Improves coverage of different aspects
- Useful for exploration and summarization
Example:
from gmf_forge_ai_data.retrieval import MMRRetriever, RetrievalQuery
from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings
embedder = AzureOpenAIEmbeddings(...)
retriever = MMRRetriever(
vector_store=vector_store,
lambda_param=0.5, # Balanced relevance/diversity
fetch_k=20 # Retrieve 20 candidates, return top_k diverse
)
query_text = "machine learning algorithms"
query_embedding = embedder.embed_text(query_text)
# Returns 5 diverse results from 20 candidates
results = retriever.retrieve_embedding(
embedding=query_embedding,
top_k=5
)
References: Carbonell, J., & Goldstein, J. (1998). The use of MMR, diversity-based reranking for reordering documents and producing summaries.
65 def __init__( 66 self, 67 vector_store: BaseVectorStore, 68 lambda_param: float = 0.5, 69 fetch_k: int = 20 70 ): 71 """ 72 Initialize MMR retriever. 73 74 Args: 75 vector_store: Vector store to search 76 lambda_param: Balance between relevance (1.0) and diversity (0.0) 77 fetch_k: Number of initial candidates to fetch (should be >= top_k) 78 79 Raises: 80 ValueError: If lambda_param not in [0, 1] or fetch_k < 1 81 """ 82 if not 0 <= lambda_param <= 1: 83 raise ValueError(f"lambda_param must be in [0, 1], got {lambda_param}") 84 if fetch_k < 1: 85 raise ValueError(f"fetch_k must be >= 1, got {fetch_k}") 86 87 self.vector_store = vector_store 88 self.lambda_param = lambda_param 89 self.fetch_k = fetch_k
Initialize MMR retriever.
Args: vector_store: Vector store to search lambda_param: Balance between relevance (1.0) and diversity (0.0) fetch_k: Number of initial candidates to fetch (should be >= top_k)
Raises: ValueError: If lambda_param not in [0, 1] or fetch_k < 1
91 def retrieve(self, query: RetrievalQuery) -> List[SearchResult]: 92 """ 93 Retrieve diverse documents using MMR reranking. 94 95 Args: 96 query: RetrievalQuery with embedding (required) 97 98 Returns: 99 List of SearchResult objects with diverse, relevant results 100 101 Raises: 102 ValueError: If query.embedding is None 103 """ 104 if query.embedding is None: 105 raise ValueError("MMRRetriever requires query.embedding for diversity calculation") 106 107 # Fetch initial candidates (more than top_k) 108 fetch_k = max(self.fetch_k, query.top_k) 109 candidates = self.vector_store.search( 110 query_embedding=query.embedding, 111 top_k=fetch_k, 112 filters=query.filters, 113 search_type="vector" 114 ) 115 116 if len(candidates) == 0: 117 return [] 118 119 if len(candidates) <= query.top_k: 120 # Not enough candidates for MMR, return as-is 121 return candidates[:query.top_k] 122 123 # Extract embeddings and relevance scores 124 query_embedding = np.array(query.embedding, dtype=np.float32) 125 candidate_embeddings = [] 126 candidate_scores = [] 127 128 for result in candidates: 129 if result.document.embedding is None: 130 raise ValueError( 131 f"Document {result.document.id} has no embedding. " 132 "MMR requires all documents to have embeddings." 133 ) 134 candidate_embeddings.append(result.document.embedding) 135 candidate_scores.append(result.score) 136 137 candidate_embeddings = np.array(candidate_embeddings, dtype=np.float32) 138 candidate_scores = np.array(candidate_scores, dtype=np.float32) 139 140 # Normalize embeddings for cosine similarity 141 query_norm = query_embedding / (np.linalg.norm(query_embedding) + 1e-10) 142 candidate_norms = candidate_embeddings / ( 143 np.linalg.norm(candidate_embeddings, axis=1, keepdims=True) + 1e-10 144 ) 145 146 # MMR selection 147 selected_indices = [] 148 selected_embeddings = [] 149 150 for _ in range(min(query.top_k, len(candidates))): 151 if len(selected_indices) == 0: 152 # First selection: most relevant 153 best_idx = int(np.argmax(candidate_scores)) 154 else: 155 # Subsequent selections: balance relevance and diversity 156 mmr_scores = [] 157 158 for idx in range(len(candidates)): 159 if idx in selected_indices: 160 mmr_scores.append(-np.inf) 161 continue 162 163 # Relevance score (already normalized 0-1 from vector store) 164 relevance = candidate_scores[idx] 165 166 # Diversity score: max similarity to selected documents 167 doc_embedding = candidate_norms[idx] 168 similarities = [ 169 np.dot(doc_embedding, selected_embeddings[i]) 170 for i in range(len(selected_embeddings)) 171 ] 172 max_similarity = max(similarities) 173 174 # MMR score: λ * relevance - (1-λ) * max_similarity 175 mmr_score = ( 176 self.lambda_param * relevance - 177 (1 - self.lambda_param) * max_similarity 178 ) 179 mmr_scores.append(mmr_score) 180 181 best_idx = int(np.argmax(mmr_scores)) 182 183 selected_indices.append(best_idx) 184 selected_embeddings.append(candidate_norms[best_idx]) 185 186 # Build results with updated ranks 187 mmr_results = [] 188 for rank, idx in enumerate(selected_indices): 189 result = candidates[idx] 190 mmr_results.append(SearchResult( 191 document=result.document, 192 score=result.score, # Keep original relevance score 193 rank=rank 194 )) 195 196 return mmr_results
Retrieve diverse documents using MMR reranking.
Args: query: RetrievalQuery with embedding (required)
Returns: List of SearchResult objects with diverse, relevant results
Raises: ValueError: If query.embedding is None
16class ParentDocumentRetriever(BaseRetriever): 17 """ 18 Retriever that searches child chunks but returns parent documents. 19 20 This pattern is useful when: 21 - Documents are chunked into small pieces for precise embedding 22 - But you want to return full parent documents for better context 23 - Multiple chunks from the same parent might match the query 24 25 Architecture: 26 - Child store: Contains small chunks with embeddings (searchable) 27 - Parent store: Contains full parent documents (for retrieval) 28 - Mapping: Children have metadata["parent_id"] pointing to parent 29 30 Workflow: 31 1. Search child store for relevant chunks 32 2. Extract parent_ids from child metadata 33 3. Retrieve parent documents from parent store 34 4. Deduplicate and rank by best child score 35 36 Example: 37 ```python 38 from gmf_forge_ai_data.retrieval import ParentDocumentRetriever 39 from gmf_forge_ai_data.vector_stores import InMemoryVectorStore, Document 40 from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings 41 42 embedder = AzureOpenAIEmbeddings(...) 43 44 # Setup stores 45 child_store = InMemoryVectorStore() 46 parent_store = InMemoryVectorStore() 47 48 # Create parent document 49 parent = Document( 50 id="doc_1", 51 content="Full document with multiple sections...", 52 embedding=embedder.embed_text("Full document...") 53 ) 54 parent_store.add_documents([parent]) 55 56 # Create child chunks 57 child1 = Document( 58 id="doc_1_chunk_0", 59 content="Introduction section...", 60 embedding=embedder.embed_text("Introduction..."), 61 metadata={"parent_id": "doc_1"} # Link to parent 62 ) 63 child2 = Document( 64 id="doc_1_chunk_1", 65 content="Methods section...", 66 embedding=embedder.embed_text("Methods..."), 67 metadata={"parent_id": "doc_1"} 68 ) 69 child_store.add_documents([child1, child2]) 70 71 # Setup retriever 72 retriever = ParentDocumentRetriever( 73 child_store=child_store, 74 parent_store=parent_store, 75 parent_id_key="parent_id" 76 ) 77 78 # Search chunks, return parent 79 query_embedding = embedder.embed_text("introduction") 80 results = retriever.retrieve_embedding( 81 embedding=query_embedding, 82 top_k=5 83 ) 84 # Returns full parent documents, not chunks 85 ``` 86 87 Benefits: 88 - Precise search (small chunks) 89 - Rich context (full parents) 90 - Automatic deduplication 91 """ 92 93 def __init__( 94 self, 95 child_store: BaseVectorStore, 96 parent_store: BaseVectorStore, 97 parent_id_key: str = "parent_id", 98 search_type: str = "vector" 99 ): 100 """ 101 Initialize parent document retriever. 102 103 Args: 104 child_store: Vector store containing child chunks with embeddings 105 parent_store: Vector store containing parent documents 106 parent_id_key: Metadata key in children pointing to parent ID 107 search_type: Search type for child store ("vector", "keyword", "hybrid") 108 109 Raises: 110 ValueError: If search_type is invalid 111 """ 112 if search_type not in ["vector", "keyword", "hybrid"]: 113 raise ValueError(f"Invalid search_type: {search_type}") 114 115 self.child_store = child_store 116 self.parent_store = parent_store 117 self.parent_id_key = parent_id_key 118 self.search_type = search_type 119 120 def retrieve(self, query: RetrievalQuery) -> List[SearchResult]: 121 """ 122 Retrieve parent documents by searching child chunks. 123 124 Args: 125 query: RetrievalQuery with appropriate parameters for search_type 126 127 Returns: 128 List of SearchResult with parent documents, deduplicated and ranked 129 130 Raises: 131 ValueError: If required query parameters are missing 132 """ 133 # Validate query based on search type 134 if self.search_type in ["vector", "hybrid"] and query.embedding is None: 135 raise ValueError(f"{self.search_type} search requires query.embedding") 136 if self.search_type in ["keyword", "hybrid"] and query.text is None: 137 raise ValueError(f"{self.search_type} search requires query.text") 138 139 # Search child store (fetch more candidates to account for deduplication) 140 fetch_k = query.top_k * 3 # Heuristic: fetch 3x to handle duplicates 141 child_results = self.child_store.search( 142 query=query.text, 143 query_embedding=query.embedding, 144 top_k=fetch_k, 145 filters=query.filters, 146 search_type=self.search_type 147 ) 148 149 if len(child_results) == 0: 150 return [] 151 152 # Extract parent IDs and track best scores 153 parent_scores: Dict[str, float] = {} # parent_id -> best_score 154 parent_order: OrderedDict[str, None] = OrderedDict() # Track first occurrence 155 156 for child_result in child_results: 157 parent_id = child_result.document.metadata.get(self.parent_id_key) 158 159 if parent_id is None: 160 # Skip children without parent link 161 continue 162 163 # Track best score for each parent (in case multiple children match) 164 if parent_id not in parent_scores: 165 parent_scores[parent_id] = child_result.score 166 parent_order[parent_id] = None # Track insertion order 167 else: 168 # Update if this child has better score 169 parent_scores[parent_id] = max( 170 parent_scores[parent_id], 171 child_result.score 172 ) 173 174 if len(parent_scores) == 0: 175 return [] 176 177 # Retrieve parent documents 178 parent_results = [] 179 for rank, parent_id in enumerate(parent_order.keys()): 180 if rank >= query.top_k: 181 break 182 183 parent_doc = self.parent_store.get_document(parent_id) 184 if parent_doc is None: 185 # Parent not found in store, skip 186 continue 187 188 parent_results.append(SearchResult( 189 document=parent_doc, 190 score=parent_scores[parent_id], # Best child score 191 rank=rank 192 )) 193 194 return parent_results
Retriever that searches child chunks but returns parent documents.
This pattern is useful when:
- Documents are chunked into small pieces for precise embedding
- But you want to return full parent documents for better context
- Multiple chunks from the same parent might match the query
Architecture:
- Child store: Contains small chunks with embeddings (searchable)
- Parent store: Contains full parent documents (for retrieval)
- Mapping: Children have metadata["parent_id"] pointing to parent
Workflow:
- Search child store for relevant chunks
- Extract parent_ids from child metadata
- Retrieve parent documents from parent store
- Deduplicate and rank by best child score
Example:
from gmf_forge_ai_data.retrieval import ParentDocumentRetriever
from gmf_forge_ai_data.vector_stores import InMemoryVectorStore, Document
from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings
embedder = AzureOpenAIEmbeddings(...)
# Setup stores
child_store = InMemoryVectorStore()
parent_store = InMemoryVectorStore()
# Create parent document
parent = Document(
id="doc_1",
content="Full document with multiple sections...",
embedding=embedder.embed_text("Full document...")
)
parent_store.add_documents([parent])
# Create child chunks
child1 = Document(
id="doc_1_chunk_0",
content="Introduction section...",
embedding=embedder.embed_text("Introduction..."),
metadata={"parent_id": "doc_1"} # Link to parent
)
child2 = Document(
id="doc_1_chunk_1",
content="Methods section...",
embedding=embedder.embed_text("Methods..."),
metadata={"parent_id": "doc_1"}
)
child_store.add_documents([child1, child2])
# Setup retriever
retriever = ParentDocumentRetriever(
child_store=child_store,
parent_store=parent_store,
parent_id_key="parent_id"
)
# Search chunks, return parent
query_embedding = embedder.embed_text("introduction")
results = retriever.retrieve_embedding(
embedding=query_embedding,
top_k=5
)
# Returns full parent documents, not chunks
Benefits:
- Precise search (small chunks)
- Rich context (full parents)
- Automatic deduplication
93 def __init__( 94 self, 95 child_store: BaseVectorStore, 96 parent_store: BaseVectorStore, 97 parent_id_key: str = "parent_id", 98 search_type: str = "vector" 99 ): 100 """ 101 Initialize parent document retriever. 102 103 Args: 104 child_store: Vector store containing child chunks with embeddings 105 parent_store: Vector store containing parent documents 106 parent_id_key: Metadata key in children pointing to parent ID 107 search_type: Search type for child store ("vector", "keyword", "hybrid") 108 109 Raises: 110 ValueError: If search_type is invalid 111 """ 112 if search_type not in ["vector", "keyword", "hybrid"]: 113 raise ValueError(f"Invalid search_type: {search_type}") 114 115 self.child_store = child_store 116 self.parent_store = parent_store 117 self.parent_id_key = parent_id_key 118 self.search_type = search_type
Initialize parent document retriever.
Args: child_store: Vector store containing child chunks with embeddings parent_store: Vector store containing parent documents parent_id_key: Metadata key in children pointing to parent ID search_type: Search type for child store ("vector", "keyword", "hybrid")
Raises: ValueError: If search_type is invalid
120 def retrieve(self, query: RetrievalQuery) -> List[SearchResult]: 121 """ 122 Retrieve parent documents by searching child chunks. 123 124 Args: 125 query: RetrievalQuery with appropriate parameters for search_type 126 127 Returns: 128 List of SearchResult with parent documents, deduplicated and ranked 129 130 Raises: 131 ValueError: If required query parameters are missing 132 """ 133 # Validate query based on search type 134 if self.search_type in ["vector", "hybrid"] and query.embedding is None: 135 raise ValueError(f"{self.search_type} search requires query.embedding") 136 if self.search_type in ["keyword", "hybrid"] and query.text is None: 137 raise ValueError(f"{self.search_type} search requires query.text") 138 139 # Search child store (fetch more candidates to account for deduplication) 140 fetch_k = query.top_k * 3 # Heuristic: fetch 3x to handle duplicates 141 child_results = self.child_store.search( 142 query=query.text, 143 query_embedding=query.embedding, 144 top_k=fetch_k, 145 filters=query.filters, 146 search_type=self.search_type 147 ) 148 149 if len(child_results) == 0: 150 return [] 151 152 # Extract parent IDs and track best scores 153 parent_scores: Dict[str, float] = {} # parent_id -> best_score 154 parent_order: OrderedDict[str, None] = OrderedDict() # Track first occurrence 155 156 for child_result in child_results: 157 parent_id = child_result.document.metadata.get(self.parent_id_key) 158 159 if parent_id is None: 160 # Skip children without parent link 161 continue 162 163 # Track best score for each parent (in case multiple children match) 164 if parent_id not in parent_scores: 165 parent_scores[parent_id] = child_result.score 166 parent_order[parent_id] = None # Track insertion order 167 else: 168 # Update if this child has better score 169 parent_scores[parent_id] = max( 170 parent_scores[parent_id], 171 child_result.score 172 ) 173 174 if len(parent_scores) == 0: 175 return [] 176 177 # Retrieve parent documents 178 parent_results = [] 179 for rank, parent_id in enumerate(parent_order.keys()): 180 if rank >= query.top_k: 181 break 182 183 parent_doc = self.parent_store.get_document(parent_id) 184 if parent_doc is None: 185 # Parent not found in store, skip 186 continue 187 188 parent_results.append(SearchResult( 189 document=parent_doc, 190 score=parent_scores[parent_id], # Best child score 191 rank=rank 192 )) 193 194 return parent_results
Retrieve parent documents by searching child chunks.
Args: query: RetrievalQuery with appropriate parameters for search_type
Returns: List of SearchResult with parent documents, deduplicated and ranked
Raises: ValueError: If required query parameters are missing
17class EnsembleRetriever(BaseRetriever): 18 """ 19 Ensemble retriever combining multiple retrieval strategies. 20 21 Combines results from multiple retrievers using score fusion techniques: 22 - Reciprocal Rank Fusion (RRF): Robust, rank-based fusion 23 - Weighted Average: Score-based fusion with configurable weights 24 - Max Score: Conservative, consensus-based fusion 25 26 Benefits: 27 - Improved recall (combine different strategies) 28 - Robustness (less sensitive to individual retriever failures) 29 - Flexibility (combine vector, keyword, hybrid, MMR, etc.) 30 31 Example: 32 ```python 33 from gmf_forge_ai_data.retrieval import ( 34 EnsembleRetriever, 35 VectorRetriever, 36 KeywordRetriever, 37 MMRRetriever 38 ) 39 from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings 40 41 embedder = AzureOpenAIEmbeddings(...) 42 43 # Create multiple retrievers 44 vector_retriever = VectorRetriever(vector_store) 45 keyword_retriever = KeywordRetriever(vector_store) 46 mmr_retriever = MMRRetriever(vector_store, lambda_param=0.7) 47 48 # Combine with RRF fusion 49 ensemble = EnsembleRetriever( 50 retrievers=[vector_retriever, keyword_retriever, mmr_retriever], 51 weights=[0.5, 0.3, 0.2], # Optional weights 52 fusion_strategy="rrf" # Reciprocal Rank Fusion 53 ) 54 55 # Query requires parameters for all retrievers 56 query_text = "machine learning" 57 query_embedding = embedder.embed_text(query_text) 58 59 query = RetrievalQuery( 60 text=query_text, # For keyword retriever 61 embedding=query_embedding, # For vector/MMR retrievers 62 top_k=5 63 ) 64 65 results = ensemble.retrieve(query) 66 # Returns fused results from all retrievers 67 ``` 68 69 Fusion Strategies: 70 71 1. **rrf** (Reciprocal Rank Fusion): 72 - score(doc) = Σ weights[i] / (k + rank_i) 73 - k = 60 (default) 74 - Robust to different score scales 75 - Rank-based, not score-based 76 77 2. **weighted_avg** (Weighted Average): 78 - Normalize scores to [0, 1] 79 - score(doc) = Σ weights[i] * normalized_score_i 80 - Score-based fusion 81 82 3. **max_score** (Maximum Score): 83 - score(doc) = max(normalized_scores) 84 - Conservative, requires consensus 85 86 References: 87 Cormack, G. V., Clarke, C. L., & Buettcher, S. (2009). 88 Reciprocal rank fusion outperforms condorcet and individual 89 ranklists. ACM SIGIR. 90 """ 91 92 def __init__( 93 self, 94 retrievers: List[BaseRetriever], 95 weights: List[float] = None, 96 fusion_strategy: str = "rrf", 97 rrf_k: int = 60 98 ): 99 """ 100 Initialize ensemble retriever. 101 102 Args: 103 retrievers: List of retrievers to combine 104 weights: Optional weights for each retriever (default: equal weights) 105 fusion_strategy: Fusion method - "rrf", "weighted_avg", or "max_score" 106 rrf_k: k parameter for RRF (default: 60) 107 108 Raises: 109 ValueError: If retrievers list is empty or params are invalid 110 """ 111 if not retrievers: 112 raise ValueError("Must provide at least one retriever") 113 114 if fusion_strategy not in ["rrf", "weighted_avg", "max_score"]: 115 raise ValueError( 116 f"Invalid fusion_strategy: {fusion_strategy}. " 117 "Must be 'rrf', 'weighted_avg', or 'max_score'" 118 ) 119 120 if weights is None: 121 # Equal weights 122 weights = [1.0 / len(retrievers)] * len(retrievers) 123 else: 124 if len(weights) != len(retrievers): 125 raise ValueError( 126 f"Number of weights ({len(weights)}) must match " 127 f"number of retrievers ({len(retrievers)})" 128 ) 129 # Normalize weights to sum to 1.0 130 total = sum(weights) 131 weights = [w / total for w in weights] 132 133 self.retrievers = retrievers 134 self.weights = weights 135 self.fusion_strategy = fusion_strategy 136 self.rrf_k = rrf_k 137 138 def retrieve(self, query: RetrievalQuery) -> List[SearchResult]: 139 """ 140 Retrieve documents using ensemble fusion. 141 142 Args: 143 query: RetrievalQuery with parameters for all retrievers 144 145 Returns: 146 List of SearchResult with fused scores and ranks 147 148 Note: 149 All retrievers are called with the same query. Ensure the query 150 contains all required parameters (text, embedding) for your retrievers. 151 """ 152 # Retrieve from all retrievers 153 all_results: List[List[SearchResult]] = [] 154 155 for retriever in self.retrievers: 156 try: 157 results = retriever.retrieve(query) 158 all_results.append(results) 159 except ValueError as e: 160 # Retriever might not support this query type, skip it 161 all_results.append([]) 162 163 if not any(all_results): 164 return [] 165 166 # Apply fusion strategy 167 if self.fusion_strategy == "rrf": 168 fused_results = self._reciprocal_rank_fusion(all_results) 169 elif self.fusion_strategy == "weighted_avg": 170 fused_results = self._weighted_average_fusion(all_results) 171 else: # max_score 172 fused_results = self._max_score_fusion(all_results) 173 174 # Sort by fused score and take top_k 175 fused_results.sort(key=lambda x: x.score, reverse=True) 176 top_results = fused_results[:query.top_k] 177 178 # Update ranks 179 for rank, result in enumerate(top_results): 180 result.rank = rank 181 182 return top_results 183 184 def _reciprocal_rank_fusion( 185 self, 186 all_results: List[List[SearchResult]] 187 ) -> List[SearchResult]: 188 """ 189 Fuse results using Reciprocal Rank Fusion (RRF). 190 191 RRF score for document d: 192 score(d) = Σ weight_i / (k + rank_i(d)) 193 194 Where rank_i(d) is the rank of d in retriever i's results. 195 """ 196 doc_scores: Dict[str, float] = defaultdict(float) 197 doc_objects: Dict[str, Document] = {} 198 199 for retriever_idx, results in enumerate(all_results): 200 weight = self.weights[retriever_idx] 201 202 for result in results: 203 doc_id = result.document.id 204 205 # RRF score contribution 206 rrf_score = weight / (self.rrf_k + result.rank) 207 doc_scores[doc_id] += rrf_score 208 209 # Keep document object (use first occurrence) 210 if doc_id not in doc_objects: 211 doc_objects[doc_id] = result.document 212 213 # Build SearchResult objects 214 fused_results = [ 215 SearchResult( 216 document=doc_objects[doc_id], 217 score=score, 218 rank=0 # Will be set later 219 ) 220 for doc_id, score in doc_scores.items() 221 ] 222 223 return fused_results 224 225 def _weighted_average_fusion( 226 self, 227 all_results: List[List[SearchResult]] 228 ) -> List[SearchResult]: 229 """ 230 Fuse results using weighted average of normalized scores. 231 232 Scores are normalized to [0, 1] within each retriever, then averaged. 233 """ 234 doc_scores: Dict[str, float] = defaultdict(float) 235 doc_objects: Dict[str, Document] = {} 236 237 for retriever_idx, results in enumerate(all_results): 238 if not results: 239 continue 240 241 weight = self.weights[retriever_idx] 242 243 # Normalize scores to [0, 1] 244 scores = np.array([r.score for r in results]) 245 min_score = scores.min() 246 max_score = scores.max() 247 248 if max_score > min_score: 249 normalized_scores = (scores - min_score) / (max_score - min_score) 250 else: 251 normalized_scores = np.ones_like(scores) 252 253 for result, norm_score in zip(results, normalized_scores): 254 doc_id = result.document.id 255 doc_scores[doc_id] += weight * norm_score 256 257 if doc_id not in doc_objects: 258 doc_objects[doc_id] = result.document 259 260 # Build SearchResult objects 261 fused_results = [ 262 SearchResult( 263 document=doc_objects[doc_id], 264 score=score, 265 rank=0 266 ) 267 for doc_id, score in doc_scores.items() 268 ] 269 270 return fused_results 271 272 def _max_score_fusion( 273 self, 274 all_results: List[List[SearchResult]] 275 ) -> List[SearchResult]: 276 """ 277 Fuse results using maximum normalized score across retrievers. 278 279 Conservative strategy: document needs high score from at least one retriever. 280 """ 281 doc_scores: Dict[str, float] = defaultdict(float) 282 doc_objects: Dict[str, Document] = {} 283 284 for retriever_idx, results in enumerate(all_results): 285 if not results: 286 continue 287 288 # Normalize scores to [0, 1] 289 scores = np.array([r.score for r in results]) 290 min_score = scores.min() 291 max_score = scores.max() 292 293 if max_score > min_score: 294 normalized_scores = (scores - min_score) / (max_score - min_score) 295 else: 296 normalized_scores = np.ones_like(scores) 297 298 for result, norm_score in zip(results, normalized_scores): 299 doc_id = result.document.id 300 301 # Take maximum normalized score 302 doc_scores[doc_id] = max(doc_scores[doc_id], norm_score) 303 304 if doc_id not in doc_objects: 305 doc_objects[doc_id] = result.document 306 307 # Build SearchResult objects 308 fused_results = [ 309 SearchResult( 310 document=doc_objects[doc_id], 311 score=score, 312 rank=0 313 ) 314 for doc_id, score in doc_scores.items() 315 ] 316 317 return fused_results
Ensemble retriever combining multiple retrieval strategies.
Combines results from multiple retrievers using score fusion techniques:
- Reciprocal Rank Fusion (RRF): Robust, rank-based fusion
- Weighted Average: Score-based fusion with configurable weights
- Max Score: Conservative, consensus-based fusion
Benefits:
- Improved recall (combine different strategies)
- Robustness (less sensitive to individual retriever failures)
- Flexibility (combine vector, keyword, hybrid, MMR, etc.)
Example:
from gmf_forge_ai_data.retrieval import (
EnsembleRetriever,
VectorRetriever,
KeywordRetriever,
MMRRetriever
)
from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings
embedder = AzureOpenAIEmbeddings(...)
# Create multiple retrievers
vector_retriever = VectorRetriever(vector_store)
keyword_retriever = KeywordRetriever(vector_store)
mmr_retriever = MMRRetriever(vector_store, lambda_param=0.7)
# Combine with RRF fusion
ensemble = EnsembleRetriever(
retrievers=[vector_retriever, keyword_retriever, mmr_retriever],
weights=[0.5, 0.3, 0.2], # Optional weights
fusion_strategy="rrf" # Reciprocal Rank Fusion
)
# Query requires parameters for all retrievers
query_text = "machine learning"
query_embedding = embedder.embed_text(query_text)
query = RetrievalQuery(
text=query_text, # For keyword retriever
embedding=query_embedding, # For vector/MMR retrievers
top_k=5
)
results = ensemble.retrieve(query)
# Returns fused results from all retrievers
Fusion Strategies:
rrf (Reciprocal Rank Fusion):
- score(doc) = Σ weights[i] / (k + rank_i)
- k = 60 (default)
- Robust to different score scales
- Rank-based, not score-based
weighted_avg (Weighted Average):
- Normalize scores to [0, 1]
- score(doc) = Σ weights[i] * normalized_score_i
- Score-based fusion
max_score (Maximum Score):
- score(doc) = max(normalized_scores)
- Conservative, requires consensus
References: Cormack, G. V., Clarke, C. L., & Buettcher, S. (2009). Reciprocal rank fusion outperforms condorcet and individual ranklists. ACM SIGIR.
92 def __init__( 93 self, 94 retrievers: List[BaseRetriever], 95 weights: List[float] = None, 96 fusion_strategy: str = "rrf", 97 rrf_k: int = 60 98 ): 99 """ 100 Initialize ensemble retriever. 101 102 Args: 103 retrievers: List of retrievers to combine 104 weights: Optional weights for each retriever (default: equal weights) 105 fusion_strategy: Fusion method - "rrf", "weighted_avg", or "max_score" 106 rrf_k: k parameter for RRF (default: 60) 107 108 Raises: 109 ValueError: If retrievers list is empty or params are invalid 110 """ 111 if not retrievers: 112 raise ValueError("Must provide at least one retriever") 113 114 if fusion_strategy not in ["rrf", "weighted_avg", "max_score"]: 115 raise ValueError( 116 f"Invalid fusion_strategy: {fusion_strategy}. " 117 "Must be 'rrf', 'weighted_avg', or 'max_score'" 118 ) 119 120 if weights is None: 121 # Equal weights 122 weights = [1.0 / len(retrievers)] * len(retrievers) 123 else: 124 if len(weights) != len(retrievers): 125 raise ValueError( 126 f"Number of weights ({len(weights)}) must match " 127 f"number of retrievers ({len(retrievers)})" 128 ) 129 # Normalize weights to sum to 1.0 130 total = sum(weights) 131 weights = [w / total for w in weights] 132 133 self.retrievers = retrievers 134 self.weights = weights 135 self.fusion_strategy = fusion_strategy 136 self.rrf_k = rrf_k
Initialize ensemble retriever.
Args: retrievers: List of retrievers to combine weights: Optional weights for each retriever (default: equal weights) fusion_strategy: Fusion method - "rrf", "weighted_avg", or "max_score" rrf_k: k parameter for RRF (default: 60)
Raises: ValueError: If retrievers list is empty or params are invalid
138 def retrieve(self, query: RetrievalQuery) -> List[SearchResult]: 139 """ 140 Retrieve documents using ensemble fusion. 141 142 Args: 143 query: RetrievalQuery with parameters for all retrievers 144 145 Returns: 146 List of SearchResult with fused scores and ranks 147 148 Note: 149 All retrievers are called with the same query. Ensure the query 150 contains all required parameters (text, embedding) for your retrievers. 151 """ 152 # Retrieve from all retrievers 153 all_results: List[List[SearchResult]] = [] 154 155 for retriever in self.retrievers: 156 try: 157 results = retriever.retrieve(query) 158 all_results.append(results) 159 except ValueError as e: 160 # Retriever might not support this query type, skip it 161 all_results.append([]) 162 163 if not any(all_results): 164 return [] 165 166 # Apply fusion strategy 167 if self.fusion_strategy == "rrf": 168 fused_results = self._reciprocal_rank_fusion(all_results) 169 elif self.fusion_strategy == "weighted_avg": 170 fused_results = self._weighted_average_fusion(all_results) 171 else: # max_score 172 fused_results = self._max_score_fusion(all_results) 173 174 # Sort by fused score and take top_k 175 fused_results.sort(key=lambda x: x.score, reverse=True) 176 top_results = fused_results[:query.top_k] 177 178 # Update ranks 179 for rank, result in enumerate(top_results): 180 result.rank = rank 181 182 return top_results
Retrieve documents using ensemble fusion.
Args: query: RetrievalQuery with parameters for all retrievers
Returns: List of SearchResult with fused scores and ranks
Note: All retrievers are called with the same query. Ensure the query contains all required parameters (text, embedding) for your retrievers.
24class HierarchicalRetriever(BaseRetriever): 25 """ 26 Two-stage hierarchical retrieval for efficient search in large collections. 27 28 Stage 1: Retrieve document summaries to identify relevant documents 29 Stage 2: Retrieve detailed chunks from top documents only 30 31 This reduces computational cost by focusing detailed retrieval on relevant documents. 32 33 Example: 34 ```python 35 # Create summary and chunk retrievers 36 summary_retriever = VectorRetriever(summary_store, embedder) 37 chunk_retriever = VectorRetriever(chunk_store, embedder) 38 39 # Create hierarchical retriever 40 retriever = HierarchicalRetriever( 41 summary_retriever=summary_retriever, 42 chunk_retriever=chunk_retriever, 43 stage1_top_k=10, # Get top 10 documents 44 stage2_top_k=5, # Get 5 chunks per document 45 document_id_field="document_id" 46 ) 47 48 # Retrieve 49 query = RetrievalQuery(text="machine learning", top_k=20) 50 results = retriever.retrieve(query) 51 ``` 52 """ 53 54 def __init__( 55 self, 56 summary_retriever: BaseRetriever, 57 chunk_retriever: BaseRetriever, 58 stage1_top_k: int = 10, 59 stage2_top_k: int = 5, 60 document_id_field: str = "document_id", 61 combine_scores: bool = True, 62 stage1_weight: float = 0.3, 63 stage2_weight: float = 0.7 64 ): 65 """ 66 Initialize hierarchical retriever. 67 68 Args: 69 summary_retriever: Retriever for document summaries (stage 1) 70 chunk_retriever: Retriever for detailed chunks (stage 2) 71 stage1_top_k: Number of documents to retrieve in stage 1 72 stage2_top_k: Number of chunks to retrieve per document in stage 2 73 document_id_field: Metadata field linking chunks to documents 74 combine_scores: If True, combine stage1 and stage2 scores 75 stage1_weight: Weight for stage 1 score (document relevance) 76 stage2_weight: Weight for stage 2 score (chunk relevance) 77 """ 78 self.summary_retriever = summary_retriever 79 self.chunk_retriever = chunk_retriever 80 self.stage1_top_k = stage1_top_k 81 self.stage2_top_k = stage2_top_k 82 self.document_id_field = document_id_field 83 self.combine_scores = combine_scores 84 self.stage1_weight = stage1_weight 85 self.stage2_weight = stage2_weight 86 87 # Normalize weights 88 total_weight = stage1_weight + stage2_weight 89 if total_weight > 0: 90 self.stage1_weight = stage1_weight / total_weight 91 self.stage2_weight = stage2_weight / total_weight 92 93 def retrieve( 94 self, 95 query: RetrievalQuery, 96 **kwargs 97 ) -> List[SearchResult]: 98 """ 99 Perform two-stage hierarchical retrieval. 100 101 Args: 102 query: The retrieval query 103 **kwargs: Additional arguments 104 105 Returns: 106 List of SearchResult objects from stage 2, optionally with combined scores 107 """ 108 # Stage 1: Retrieve document summaries 109 logger.info(f"Stage 1: Retrieving top {self.stage1_top_k} document summaries") 110 111 stage1_query = RetrievalQuery( 112 text=query.text, 113 embedding=query.embedding, 114 top_k=self.stage1_top_k, 115 filters=query.filters, 116 metadata=query.metadata 117 ) 118 119 summary_results = self.summary_retriever.retrieve(stage1_query, **kwargs) 120 121 if not summary_results: 122 logger.warning("Stage 1 returned no results") 123 return [] 124 125 # Extract document IDs from summaries 126 document_ids = [] 127 document_scores = {} 128 129 for result in summary_results: 130 doc_id = result.document.metadata.get(self.document_id_field) 131 if doc_id: 132 document_ids.append(doc_id) 133 document_scores[doc_id] = result.score 134 else: 135 # If no document_id field, use the document id itself 136 document_ids.append(result.document.id) 137 document_scores[result.document.id] = result.score 138 139 logger.info(f"Stage 1 identified {len(document_ids)} relevant documents") 140 141 # Stage 2: Retrieve detailed chunks from top documents 142 logger.info(f"Stage 2: Retrieving up to {self.stage2_top_k} chunks per document") 143 144 all_chunks: List[SearchResult] = [] 145 146 for doc_id in document_ids: 147 # Create filter for this document 148 doc_filter = {self.document_id_field: doc_id} 149 150 # Merge with existing filters if any 151 if query.filters: 152 doc_filter.update(query.filters) 153 154 # Retrieve chunks for this document 155 stage2_query = RetrievalQuery( 156 text=query.text, 157 embedding=query.embedding, 158 top_k=self.stage2_top_k, 159 filters=doc_filter, 160 metadata=query.metadata 161 ) 162 163 try: 164 chunk_results = self.chunk_retriever.retrieve(stage2_query, **kwargs) 165 166 if self.combine_scores and doc_id in document_scores: 167 # Combine stage 1 and stage 2 scores 168 for result in chunk_results: 169 combined_score = ( 170 self.stage1_weight * document_scores[doc_id] + 171 self.stage2_weight * result.score 172 ) 173 # Create new result with combined score 174 result.score = combined_score 175 176 all_chunks.extend(chunk_results) 177 178 except ValueError as e: 179 logger.warning(f"Failed to retrieve chunks for document {doc_id}: {e}") 180 continue 181 182 if not all_chunks: 183 logger.warning("Stage 2 returned no chunks") 184 return [] 185 186 # Sort by score and limit to top_k 187 all_chunks.sort(key=lambda x: x.score, reverse=True) 188 final_results = all_chunks[:query.top_k] 189 190 # Update ranks 191 for i, result in enumerate(final_results): 192 result.rank = i + 1 193 194 logger.info( 195 f"Hierarchical retrieval complete: {len(final_results)} chunks from " 196 f"{len(document_ids)} documents" 197 ) 198 199 return final_results
Two-stage hierarchical retrieval for efficient search in large collections.
Stage 1: Retrieve document summaries to identify relevant documents Stage 2: Retrieve detailed chunks from top documents only
This reduces computational cost by focusing detailed retrieval on relevant documents.
Example:
# Create summary and chunk retrievers
summary_retriever = VectorRetriever(summary_store, embedder)
chunk_retriever = VectorRetriever(chunk_store, embedder)
# Create hierarchical retriever
retriever = HierarchicalRetriever(
summary_retriever=summary_retriever,
chunk_retriever=chunk_retriever,
stage1_top_k=10, # Get top 10 documents
stage2_top_k=5, # Get 5 chunks per document
document_id_field="document_id"
)
# Retrieve
query = RetrievalQuery(text="machine learning", top_k=20)
results = retriever.retrieve(query)
54 def __init__( 55 self, 56 summary_retriever: BaseRetriever, 57 chunk_retriever: BaseRetriever, 58 stage1_top_k: int = 10, 59 stage2_top_k: int = 5, 60 document_id_field: str = "document_id", 61 combine_scores: bool = True, 62 stage1_weight: float = 0.3, 63 stage2_weight: float = 0.7 64 ): 65 """ 66 Initialize hierarchical retriever. 67 68 Args: 69 summary_retriever: Retriever for document summaries (stage 1) 70 chunk_retriever: Retriever for detailed chunks (stage 2) 71 stage1_top_k: Number of documents to retrieve in stage 1 72 stage2_top_k: Number of chunks to retrieve per document in stage 2 73 document_id_field: Metadata field linking chunks to documents 74 combine_scores: If True, combine stage1 and stage2 scores 75 stage1_weight: Weight for stage 1 score (document relevance) 76 stage2_weight: Weight for stage 2 score (chunk relevance) 77 """ 78 self.summary_retriever = summary_retriever 79 self.chunk_retriever = chunk_retriever 80 self.stage1_top_k = stage1_top_k 81 self.stage2_top_k = stage2_top_k 82 self.document_id_field = document_id_field 83 self.combine_scores = combine_scores 84 self.stage1_weight = stage1_weight 85 self.stage2_weight = stage2_weight 86 87 # Normalize weights 88 total_weight = stage1_weight + stage2_weight 89 if total_weight > 0: 90 self.stage1_weight = stage1_weight / total_weight 91 self.stage2_weight = stage2_weight / total_weight
Initialize hierarchical retriever.
Args: summary_retriever: Retriever for document summaries (stage 1) chunk_retriever: Retriever for detailed chunks (stage 2) stage1_top_k: Number of documents to retrieve in stage 1 stage2_top_k: Number of chunks to retrieve per document in stage 2 document_id_field: Metadata field linking chunks to documents combine_scores: If True, combine stage1 and stage2 scores stage1_weight: Weight for stage 1 score (document relevance) stage2_weight: Weight for stage 2 score (chunk relevance)
93 def retrieve( 94 self, 95 query: RetrievalQuery, 96 **kwargs 97 ) -> List[SearchResult]: 98 """ 99 Perform two-stage hierarchical retrieval. 100 101 Args: 102 query: The retrieval query 103 **kwargs: Additional arguments 104 105 Returns: 106 List of SearchResult objects from stage 2, optionally with combined scores 107 """ 108 # Stage 1: Retrieve document summaries 109 logger.info(f"Stage 1: Retrieving top {self.stage1_top_k} document summaries") 110 111 stage1_query = RetrievalQuery( 112 text=query.text, 113 embedding=query.embedding, 114 top_k=self.stage1_top_k, 115 filters=query.filters, 116 metadata=query.metadata 117 ) 118 119 summary_results = self.summary_retriever.retrieve(stage1_query, **kwargs) 120 121 if not summary_results: 122 logger.warning("Stage 1 returned no results") 123 return [] 124 125 # Extract document IDs from summaries 126 document_ids = [] 127 document_scores = {} 128 129 for result in summary_results: 130 doc_id = result.document.metadata.get(self.document_id_field) 131 if doc_id: 132 document_ids.append(doc_id) 133 document_scores[doc_id] = result.score 134 else: 135 # If no document_id field, use the document id itself 136 document_ids.append(result.document.id) 137 document_scores[result.document.id] = result.score 138 139 logger.info(f"Stage 1 identified {len(document_ids)} relevant documents") 140 141 # Stage 2: Retrieve detailed chunks from top documents 142 logger.info(f"Stage 2: Retrieving up to {self.stage2_top_k} chunks per document") 143 144 all_chunks: List[SearchResult] = [] 145 146 for doc_id in document_ids: 147 # Create filter for this document 148 doc_filter = {self.document_id_field: doc_id} 149 150 # Merge with existing filters if any 151 if query.filters: 152 doc_filter.update(query.filters) 153 154 # Retrieve chunks for this document 155 stage2_query = RetrievalQuery( 156 text=query.text, 157 embedding=query.embedding, 158 top_k=self.stage2_top_k, 159 filters=doc_filter, 160 metadata=query.metadata 161 ) 162 163 try: 164 chunk_results = self.chunk_retriever.retrieve(stage2_query, **kwargs) 165 166 if self.combine_scores and doc_id in document_scores: 167 # Combine stage 1 and stage 2 scores 168 for result in chunk_results: 169 combined_score = ( 170 self.stage1_weight * document_scores[doc_id] + 171 self.stage2_weight * result.score 172 ) 173 # Create new result with combined score 174 result.score = combined_score 175 176 all_chunks.extend(chunk_results) 177 178 except ValueError as e: 179 logger.warning(f"Failed to retrieve chunks for document {doc_id}: {e}") 180 continue 181 182 if not all_chunks: 183 logger.warning("Stage 2 returned no chunks") 184 return [] 185 186 # Sort by score and limit to top_k 187 all_chunks.sort(key=lambda x: x.score, reverse=True) 188 final_results = all_chunks[:query.top_k] 189 190 # Update ranks 191 for i, result in enumerate(final_results): 192 result.rank = i + 1 193 194 logger.info( 195 f"Hierarchical retrieval complete: {len(final_results)} chunks from " 196 f"{len(document_ids)} documents" 197 ) 198 199 return final_results
Perform two-stage hierarchical retrieval.
Args: query: The retrieval query **kwargs: Additional arguments
Returns: List of SearchResult objects from stage 2, optionally with combined scores
45class GraphRetriever(BaseRetriever): 46 """ 47 Graph-based retrieval using entity relationships. 48 49 This retriever: 50 1. Extracts entities from the query 51 2. Finds matching entities in the knowledge graph 52 3. Traverses the graph to find related entities 53 4. Retrieves documents associated with relevant entities 54 55 Example: 56 ```python 57 # Create knowledge graph 58 graph = nx.DiGraph() 59 graph.add_edge("Python", "Machine Learning", relation="used_in", weight=0.9) 60 graph.add_edge("Machine Learning", "Neural Networks", relation="includes", weight=0.8) 61 62 # Create entity-document mapping 63 entity_docs = { 64 "Python": ["doc1", "doc2"], 65 "Machine Learning": ["doc3", "doc4"], 66 "Neural Networks": ["doc5"] 67 } 68 69 # Create retriever 70 retriever = GraphRetriever( 71 vector_store=store, 72 knowledge_graph=graph, 73 entity_document_mapping=entity_docs, 74 embedder=embedder, 75 max_hops=2 76 ) 77 78 # Retrieve 79 query = RetrievalQuery(text="Python for deep learning", top_k=5) 80 results = retriever.retrieve(query) 81 ``` 82 """ 83 84 def __init__( 85 self, 86 vector_store: BaseVectorStore, 87 knowledge_graph: Optional['nx.DiGraph'] = None, 88 entity_document_mapping: Optional[Dict[str, List[str]]] = None, 89 embedder: Optional[Any] = None, 90 max_hops: int = 2, 91 min_relation_weight: float = 0.5, 92 combine_vector_scores: bool = True, 93 graph_weight: float = 0.4, 94 vector_weight: float = 0.6 95 ): 96 """ 97 Initialize graph retriever. 98 99 Args: 100 vector_store: Vector store for document retrieval 101 knowledge_graph: NetworkX DiGraph with entity relationships 102 entity_document_mapping: Dict mapping entity IDs to document IDs 103 embedder: Embedding provider for query encoding 104 max_hops: Maximum number of hops in graph traversal 105 min_relation_weight: Minimum weight for relationship edges 106 combine_vector_scores: If True, combine graph and vector scores 107 graph_weight: Weight for graph-based relevance 108 vector_weight: Weight for vector similarity 109 """ 110 if not NETWORKX_AVAILABLE: 111 raise ImportError( 112 "NetworkX is required for GraphRetriever. " 113 "Install with: pip install networkx" 114 ) 115 116 self.vector_store = vector_store 117 self.knowledge_graph = knowledge_graph or nx.DiGraph() 118 self.entity_document_mapping = entity_document_mapping or {} 119 self.embedder = embedder 120 self.max_hops = max_hops 121 self.min_relation_weight = min_relation_weight 122 self.combine_vector_scores = combine_vector_scores 123 self.graph_weight = graph_weight 124 self.vector_weight = vector_weight 125 126 # Normalize weights 127 total_weight = graph_weight + vector_weight 128 if total_weight > 0: 129 self.graph_weight = graph_weight / total_weight 130 self.vector_weight = vector_weight / total_weight 131 132 def extract_entities(self, query_text: str) -> List[str]: 133 """ 134 Extract entities from query text. 135 136 Simple implementation: looks for entity names in the knowledge graph. 137 For production, use NER models (spaCy, Hugging Face, etc.). 138 139 Args: 140 query_text: Query text 141 142 Returns: 143 List of entity IDs found in the query 144 """ 145 query_lower = query_text.lower() 146 entities = [] 147 148 # Simple matching: check if entity names appear in query 149 for entity_id in self.knowledge_graph.nodes(): 150 entity_name = str(entity_id).lower() 151 if entity_name in query_lower: 152 entities.append(entity_id) 153 154 logger.info(f"Extracted {len(entities)} entities from query: {entities}") 155 return entities 156 157 def traverse_graph( 158 self, 159 seed_entities: List[str], 160 max_hops: int 161 ) -> Dict[str, float]: 162 """ 163 Traverse knowledge graph from seed entities. 164 165 Args: 166 seed_entities: Starting entities 167 max_hops: Maximum number of hops 168 169 Returns: 170 Dict mapping entity IDs to relevance scores (0-1) 171 """ 172 entity_scores: Dict[str, float] = {} 173 174 # Initialize seed entities with score 1.0 175 for entity in seed_entities: 176 if entity in self.knowledge_graph: 177 entity_scores[entity] = 1.0 178 179 if not entity_scores: 180 return entity_scores 181 182 # BFS traversal 183 visited: Set[str] = set() 184 current_level = [(e, 1.0, 0) for e in seed_entities] # (entity, score, hop) 185 186 while current_level: 187 next_level = [] 188 189 for entity_id, score, hop in current_level: 190 if entity_id in visited or hop >= max_hops: 191 continue 192 193 visited.add(entity_id) 194 195 # Get neighbors 196 if entity_id not in self.knowledge_graph: 197 continue 198 199 for neighbor in self.knowledge_graph.neighbors(entity_id): 200 # Get edge weight 201 edge_data = self.knowledge_graph.get_edge_data(entity_id, neighbor) 202 edge_weight = edge_data.get('weight', 1.0) if edge_data else 1.0 203 204 # Skip weak relationships 205 if edge_weight < self.min_relation_weight: 206 continue 207 208 # Calculate neighbor score (decay with hops) 209 decay_factor = 0.7 ** (hop + 1) 210 neighbor_score = score * edge_weight * decay_factor 211 212 # Update score if better 213 if neighbor not in entity_scores or neighbor_score > entity_scores[neighbor]: 214 entity_scores[neighbor] = neighbor_score 215 next_level.append((neighbor, neighbor_score, hop + 1)) 216 217 current_level = next_level 218 219 logger.info(f"Graph traversal found {len(entity_scores)} relevant entities") 220 return entity_scores 221 222 def retrieve( 223 self, 224 query: RetrievalQuery, 225 **kwargs 226 ) -> List[SearchResult]: 227 """ 228 Perform graph-based retrieval. 229 230 Args: 231 query: The retrieval query 232 **kwargs: Additional arguments 233 234 Returns: 235 List of SearchResult objects ranked by combined graph + vector scores 236 """ 237 if not query.text: 238 raise ValueError("GraphRetriever requires query.text") 239 240 # Step 1: Extract entities from query 241 seed_entities = self.extract_entities(query.text) 242 243 if not seed_entities: 244 logger.warning("No entities found in query, falling back to vector search") 245 # Fallback to pure vector search 246 if self.embedder and not query.embedding: 247 query.embedding = self.embedder.embed_text(query.text) 248 249 return self.vector_store.search( 250 query_embedding=query.embedding, 251 top_k=query.top_k, 252 filters=query.filters 253 ) 254 255 # Step 2: Traverse graph to find related entities 256 entity_scores = self.traverse_graph(seed_entities, self.max_hops) 257 258 # Step 3: Collect documents associated with relevant entities 259 doc_graph_scores: Dict[str, float] = {} 260 261 for entity_id, entity_score in entity_scores.items(): 262 doc_ids = self.entity_document_mapping.get(entity_id, []) 263 for doc_id in doc_ids: 264 if doc_id not in doc_graph_scores or entity_score > doc_graph_scores[doc_id]: 265 doc_graph_scores[doc_id] = entity_score 266 267 logger.info(f"Found {len(doc_graph_scores)} documents via graph traversal") 268 269 if not doc_graph_scores: 270 logger.warning("No documents found via graph, falling back to vector search") 271 if self.embedder and not query.embedding: 272 query.embedding = self.embedder.embed_text(query.text) 273 274 return self.vector_store.search( 275 query_embedding=query.embedding, 276 top_k=query.top_k, 277 filters=query.filters 278 ) 279 280 # Step 4: Get vector scores if combining 281 if self.combine_vector_scores: 282 # Get embeddings for vector search 283 if self.embedder and not query.embedding: 284 query.embedding = self.embedder.embed_text(query.text) 285 286 # Retrieve more documents for scoring 287 vector_results = self.vector_store.search( 288 query_embedding=query.embedding, 289 top_k=query.top_k * 3, # Get more for better coverage 290 filters=query.filters 291 ) 292 293 # Create document ID to vector score mapping 294 doc_vector_scores: Dict[str, float] = {} 295 for result in vector_results: 296 doc_vector_scores[result.document.id] = result.score 297 298 # Combine scores 299 combined_results: List[SearchResult] = [] 300 301 for doc_id, graph_score in doc_graph_scores.items(): 302 # Get document from vector store 303 document = self.vector_store.get_document(doc_id) 304 if not document: 305 continue 306 307 # Get vector score (0 if not found) 308 vector_score = doc_vector_scores.get(doc_id, 0.0) 309 310 # Combine scores 311 combined_score = ( 312 self.graph_weight * graph_score + 313 self.vector_weight * vector_score 314 ) 315 316 combined_results.append( 317 SearchResult( 318 document=document, 319 score=combined_score, 320 rank=0 # Will be set later 321 ) 322 ) 323 else: 324 # Use only graph scores 325 combined_results: List[SearchResult] = [] 326 327 for doc_id, graph_score in doc_graph_scores.items(): 328 document = self.vector_store.get_document(doc_id) 329 if not document: 330 continue 331 332 combined_results.append( 333 SearchResult( 334 document=document, 335 score=graph_score, 336 rank=0 337 ) 338 ) 339 340 # Sort by score and limit to top_k 341 combined_results.sort(key=lambda x: x.score, reverse=True) 342 final_results = combined_results[:query.top_k] 343 344 # Update ranks 345 for i, result in enumerate(final_results): 346 result.rank = i + 1 347 348 logger.info(f"Graph retrieval complete: {len(final_results)} results") 349 350 return final_results
Graph-based retrieval using entity relationships.
This retriever:
- Extracts entities from the query
- Finds matching entities in the knowledge graph
- Traverses the graph to find related entities
- Retrieves documents associated with relevant entities
Example:
# Create knowledge graph
graph = nx.DiGraph()
graph.add_edge("Python", "Machine Learning", relation="used_in", weight=0.9)
graph.add_edge("Machine Learning", "Neural Networks", relation="includes", weight=0.8)
# Create entity-document mapping
entity_docs = {
"Python": ["doc1", "doc2"],
"Machine Learning": ["doc3", "doc4"],
"Neural Networks": ["doc5"]
}
# Create retriever
retriever = GraphRetriever(
vector_store=store,
knowledge_graph=graph,
entity_document_mapping=entity_docs,
embedder=embedder,
max_hops=2
)
# Retrieve
query = RetrievalQuery(text="Python for deep learning", top_k=5)
results = retriever.retrieve(query)
84 def __init__( 85 self, 86 vector_store: BaseVectorStore, 87 knowledge_graph: Optional['nx.DiGraph'] = None, 88 entity_document_mapping: Optional[Dict[str, List[str]]] = None, 89 embedder: Optional[Any] = None, 90 max_hops: int = 2, 91 min_relation_weight: float = 0.5, 92 combine_vector_scores: bool = True, 93 graph_weight: float = 0.4, 94 vector_weight: float = 0.6 95 ): 96 """ 97 Initialize graph retriever. 98 99 Args: 100 vector_store: Vector store for document retrieval 101 knowledge_graph: NetworkX DiGraph with entity relationships 102 entity_document_mapping: Dict mapping entity IDs to document IDs 103 embedder: Embedding provider for query encoding 104 max_hops: Maximum number of hops in graph traversal 105 min_relation_weight: Minimum weight for relationship edges 106 combine_vector_scores: If True, combine graph and vector scores 107 graph_weight: Weight for graph-based relevance 108 vector_weight: Weight for vector similarity 109 """ 110 if not NETWORKX_AVAILABLE: 111 raise ImportError( 112 "NetworkX is required for GraphRetriever. " 113 "Install with: pip install networkx" 114 ) 115 116 self.vector_store = vector_store 117 self.knowledge_graph = knowledge_graph or nx.DiGraph() 118 self.entity_document_mapping = entity_document_mapping or {} 119 self.embedder = embedder 120 self.max_hops = max_hops 121 self.min_relation_weight = min_relation_weight 122 self.combine_vector_scores = combine_vector_scores 123 self.graph_weight = graph_weight 124 self.vector_weight = vector_weight 125 126 # Normalize weights 127 total_weight = graph_weight + vector_weight 128 if total_weight > 0: 129 self.graph_weight = graph_weight / total_weight 130 self.vector_weight = vector_weight / total_weight
Initialize graph retriever.
Args: vector_store: Vector store for document retrieval knowledge_graph: NetworkX DiGraph with entity relationships entity_document_mapping: Dict mapping entity IDs to document IDs embedder: Embedding provider for query encoding max_hops: Maximum number of hops in graph traversal min_relation_weight: Minimum weight for relationship edges combine_vector_scores: If True, combine graph and vector scores graph_weight: Weight for graph-based relevance vector_weight: Weight for vector similarity
132 def extract_entities(self, query_text: str) -> List[str]: 133 """ 134 Extract entities from query text. 135 136 Simple implementation: looks for entity names in the knowledge graph. 137 For production, use NER models (spaCy, Hugging Face, etc.). 138 139 Args: 140 query_text: Query text 141 142 Returns: 143 List of entity IDs found in the query 144 """ 145 query_lower = query_text.lower() 146 entities = [] 147 148 # Simple matching: check if entity names appear in query 149 for entity_id in self.knowledge_graph.nodes(): 150 entity_name = str(entity_id).lower() 151 if entity_name in query_lower: 152 entities.append(entity_id) 153 154 logger.info(f"Extracted {len(entities)} entities from query: {entities}") 155 return entities
Extract entities from query text.
Simple implementation: looks for entity names in the knowledge graph. For production, use NER models (spaCy, Hugging Face, etc.).
Args: query_text: Query text
Returns: List of entity IDs found in the query
157 def traverse_graph( 158 self, 159 seed_entities: List[str], 160 max_hops: int 161 ) -> Dict[str, float]: 162 """ 163 Traverse knowledge graph from seed entities. 164 165 Args: 166 seed_entities: Starting entities 167 max_hops: Maximum number of hops 168 169 Returns: 170 Dict mapping entity IDs to relevance scores (0-1) 171 """ 172 entity_scores: Dict[str, float] = {} 173 174 # Initialize seed entities with score 1.0 175 for entity in seed_entities: 176 if entity in self.knowledge_graph: 177 entity_scores[entity] = 1.0 178 179 if not entity_scores: 180 return entity_scores 181 182 # BFS traversal 183 visited: Set[str] = set() 184 current_level = [(e, 1.0, 0) for e in seed_entities] # (entity, score, hop) 185 186 while current_level: 187 next_level = [] 188 189 for entity_id, score, hop in current_level: 190 if entity_id in visited or hop >= max_hops: 191 continue 192 193 visited.add(entity_id) 194 195 # Get neighbors 196 if entity_id not in self.knowledge_graph: 197 continue 198 199 for neighbor in self.knowledge_graph.neighbors(entity_id): 200 # Get edge weight 201 edge_data = self.knowledge_graph.get_edge_data(entity_id, neighbor) 202 edge_weight = edge_data.get('weight', 1.0) if edge_data else 1.0 203 204 # Skip weak relationships 205 if edge_weight < self.min_relation_weight: 206 continue 207 208 # Calculate neighbor score (decay with hops) 209 decay_factor = 0.7 ** (hop + 1) 210 neighbor_score = score * edge_weight * decay_factor 211 212 # Update score if better 213 if neighbor not in entity_scores or neighbor_score > entity_scores[neighbor]: 214 entity_scores[neighbor] = neighbor_score 215 next_level.append((neighbor, neighbor_score, hop + 1)) 216 217 current_level = next_level 218 219 logger.info(f"Graph traversal found {len(entity_scores)} relevant entities") 220 return entity_scores
Traverse knowledge graph from seed entities.
Args: seed_entities: Starting entities max_hops: Maximum number of hops
Returns: Dict mapping entity IDs to relevance scores (0-1)
222 def retrieve( 223 self, 224 query: RetrievalQuery, 225 **kwargs 226 ) -> List[SearchResult]: 227 """ 228 Perform graph-based retrieval. 229 230 Args: 231 query: The retrieval query 232 **kwargs: Additional arguments 233 234 Returns: 235 List of SearchResult objects ranked by combined graph + vector scores 236 """ 237 if not query.text: 238 raise ValueError("GraphRetriever requires query.text") 239 240 # Step 1: Extract entities from query 241 seed_entities = self.extract_entities(query.text) 242 243 if not seed_entities: 244 logger.warning("No entities found in query, falling back to vector search") 245 # Fallback to pure vector search 246 if self.embedder and not query.embedding: 247 query.embedding = self.embedder.embed_text(query.text) 248 249 return self.vector_store.search( 250 query_embedding=query.embedding, 251 top_k=query.top_k, 252 filters=query.filters 253 ) 254 255 # Step 2: Traverse graph to find related entities 256 entity_scores = self.traverse_graph(seed_entities, self.max_hops) 257 258 # Step 3: Collect documents associated with relevant entities 259 doc_graph_scores: Dict[str, float] = {} 260 261 for entity_id, entity_score in entity_scores.items(): 262 doc_ids = self.entity_document_mapping.get(entity_id, []) 263 for doc_id in doc_ids: 264 if doc_id not in doc_graph_scores or entity_score > doc_graph_scores[doc_id]: 265 doc_graph_scores[doc_id] = entity_score 266 267 logger.info(f"Found {len(doc_graph_scores)} documents via graph traversal") 268 269 if not doc_graph_scores: 270 logger.warning("No documents found via graph, falling back to vector search") 271 if self.embedder and not query.embedding: 272 query.embedding = self.embedder.embed_text(query.text) 273 274 return self.vector_store.search( 275 query_embedding=query.embedding, 276 top_k=query.top_k, 277 filters=query.filters 278 ) 279 280 # Step 4: Get vector scores if combining 281 if self.combine_vector_scores: 282 # Get embeddings for vector search 283 if self.embedder and not query.embedding: 284 query.embedding = self.embedder.embed_text(query.text) 285 286 # Retrieve more documents for scoring 287 vector_results = self.vector_store.search( 288 query_embedding=query.embedding, 289 top_k=query.top_k * 3, # Get more for better coverage 290 filters=query.filters 291 ) 292 293 # Create document ID to vector score mapping 294 doc_vector_scores: Dict[str, float] = {} 295 for result in vector_results: 296 doc_vector_scores[result.document.id] = result.score 297 298 # Combine scores 299 combined_results: List[SearchResult] = [] 300 301 for doc_id, graph_score in doc_graph_scores.items(): 302 # Get document from vector store 303 document = self.vector_store.get_document(doc_id) 304 if not document: 305 continue 306 307 # Get vector score (0 if not found) 308 vector_score = doc_vector_scores.get(doc_id, 0.0) 309 310 # Combine scores 311 combined_score = ( 312 self.graph_weight * graph_score + 313 self.vector_weight * vector_score 314 ) 315 316 combined_results.append( 317 SearchResult( 318 document=document, 319 score=combined_score, 320 rank=0 # Will be set later 321 ) 322 ) 323 else: 324 # Use only graph scores 325 combined_results: List[SearchResult] = [] 326 327 for doc_id, graph_score in doc_graph_scores.items(): 328 document = self.vector_store.get_document(doc_id) 329 if not document: 330 continue 331 332 combined_results.append( 333 SearchResult( 334 document=document, 335 score=graph_score, 336 rank=0 337 ) 338 ) 339 340 # Sort by score and limit to top_k 341 combined_results.sort(key=lambda x: x.score, reverse=True) 342 final_results = combined_results[:query.top_k] 343 344 # Update ranks 345 for i, result in enumerate(final_results): 346 result.rank = i + 1 347 348 logger.info(f"Graph retrieval complete: {len(final_results)} results") 349 350 return final_results
Perform graph-based retrieval.
Args: query: The retrieval query **kwargs: Additional arguments
Returns: List of SearchResult objects ranked by combined graph + vector scores
38class SQLRetriever(BaseRetriever): 39 """ 40 Retrieval from structured databases using SQL queries. 41 42 This retriever: 43 1. Converts natural language queries to SQL (via LLM or rule-based) 44 2. Executes SQL against a database 45 3. Converts results to Document objects 46 4. Returns SearchResult objects with relevance scores 47 48 Example: 49 ```python 50 import sqlite3 51 52 # Create database connection 53 conn = sqlite3.connect("products.db") 54 55 # Define schema 56 schema = SQLSchema( 57 table_name="products", 58 columns=[ 59 {"name": "id", "type": "INTEGER", "description": "Product ID"}, 60 {"name": "name", "type": "TEXT", "description": "Product name"}, 61 {"name": "price", "type": "REAL", "description": "Price in USD"}, 62 {"name": "category", "type": "TEXT", "description": "Product category"} 63 ], 64 primary_key="id", 65 description="E-commerce product catalog" 66 ) 67 68 # Create retriever 69 retriever = SQLRetriever( 70 db_connection=conn, 71 schema=schema, 72 text_to_sql_fn=my_text_to_sql_function, 73 db_type="sqlite" 74 ) 75 76 # Retrieve 77 query = RetrievalQuery(text="products under $100 in electronics", top_k=10) 78 results = retriever.retrieve(query) 79 ``` 80 """ 81 82 def __init__( 83 self, 84 db_connection: Any, 85 schema: SQLSchema, 86 text_to_sql_fn: Optional[Callable[[str, SQLSchema], SQLQuery]] = None, 87 db_type: str = "sqlite", 88 content_columns: Optional[List[str]] = None, 89 score_column: Optional[str] = None, 90 default_score: float = 1.0 91 ): 92 """ 93 Initialize SQL retriever. 94 95 Args: 96 db_connection: Database connection object (e.g., sqlite3.Connection) 97 schema: Database schema information 98 text_to_sql_fn: Function to convert text to SQL (None = use simple rules) 99 db_type: Database type ("sqlite", "postgresql", "mysql") 100 content_columns: Columns to use for document content (None = all columns) 101 score_column: Column to use for relevance score (None = use default_score) 102 default_score: Default score when no score_column specified 103 """ 104 self.db_connection = db_connection 105 self.schema = schema 106 self.text_to_sql_fn = text_to_sql_fn or self._simple_text_to_sql 107 self.db_type = db_type 108 self.content_columns = content_columns 109 self.score_column = score_column 110 self.default_score = default_score 111 112 def _simple_text_to_sql(self, query_text: str, schema: SQLSchema) -> SQLQuery: 113 """ 114 Simple rule-based text-to-SQL conversion. 115 116 For production, use LLM-based conversion or specialized libraries. 117 118 Args: 119 query_text: Natural language query 120 schema: Database schema 121 122 Returns: 123 SQLQuery object 124 """ 125 # Extract keywords for filtering 126 query_lower = query_text.lower() 127 128 # Build SELECT clause 129 if self.content_columns: 130 columns = ", ".join(self.content_columns) 131 else: 132 columns = "*" 133 134 # Add score column if specified 135 if self.score_column: 136 columns = f"{columns}, {self.score_column}" 137 138 # Build basic SELECT 139 sql = f"SELECT {columns} FROM {schema.table_name}" 140 141 # Simple WHERE clause based on keywords 142 conditions = [] 143 144 # Look for numeric comparisons 145 if "under" in query_lower or "less than" in query_lower or "<" in query_lower: 146 # Try to find price column 147 price_cols = [c["name"] for c in schema.columns if "price" in c["name"].lower() or "cost" in c["name"].lower()] 148 if price_cols: 149 # Extract number 150 import re 151 numbers = re.findall(r'\d+', query_text) 152 if numbers: 153 conditions.append(f"{price_cols[0]} < {numbers[0]}") 154 155 if "over" in query_lower or "more than" in query_lower or ">" in query_lower: 156 price_cols = [c["name"] for c in schema.columns if "price" in c["name"].lower() or "cost" in c["name"].lower()] 157 if price_cols: 158 import re 159 numbers = re.findall(r'\d+', query_text) 160 if numbers: 161 conditions.append(f"{price_cols[0]} > {numbers[0]}") 162 163 # Look for text columns to search 164 text_cols = [c["name"] for c in schema.columns if c["type"].upper() in ["TEXT", "VARCHAR", "STRING"]] 165 166 # Build LIKE conditions for text search 167 search_terms = [] 168 # Remove common words 169 stop_words = {"in", "the", "a", "an", "under", "over", "less", "more", "than", "products", "items"} 170 words = [w for w in query_lower.split() if w not in stop_words and not w.isdigit()] 171 172 for word in words: 173 if len(word) > 2: # Skip very short words 174 search_terms.append(word) 175 176 if search_terms and text_cols: 177 like_conditions = [] 178 for col in text_cols: 179 for term in search_terms: 180 like_conditions.append(f"{col} LIKE '%{term}%'") 181 182 if like_conditions: 183 conditions.append(f"({' OR '.join(like_conditions)})") 184 185 # Add WHERE clause if conditions exist 186 if conditions: 187 sql += " WHERE " + " AND ".join(conditions) 188 189 # Add LIMIT 190 sql += " LIMIT 100" # Safety limit 191 192 return SQLQuery( 193 sql=sql, 194 explanation=f"Simple keyword-based SQL generation" 195 ) 196 197 def _execute_sql(self, sql_query: SQLQuery) -> List[Dict[str, Any]]: 198 """ 199 Execute SQL query and return results. 200 201 Args: 202 sql_query: SQL query to execute 203 204 Returns: 205 List of result rows as dictionaries 206 """ 207 cursor = self.db_connection.cursor() 208 209 try: 210 logger.info(f"Executing SQL: {sql_query.sql}") 211 212 if sql_query.parameters: 213 cursor.execute(sql_query.sql, sql_query.parameters) 214 else: 215 cursor.execute(sql_query.sql) 216 217 # Get column names 218 if cursor.description: 219 columns = [desc[0] for desc in cursor.description] 220 221 # Fetch results 222 rows = cursor.fetchall() 223 224 # Convert to dictionaries 225 results = [] 226 for row in rows: 227 result_dict = dict(zip(columns, row)) 228 results.append(result_dict) 229 230 logger.info(f"SQL returned {len(results)} rows") 231 return results 232 else: 233 return [] 234 235 except Exception as e: 236 logger.error(f"SQL execution failed: {e}") 237 raise 238 finally: 239 cursor.close() 240 241 def _row_to_document(self, row: Dict[str, Any], index: int) -> Document: 242 """ 243 Convert database row to Document object. 244 245 Args: 246 row: Database row as dictionary 247 index: Row index for ID generation 248 249 Returns: 250 Document object 251 """ 252 # Extract score if available 253 score = self.default_score 254 if self.score_column and self.score_column in row: 255 score = float(row[self.score_column]) 256 257 # Build content from specified columns or all columns 258 if self.content_columns: 259 content_parts = [] 260 for col in self.content_columns: 261 if col in row: 262 content_parts.append(f"{col}: {row[col]}") 263 content = "\n".join(content_parts) 264 else: 265 # Use JSON representation 266 content = json.dumps(row, indent=2, default=str) 267 268 # Use primary key if available, otherwise use index 269 doc_id = str(row.get(self.schema.primary_key, f"sql_result_{index}")) 270 271 # Store all row data in metadata 272 metadata = { 273 "source": "sql_database", 274 "table": self.schema.table_name, 275 "row_data": row, 276 "sql_score": score 277 } 278 279 return Document( 280 id=doc_id, 281 content=content, 282 metadata=metadata 283 ) 284 285 def retrieve( 286 self, 287 query: RetrievalQuery, 288 **kwargs 289 ) -> List[SearchResult]: 290 """ 291 Perform SQL-based retrieval. 292 293 Args: 294 query: The retrieval query 295 **kwargs: Additional arguments 296 297 Returns: 298 List of SearchResult objects from database query 299 """ 300 if not query.text: 301 raise ValueError("SQLRetriever requires query.text") 302 303 # Convert text to SQL 304 sql_query = self.text_to_sql_fn(query.text, self.schema) 305 306 logger.info(f"Generated SQL: {sql_query.sql}") 307 if sql_query.explanation: 308 logger.info(f"Explanation: {sql_query.explanation}") 309 310 # Execute SQL 311 rows = self._execute_sql(sql_query) 312 313 if not rows: 314 logger.warning("SQL query returned no results") 315 return [] 316 317 # Convert rows to Documents 318 documents = [ 319 self._row_to_document(row, i) 320 for i, row in enumerate(rows) 321 ] 322 323 # Create SearchResult objects 324 results = [ 325 SearchResult( 326 document=doc, 327 score=doc.metadata.get("sql_score", self.default_score), 328 rank=i + 1 329 ) 330 for i, doc in enumerate(documents) 331 ] 332 333 # Limit to top_k 334 results = results[:query.top_k] 335 336 # Update ranks 337 for i, result in enumerate(results): 338 result.rank = i + 1 339 340 logger.info(f"SQL retrieval complete: {len(results)} results") 341 342 return results
Retrieval from structured databases using SQL queries.
This retriever:
- Converts natural language queries to SQL (via LLM or rule-based)
- Executes SQL against a database
- Converts results to Document objects
- Returns SearchResult objects with relevance scores
Example:
import sqlite3
# Create database connection
conn = sqlite3.connect("products.db")
# Define schema
schema = SQLSchema(
table_name="products",
columns=[
{"name": "id", "type": "INTEGER", "description": "Product ID"},
{"name": "name", "type": "TEXT", "description": "Product name"},
{"name": "price", "type": "REAL", "description": "Price in USD"},
{"name": "category", "type": "TEXT", "description": "Product category"}
],
primary_key="id",
description="E-commerce product catalog"
)
# Create retriever
retriever = SQLRetriever(
db_connection=conn,
schema=schema,
text_to_sql_fn=my_text_to_sql_function,
db_type="sqlite"
)
# Retrieve
query = RetrievalQuery(text="products under $100 in electronics", top_k=10)
results = retriever.retrieve(query)
82 def __init__( 83 self, 84 db_connection: Any, 85 schema: SQLSchema, 86 text_to_sql_fn: Optional[Callable[[str, SQLSchema], SQLQuery]] = None, 87 db_type: str = "sqlite", 88 content_columns: Optional[List[str]] = None, 89 score_column: Optional[str] = None, 90 default_score: float = 1.0 91 ): 92 """ 93 Initialize SQL retriever. 94 95 Args: 96 db_connection: Database connection object (e.g., sqlite3.Connection) 97 schema: Database schema information 98 text_to_sql_fn: Function to convert text to SQL (None = use simple rules) 99 db_type: Database type ("sqlite", "postgresql", "mysql") 100 content_columns: Columns to use for document content (None = all columns) 101 score_column: Column to use for relevance score (None = use default_score) 102 default_score: Default score when no score_column specified 103 """ 104 self.db_connection = db_connection 105 self.schema = schema 106 self.text_to_sql_fn = text_to_sql_fn or self._simple_text_to_sql 107 self.db_type = db_type 108 self.content_columns = content_columns 109 self.score_column = score_column 110 self.default_score = default_score
Initialize SQL retriever.
Args: db_connection: Database connection object (e.g., sqlite3.Connection) schema: Database schema information text_to_sql_fn: Function to convert text to SQL (None = use simple rules) db_type: Database type ("sqlite", "postgresql", "mysql") content_columns: Columns to use for document content (None = all columns) score_column: Column to use for relevance score (None = use default_score) default_score: Default score when no score_column specified
285 def retrieve( 286 self, 287 query: RetrievalQuery, 288 **kwargs 289 ) -> List[SearchResult]: 290 """ 291 Perform SQL-based retrieval. 292 293 Args: 294 query: The retrieval query 295 **kwargs: Additional arguments 296 297 Returns: 298 List of SearchResult objects from database query 299 """ 300 if not query.text: 301 raise ValueError("SQLRetriever requires query.text") 302 303 # Convert text to SQL 304 sql_query = self.text_to_sql_fn(query.text, self.schema) 305 306 logger.info(f"Generated SQL: {sql_query.sql}") 307 if sql_query.explanation: 308 logger.info(f"Explanation: {sql_query.explanation}") 309 310 # Execute SQL 311 rows = self._execute_sql(sql_query) 312 313 if not rows: 314 logger.warning("SQL query returned no results") 315 return [] 316 317 # Convert rows to Documents 318 documents = [ 319 self._row_to_document(row, i) 320 for i, row in enumerate(rows) 321 ] 322 323 # Create SearchResult objects 324 results = [ 325 SearchResult( 326 document=doc, 327 score=doc.metadata.get("sql_score", self.default_score), 328 rank=i + 1 329 ) 330 for i, doc in enumerate(documents) 331 ] 332 333 # Limit to top_k 334 results = results[:query.top_k] 335 336 # Update ranks 337 for i, result in enumerate(results): 338 result.rank = i + 1 339 340 logger.info(f"SQL retrieval complete: {len(results)} results") 341 342 return results
Perform SQL-based retrieval.
Args: query: The retrieval query **kwargs: Additional arguments
Returns: List of SearchResult objects from database query
31class MultiIndexRetriever(BaseRetriever): 32 """ 33 Retrieve from multiple indices/sources and merge results. 34 35 This retriever: 36 1. Queries multiple independent retrievers in parallel (or sequentially) 37 2. Merges results from all sources 38 3. Applies source-specific weights and boosts 39 4. Re-ranks using RRF or weighted scoring 40 41 Use cases: 42 - Multi-domain search (HR + Finance + IT) 43 - Federated search across multiple databases 44 - Combining different data sources (documents + SQL + graph) 45 46 Example: 47 ```python 48 # Create retrievers for different sources 49 hr_retriever = VectorRetriever(hr_store, embedder) 50 finance_retriever = VectorRetriever(finance_store, embedder) 51 it_retriever = VectorRetriever(it_store, embedder) 52 53 # Create multi-index retriever 54 retriever = MultiIndexRetriever( 55 sources=[ 56 SourceConfig("HR", hr_retriever, weight=1.0), 57 SourceConfig("Finance", finance_retriever, weight=1.5), 58 SourceConfig("IT", it_retriever, weight=1.0) 59 ], 60 fusion_strategy="rrf" 61 ) 62 63 # Retrieve across all sources 64 query = RetrievalQuery(text="employee benefits policy", top_k=10) 65 results = retriever.retrieve(query) 66 ``` 67 """ 68 69 def __init__( 70 self, 71 sources: List[SourceConfig], 72 fusion_strategy: str = "rrf", 73 rrf_k: int = 60, 74 normalize_scores: bool = True 75 ): 76 """ 77 Initialize multi-index retriever. 78 79 Args: 80 sources: List of SourceConfig objects defining retrievers 81 fusion_strategy: Strategy for merging results: 82 - "rrf": Reciprocal Rank Fusion 83 - "weighted_average": Weighted average of normalized scores 84 - "max_score": Take maximum score across sources 85 rrf_k: Constant for RRF formula (default 60) 86 normalize_scores: Normalize scores to [0,1] before fusion 87 """ 88 self.sources = sources 89 self.fusion_strategy = fusion_strategy 90 self.rrf_k = rrf_k 91 self.normalize_scores = normalize_scores 92 93 # Validate sources 94 if not sources: 95 raise ValueError("At least one source must be provided") 96 97 # Normalize weights 98 total_weight = sum(s.weight for s in sources if s.enabled) 99 if total_weight > 0: 100 for source in sources: 101 source.weight = source.weight / total_weight 102 103 def _normalize_scores_for_source( 104 self, 105 results: List[SearchResult] 106 ) -> List[SearchResult]: 107 """ 108 Normalize scores to [0, 1] range. 109 110 Args: 111 results: Search results 112 113 Returns: 114 Results with normalized scores 115 """ 116 if not results: 117 return results 118 119 scores = [r.score for r in results] 120 min_score = min(scores) 121 max_score = max(scores) 122 123 if max_score - min_score == 0: 124 # All scores are the same 125 for result in results: 126 result.score = 1.0 127 else: 128 for result in results: 129 result.score = (result.score - min_score) / (max_score - min_score) 130 131 return results 132 133 def _rrf_fusion( 134 self, 135 source_results: Dict[str, List[SearchResult]] 136 ) -> List[SearchResult]: 137 """ 138 Merge results using Reciprocal Rank Fusion. 139 140 Args: 141 source_results: Dict mapping source names to their results 142 143 Returns: 144 Merged and re-ranked results 145 """ 146 # Track all unique documents and their RRF scores 147 doc_rrf_scores: Dict[str, float] = {} 148 doc_objects: Dict[str, SearchResult] = {} 149 150 # Calculate RRF scores 151 for source_name, results in source_results.items(): 152 # Get source config 153 source_config = next((s for s in self.sources if s.name == source_name), None) 154 if not source_config or not source_config.enabled: 155 continue 156 157 weight = source_config.weight 158 boost = source_config.boost_factor 159 160 for result in results: 161 doc_id = result.document.id 162 163 # RRF score: weight / (k + rank) 164 rrf_score = (weight * boost) / (self.rrf_k + result.rank) 165 166 if doc_id not in doc_rrf_scores: 167 doc_rrf_scores[doc_id] = rrf_score 168 doc_objects[doc_id] = result 169 # Add source metadata 170 result.document.metadata["retrieval_source"] = source_name 171 else: 172 doc_rrf_scores[doc_id] += rrf_score 173 # If document appears in multiple sources, mark it 174 existing_source = doc_objects[doc_id].document.metadata.get("retrieval_source", "") 175 if source_name not in existing_source: 176 doc_objects[doc_id].document.metadata["retrieval_source"] = f"{existing_source}, {source_name}" 177 178 # Create final results 179 merged_results = [] 180 for doc_id, rrf_score in doc_rrf_scores.items(): 181 result = doc_objects[doc_id] 182 result.score = rrf_score 183 merged_results.append(result) 184 185 # Sort by RRF score 186 merged_results.sort(key=lambda x: x.score, reverse=True) 187 188 return merged_results 189 190 def _weighted_average_fusion( 191 self, 192 source_results: Dict[str, List[SearchResult]] 193 ) -> List[SearchResult]: 194 """ 195 Merge results using weighted average of scores. 196 197 Args: 198 source_results: Dict mapping source names to their results 199 200 Returns: 201 Merged and re-ranked results 202 """ 203 # Track all unique documents and their weighted scores 204 doc_weighted_scores: Dict[str, float] = {} 205 doc_score_counts: Dict[str, int] = {} 206 doc_objects: Dict[str, SearchResult] = {} 207 208 # Calculate weighted scores 209 for source_name, results in source_results.items(): 210 # Get source config 211 source_config = next((s for s in self.sources if s.name == source_name), None) 212 if not source_config or not source_config.enabled: 213 continue 214 215 # Normalize scores for this source 216 if self.normalize_scores: 217 results = self._normalize_scores_for_source(results) 218 219 weight = source_config.weight 220 boost = source_config.boost_factor 221 222 for result in results: 223 doc_id = result.document.id 224 weighted_score = result.score * weight * boost 225 226 if doc_id not in doc_weighted_scores: 227 doc_weighted_scores[doc_id] = weighted_score 228 doc_score_counts[doc_id] = 1 229 doc_objects[doc_id] = result 230 result.document.metadata["retrieval_source"] = source_name 231 else: 232 doc_weighted_scores[doc_id] += weighted_score 233 doc_score_counts[doc_id] += 1 234 existing_source = doc_objects[doc_id].document.metadata.get("retrieval_source", "") 235 if source_name not in existing_source: 236 doc_objects[doc_id].document.metadata["retrieval_source"] = f"{existing_source}, {source_name}" 237 238 # Create final results with averaged scores 239 merged_results = [] 240 for doc_id, total_score in doc_weighted_scores.items(): 241 result = doc_objects[doc_id] 242 # Average the weighted scores 243 result.score = total_score / doc_score_counts[doc_id] 244 merged_results.append(result) 245 246 # Sort by score 247 merged_results.sort(key=lambda x: x.score, reverse=True) 248 249 return merged_results 250 251 def _max_score_fusion( 252 self, 253 source_results: Dict[str, List[SearchResult]] 254 ) -> List[SearchResult]: 255 """ 256 Merge results using maximum score across sources. 257 258 Args: 259 source_results: Dict mapping source names to their results 260 261 Returns: 262 Merged and re-ranked results 263 """ 264 # Track all unique documents and their max scores 265 doc_max_scores: Dict[str, float] = {} 266 doc_objects: Dict[str, SearchResult] = {} 267 268 # Find max scores 269 for source_name, results in source_results.items(): 270 # Get source config 271 source_config = next((s for s in self.sources if s.name == source_name), None) 272 if not source_config or not source_config.enabled: 273 continue 274 275 # Normalize scores for this source 276 if self.normalize_scores: 277 results = self._normalize_scores_for_source(results) 278 279 boost = source_config.boost_factor 280 281 for result in results: 282 doc_id = result.document.id 283 boosted_score = result.score * boost 284 285 if doc_id not in doc_max_scores or boosted_score > doc_max_scores[doc_id]: 286 doc_max_scores[doc_id] = boosted_score 287 doc_objects[doc_id] = result 288 result.document.metadata["retrieval_source"] = source_name 289 else: 290 # Update source metadata if from multiple sources 291 existing_source = doc_objects[doc_id].document.metadata.get("retrieval_source", "") 292 if source_name not in existing_source: 293 doc_objects[doc_id].document.metadata["retrieval_source"] = f"{existing_source}, {source_name}" 294 295 # Create final results 296 merged_results = [] 297 for doc_id, max_score in doc_max_scores.items(): 298 result = doc_objects[doc_id] 299 result.score = max_score 300 merged_results.append(result) 301 302 # Sort by score 303 merged_results.sort(key=lambda x: x.score, reverse=True) 304 305 return merged_results 306 307 def retrieve( 308 self, 309 query: RetrievalQuery, 310 **kwargs 311 ) -> List[SearchResult]: 312 """ 313 Perform multi-source retrieval. 314 315 Args: 316 query: The retrieval query 317 **kwargs: Additional arguments passed to each retriever 318 319 Returns: 320 List of SearchResult objects merged from all sources 321 """ 322 # Retrieve from each source 323 source_results: Dict[str, List[SearchResult]] = {} 324 325 for source in self.sources: 326 if not source.enabled: 327 logger.info(f"Skipping disabled source: {source.name}") 328 continue 329 330 try: 331 logger.info(f"Retrieving from source: {source.name}") 332 results = source.retriever.retrieve(query, **kwargs) 333 source_results[source.name] = results 334 logger.info(f"Source {source.name} returned {len(results)} results") 335 336 except Exception as e: 337 logger.warning(f"Retrieval from source {source.name} failed: {e}") 338 source_results[source.name] = [] 339 340 # Check if we got any results 341 total_results = sum(len(results) for results in source_results.values()) 342 if total_results == 0: 343 logger.warning("No results from any source") 344 return [] 345 346 # Merge results based on fusion strategy 347 logger.info(f"Merging results using {self.fusion_strategy} strategy") 348 349 if self.fusion_strategy == "rrf": 350 merged_results = self._rrf_fusion(source_results) 351 elif self.fusion_strategy == "weighted_average": 352 merged_results = self._weighted_average_fusion(source_results) 353 elif self.fusion_strategy == "max_score": 354 merged_results = self._max_score_fusion(source_results) 355 else: 356 raise ValueError(f"Unknown fusion strategy: {self.fusion_strategy}") 357 358 # Limit to top_k 359 final_results = merged_results[:query.top_k] 360 361 # Update ranks 362 for i, result in enumerate(final_results): 363 result.rank = i + 1 364 365 logger.info( 366 f"Multi-index retrieval complete: {len(final_results)} results " 367 f"from {len(source_results)} sources" 368 ) 369 370 return final_results
Retrieve from multiple indices/sources and merge results.
This retriever:
- Queries multiple independent retrievers in parallel (or sequentially)
- Merges results from all sources
- Applies source-specific weights and boosts
- Re-ranks using RRF or weighted scoring
Use cases:
- Multi-domain search (HR + Finance + IT)
- Federated search across multiple databases
- Combining different data sources (documents + SQL + graph)
Example:
# Create retrievers for different sources
hr_retriever = VectorRetriever(hr_store, embedder)
finance_retriever = VectorRetriever(finance_store, embedder)
it_retriever = VectorRetriever(it_store, embedder)
# Create multi-index retriever
retriever = MultiIndexRetriever(
sources=[
SourceConfig("HR", hr_retriever, weight=1.0),
SourceConfig("Finance", finance_retriever, weight=1.5),
SourceConfig("IT", it_retriever, weight=1.0)
],
fusion_strategy="rrf"
)
# Retrieve across all sources
query = RetrievalQuery(text="employee benefits policy", top_k=10)
results = retriever.retrieve(query)
69 def __init__( 70 self, 71 sources: List[SourceConfig], 72 fusion_strategy: str = "rrf", 73 rrf_k: int = 60, 74 normalize_scores: bool = True 75 ): 76 """ 77 Initialize multi-index retriever. 78 79 Args: 80 sources: List of SourceConfig objects defining retrievers 81 fusion_strategy: Strategy for merging results: 82 - "rrf": Reciprocal Rank Fusion 83 - "weighted_average": Weighted average of normalized scores 84 - "max_score": Take maximum score across sources 85 rrf_k: Constant for RRF formula (default 60) 86 normalize_scores: Normalize scores to [0,1] before fusion 87 """ 88 self.sources = sources 89 self.fusion_strategy = fusion_strategy 90 self.rrf_k = rrf_k 91 self.normalize_scores = normalize_scores 92 93 # Validate sources 94 if not sources: 95 raise ValueError("At least one source must be provided") 96 97 # Normalize weights 98 total_weight = sum(s.weight for s in sources if s.enabled) 99 if total_weight > 0: 100 for source in sources: 101 source.weight = source.weight / total_weight
Initialize multi-index retriever.
Args: sources: List of SourceConfig objects defining retrievers fusion_strategy: Strategy for merging results: - "rrf": Reciprocal Rank Fusion - "weighted_average": Weighted average of normalized scores - "max_score": Take maximum score across sources rrf_k: Constant for RRF formula (default 60) normalize_scores: Normalize scores to [0,1] before fusion
307 def retrieve( 308 self, 309 query: RetrievalQuery, 310 **kwargs 311 ) -> List[SearchResult]: 312 """ 313 Perform multi-source retrieval. 314 315 Args: 316 query: The retrieval query 317 **kwargs: Additional arguments passed to each retriever 318 319 Returns: 320 List of SearchResult objects merged from all sources 321 """ 322 # Retrieve from each source 323 source_results: Dict[str, List[SearchResult]] = {} 324 325 for source in self.sources: 326 if not source.enabled: 327 logger.info(f"Skipping disabled source: {source.name}") 328 continue 329 330 try: 331 logger.info(f"Retrieving from source: {source.name}") 332 results = source.retriever.retrieve(query, **kwargs) 333 source_results[source.name] = results 334 logger.info(f"Source {source.name} returned {len(results)} results") 335 336 except Exception as e: 337 logger.warning(f"Retrieval from source {source.name} failed: {e}") 338 source_results[source.name] = [] 339 340 # Check if we got any results 341 total_results = sum(len(results) for results in source_results.values()) 342 if total_results == 0: 343 logger.warning("No results from any source") 344 return [] 345 346 # Merge results based on fusion strategy 347 logger.info(f"Merging results using {self.fusion_strategy} strategy") 348 349 if self.fusion_strategy == "rrf": 350 merged_results = self._rrf_fusion(source_results) 351 elif self.fusion_strategy == "weighted_average": 352 merged_results = self._weighted_average_fusion(source_results) 353 elif self.fusion_strategy == "max_score": 354 merged_results = self._max_score_fusion(source_results) 355 else: 356 raise ValueError(f"Unknown fusion strategy: {self.fusion_strategy}") 357 358 # Limit to top_k 359 final_results = merged_results[:query.top_k] 360 361 # Update ranks 362 for i, result in enumerate(final_results): 363 result.rank = i + 1 364 365 logger.info( 366 f"Multi-index retrieval complete: {len(final_results)} results " 367 f"from {len(source_results)} sources" 368 ) 369 370 return final_results
Perform multi-source retrieval.
Args: query: The retrieval query **kwargs: Additional arguments passed to each retriever
Returns: List of SearchResult objects merged from all sources
21@dataclass 22class SQLSchema: 23 """Represents a database schema.""" 24 table_name: str 25 columns: List[Dict[str, str]] # [{"name": "col1", "type": "int", "description": "..."}] 26 primary_key: Optional[str] = None 27 description: Optional[str] = None
Represents a database schema.
30@dataclass 31class SQLQuery: 32 """Represents a generated SQL query.""" 33 sql: str 34 parameters: Dict[str, Any] = field(default_factory=dict) 35 explanation: Optional[str] = None
Represents a generated SQL query.
21@dataclass 22class SourceConfig: 23 """Configuration for a retrieval source.""" 24 name: str 25 retriever: BaseRetriever 26 weight: float = 1.0 27 boost_factor: float = 1.0 28 enabled: bool = True
Configuration for a retrieval source.