gmf_forge_ai_data.retrieval

Retrieval strategies for GMF Forge AI Data Layer.

This module provides various retrieval strategies for finding relevant documents from vector stores. All retrievers implement the BaseRetriever interface.

Available Retrievers:

Basic Retrievers (wrappers around vector store search):

  • VectorRetriever: Pure vector similarity search
  • KeywordRetriever: Pure keyword/BM25 search
  • HybridRetriever: Combined vector + keyword search

Advanced Retrievers (production patterns):

  • MMRRetriever: Maximal Marginal Relevance for diverse results
  • ParentDocumentRetriever: Search chunks, return parent documents
  • EnsembleRetriever: Combine multiple retrievers with score fusion
  • HierarchicalRetriever: Two-stage retrieval (coarse → fine)
  • GraphRetriever: Graph-based retrieval with entity relationships
  • SQLRetriever: Structured data retrieval via SQL
  • MultiIndexRetriever: Retrieve from multiple indices/sources

Core Classes:

  • BaseRetriever: Abstract interface for all retrievers
  • RetrievalQuery: Container for query parameters
 1"""
 2Retrieval strategies for GMF Forge AI Data Layer.
 3
 4This module provides various retrieval strategies for finding relevant documents
 5from vector stores. All retrievers implement the BaseRetriever interface.
 6
 7Available Retrievers:
 8
 9**Basic Retrievers** (wrappers around vector store search):
10- VectorRetriever: Pure vector similarity search
11- KeywordRetriever: Pure keyword/BM25 search
12- HybridRetriever: Combined vector + keyword search
13
14**Advanced Retrievers** (production patterns):
15- MMRRetriever: Maximal Marginal Relevance for diverse results
16- ParentDocumentRetriever: Search chunks, return parent documents
17- EnsembleRetriever: Combine multiple retrievers with score fusion
18- HierarchicalRetriever: Two-stage retrieval (coarse → fine)
19- GraphRetriever: Graph-based retrieval with entity relationships
20- SQLRetriever: Structured data retrieval via SQL
21- MultiIndexRetriever: Retrieve from multiple indices/sources
22
23Core Classes:
24- BaseRetriever: Abstract interface for all retrievers
25- RetrievalQuery: Container for query parameters
26"""
27
28from .base_retriever import BaseRetriever, RetrievalQuery
29from .basic_retrievers import VectorRetriever, KeywordRetriever, HybridRetriever
30from .mmr_retriever import MMRRetriever
31from .parent_document_retriever import ParentDocumentRetriever
32from .ensemble_retriever import EnsembleRetriever
33from .hierarchical_retriever import HierarchicalRetriever
34from .graph_retriever import GraphRetriever
35from .sql_retriever import SQLRetriever, SQLSchema, SQLQuery
36from .multi_index_retriever import MultiIndexRetriever, SourceConfig
37
38__all__ = [
39    # Base classes
40    "BaseRetriever",
41    "RetrievalQuery",
42    
43    # Basic retrievers
44    "VectorRetriever",
45    "KeywordRetriever",
46    "HybridRetriever",
47    
48    # Advanced retrievers
49    "MMRRetriever",
50    "ParentDocumentRetriever",
51    "EnsembleRetriever",
52    "HierarchicalRetriever",
53    "GraphRetriever",
54    "SQLRetriever",
55    "MultiIndexRetriever",
56    
57    # Helper classes
58    "SQLSchema",
59    "SQLQuery",
60    "SourceConfig",
61]
62
63__version__ = "1.0.0"
class BaseRetriever(abc.ABC):
 35class BaseRetriever(ABC):
 36    """
 37    Abstract base class for all retriever implementations.
 38    
 39    Retrievers provide various strategies for finding relevant documents
 40    from a vector store or other data source.
 41    
 42    All implementations must provide:
 43    - retrieve(): Main retrieval method
 44    """
 45    
 46    @abstractmethod
 47    def retrieve(
 48        self,
 49        query: RetrievalQuery
 50    ) -> List[SearchResult]:
 51        """
 52        Retrieve relevant documents based on the query.
 53        
 54        Args:
 55            query: RetrievalQuery containing query parameters
 56        
 57        Returns:
 58            List of SearchResult objects, ordered by relevance
 59        
 60        Raises:
 61            ValueError: If required query parameters are missing
 62        """
 63        pass
 64    
 65    def retrieve_text(
 66        self,
 67        text: str,
 68        top_k: int = 5,
 69        filters: Optional[Dict[str, Any]] = None
 70    ) -> List[SearchResult]:
 71        """
 72        Convenience method for text-based retrieval.
 73        
 74        Args:
 75            text: Query text
 76            top_k: Number of results to retrieve
 77            filters: Optional metadata filters
 78        
 79        Returns:
 80            List of SearchResult objects
 81        """
 82        query = RetrievalQuery(
 83            text=text,
 84            top_k=top_k,
 85            filters=filters
 86        )
 87        return self.retrieve(query)
 88    
 89    def retrieve_embedding(
 90        self,
 91        embedding: List[float],
 92        top_k: int = 5,
 93        filters: Optional[Dict[str, Any]] = None
 94    ) -> List[SearchResult]:
 95        """
 96        Convenience method for embedding-based retrieval.
 97        
 98        Args:
 99            embedding: Query embedding vector
100            top_k: Number of results to retrieve
101            filters: Optional metadata filters
102        
103        Returns:
104            List of SearchResult objects
105        """
106        query = RetrievalQuery(
107            embedding=embedding,
108            top_k=top_k,
109            filters=filters
110        )
111        return self.retrieve(query)

Abstract base class for all retriever implementations.

Retrievers provide various strategies for finding relevant documents from a vector store or other data source.

All implementations must provide:

  • retrieve(): Main retrieval method
@abstractmethod
def retrieve( self, query: RetrievalQuery) -> List[gmf_forge_ai_data.SearchResult]:
46    @abstractmethod
47    def retrieve(
48        self,
49        query: RetrievalQuery
50    ) -> List[SearchResult]:
51        """
52        Retrieve relevant documents based on the query.
53        
54        Args:
55            query: RetrievalQuery containing query parameters
56        
57        Returns:
58            List of SearchResult objects, ordered by relevance
59        
60        Raises:
61            ValueError: If required query parameters are missing
62        """
63        pass

Retrieve relevant documents based on the query.

Args: query: RetrievalQuery containing query parameters

Returns: List of SearchResult objects, ordered by relevance

Raises: ValueError: If required query parameters are missing

def retrieve_text( self, text: str, top_k: int = 5, filters: Optional[Dict[str, Any]] = None) -> List[gmf_forge_ai_data.SearchResult]:
65    def retrieve_text(
66        self,
67        text: str,
68        top_k: int = 5,
69        filters: Optional[Dict[str, Any]] = None
70    ) -> List[SearchResult]:
71        """
72        Convenience method for text-based retrieval.
73        
74        Args:
75            text: Query text
76            top_k: Number of results to retrieve
77            filters: Optional metadata filters
78        
79        Returns:
80            List of SearchResult objects
81        """
82        query = RetrievalQuery(
83            text=text,
84            top_k=top_k,
85            filters=filters
86        )
87        return self.retrieve(query)

Convenience method for text-based retrieval.

Args: text: Query text top_k: Number of results to retrieve filters: Optional metadata filters

Returns: List of SearchResult objects

def retrieve_embedding( self, embedding: List[float], top_k: int = 5, filters: Optional[Dict[str, Any]] = None) -> List[gmf_forge_ai_data.SearchResult]:
 89    def retrieve_embedding(
 90        self,
 91        embedding: List[float],
 92        top_k: int = 5,
 93        filters: Optional[Dict[str, Any]] = None
 94    ) -> List[SearchResult]:
 95        """
 96        Convenience method for embedding-based retrieval.
 97        
 98        Args:
 99            embedding: Query embedding vector
100            top_k: Number of results to retrieve
101            filters: Optional metadata filters
102        
103        Returns:
104            List of SearchResult objects
105        """
106        query = RetrievalQuery(
107            embedding=embedding,
108            top_k=top_k,
109            filters=filters
110        )
111        return self.retrieve(query)

Convenience method for embedding-based retrieval.

Args: embedding: Query embedding vector top_k: Number of results to retrieve filters: Optional metadata filters

Returns: List of SearchResult objects

@dataclass
class RetrievalQuery:
16@dataclass
17class RetrievalQuery:
18    """
19    Represents a retrieval query with optional query text and embedding.
20    
21    Attributes:
22        text: Query text (required for keyword/hybrid search)
23        embedding: Query embedding vector (required for vector search)
24        top_k: Number of results to retrieve
25        filters: Metadata filters to apply
26        metadata: Additional query metadata
27    """
28    text: Optional[str] = None
29    embedding: Optional[List[float]] = None
30    top_k: int = 5
31    filters: Optional[Dict[str, Any]] = None
32    metadata: Optional[Dict[str, Any]] = None

Represents a retrieval query with optional query text and embedding.

Attributes: text: Query text (required for keyword/hybrid search) embedding: Query embedding vector (required for vector search) top_k: Number of results to retrieve filters: Metadata filters to apply metadata: Additional query metadata

RetrievalQuery( text: Optional[str] = None, embedding: Optional[List[float]] = None, top_k: int = 5, filters: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None)
text: Optional[str] = None
embedding: Optional[List[float]] = None
top_k: int = 5
filters: Optional[Dict[str, Any]] = None
metadata: Optional[Dict[str, Any]] = None
class VectorRetriever(gmf_forge_ai_data.retrieval.BaseRetriever):
17class VectorRetriever(BaseRetriever):
18    """
19    Vector similarity retriever using cosine similarity.
20    
21    Performs pure vector search using embeddings. Requires query embeddings to be
22    provided externally (e.g., using an embeddings provider).
23    
24    Example:
25        ```python
26        from gmf_forge_ai_data.retrieval import VectorRetriever, RetrievalQuery
27        from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings
28        
29        # Setup
30        embedder = AzureOpenAIEmbeddings(...)
31        retriever = VectorRetriever(vector_store)
32        
33        # Retrieve
34        query_text = "What is machine learning?"
35        query_embedding = embedder.embed_text(query_text)
36        
37        results = retriever.retrieve_embedding(
38            embedding=query_embedding,
39            top_k=5
40        )
41        ```
42    """
43    
44    def __init__(self, vector_store: BaseVectorStore):
45        """
46        Initialize vector retriever.
47        
48        Args:
49            vector_store: Vector store to search
50        """
51        self.vector_store = vector_store
52    
53    def retrieve(self, query: RetrievalQuery) -> List[SearchResult]:
54        """
55        Retrieve documents using vector similarity search.
56        
57        Args:
58            query: RetrievalQuery with embedding and parameters
59        
60        Returns:
61            List of SearchResult objects ordered by similarity
62        
63        Raises:
64            ValueError: If query.embedding is None
65        """
66        if query.embedding is None:
67            raise ValueError("VectorRetriever requires query.embedding")
68        
69        return self.vector_store.search(
70            query_embedding=query.embedding,
71            top_k=query.top_k,
72            filters=query.filters,
73            search_type="vector"
74        )

Vector similarity retriever using cosine similarity.

Performs pure vector search using embeddings. Requires query embeddings to be provided externally (e.g., using an embeddings provider).

Example:

from gmf_forge_ai_data.retrieval import VectorRetriever, RetrievalQuery
from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings

# Setup
embedder = AzureOpenAIEmbeddings(...)
retriever = VectorRetriever(vector_store)

# Retrieve
query_text = "What is machine learning?"
query_embedding = embedder.embed_text(query_text)

results = retriever.retrieve_embedding(
    embedding=query_embedding,
    top_k=5
)
VectorRetriever( vector_store: gmf_forge_ai_data.BaseVectorStore)
44    def __init__(self, vector_store: BaseVectorStore):
45        """
46        Initialize vector retriever.
47        
48        Args:
49            vector_store: Vector store to search
50        """
51        self.vector_store = vector_store

Initialize vector retriever.

Args: vector_store: Vector store to search

vector_store
def retrieve( self, query: RetrievalQuery) -> List[gmf_forge_ai_data.SearchResult]:
53    def retrieve(self, query: RetrievalQuery) -> List[SearchResult]:
54        """
55        Retrieve documents using vector similarity search.
56        
57        Args:
58            query: RetrievalQuery with embedding and parameters
59        
60        Returns:
61            List of SearchResult objects ordered by similarity
62        
63        Raises:
64            ValueError: If query.embedding is None
65        """
66        if query.embedding is None:
67            raise ValueError("VectorRetriever requires query.embedding")
68        
69        return self.vector_store.search(
70            query_embedding=query.embedding,
71            top_k=query.top_k,
72            filters=query.filters,
73            search_type="vector"
74        )

Retrieve documents using vector similarity search.

Args: query: RetrievalQuery with embedding and parameters

Returns: List of SearchResult objects ordered by similarity

Raises: ValueError: If query.embedding is None

class KeywordRetriever(gmf_forge_ai_data.retrieval.BaseRetriever):
 77class KeywordRetriever(BaseRetriever):
 78    """
 79    Keyword/BM25 retriever using text matching.
 80    
 81    Performs traditional keyword-based search without using embeddings.
 82    Uses BM25 for Azure AI Search, Jaccard similarity for in-memory.
 83    
 84    Example:
 85        ```python
 86        from gmf_forge_ai_data.retrieval import KeywordRetriever, RetrievalQuery
 87        
 88        retriever = KeywordRetriever(vector_store)
 89        
 90        results = retriever.retrieve_text(
 91            text="machine learning algorithms",
 92            top_k=5
 93        )
 94        ```
 95    """
 96    
 97    def __init__(self, vector_store: BaseVectorStore):
 98        """
 99        Initialize keyword retriever.
100        
101        Args:
102            vector_store: Vector store to search
103        """
104        self.vector_store = vector_store
105    
106    def retrieve(self, query: RetrievalQuery) -> List[SearchResult]:
107        """
108        Retrieve documents using keyword search.
109        
110        Args:
111            query: RetrievalQuery with text and parameters
112        
113        Returns:
114            List of SearchResult objects ordered by keyword relevance
115        
116        Raises:
117            ValueError: If query.text is None
118        """
119        if query.text is None:
120            raise ValueError("KeywordRetriever requires query.text")
121        
122        return self.vector_store.search(
123            query=query.text,
124            top_k=query.top_k,
125            filters=query.filters,
126            search_type="keyword"
127        )

Keyword/BM25 retriever using text matching.

Performs traditional keyword-based search without using embeddings. Uses BM25 for Azure AI Search, Jaccard similarity for in-memory.

Example:

from gmf_forge_ai_data.retrieval import KeywordRetriever, RetrievalQuery

retriever = KeywordRetriever(vector_store)

results = retriever.retrieve_text(
    text="machine learning algorithms",
    top_k=5
)
KeywordRetriever( vector_store: gmf_forge_ai_data.BaseVectorStore)
 97    def __init__(self, vector_store: BaseVectorStore):
 98        """
 99        Initialize keyword retriever.
100        
101        Args:
102            vector_store: Vector store to search
103        """
104        self.vector_store = vector_store

Initialize keyword retriever.

Args: vector_store: Vector store to search

vector_store
def retrieve( self, query: RetrievalQuery) -> List[gmf_forge_ai_data.SearchResult]:
106    def retrieve(self, query: RetrievalQuery) -> List[SearchResult]:
107        """
108        Retrieve documents using keyword search.
109        
110        Args:
111            query: RetrievalQuery with text and parameters
112        
113        Returns:
114            List of SearchResult objects ordered by keyword relevance
115        
116        Raises:
117            ValueError: If query.text is None
118        """
119        if query.text is None:
120            raise ValueError("KeywordRetriever requires query.text")
121        
122        return self.vector_store.search(
123            query=query.text,
124            top_k=query.top_k,
125            filters=query.filters,
126            search_type="keyword"
127        )

Retrieve documents using keyword search.

Args: query: RetrievalQuery with text and parameters

Returns: List of SearchResult objects ordered by keyword relevance

Raises: ValueError: If query.text is None

class HybridRetriever(gmf_forge_ai_data.retrieval.BaseRetriever):
130class HybridRetriever(BaseRetriever):
131    """
132    Hybrid retriever combining vector and keyword search.
133    
134    Combines vector similarity with keyword matching for comprehensive retrieval.
135    Scores are normalized and combined (typically 50/50 weighting).
136    
137    Example:
138        ```python
139        from gmf_forge_ai_data.retrieval import HybridRetriever, RetrievalQuery
140        from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings
141        
142        embedder = AzureOpenAIEmbeddings(...)
143        retriever = HybridRetriever(vector_store)
144        
145        query_text = "machine learning"
146        query_embedding = embedder.embed_text(query_text)
147        
148        # Hybrid search using both text and embedding
149        query = RetrievalQuery(
150            text=query_text,
151            embedding=query_embedding,
152            top_k=5
153        )
154        results = retriever.retrieve(query)
155        ```
156    """
157    
158    def __init__(self, vector_store: BaseVectorStore):
159        """
160        Initialize hybrid retriever.
161        
162        Args:
163            vector_store: Vector store to search
164        """
165        self.vector_store = vector_store
166    
167    def retrieve(self, query: RetrievalQuery) -> List[SearchResult]:
168        """
169        Retrieve documents using hybrid search (vector + keyword).
170        
171        Args:
172            query: RetrievalQuery with both text and embedding
173        
174        Returns:
175            List of SearchResult objects ordered by combined relevance
176        
177        Raises:
178            ValueError: If query.text or query.embedding is None
179        """
180        if query.text is None:
181            raise ValueError("HybridRetriever requires query.text")
182        if query.embedding is None:
183            raise ValueError("HybridRetriever requires query.embedding")
184        
185        return self.vector_store.search(
186            query=query.text,
187            query_embedding=query.embedding,
188            top_k=query.top_k,
189            filters=query.filters,
190            search_type="hybrid"
191        )

Hybrid retriever combining vector and keyword search.

Combines vector similarity with keyword matching for comprehensive retrieval. Scores are normalized and combined (typically 50/50 weighting).

Example:

from gmf_forge_ai_data.retrieval import HybridRetriever, RetrievalQuery
from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings

embedder = AzureOpenAIEmbeddings(...)
retriever = HybridRetriever(vector_store)

query_text = "machine learning"
query_embedding = embedder.embed_text(query_text)

# Hybrid search using both text and embedding
query = RetrievalQuery(
    text=query_text,
    embedding=query_embedding,
    top_k=5
)
results = retriever.retrieve(query)
HybridRetriever( vector_store: gmf_forge_ai_data.BaseVectorStore)
158    def __init__(self, vector_store: BaseVectorStore):
159        """
160        Initialize hybrid retriever.
161        
162        Args:
163            vector_store: Vector store to search
164        """
165        self.vector_store = vector_store

Initialize hybrid retriever.

Args: vector_store: Vector store to search

vector_store
def retrieve( self, query: RetrievalQuery) -> List[gmf_forge_ai_data.SearchResult]:
167    def retrieve(self, query: RetrievalQuery) -> List[SearchResult]:
168        """
169        Retrieve documents using hybrid search (vector + keyword).
170        
171        Args:
172            query: RetrievalQuery with both text and embedding
173        
174        Returns:
175            List of SearchResult objects ordered by combined relevance
176        
177        Raises:
178            ValueError: If query.text or query.embedding is None
179        """
180        if query.text is None:
181            raise ValueError("HybridRetriever requires query.text")
182        if query.embedding is None:
183            raise ValueError("HybridRetriever requires query.embedding")
184        
185        return self.vector_store.search(
186            query=query.text,
187            query_embedding=query.embedding,
188            top_k=query.top_k,
189            filters=query.filters,
190            search_type="hybrid"
191        )

Retrieve documents using hybrid search (vector + keyword).

Args: query: RetrievalQuery with both text and embedding

Returns: List of SearchResult objects ordered by combined relevance

Raises: ValueError: If query.text or query.embedding is None

class MMRRetriever(gmf_forge_ai_data.retrieval.BaseRetriever):
 16class MMRRetriever(BaseRetriever):
 17    """
 18    Maximal Marginal Relevance (MMR) retriever for diverse results.
 19    
 20    MMR reranks initial retrieval results to balance relevance with diversity:
 21    - High lambda (λ → 1.0): Prioritize relevance
 22    - Low lambda (λ → 0.0): Prioritize diversity
 23    - λ = 0.5: Balanced (default)
 24    
 25    Algorithm:
 26    1. Retrieve initial candidate documents (fetch_k results)
 27    2. Select most relevant document first
 28    3. For each remaining selection:
 29       - Score = λ * relevance - (1-λ) * max_similarity_to_selected
 30       - Choose document with highest score
 31    4. Return top_k diverse results
 32    
 33    Benefits:
 34    - Reduces redundancy in results
 35    - Improves coverage of different aspects
 36    - Useful for exploration and summarization
 37    
 38    Example:
 39        ```python
 40        from gmf_forge_ai_data.retrieval import MMRRetriever, RetrievalQuery
 41        from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings
 42        
 43        embedder = AzureOpenAIEmbeddings(...)
 44        retriever = MMRRetriever(
 45            vector_store=vector_store,
 46            lambda_param=0.5,  # Balanced relevance/diversity
 47            fetch_k=20         # Retrieve 20 candidates, return top_k diverse
 48        )
 49        
 50        query_text = "machine learning algorithms"
 51        query_embedding = embedder.embed_text(query_text)
 52        
 53        # Returns 5 diverse results from 20 candidates
 54        results = retriever.retrieve_embedding(
 55            embedding=query_embedding,
 56            top_k=5
 57        )
 58        ```
 59    
 60    References:
 61        Carbonell, J., & Goldstein, J. (1998). The use of MMR, diversity-based
 62        reranking for reordering documents and producing summaries.
 63    """
 64    
 65    def __init__(
 66        self,
 67        vector_store: BaseVectorStore,
 68        lambda_param: float = 0.5,
 69        fetch_k: int = 20
 70    ):
 71        """
 72        Initialize MMR retriever.
 73        
 74        Args:
 75            vector_store: Vector store to search
 76            lambda_param: Balance between relevance (1.0) and diversity (0.0)
 77            fetch_k: Number of initial candidates to fetch (should be >= top_k)
 78        
 79        Raises:
 80            ValueError: If lambda_param not in [0, 1] or fetch_k < 1
 81        """
 82        if not 0 <= lambda_param <= 1:
 83            raise ValueError(f"lambda_param must be in [0, 1], got {lambda_param}")
 84        if fetch_k < 1:
 85            raise ValueError(f"fetch_k must be >= 1, got {fetch_k}")
 86        
 87        self.vector_store = vector_store
 88        self.lambda_param = lambda_param
 89        self.fetch_k = fetch_k
 90    
 91    def retrieve(self, query: RetrievalQuery) -> List[SearchResult]:
 92        """
 93        Retrieve diverse documents using MMR reranking.
 94        
 95        Args:
 96            query: RetrievalQuery with embedding (required)
 97        
 98        Returns:
 99            List of SearchResult objects with diverse, relevant results
100        
101        Raises:
102            ValueError: If query.embedding is None
103        """
104        if query.embedding is None:
105            raise ValueError("MMRRetriever requires query.embedding for diversity calculation")
106        
107        # Fetch initial candidates (more than top_k)
108        fetch_k = max(self.fetch_k, query.top_k)
109        candidates = self.vector_store.search(
110            query_embedding=query.embedding,
111            top_k=fetch_k,
112            filters=query.filters,
113            search_type="vector"
114        )
115        
116        if len(candidates) == 0:
117            return []
118        
119        if len(candidates) <= query.top_k:
120            # Not enough candidates for MMR, return as-is
121            return candidates[:query.top_k]
122        
123        # Extract embeddings and relevance scores
124        query_embedding = np.array(query.embedding, dtype=np.float32)
125        candidate_embeddings = []
126        candidate_scores = []
127        
128        for result in candidates:
129            if result.document.embedding is None:
130                raise ValueError(
131                    f"Document {result.document.id} has no embedding. "
132                    "MMR requires all documents to have embeddings."
133                )
134            candidate_embeddings.append(result.document.embedding)
135            candidate_scores.append(result.score)
136        
137        candidate_embeddings = np.array(candidate_embeddings, dtype=np.float32)
138        candidate_scores = np.array(candidate_scores, dtype=np.float32)
139        
140        # Normalize embeddings for cosine similarity
141        query_norm = query_embedding / (np.linalg.norm(query_embedding) + 1e-10)
142        candidate_norms = candidate_embeddings / (
143            np.linalg.norm(candidate_embeddings, axis=1, keepdims=True) + 1e-10
144        )
145        
146        # MMR selection
147        selected_indices = []
148        selected_embeddings = []
149        
150        for _ in range(min(query.top_k, len(candidates))):
151            if len(selected_indices) == 0:
152                # First selection: most relevant
153                best_idx = int(np.argmax(candidate_scores))
154            else:
155                # Subsequent selections: balance relevance and diversity
156                mmr_scores = []
157                
158                for idx in range(len(candidates)):
159                    if idx in selected_indices:
160                        mmr_scores.append(-np.inf)
161                        continue
162                    
163                    # Relevance score (already normalized 0-1 from vector store)
164                    relevance = candidate_scores[idx]
165                    
166                    # Diversity score: max similarity to selected documents
167                    doc_embedding = candidate_norms[idx]
168                    similarities = [
169                        np.dot(doc_embedding, selected_embeddings[i])
170                        for i in range(len(selected_embeddings))
171                    ]
172                    max_similarity = max(similarities)
173                    
174                    # MMR score: λ * relevance - (1-λ) * max_similarity
175                    mmr_score = (
176                        self.lambda_param * relevance -
177                        (1 - self.lambda_param) * max_similarity
178                    )
179                    mmr_scores.append(mmr_score)
180                
181                best_idx = int(np.argmax(mmr_scores))
182            
183            selected_indices.append(best_idx)
184            selected_embeddings.append(candidate_norms[best_idx])
185        
186        # Build results with updated ranks
187        mmr_results = []
188        for rank, idx in enumerate(selected_indices):
189            result = candidates[idx]
190            mmr_results.append(SearchResult(
191                document=result.document,
192                score=result.score,  # Keep original relevance score
193                rank=rank
194            ))
195        
196        return mmr_results

Maximal Marginal Relevance (MMR) retriever for diverse results.

MMR reranks initial retrieval results to balance relevance with diversity:

  • High lambda (λ → 1.0): Prioritize relevance
  • Low lambda (λ → 0.0): Prioritize diversity
  • λ = 0.5: Balanced (default)

Algorithm:

  1. Retrieve initial candidate documents (fetch_k results)
  2. Select most relevant document first
  3. For each remaining selection:
    • Score = λ * relevance - (1-λ) * max_similarity_to_selected
    • Choose document with highest score
  4. Return top_k diverse results

Benefits:

  • Reduces redundancy in results
  • Improves coverage of different aspects
  • Useful for exploration and summarization

Example:

from gmf_forge_ai_data.retrieval import MMRRetriever, RetrievalQuery
from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings

embedder = AzureOpenAIEmbeddings(...)
retriever = MMRRetriever(
    vector_store=vector_store,
    lambda_param=0.5,  # Balanced relevance/diversity
    fetch_k=20         # Retrieve 20 candidates, return top_k diverse
)

query_text = "machine learning algorithms"
query_embedding = embedder.embed_text(query_text)

# Returns 5 diverse results from 20 candidates
results = retriever.retrieve_embedding(
    embedding=query_embedding,
    top_k=5
)

References: Carbonell, J., & Goldstein, J. (1998). The use of MMR, diversity-based reranking for reordering documents and producing summaries.

MMRRetriever( vector_store: gmf_forge_ai_data.BaseVectorStore, lambda_param: float = 0.5, fetch_k: int = 20)
65    def __init__(
66        self,
67        vector_store: BaseVectorStore,
68        lambda_param: float = 0.5,
69        fetch_k: int = 20
70    ):
71        """
72        Initialize MMR retriever.
73        
74        Args:
75            vector_store: Vector store to search
76            lambda_param: Balance between relevance (1.0) and diversity (0.0)
77            fetch_k: Number of initial candidates to fetch (should be >= top_k)
78        
79        Raises:
80            ValueError: If lambda_param not in [0, 1] or fetch_k < 1
81        """
82        if not 0 <= lambda_param <= 1:
83            raise ValueError(f"lambda_param must be in [0, 1], got {lambda_param}")
84        if fetch_k < 1:
85            raise ValueError(f"fetch_k must be >= 1, got {fetch_k}")
86        
87        self.vector_store = vector_store
88        self.lambda_param = lambda_param
89        self.fetch_k = fetch_k

Initialize MMR retriever.

Args: vector_store: Vector store to search lambda_param: Balance between relevance (1.0) and diversity (0.0) fetch_k: Number of initial candidates to fetch (should be >= top_k)

Raises: ValueError: If lambda_param not in [0, 1] or fetch_k < 1

vector_store
lambda_param
fetch_k
def retrieve( self, query: RetrievalQuery) -> List[gmf_forge_ai_data.SearchResult]:
 91    def retrieve(self, query: RetrievalQuery) -> List[SearchResult]:
 92        """
 93        Retrieve diverse documents using MMR reranking.
 94        
 95        Args:
 96            query: RetrievalQuery with embedding (required)
 97        
 98        Returns:
 99            List of SearchResult objects with diverse, relevant results
100        
101        Raises:
102            ValueError: If query.embedding is None
103        """
104        if query.embedding is None:
105            raise ValueError("MMRRetriever requires query.embedding for diversity calculation")
106        
107        # Fetch initial candidates (more than top_k)
108        fetch_k = max(self.fetch_k, query.top_k)
109        candidates = self.vector_store.search(
110            query_embedding=query.embedding,
111            top_k=fetch_k,
112            filters=query.filters,
113            search_type="vector"
114        )
115        
116        if len(candidates) == 0:
117            return []
118        
119        if len(candidates) <= query.top_k:
120            # Not enough candidates for MMR, return as-is
121            return candidates[:query.top_k]
122        
123        # Extract embeddings and relevance scores
124        query_embedding = np.array(query.embedding, dtype=np.float32)
125        candidate_embeddings = []
126        candidate_scores = []
127        
128        for result in candidates:
129            if result.document.embedding is None:
130                raise ValueError(
131                    f"Document {result.document.id} has no embedding. "
132                    "MMR requires all documents to have embeddings."
133                )
134            candidate_embeddings.append(result.document.embedding)
135            candidate_scores.append(result.score)
136        
137        candidate_embeddings = np.array(candidate_embeddings, dtype=np.float32)
138        candidate_scores = np.array(candidate_scores, dtype=np.float32)
139        
140        # Normalize embeddings for cosine similarity
141        query_norm = query_embedding / (np.linalg.norm(query_embedding) + 1e-10)
142        candidate_norms = candidate_embeddings / (
143            np.linalg.norm(candidate_embeddings, axis=1, keepdims=True) + 1e-10
144        )
145        
146        # MMR selection
147        selected_indices = []
148        selected_embeddings = []
149        
150        for _ in range(min(query.top_k, len(candidates))):
151            if len(selected_indices) == 0:
152                # First selection: most relevant
153                best_idx = int(np.argmax(candidate_scores))
154            else:
155                # Subsequent selections: balance relevance and diversity
156                mmr_scores = []
157                
158                for idx in range(len(candidates)):
159                    if idx in selected_indices:
160                        mmr_scores.append(-np.inf)
161                        continue
162                    
163                    # Relevance score (already normalized 0-1 from vector store)
164                    relevance = candidate_scores[idx]
165                    
166                    # Diversity score: max similarity to selected documents
167                    doc_embedding = candidate_norms[idx]
168                    similarities = [
169                        np.dot(doc_embedding, selected_embeddings[i])
170                        for i in range(len(selected_embeddings))
171                    ]
172                    max_similarity = max(similarities)
173                    
174                    # MMR score: λ * relevance - (1-λ) * max_similarity
175                    mmr_score = (
176                        self.lambda_param * relevance -
177                        (1 - self.lambda_param) * max_similarity
178                    )
179                    mmr_scores.append(mmr_score)
180                
181                best_idx = int(np.argmax(mmr_scores))
182            
183            selected_indices.append(best_idx)
184            selected_embeddings.append(candidate_norms[best_idx])
185        
186        # Build results with updated ranks
187        mmr_results = []
188        for rank, idx in enumerate(selected_indices):
189            result = candidates[idx]
190            mmr_results.append(SearchResult(
191                document=result.document,
192                score=result.score,  # Keep original relevance score
193                rank=rank
194            ))
195        
196        return mmr_results

Retrieve diverse documents using MMR reranking.

Args: query: RetrievalQuery with embedding (required)

Returns: List of SearchResult objects with diverse, relevant results

Raises: ValueError: If query.embedding is None

class ParentDocumentRetriever(gmf_forge_ai_data.retrieval.BaseRetriever):
 16class ParentDocumentRetriever(BaseRetriever):
 17    """
 18    Retriever that searches child chunks but returns parent documents.
 19    
 20    This pattern is useful when:
 21    - Documents are chunked into small pieces for precise embedding
 22    - But you want to return full parent documents for better context
 23    - Multiple chunks from the same parent might match the query
 24    
 25    Architecture:
 26    - Child store: Contains small chunks with embeddings (searchable)
 27    - Parent store: Contains full parent documents (for retrieval)
 28    - Mapping: Children have metadata["parent_id"] pointing to parent
 29    
 30    Workflow:
 31    1. Search child store for relevant chunks
 32    2. Extract parent_ids from child metadata
 33    3. Retrieve parent documents from parent store
 34    4. Deduplicate and rank by best child score
 35    
 36    Example:
 37        ```python
 38        from gmf_forge_ai_data.retrieval import ParentDocumentRetriever
 39        from gmf_forge_ai_data.vector_stores import InMemoryVectorStore, Document
 40        from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings
 41        
 42        embedder = AzureOpenAIEmbeddings(...)
 43        
 44        # Setup stores
 45        child_store = InMemoryVectorStore()
 46        parent_store = InMemoryVectorStore()
 47        
 48        # Create parent document
 49        parent = Document(
 50            id="doc_1",
 51            content="Full document with multiple sections...",
 52            embedding=embedder.embed_text("Full document...")
 53        )
 54        parent_store.add_documents([parent])
 55        
 56        # Create child chunks
 57        child1 = Document(
 58            id="doc_1_chunk_0",
 59            content="Introduction section...",
 60            embedding=embedder.embed_text("Introduction..."),
 61            metadata={"parent_id": "doc_1"}  # Link to parent
 62        )
 63        child2 = Document(
 64            id="doc_1_chunk_1",
 65            content="Methods section...",
 66            embedding=embedder.embed_text("Methods..."),
 67            metadata={"parent_id": "doc_1"}
 68        )
 69        child_store.add_documents([child1, child2])
 70        
 71        # Setup retriever
 72        retriever = ParentDocumentRetriever(
 73            child_store=child_store,
 74            parent_store=parent_store,
 75            parent_id_key="parent_id"
 76        )
 77        
 78        # Search chunks, return parent
 79        query_embedding = embedder.embed_text("introduction")
 80        results = retriever.retrieve_embedding(
 81            embedding=query_embedding,
 82            top_k=5
 83        )
 84        # Returns full parent documents, not chunks
 85        ```
 86    
 87    Benefits:
 88    - Precise search (small chunks)
 89    - Rich context (full parents)
 90    - Automatic deduplication
 91    """
 92    
 93    def __init__(
 94        self,
 95        child_store: BaseVectorStore,
 96        parent_store: BaseVectorStore,
 97        parent_id_key: str = "parent_id",
 98        search_type: str = "vector"
 99    ):
100        """
101        Initialize parent document retriever.
102        
103        Args:
104            child_store: Vector store containing child chunks with embeddings
105            parent_store: Vector store containing parent documents
106            parent_id_key: Metadata key in children pointing to parent ID
107            search_type: Search type for child store ("vector", "keyword", "hybrid")
108        
109        Raises:
110            ValueError: If search_type is invalid
111        """
112        if search_type not in ["vector", "keyword", "hybrid"]:
113            raise ValueError(f"Invalid search_type: {search_type}")
114        
115        self.child_store = child_store
116        self.parent_store = parent_store
117        self.parent_id_key = parent_id_key
118        self.search_type = search_type
119    
120    def retrieve(self, query: RetrievalQuery) -> List[SearchResult]:
121        """
122        Retrieve parent documents by searching child chunks.
123        
124        Args:
125            query: RetrievalQuery with appropriate parameters for search_type
126        
127        Returns:
128            List of SearchResult with parent documents, deduplicated and ranked
129        
130        Raises:
131            ValueError: If required query parameters are missing
132        """
133        # Validate query based on search type
134        if self.search_type in ["vector", "hybrid"] and query.embedding is None:
135            raise ValueError(f"{self.search_type} search requires query.embedding")
136        if self.search_type in ["keyword", "hybrid"] and query.text is None:
137            raise ValueError(f"{self.search_type} search requires query.text")
138        
139        # Search child store (fetch more candidates to account for deduplication)
140        fetch_k = query.top_k * 3  # Heuristic: fetch 3x to handle duplicates
141        child_results = self.child_store.search(
142            query=query.text,
143            query_embedding=query.embedding,
144            top_k=fetch_k,
145            filters=query.filters,
146            search_type=self.search_type
147        )
148        
149        if len(child_results) == 0:
150            return []
151        
152        # Extract parent IDs and track best scores
153        parent_scores: Dict[str, float] = {}  # parent_id -> best_score
154        parent_order: OrderedDict[str, None] = OrderedDict()  # Track first occurrence
155        
156        for child_result in child_results:
157            parent_id = child_result.document.metadata.get(self.parent_id_key)
158            
159            if parent_id is None:
160                # Skip children without parent link
161                continue
162            
163            # Track best score for each parent (in case multiple children match)
164            if parent_id not in parent_scores:
165                parent_scores[parent_id] = child_result.score
166                parent_order[parent_id] = None  # Track insertion order
167            else:
168                # Update if this child has better score
169                parent_scores[parent_id] = max(
170                    parent_scores[parent_id],
171                    child_result.score
172                )
173        
174        if len(parent_scores) == 0:
175            return []
176        
177        # Retrieve parent documents
178        parent_results = []
179        for rank, parent_id in enumerate(parent_order.keys()):
180            if rank >= query.top_k:
181                break
182            
183            parent_doc = self.parent_store.get_document(parent_id)
184            if parent_doc is None:
185                # Parent not found in store, skip
186                continue
187            
188            parent_results.append(SearchResult(
189                document=parent_doc,
190                score=parent_scores[parent_id],  # Best child score
191                rank=rank
192            ))
193        
194        return parent_results

Retriever that searches child chunks but returns parent documents.

This pattern is useful when:

  • Documents are chunked into small pieces for precise embedding
  • But you want to return full parent documents for better context
  • Multiple chunks from the same parent might match the query

Architecture:

  • Child store: Contains small chunks with embeddings (searchable)
  • Parent store: Contains full parent documents (for retrieval)
  • Mapping: Children have metadata["parent_id"] pointing to parent

Workflow:

  1. Search child store for relevant chunks
  2. Extract parent_ids from child metadata
  3. Retrieve parent documents from parent store
  4. Deduplicate and rank by best child score

Example:

from gmf_forge_ai_data.retrieval import ParentDocumentRetriever
from gmf_forge_ai_data.vector_stores import InMemoryVectorStore, Document
from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings

embedder = AzureOpenAIEmbeddings(...)

# Setup stores
child_store = InMemoryVectorStore()
parent_store = InMemoryVectorStore()

# Create parent document
parent = Document(
    id="doc_1",
    content="Full document with multiple sections...",
    embedding=embedder.embed_text("Full document...")
)
parent_store.add_documents([parent])

# Create child chunks
child1 = Document(
    id="doc_1_chunk_0",
    content="Introduction section...",
    embedding=embedder.embed_text("Introduction..."),
    metadata={"parent_id": "doc_1"}  # Link to parent
)
child2 = Document(
    id="doc_1_chunk_1",
    content="Methods section...",
    embedding=embedder.embed_text("Methods..."),
    metadata={"parent_id": "doc_1"}
)
child_store.add_documents([child1, child2])

# Setup retriever
retriever = ParentDocumentRetriever(
    child_store=child_store,
    parent_store=parent_store,
    parent_id_key="parent_id"
)

# Search chunks, return parent
query_embedding = embedder.embed_text("introduction")
results = retriever.retrieve_embedding(
    embedding=query_embedding,
    top_k=5
)
# Returns full parent documents, not chunks

Benefits:

  • Precise search (small chunks)
  • Rich context (full parents)
  • Automatic deduplication
ParentDocumentRetriever( child_store: gmf_forge_ai_data.BaseVectorStore, parent_store: gmf_forge_ai_data.BaseVectorStore, parent_id_key: str = 'parent_id', search_type: str = 'vector')
 93    def __init__(
 94        self,
 95        child_store: BaseVectorStore,
 96        parent_store: BaseVectorStore,
 97        parent_id_key: str = "parent_id",
 98        search_type: str = "vector"
 99    ):
100        """
101        Initialize parent document retriever.
102        
103        Args:
104            child_store: Vector store containing child chunks with embeddings
105            parent_store: Vector store containing parent documents
106            parent_id_key: Metadata key in children pointing to parent ID
107            search_type: Search type for child store ("vector", "keyword", "hybrid")
108        
109        Raises:
110            ValueError: If search_type is invalid
111        """
112        if search_type not in ["vector", "keyword", "hybrid"]:
113            raise ValueError(f"Invalid search_type: {search_type}")
114        
115        self.child_store = child_store
116        self.parent_store = parent_store
117        self.parent_id_key = parent_id_key
118        self.search_type = search_type

Initialize parent document retriever.

Args: child_store: Vector store containing child chunks with embeddings parent_store: Vector store containing parent documents parent_id_key: Metadata key in children pointing to parent ID search_type: Search type for child store ("vector", "keyword", "hybrid")

Raises: ValueError: If search_type is invalid

child_store
parent_store
parent_id_key
search_type
def retrieve( self, query: RetrievalQuery) -> List[gmf_forge_ai_data.SearchResult]:
120    def retrieve(self, query: RetrievalQuery) -> List[SearchResult]:
121        """
122        Retrieve parent documents by searching child chunks.
123        
124        Args:
125            query: RetrievalQuery with appropriate parameters for search_type
126        
127        Returns:
128            List of SearchResult with parent documents, deduplicated and ranked
129        
130        Raises:
131            ValueError: If required query parameters are missing
132        """
133        # Validate query based on search type
134        if self.search_type in ["vector", "hybrid"] and query.embedding is None:
135            raise ValueError(f"{self.search_type} search requires query.embedding")
136        if self.search_type in ["keyword", "hybrid"] and query.text is None:
137            raise ValueError(f"{self.search_type} search requires query.text")
138        
139        # Search child store (fetch more candidates to account for deduplication)
140        fetch_k = query.top_k * 3  # Heuristic: fetch 3x to handle duplicates
141        child_results = self.child_store.search(
142            query=query.text,
143            query_embedding=query.embedding,
144            top_k=fetch_k,
145            filters=query.filters,
146            search_type=self.search_type
147        )
148        
149        if len(child_results) == 0:
150            return []
151        
152        # Extract parent IDs and track best scores
153        parent_scores: Dict[str, float] = {}  # parent_id -> best_score
154        parent_order: OrderedDict[str, None] = OrderedDict()  # Track first occurrence
155        
156        for child_result in child_results:
157            parent_id = child_result.document.metadata.get(self.parent_id_key)
158            
159            if parent_id is None:
160                # Skip children without parent link
161                continue
162            
163            # Track best score for each parent (in case multiple children match)
164            if parent_id not in parent_scores:
165                parent_scores[parent_id] = child_result.score
166                parent_order[parent_id] = None  # Track insertion order
167            else:
168                # Update if this child has better score
169                parent_scores[parent_id] = max(
170                    parent_scores[parent_id],
171                    child_result.score
172                )
173        
174        if len(parent_scores) == 0:
175            return []
176        
177        # Retrieve parent documents
178        parent_results = []
179        for rank, parent_id in enumerate(parent_order.keys()):
180            if rank >= query.top_k:
181                break
182            
183            parent_doc = self.parent_store.get_document(parent_id)
184            if parent_doc is None:
185                # Parent not found in store, skip
186                continue
187            
188            parent_results.append(SearchResult(
189                document=parent_doc,
190                score=parent_scores[parent_id],  # Best child score
191                rank=rank
192            ))
193        
194        return parent_results

Retrieve parent documents by searching child chunks.

Args: query: RetrievalQuery with appropriate parameters for search_type

Returns: List of SearchResult with parent documents, deduplicated and ranked

Raises: ValueError: If required query parameters are missing

class EnsembleRetriever(gmf_forge_ai_data.retrieval.BaseRetriever):
 17class EnsembleRetriever(BaseRetriever):
 18    """
 19    Ensemble retriever combining multiple retrieval strategies.
 20    
 21    Combines results from multiple retrievers using score fusion techniques:
 22    - Reciprocal Rank Fusion (RRF): Robust, rank-based fusion
 23    - Weighted Average: Score-based fusion with configurable weights
 24    - Max Score: Conservative, consensus-based fusion
 25    
 26    Benefits:
 27    - Improved recall (combine different strategies)
 28    - Robustness (less sensitive to individual retriever failures)
 29    - Flexibility (combine vector, keyword, hybrid, MMR, etc.)
 30    
 31    Example:
 32        ```python
 33        from gmf_forge_ai_data.retrieval import (
 34            EnsembleRetriever,
 35            VectorRetriever,
 36            KeywordRetriever,
 37            MMRRetriever
 38        )
 39        from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings
 40        
 41        embedder = AzureOpenAIEmbeddings(...)
 42        
 43        # Create multiple retrievers
 44        vector_retriever = VectorRetriever(vector_store)
 45        keyword_retriever = KeywordRetriever(vector_store)
 46        mmr_retriever = MMRRetriever(vector_store, lambda_param=0.7)
 47        
 48        # Combine with RRF fusion
 49        ensemble = EnsembleRetriever(
 50            retrievers=[vector_retriever, keyword_retriever, mmr_retriever],
 51            weights=[0.5, 0.3, 0.2],  # Optional weights
 52            fusion_strategy="rrf"  # Reciprocal Rank Fusion
 53        )
 54        
 55        # Query requires parameters for all retrievers
 56        query_text = "machine learning"
 57        query_embedding = embedder.embed_text(query_text)
 58        
 59        query = RetrievalQuery(
 60            text=query_text,  # For keyword retriever
 61            embedding=query_embedding,  # For vector/MMR retrievers
 62            top_k=5
 63        )
 64        
 65        results = ensemble.retrieve(query)
 66        # Returns fused results from all retrievers
 67        ```
 68    
 69    Fusion Strategies:
 70    
 71    1. **rrf** (Reciprocal Rank Fusion):
 72       - score(doc) = Σ weights[i] / (k + rank_i)
 73       - k = 60 (default)
 74       - Robust to different score scales
 75       - Rank-based, not score-based
 76    
 77    2. **weighted_avg** (Weighted Average):
 78       - Normalize scores to [0, 1]
 79       - score(doc) = Σ weights[i] * normalized_score_i
 80       - Score-based fusion
 81    
 82    3. **max_score** (Maximum Score):
 83       - score(doc) = max(normalized_scores)
 84       - Conservative, requires consensus
 85    
 86    References:
 87        Cormack, G. V., Clarke, C. L., & Buettcher, S. (2009).
 88        Reciprocal rank fusion outperforms condorcet and individual
 89        ranklists. ACM SIGIR.
 90    """
 91    
 92    def __init__(
 93        self,
 94        retrievers: List[BaseRetriever],
 95        weights: List[float] = None,
 96        fusion_strategy: str = "rrf",
 97        rrf_k: int = 60
 98    ):
 99        """
100        Initialize ensemble retriever.
101        
102        Args:
103            retrievers: List of retrievers to combine
104            weights: Optional weights for each retriever (default: equal weights)
105            fusion_strategy: Fusion method - "rrf", "weighted_avg", or "max_score"
106            rrf_k: k parameter for RRF (default: 60)
107        
108        Raises:
109            ValueError: If retrievers list is empty or params are invalid
110        """
111        if not retrievers:
112            raise ValueError("Must provide at least one retriever")
113        
114        if fusion_strategy not in ["rrf", "weighted_avg", "max_score"]:
115            raise ValueError(
116                f"Invalid fusion_strategy: {fusion_strategy}. "
117                "Must be 'rrf', 'weighted_avg', or 'max_score'"
118            )
119        
120        if weights is None:
121            # Equal weights
122            weights = [1.0 / len(retrievers)] * len(retrievers)
123        else:
124            if len(weights) != len(retrievers):
125                raise ValueError(
126                    f"Number of weights ({len(weights)}) must match "
127                    f"number of retrievers ({len(retrievers)})"
128                )
129            # Normalize weights to sum to 1.0
130            total = sum(weights)
131            weights = [w / total for w in weights]
132        
133        self.retrievers = retrievers
134        self.weights = weights
135        self.fusion_strategy = fusion_strategy
136        self.rrf_k = rrf_k
137    
138    def retrieve(self, query: RetrievalQuery) -> List[SearchResult]:
139        """
140        Retrieve documents using ensemble fusion.
141        
142        Args:
143            query: RetrievalQuery with parameters for all retrievers
144        
145        Returns:
146            List of SearchResult with fused scores and ranks
147        
148        Note:
149            All retrievers are called with the same query. Ensure the query
150            contains all required parameters (text, embedding) for your retrievers.
151        """
152        # Retrieve from all retrievers
153        all_results: List[List[SearchResult]] = []
154        
155        for retriever in self.retrievers:
156            try:
157                results = retriever.retrieve(query)
158                all_results.append(results)
159            except ValueError as e:
160                # Retriever might not support this query type, skip it
161                all_results.append([])
162        
163        if not any(all_results):
164            return []
165        
166        # Apply fusion strategy
167        if self.fusion_strategy == "rrf":
168            fused_results = self._reciprocal_rank_fusion(all_results)
169        elif self.fusion_strategy == "weighted_avg":
170            fused_results = self._weighted_average_fusion(all_results)
171        else:  # max_score
172            fused_results = self._max_score_fusion(all_results)
173        
174        # Sort by fused score and take top_k
175        fused_results.sort(key=lambda x: x.score, reverse=True)
176        top_results = fused_results[:query.top_k]
177        
178        # Update ranks
179        for rank, result in enumerate(top_results):
180            result.rank = rank
181        
182        return top_results
183    
184    def _reciprocal_rank_fusion(
185        self,
186        all_results: List[List[SearchResult]]
187    ) -> List[SearchResult]:
188        """
189        Fuse results using Reciprocal Rank Fusion (RRF).
190        
191        RRF score for document d:
192            score(d) = Σ weight_i / (k + rank_i(d))
193        
194        Where rank_i(d) is the rank of d in retriever i's results.
195        """
196        doc_scores: Dict[str, float] = defaultdict(float)
197        doc_objects: Dict[str, Document] = {}
198        
199        for retriever_idx, results in enumerate(all_results):
200            weight = self.weights[retriever_idx]
201            
202            for result in results:
203                doc_id = result.document.id
204                
205                # RRF score contribution
206                rrf_score = weight / (self.rrf_k + result.rank)
207                doc_scores[doc_id] += rrf_score
208                
209                # Keep document object (use first occurrence)
210                if doc_id not in doc_objects:
211                    doc_objects[doc_id] = result.document
212        
213        # Build SearchResult objects
214        fused_results = [
215            SearchResult(
216                document=doc_objects[doc_id],
217                score=score,
218                rank=0  # Will be set later
219            )
220            for doc_id, score in doc_scores.items()
221        ]
222        
223        return fused_results
224    
225    def _weighted_average_fusion(
226        self,
227        all_results: List[List[SearchResult]]
228    ) -> List[SearchResult]:
229        """
230        Fuse results using weighted average of normalized scores.
231        
232        Scores are normalized to [0, 1] within each retriever, then averaged.
233        """
234        doc_scores: Dict[str, float] = defaultdict(float)
235        doc_objects: Dict[str, Document] = {}
236        
237        for retriever_idx, results in enumerate(all_results):
238            if not results:
239                continue
240            
241            weight = self.weights[retriever_idx]
242            
243            # Normalize scores to [0, 1]
244            scores = np.array([r.score for r in results])
245            min_score = scores.min()
246            max_score = scores.max()
247            
248            if max_score > min_score:
249                normalized_scores = (scores - min_score) / (max_score - min_score)
250            else:
251                normalized_scores = np.ones_like(scores)
252            
253            for result, norm_score in zip(results, normalized_scores):
254                doc_id = result.document.id
255                doc_scores[doc_id] += weight * norm_score
256                
257                if doc_id not in doc_objects:
258                    doc_objects[doc_id] = result.document
259        
260        # Build SearchResult objects
261        fused_results = [
262            SearchResult(
263                document=doc_objects[doc_id],
264                score=score,
265                rank=0
266            )
267            for doc_id, score in doc_scores.items()
268        ]
269        
270        return fused_results
271    
272    def _max_score_fusion(
273        self,
274        all_results: List[List[SearchResult]]
275    ) -> List[SearchResult]:
276        """
277        Fuse results using maximum normalized score across retrievers.
278        
279        Conservative strategy: document needs high score from at least one retriever.
280        """
281        doc_scores: Dict[str, float] = defaultdict(float)
282        doc_objects: Dict[str, Document] = {}
283        
284        for retriever_idx, results in enumerate(all_results):
285            if not results:
286                continue
287            
288            # Normalize scores to [0, 1]
289            scores = np.array([r.score for r in results])
290            min_score = scores.min()
291            max_score = scores.max()
292            
293            if max_score > min_score:
294                normalized_scores = (scores - min_score) / (max_score - min_score)
295            else:
296                normalized_scores = np.ones_like(scores)
297            
298            for result, norm_score in zip(results, normalized_scores):
299                doc_id = result.document.id
300                
301                # Take maximum normalized score
302                doc_scores[doc_id] = max(doc_scores[doc_id], norm_score)
303                
304                if doc_id not in doc_objects:
305                    doc_objects[doc_id] = result.document
306        
307        # Build SearchResult objects
308        fused_results = [
309            SearchResult(
310                document=doc_objects[doc_id],
311                score=score,
312                rank=0
313            )
314            for doc_id, score in doc_scores.items()
315        ]
316        
317        return fused_results

Ensemble retriever combining multiple retrieval strategies.

Combines results from multiple retrievers using score fusion techniques:

  • Reciprocal Rank Fusion (RRF): Robust, rank-based fusion
  • Weighted Average: Score-based fusion with configurable weights
  • Max Score: Conservative, consensus-based fusion

Benefits:

  • Improved recall (combine different strategies)
  • Robustness (less sensitive to individual retriever failures)
  • Flexibility (combine vector, keyword, hybrid, MMR, etc.)

Example:

from gmf_forge_ai_data.retrieval import (
    EnsembleRetriever,
    VectorRetriever,
    KeywordRetriever,
    MMRRetriever
)
from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings

embedder = AzureOpenAIEmbeddings(...)

# Create multiple retrievers
vector_retriever = VectorRetriever(vector_store)
keyword_retriever = KeywordRetriever(vector_store)
mmr_retriever = MMRRetriever(vector_store, lambda_param=0.7)

# Combine with RRF fusion
ensemble = EnsembleRetriever(
    retrievers=[vector_retriever, keyword_retriever, mmr_retriever],
    weights=[0.5, 0.3, 0.2],  # Optional weights
    fusion_strategy="rrf"  # Reciprocal Rank Fusion
)

# Query requires parameters for all retrievers
query_text = "machine learning"
query_embedding = embedder.embed_text(query_text)

query = RetrievalQuery(
    text=query_text,  # For keyword retriever
    embedding=query_embedding,  # For vector/MMR retrievers
    top_k=5
)

results = ensemble.retrieve(query)
# Returns fused results from all retrievers

Fusion Strategies:

  1. rrf (Reciprocal Rank Fusion):

    • score(doc) = Σ weights[i] / (k + rank_i)
    • k = 60 (default)
    • Robust to different score scales
    • Rank-based, not score-based
  2. weighted_avg (Weighted Average):

    • Normalize scores to [0, 1]
    • score(doc) = Σ weights[i] * normalized_score_i
    • Score-based fusion
  3. max_score (Maximum Score):

    • score(doc) = max(normalized_scores)
    • Conservative, requires consensus

References: Cormack, G. V., Clarke, C. L., & Buettcher, S. (2009). Reciprocal rank fusion outperforms condorcet and individual ranklists. ACM SIGIR.

EnsembleRetriever( retrievers: List[BaseRetriever], weights: List[float] = None, fusion_strategy: str = 'rrf', rrf_k: int = 60)
 92    def __init__(
 93        self,
 94        retrievers: List[BaseRetriever],
 95        weights: List[float] = None,
 96        fusion_strategy: str = "rrf",
 97        rrf_k: int = 60
 98    ):
 99        """
100        Initialize ensemble retriever.
101        
102        Args:
103            retrievers: List of retrievers to combine
104            weights: Optional weights for each retriever (default: equal weights)
105            fusion_strategy: Fusion method - "rrf", "weighted_avg", or "max_score"
106            rrf_k: k parameter for RRF (default: 60)
107        
108        Raises:
109            ValueError: If retrievers list is empty or params are invalid
110        """
111        if not retrievers:
112            raise ValueError("Must provide at least one retriever")
113        
114        if fusion_strategy not in ["rrf", "weighted_avg", "max_score"]:
115            raise ValueError(
116                f"Invalid fusion_strategy: {fusion_strategy}. "
117                "Must be 'rrf', 'weighted_avg', or 'max_score'"
118            )
119        
120        if weights is None:
121            # Equal weights
122            weights = [1.0 / len(retrievers)] * len(retrievers)
123        else:
124            if len(weights) != len(retrievers):
125                raise ValueError(
126                    f"Number of weights ({len(weights)}) must match "
127                    f"number of retrievers ({len(retrievers)})"
128                )
129            # Normalize weights to sum to 1.0
130            total = sum(weights)
131            weights = [w / total for w in weights]
132        
133        self.retrievers = retrievers
134        self.weights = weights
135        self.fusion_strategy = fusion_strategy
136        self.rrf_k = rrf_k

Initialize ensemble retriever.

Args: retrievers: List of retrievers to combine weights: Optional weights for each retriever (default: equal weights) fusion_strategy: Fusion method - "rrf", "weighted_avg", or "max_score" rrf_k: k parameter for RRF (default: 60)

Raises: ValueError: If retrievers list is empty or params are invalid

retrievers
weights
fusion_strategy
rrf_k
def retrieve( self, query: RetrievalQuery) -> List[gmf_forge_ai_data.SearchResult]:
138    def retrieve(self, query: RetrievalQuery) -> List[SearchResult]:
139        """
140        Retrieve documents using ensemble fusion.
141        
142        Args:
143            query: RetrievalQuery with parameters for all retrievers
144        
145        Returns:
146            List of SearchResult with fused scores and ranks
147        
148        Note:
149            All retrievers are called with the same query. Ensure the query
150            contains all required parameters (text, embedding) for your retrievers.
151        """
152        # Retrieve from all retrievers
153        all_results: List[List[SearchResult]] = []
154        
155        for retriever in self.retrievers:
156            try:
157                results = retriever.retrieve(query)
158                all_results.append(results)
159            except ValueError as e:
160                # Retriever might not support this query type, skip it
161                all_results.append([])
162        
163        if not any(all_results):
164            return []
165        
166        # Apply fusion strategy
167        if self.fusion_strategy == "rrf":
168            fused_results = self._reciprocal_rank_fusion(all_results)
169        elif self.fusion_strategy == "weighted_avg":
170            fused_results = self._weighted_average_fusion(all_results)
171        else:  # max_score
172            fused_results = self._max_score_fusion(all_results)
173        
174        # Sort by fused score and take top_k
175        fused_results.sort(key=lambda x: x.score, reverse=True)
176        top_results = fused_results[:query.top_k]
177        
178        # Update ranks
179        for rank, result in enumerate(top_results):
180            result.rank = rank
181        
182        return top_results

Retrieve documents using ensemble fusion.

Args: query: RetrievalQuery with parameters for all retrievers

Returns: List of SearchResult with fused scores and ranks

Note: All retrievers are called with the same query. Ensure the query contains all required parameters (text, embedding) for your retrievers.

class HierarchicalRetriever(gmf_forge_ai_data.retrieval.BaseRetriever):
 24class HierarchicalRetriever(BaseRetriever):
 25    """
 26    Two-stage hierarchical retrieval for efficient search in large collections.
 27    
 28    Stage 1: Retrieve document summaries to identify relevant documents
 29    Stage 2: Retrieve detailed chunks from top documents only
 30    
 31    This reduces computational cost by focusing detailed retrieval on relevant documents.
 32    
 33    Example:
 34        ```python
 35        # Create summary and chunk retrievers
 36        summary_retriever = VectorRetriever(summary_store, embedder)
 37        chunk_retriever = VectorRetriever(chunk_store, embedder)
 38        
 39        # Create hierarchical retriever
 40        retriever = HierarchicalRetriever(
 41            summary_retriever=summary_retriever,
 42            chunk_retriever=chunk_retriever,
 43            stage1_top_k=10,  # Get top 10 documents
 44            stage2_top_k=5,   # Get 5 chunks per document
 45            document_id_field="document_id"
 46        )
 47        
 48        # Retrieve
 49        query = RetrievalQuery(text="machine learning", top_k=20)
 50        results = retriever.retrieve(query)
 51        ```
 52    """
 53    
 54    def __init__(
 55        self,
 56        summary_retriever: BaseRetriever,
 57        chunk_retriever: BaseRetriever,
 58        stage1_top_k: int = 10,
 59        stage2_top_k: int = 5,
 60        document_id_field: str = "document_id",
 61        combine_scores: bool = True,
 62        stage1_weight: float = 0.3,
 63        stage2_weight: float = 0.7
 64    ):
 65        """
 66        Initialize hierarchical retriever.
 67        
 68        Args:
 69            summary_retriever: Retriever for document summaries (stage 1)
 70            chunk_retriever: Retriever for detailed chunks (stage 2)
 71            stage1_top_k: Number of documents to retrieve in stage 1
 72            stage2_top_k: Number of chunks to retrieve per document in stage 2
 73            document_id_field: Metadata field linking chunks to documents
 74            combine_scores: If True, combine stage1 and stage2 scores
 75            stage1_weight: Weight for stage 1 score (document relevance)
 76            stage2_weight: Weight for stage 2 score (chunk relevance)
 77        """
 78        self.summary_retriever = summary_retriever
 79        self.chunk_retriever = chunk_retriever
 80        self.stage1_top_k = stage1_top_k
 81        self.stage2_top_k = stage2_top_k
 82        self.document_id_field = document_id_field
 83        self.combine_scores = combine_scores
 84        self.stage1_weight = stage1_weight
 85        self.stage2_weight = stage2_weight
 86        
 87        # Normalize weights
 88        total_weight = stage1_weight + stage2_weight
 89        if total_weight > 0:
 90            self.stage1_weight = stage1_weight / total_weight
 91            self.stage2_weight = stage2_weight / total_weight
 92    
 93    def retrieve(
 94        self,
 95        query: RetrievalQuery,
 96        **kwargs
 97    ) -> List[SearchResult]:
 98        """
 99        Perform two-stage hierarchical retrieval.
100        
101        Args:
102            query: The retrieval query
103            **kwargs: Additional arguments
104            
105        Returns:
106            List of SearchResult objects from stage 2, optionally with combined scores
107        """
108        # Stage 1: Retrieve document summaries
109        logger.info(f"Stage 1: Retrieving top {self.stage1_top_k} document summaries")
110        
111        stage1_query = RetrievalQuery(
112            text=query.text,
113            embedding=query.embedding,
114            top_k=self.stage1_top_k,
115            filters=query.filters,
116            metadata=query.metadata
117        )
118        
119        summary_results = self.summary_retriever.retrieve(stage1_query, **kwargs)
120        
121        if not summary_results:
122            logger.warning("Stage 1 returned no results")
123            return []
124        
125        # Extract document IDs from summaries
126        document_ids = []
127        document_scores = {}
128        
129        for result in summary_results:
130            doc_id = result.document.metadata.get(self.document_id_field)
131            if doc_id:
132                document_ids.append(doc_id)
133                document_scores[doc_id] = result.score
134            else:
135                # If no document_id field, use the document id itself
136                document_ids.append(result.document.id)
137                document_scores[result.document.id] = result.score
138        
139        logger.info(f"Stage 1 identified {len(document_ids)} relevant documents")
140        
141        # Stage 2: Retrieve detailed chunks from top documents
142        logger.info(f"Stage 2: Retrieving up to {self.stage2_top_k} chunks per document")
143        
144        all_chunks: List[SearchResult] = []
145        
146        for doc_id in document_ids:
147            # Create filter for this document
148            doc_filter = {self.document_id_field: doc_id}
149            
150            # Merge with existing filters if any
151            if query.filters:
152                doc_filter.update(query.filters)
153            
154            # Retrieve chunks for this document
155            stage2_query = RetrievalQuery(
156                text=query.text,
157                embedding=query.embedding,
158                top_k=self.stage2_top_k,
159                filters=doc_filter,
160                metadata=query.metadata
161            )
162            
163            try:
164                chunk_results = self.chunk_retriever.retrieve(stage2_query, **kwargs)
165                
166                if self.combine_scores and doc_id in document_scores:
167                    # Combine stage 1 and stage 2 scores
168                    for result in chunk_results:
169                        combined_score = (
170                            self.stage1_weight * document_scores[doc_id] +
171                            self.stage2_weight * result.score
172                        )
173                        # Create new result with combined score
174                        result.score = combined_score
175                
176                all_chunks.extend(chunk_results)
177                
178            except ValueError as e:
179                logger.warning(f"Failed to retrieve chunks for document {doc_id}: {e}")
180                continue
181        
182        if not all_chunks:
183            logger.warning("Stage 2 returned no chunks")
184            return []
185        
186        # Sort by score and limit to top_k
187        all_chunks.sort(key=lambda x: x.score, reverse=True)
188        final_results = all_chunks[:query.top_k]
189        
190        # Update ranks
191        for i, result in enumerate(final_results):
192            result.rank = i + 1
193        
194        logger.info(
195            f"Hierarchical retrieval complete: {len(final_results)} chunks from "
196            f"{len(document_ids)} documents"
197        )
198        
199        return final_results

Two-stage hierarchical retrieval for efficient search in large collections.

Stage 1: Retrieve document summaries to identify relevant documents Stage 2: Retrieve detailed chunks from top documents only

This reduces computational cost by focusing detailed retrieval on relevant documents.

Example:

# Create summary and chunk retrievers
summary_retriever = VectorRetriever(summary_store, embedder)
chunk_retriever = VectorRetriever(chunk_store, embedder)

# Create hierarchical retriever
retriever = HierarchicalRetriever(
    summary_retriever=summary_retriever,
    chunk_retriever=chunk_retriever,
    stage1_top_k=10,  # Get top 10 documents
    stage2_top_k=5,   # Get 5 chunks per document
    document_id_field="document_id"
)

# Retrieve
query = RetrievalQuery(text="machine learning", top_k=20)
results = retriever.retrieve(query)
HierarchicalRetriever( summary_retriever: BaseRetriever, chunk_retriever: BaseRetriever, stage1_top_k: int = 10, stage2_top_k: int = 5, document_id_field: str = 'document_id', combine_scores: bool = True, stage1_weight: float = 0.3, stage2_weight: float = 0.7)
54    def __init__(
55        self,
56        summary_retriever: BaseRetriever,
57        chunk_retriever: BaseRetriever,
58        stage1_top_k: int = 10,
59        stage2_top_k: int = 5,
60        document_id_field: str = "document_id",
61        combine_scores: bool = True,
62        stage1_weight: float = 0.3,
63        stage2_weight: float = 0.7
64    ):
65        """
66        Initialize hierarchical retriever.
67        
68        Args:
69            summary_retriever: Retriever for document summaries (stage 1)
70            chunk_retriever: Retriever for detailed chunks (stage 2)
71            stage1_top_k: Number of documents to retrieve in stage 1
72            stage2_top_k: Number of chunks to retrieve per document in stage 2
73            document_id_field: Metadata field linking chunks to documents
74            combine_scores: If True, combine stage1 and stage2 scores
75            stage1_weight: Weight for stage 1 score (document relevance)
76            stage2_weight: Weight for stage 2 score (chunk relevance)
77        """
78        self.summary_retriever = summary_retriever
79        self.chunk_retriever = chunk_retriever
80        self.stage1_top_k = stage1_top_k
81        self.stage2_top_k = stage2_top_k
82        self.document_id_field = document_id_field
83        self.combine_scores = combine_scores
84        self.stage1_weight = stage1_weight
85        self.stage2_weight = stage2_weight
86        
87        # Normalize weights
88        total_weight = stage1_weight + stage2_weight
89        if total_weight > 0:
90            self.stage1_weight = stage1_weight / total_weight
91            self.stage2_weight = stage2_weight / total_weight

Initialize hierarchical retriever.

Args: summary_retriever: Retriever for document summaries (stage 1) chunk_retriever: Retriever for detailed chunks (stage 2) stage1_top_k: Number of documents to retrieve in stage 1 stage2_top_k: Number of chunks to retrieve per document in stage 2 document_id_field: Metadata field linking chunks to documents combine_scores: If True, combine stage1 and stage2 scores stage1_weight: Weight for stage 1 score (document relevance) stage2_weight: Weight for stage 2 score (chunk relevance)

summary_retriever
chunk_retriever
stage1_top_k
stage2_top_k
document_id_field
combine_scores
stage1_weight
stage2_weight
def retrieve( self, query: RetrievalQuery, **kwargs) -> List[gmf_forge_ai_data.SearchResult]:
 93    def retrieve(
 94        self,
 95        query: RetrievalQuery,
 96        **kwargs
 97    ) -> List[SearchResult]:
 98        """
 99        Perform two-stage hierarchical retrieval.
100        
101        Args:
102            query: The retrieval query
103            **kwargs: Additional arguments
104            
105        Returns:
106            List of SearchResult objects from stage 2, optionally with combined scores
107        """
108        # Stage 1: Retrieve document summaries
109        logger.info(f"Stage 1: Retrieving top {self.stage1_top_k} document summaries")
110        
111        stage1_query = RetrievalQuery(
112            text=query.text,
113            embedding=query.embedding,
114            top_k=self.stage1_top_k,
115            filters=query.filters,
116            metadata=query.metadata
117        )
118        
119        summary_results = self.summary_retriever.retrieve(stage1_query, **kwargs)
120        
121        if not summary_results:
122            logger.warning("Stage 1 returned no results")
123            return []
124        
125        # Extract document IDs from summaries
126        document_ids = []
127        document_scores = {}
128        
129        for result in summary_results:
130            doc_id = result.document.metadata.get(self.document_id_field)
131            if doc_id:
132                document_ids.append(doc_id)
133                document_scores[doc_id] = result.score
134            else:
135                # If no document_id field, use the document id itself
136                document_ids.append(result.document.id)
137                document_scores[result.document.id] = result.score
138        
139        logger.info(f"Stage 1 identified {len(document_ids)} relevant documents")
140        
141        # Stage 2: Retrieve detailed chunks from top documents
142        logger.info(f"Stage 2: Retrieving up to {self.stage2_top_k} chunks per document")
143        
144        all_chunks: List[SearchResult] = []
145        
146        for doc_id in document_ids:
147            # Create filter for this document
148            doc_filter = {self.document_id_field: doc_id}
149            
150            # Merge with existing filters if any
151            if query.filters:
152                doc_filter.update(query.filters)
153            
154            # Retrieve chunks for this document
155            stage2_query = RetrievalQuery(
156                text=query.text,
157                embedding=query.embedding,
158                top_k=self.stage2_top_k,
159                filters=doc_filter,
160                metadata=query.metadata
161            )
162            
163            try:
164                chunk_results = self.chunk_retriever.retrieve(stage2_query, **kwargs)
165                
166                if self.combine_scores and doc_id in document_scores:
167                    # Combine stage 1 and stage 2 scores
168                    for result in chunk_results:
169                        combined_score = (
170                            self.stage1_weight * document_scores[doc_id] +
171                            self.stage2_weight * result.score
172                        )
173                        # Create new result with combined score
174                        result.score = combined_score
175                
176                all_chunks.extend(chunk_results)
177                
178            except ValueError as e:
179                logger.warning(f"Failed to retrieve chunks for document {doc_id}: {e}")
180                continue
181        
182        if not all_chunks:
183            logger.warning("Stage 2 returned no chunks")
184            return []
185        
186        # Sort by score and limit to top_k
187        all_chunks.sort(key=lambda x: x.score, reverse=True)
188        final_results = all_chunks[:query.top_k]
189        
190        # Update ranks
191        for i, result in enumerate(final_results):
192            result.rank = i + 1
193        
194        logger.info(
195            f"Hierarchical retrieval complete: {len(final_results)} chunks from "
196            f"{len(document_ids)} documents"
197        )
198        
199        return final_results

Perform two-stage hierarchical retrieval.

Args: query: The retrieval query **kwargs: Additional arguments

Returns: List of SearchResult objects from stage 2, optionally with combined scores

class GraphRetriever(gmf_forge_ai_data.retrieval.BaseRetriever):
 45class GraphRetriever(BaseRetriever):
 46    """
 47    Graph-based retrieval using entity relationships.
 48    
 49    This retriever:
 50    1. Extracts entities from the query
 51    2. Finds matching entities in the knowledge graph
 52    3. Traverses the graph to find related entities
 53    4. Retrieves documents associated with relevant entities
 54    
 55    Example:
 56        ```python
 57        # Create knowledge graph
 58        graph = nx.DiGraph()
 59        graph.add_edge("Python", "Machine Learning", relation="used_in", weight=0.9)
 60        graph.add_edge("Machine Learning", "Neural Networks", relation="includes", weight=0.8)
 61        
 62        # Create entity-document mapping
 63        entity_docs = {
 64            "Python": ["doc1", "doc2"],
 65            "Machine Learning": ["doc3", "doc4"],
 66            "Neural Networks": ["doc5"]
 67        }
 68        
 69        # Create retriever
 70        retriever = GraphRetriever(
 71            vector_store=store,
 72            knowledge_graph=graph,
 73            entity_document_mapping=entity_docs,
 74            embedder=embedder,
 75            max_hops=2
 76        )
 77        
 78        # Retrieve
 79        query = RetrievalQuery(text="Python for deep learning", top_k=5)
 80        results = retriever.retrieve(query)
 81        ```
 82    """
 83    
 84    def __init__(
 85        self,
 86        vector_store: BaseVectorStore,
 87        knowledge_graph: Optional['nx.DiGraph'] = None,
 88        entity_document_mapping: Optional[Dict[str, List[str]]] = None,
 89        embedder: Optional[Any] = None,
 90        max_hops: int = 2,
 91        min_relation_weight: float = 0.5,
 92        combine_vector_scores: bool = True,
 93        graph_weight: float = 0.4,
 94        vector_weight: float = 0.6
 95    ):
 96        """
 97        Initialize graph retriever.
 98        
 99        Args:
100            vector_store: Vector store for document retrieval
101            knowledge_graph: NetworkX DiGraph with entity relationships
102            entity_document_mapping: Dict mapping entity IDs to document IDs
103            embedder: Embedding provider for query encoding
104            max_hops: Maximum number of hops in graph traversal
105            min_relation_weight: Minimum weight for relationship edges
106            combine_vector_scores: If True, combine graph and vector scores
107            graph_weight: Weight for graph-based relevance
108            vector_weight: Weight for vector similarity
109        """
110        if not NETWORKX_AVAILABLE:
111            raise ImportError(
112                "NetworkX is required for GraphRetriever. "
113                "Install with: pip install networkx"
114            )
115        
116        self.vector_store = vector_store
117        self.knowledge_graph = knowledge_graph or nx.DiGraph()
118        self.entity_document_mapping = entity_document_mapping or {}
119        self.embedder = embedder
120        self.max_hops = max_hops
121        self.min_relation_weight = min_relation_weight
122        self.combine_vector_scores = combine_vector_scores
123        self.graph_weight = graph_weight
124        self.vector_weight = vector_weight
125        
126        # Normalize weights
127        total_weight = graph_weight + vector_weight
128        if total_weight > 0:
129            self.graph_weight = graph_weight / total_weight
130            self.vector_weight = vector_weight / total_weight
131    
132    def extract_entities(self, query_text: str) -> List[str]:
133        """
134        Extract entities from query text.
135        
136        Simple implementation: looks for entity names in the knowledge graph.
137        For production, use NER models (spaCy, Hugging Face, etc.).
138        
139        Args:
140            query_text: Query text
141            
142        Returns:
143            List of entity IDs found in the query
144        """
145        query_lower = query_text.lower()
146        entities = []
147        
148        # Simple matching: check if entity names appear in query
149        for entity_id in self.knowledge_graph.nodes():
150            entity_name = str(entity_id).lower()
151            if entity_name in query_lower:
152                entities.append(entity_id)
153        
154        logger.info(f"Extracted {len(entities)} entities from query: {entities}")
155        return entities
156    
157    def traverse_graph(
158        self,
159        seed_entities: List[str],
160        max_hops: int
161    ) -> Dict[str, float]:
162        """
163        Traverse knowledge graph from seed entities.
164        
165        Args:
166            seed_entities: Starting entities
167            max_hops: Maximum number of hops
168            
169        Returns:
170            Dict mapping entity IDs to relevance scores (0-1)
171        """
172        entity_scores: Dict[str, float] = {}
173        
174        # Initialize seed entities with score 1.0
175        for entity in seed_entities:
176            if entity in self.knowledge_graph:
177                entity_scores[entity] = 1.0
178        
179        if not entity_scores:
180            return entity_scores
181        
182        # BFS traversal
183        visited: Set[str] = set()
184        current_level = [(e, 1.0, 0) for e in seed_entities]  # (entity, score, hop)
185        
186        while current_level:
187            next_level = []
188            
189            for entity_id, score, hop in current_level:
190                if entity_id in visited or hop >= max_hops:
191                    continue
192                
193                visited.add(entity_id)
194                
195                # Get neighbors
196                if entity_id not in self.knowledge_graph:
197                    continue
198                
199                for neighbor in self.knowledge_graph.neighbors(entity_id):
200                    # Get edge weight
201                    edge_data = self.knowledge_graph.get_edge_data(entity_id, neighbor)
202                    edge_weight = edge_data.get('weight', 1.0) if edge_data else 1.0
203                    
204                    # Skip weak relationships
205                    if edge_weight < self.min_relation_weight:
206                        continue
207                    
208                    # Calculate neighbor score (decay with hops)
209                    decay_factor = 0.7 ** (hop + 1)
210                    neighbor_score = score * edge_weight * decay_factor
211                    
212                    # Update score if better
213                    if neighbor not in entity_scores or neighbor_score > entity_scores[neighbor]:
214                        entity_scores[neighbor] = neighbor_score
215                        next_level.append((neighbor, neighbor_score, hop + 1))
216            
217            current_level = next_level
218        
219        logger.info(f"Graph traversal found {len(entity_scores)} relevant entities")
220        return entity_scores
221    
222    def retrieve(
223        self,
224        query: RetrievalQuery,
225        **kwargs
226    ) -> List[SearchResult]:
227        """
228        Perform graph-based retrieval.
229        
230        Args:
231            query: The retrieval query
232            **kwargs: Additional arguments
233            
234        Returns:
235            List of SearchResult objects ranked by combined graph + vector scores
236        """
237        if not query.text:
238            raise ValueError("GraphRetriever requires query.text")
239        
240        # Step 1: Extract entities from query
241        seed_entities = self.extract_entities(query.text)
242        
243        if not seed_entities:
244            logger.warning("No entities found in query, falling back to vector search")
245            # Fallback to pure vector search
246            if self.embedder and not query.embedding:
247                query.embedding = self.embedder.embed_text(query.text)
248            
249            return self.vector_store.search(
250                query_embedding=query.embedding,
251                top_k=query.top_k,
252                filters=query.filters
253            )
254        
255        # Step 2: Traverse graph to find related entities
256        entity_scores = self.traverse_graph(seed_entities, self.max_hops)
257        
258        # Step 3: Collect documents associated with relevant entities
259        doc_graph_scores: Dict[str, float] = {}
260        
261        for entity_id, entity_score in entity_scores.items():
262            doc_ids = self.entity_document_mapping.get(entity_id, [])
263            for doc_id in doc_ids:
264                if doc_id not in doc_graph_scores or entity_score > doc_graph_scores[doc_id]:
265                    doc_graph_scores[doc_id] = entity_score
266        
267        logger.info(f"Found {len(doc_graph_scores)} documents via graph traversal")
268        
269        if not doc_graph_scores:
270            logger.warning("No documents found via graph, falling back to vector search")
271            if self.embedder and not query.embedding:
272                query.embedding = self.embedder.embed_text(query.text)
273            
274            return self.vector_store.search(
275                query_embedding=query.embedding,
276                top_k=query.top_k,
277                filters=query.filters
278            )
279        
280        # Step 4: Get vector scores if combining
281        if self.combine_vector_scores:
282            # Get embeddings for vector search
283            if self.embedder and not query.embedding:
284                query.embedding = self.embedder.embed_text(query.text)
285            
286            # Retrieve more documents for scoring
287            vector_results = self.vector_store.search(
288                query_embedding=query.embedding,
289                top_k=query.top_k * 3,  # Get more for better coverage
290                filters=query.filters
291            )
292            
293            # Create document ID to vector score mapping
294            doc_vector_scores: Dict[str, float] = {}
295            for result in vector_results:
296                doc_vector_scores[result.document.id] = result.score
297            
298            # Combine scores
299            combined_results: List[SearchResult] = []
300            
301            for doc_id, graph_score in doc_graph_scores.items():
302                # Get document from vector store
303                document = self.vector_store.get_document(doc_id)
304                if not document:
305                    continue
306                
307                # Get vector score (0 if not found)
308                vector_score = doc_vector_scores.get(doc_id, 0.0)
309                
310                # Combine scores
311                combined_score = (
312                    self.graph_weight * graph_score +
313                    self.vector_weight * vector_score
314                )
315                
316                combined_results.append(
317                    SearchResult(
318                        document=document,
319                        score=combined_score,
320                        rank=0  # Will be set later
321                    )
322                )
323        else:
324            # Use only graph scores
325            combined_results: List[SearchResult] = []
326            
327            for doc_id, graph_score in doc_graph_scores.items():
328                document = self.vector_store.get_document(doc_id)
329                if not document:
330                    continue
331                
332                combined_results.append(
333                    SearchResult(
334                        document=document,
335                        score=graph_score,
336                        rank=0
337                    )
338                )
339        
340        # Sort by score and limit to top_k
341        combined_results.sort(key=lambda x: x.score, reverse=True)
342        final_results = combined_results[:query.top_k]
343        
344        # Update ranks
345        for i, result in enumerate(final_results):
346            result.rank = i + 1
347        
348        logger.info(f"Graph retrieval complete: {len(final_results)} results")
349        
350        return final_results

Graph-based retrieval using entity relationships.

This retriever:

  1. Extracts entities from the query
  2. Finds matching entities in the knowledge graph
  3. Traverses the graph to find related entities
  4. Retrieves documents associated with relevant entities

Example:

# Create knowledge graph
graph = nx.DiGraph()
graph.add_edge("Python", "Machine Learning", relation="used_in", weight=0.9)
graph.add_edge("Machine Learning", "Neural Networks", relation="includes", weight=0.8)

# Create entity-document mapping
entity_docs = {
    "Python": ["doc1", "doc2"],
    "Machine Learning": ["doc3", "doc4"],
    "Neural Networks": ["doc5"]
}

# Create retriever
retriever = GraphRetriever(
    vector_store=store,
    knowledge_graph=graph,
    entity_document_mapping=entity_docs,
    embedder=embedder,
    max_hops=2
)

# Retrieve
query = RetrievalQuery(text="Python for deep learning", top_k=5)
results = retriever.retrieve(query)
GraphRetriever( vector_store: gmf_forge_ai_data.BaseVectorStore, knowledge_graph: Optional[networkx.classes.digraph.DiGraph] = None, entity_document_mapping: Optional[Dict[str, List[str]]] = None, embedder: Optional[Any] = None, max_hops: int = 2, min_relation_weight: float = 0.5, combine_vector_scores: bool = True, graph_weight: float = 0.4, vector_weight: float = 0.6)
 84    def __init__(
 85        self,
 86        vector_store: BaseVectorStore,
 87        knowledge_graph: Optional['nx.DiGraph'] = None,
 88        entity_document_mapping: Optional[Dict[str, List[str]]] = None,
 89        embedder: Optional[Any] = None,
 90        max_hops: int = 2,
 91        min_relation_weight: float = 0.5,
 92        combine_vector_scores: bool = True,
 93        graph_weight: float = 0.4,
 94        vector_weight: float = 0.6
 95    ):
 96        """
 97        Initialize graph retriever.
 98        
 99        Args:
100            vector_store: Vector store for document retrieval
101            knowledge_graph: NetworkX DiGraph with entity relationships
102            entity_document_mapping: Dict mapping entity IDs to document IDs
103            embedder: Embedding provider for query encoding
104            max_hops: Maximum number of hops in graph traversal
105            min_relation_weight: Minimum weight for relationship edges
106            combine_vector_scores: If True, combine graph and vector scores
107            graph_weight: Weight for graph-based relevance
108            vector_weight: Weight for vector similarity
109        """
110        if not NETWORKX_AVAILABLE:
111            raise ImportError(
112                "NetworkX is required for GraphRetriever. "
113                "Install with: pip install networkx"
114            )
115        
116        self.vector_store = vector_store
117        self.knowledge_graph = knowledge_graph or nx.DiGraph()
118        self.entity_document_mapping = entity_document_mapping or {}
119        self.embedder = embedder
120        self.max_hops = max_hops
121        self.min_relation_weight = min_relation_weight
122        self.combine_vector_scores = combine_vector_scores
123        self.graph_weight = graph_weight
124        self.vector_weight = vector_weight
125        
126        # Normalize weights
127        total_weight = graph_weight + vector_weight
128        if total_weight > 0:
129            self.graph_weight = graph_weight / total_weight
130            self.vector_weight = vector_weight / total_weight

Initialize graph retriever.

Args: vector_store: Vector store for document retrieval knowledge_graph: NetworkX DiGraph with entity relationships entity_document_mapping: Dict mapping entity IDs to document IDs embedder: Embedding provider for query encoding max_hops: Maximum number of hops in graph traversal min_relation_weight: Minimum weight for relationship edges combine_vector_scores: If True, combine graph and vector scores graph_weight: Weight for graph-based relevance vector_weight: Weight for vector similarity

vector_store
knowledge_graph
entity_document_mapping
embedder
max_hops
min_relation_weight
combine_vector_scores
graph_weight
vector_weight
def extract_entities(self, query_text: str) -> List[str]:
132    def extract_entities(self, query_text: str) -> List[str]:
133        """
134        Extract entities from query text.
135        
136        Simple implementation: looks for entity names in the knowledge graph.
137        For production, use NER models (spaCy, Hugging Face, etc.).
138        
139        Args:
140            query_text: Query text
141            
142        Returns:
143            List of entity IDs found in the query
144        """
145        query_lower = query_text.lower()
146        entities = []
147        
148        # Simple matching: check if entity names appear in query
149        for entity_id in self.knowledge_graph.nodes():
150            entity_name = str(entity_id).lower()
151            if entity_name in query_lower:
152                entities.append(entity_id)
153        
154        logger.info(f"Extracted {len(entities)} entities from query: {entities}")
155        return entities

Extract entities from query text.

Simple implementation: looks for entity names in the knowledge graph. For production, use NER models (spaCy, Hugging Face, etc.).

Args: query_text: Query text

Returns: List of entity IDs found in the query

def traverse_graph(self, seed_entities: List[str], max_hops: int) -> Dict[str, float]:
157    def traverse_graph(
158        self,
159        seed_entities: List[str],
160        max_hops: int
161    ) -> Dict[str, float]:
162        """
163        Traverse knowledge graph from seed entities.
164        
165        Args:
166            seed_entities: Starting entities
167            max_hops: Maximum number of hops
168            
169        Returns:
170            Dict mapping entity IDs to relevance scores (0-1)
171        """
172        entity_scores: Dict[str, float] = {}
173        
174        # Initialize seed entities with score 1.0
175        for entity in seed_entities:
176            if entity in self.knowledge_graph:
177                entity_scores[entity] = 1.0
178        
179        if not entity_scores:
180            return entity_scores
181        
182        # BFS traversal
183        visited: Set[str] = set()
184        current_level = [(e, 1.0, 0) for e in seed_entities]  # (entity, score, hop)
185        
186        while current_level:
187            next_level = []
188            
189            for entity_id, score, hop in current_level:
190                if entity_id in visited or hop >= max_hops:
191                    continue
192                
193                visited.add(entity_id)
194                
195                # Get neighbors
196                if entity_id not in self.knowledge_graph:
197                    continue
198                
199                for neighbor in self.knowledge_graph.neighbors(entity_id):
200                    # Get edge weight
201                    edge_data = self.knowledge_graph.get_edge_data(entity_id, neighbor)
202                    edge_weight = edge_data.get('weight', 1.0) if edge_data else 1.0
203                    
204                    # Skip weak relationships
205                    if edge_weight < self.min_relation_weight:
206                        continue
207                    
208                    # Calculate neighbor score (decay with hops)
209                    decay_factor = 0.7 ** (hop + 1)
210                    neighbor_score = score * edge_weight * decay_factor
211                    
212                    # Update score if better
213                    if neighbor not in entity_scores or neighbor_score > entity_scores[neighbor]:
214                        entity_scores[neighbor] = neighbor_score
215                        next_level.append((neighbor, neighbor_score, hop + 1))
216            
217            current_level = next_level
218        
219        logger.info(f"Graph traversal found {len(entity_scores)} relevant entities")
220        return entity_scores

Traverse knowledge graph from seed entities.

Args: seed_entities: Starting entities max_hops: Maximum number of hops

Returns: Dict mapping entity IDs to relevance scores (0-1)

def retrieve( self, query: RetrievalQuery, **kwargs) -> List[gmf_forge_ai_data.SearchResult]:
222    def retrieve(
223        self,
224        query: RetrievalQuery,
225        **kwargs
226    ) -> List[SearchResult]:
227        """
228        Perform graph-based retrieval.
229        
230        Args:
231            query: The retrieval query
232            **kwargs: Additional arguments
233            
234        Returns:
235            List of SearchResult objects ranked by combined graph + vector scores
236        """
237        if not query.text:
238            raise ValueError("GraphRetriever requires query.text")
239        
240        # Step 1: Extract entities from query
241        seed_entities = self.extract_entities(query.text)
242        
243        if not seed_entities:
244            logger.warning("No entities found in query, falling back to vector search")
245            # Fallback to pure vector search
246            if self.embedder and not query.embedding:
247                query.embedding = self.embedder.embed_text(query.text)
248            
249            return self.vector_store.search(
250                query_embedding=query.embedding,
251                top_k=query.top_k,
252                filters=query.filters
253            )
254        
255        # Step 2: Traverse graph to find related entities
256        entity_scores = self.traverse_graph(seed_entities, self.max_hops)
257        
258        # Step 3: Collect documents associated with relevant entities
259        doc_graph_scores: Dict[str, float] = {}
260        
261        for entity_id, entity_score in entity_scores.items():
262            doc_ids = self.entity_document_mapping.get(entity_id, [])
263            for doc_id in doc_ids:
264                if doc_id not in doc_graph_scores or entity_score > doc_graph_scores[doc_id]:
265                    doc_graph_scores[doc_id] = entity_score
266        
267        logger.info(f"Found {len(doc_graph_scores)} documents via graph traversal")
268        
269        if not doc_graph_scores:
270            logger.warning("No documents found via graph, falling back to vector search")
271            if self.embedder and not query.embedding:
272                query.embedding = self.embedder.embed_text(query.text)
273            
274            return self.vector_store.search(
275                query_embedding=query.embedding,
276                top_k=query.top_k,
277                filters=query.filters
278            )
279        
280        # Step 4: Get vector scores if combining
281        if self.combine_vector_scores:
282            # Get embeddings for vector search
283            if self.embedder and not query.embedding:
284                query.embedding = self.embedder.embed_text(query.text)
285            
286            # Retrieve more documents for scoring
287            vector_results = self.vector_store.search(
288                query_embedding=query.embedding,
289                top_k=query.top_k * 3,  # Get more for better coverage
290                filters=query.filters
291            )
292            
293            # Create document ID to vector score mapping
294            doc_vector_scores: Dict[str, float] = {}
295            for result in vector_results:
296                doc_vector_scores[result.document.id] = result.score
297            
298            # Combine scores
299            combined_results: List[SearchResult] = []
300            
301            for doc_id, graph_score in doc_graph_scores.items():
302                # Get document from vector store
303                document = self.vector_store.get_document(doc_id)
304                if not document:
305                    continue
306                
307                # Get vector score (0 if not found)
308                vector_score = doc_vector_scores.get(doc_id, 0.0)
309                
310                # Combine scores
311                combined_score = (
312                    self.graph_weight * graph_score +
313                    self.vector_weight * vector_score
314                )
315                
316                combined_results.append(
317                    SearchResult(
318                        document=document,
319                        score=combined_score,
320                        rank=0  # Will be set later
321                    )
322                )
323        else:
324            # Use only graph scores
325            combined_results: List[SearchResult] = []
326            
327            for doc_id, graph_score in doc_graph_scores.items():
328                document = self.vector_store.get_document(doc_id)
329                if not document:
330                    continue
331                
332                combined_results.append(
333                    SearchResult(
334                        document=document,
335                        score=graph_score,
336                        rank=0
337                    )
338                )
339        
340        # Sort by score and limit to top_k
341        combined_results.sort(key=lambda x: x.score, reverse=True)
342        final_results = combined_results[:query.top_k]
343        
344        # Update ranks
345        for i, result in enumerate(final_results):
346            result.rank = i + 1
347        
348        logger.info(f"Graph retrieval complete: {len(final_results)} results")
349        
350        return final_results

Perform graph-based retrieval.

Args: query: The retrieval query **kwargs: Additional arguments

Returns: List of SearchResult objects ranked by combined graph + vector scores

class SQLRetriever(gmf_forge_ai_data.retrieval.BaseRetriever):
 38class SQLRetriever(BaseRetriever):
 39    """
 40    Retrieval from structured databases using SQL queries.
 41    
 42    This retriever:
 43    1. Converts natural language queries to SQL (via LLM or rule-based)
 44    2. Executes SQL against a database
 45    3. Converts results to Document objects
 46    4. Returns SearchResult objects with relevance scores
 47    
 48    Example:
 49        ```python
 50        import sqlite3
 51        
 52        # Create database connection
 53        conn = sqlite3.connect("products.db")
 54        
 55        # Define schema
 56        schema = SQLSchema(
 57            table_name="products",
 58            columns=[
 59                {"name": "id", "type": "INTEGER", "description": "Product ID"},
 60                {"name": "name", "type": "TEXT", "description": "Product name"},
 61                {"name": "price", "type": "REAL", "description": "Price in USD"},
 62                {"name": "category", "type": "TEXT", "description": "Product category"}
 63            ],
 64            primary_key="id",
 65            description="E-commerce product catalog"
 66        )
 67        
 68        # Create retriever
 69        retriever = SQLRetriever(
 70            db_connection=conn,
 71            schema=schema,
 72            text_to_sql_fn=my_text_to_sql_function,
 73            db_type="sqlite"
 74        )
 75        
 76        # Retrieve
 77        query = RetrievalQuery(text="products under $100 in electronics", top_k=10)
 78        results = retriever.retrieve(query)
 79        ```
 80    """
 81    
 82    def __init__(
 83        self,
 84        db_connection: Any,
 85        schema: SQLSchema,
 86        text_to_sql_fn: Optional[Callable[[str, SQLSchema], SQLQuery]] = None,
 87        db_type: str = "sqlite",
 88        content_columns: Optional[List[str]] = None,
 89        score_column: Optional[str] = None,
 90        default_score: float = 1.0
 91    ):
 92        """
 93        Initialize SQL retriever.
 94        
 95        Args:
 96            db_connection: Database connection object (e.g., sqlite3.Connection)
 97            schema: Database schema information
 98            text_to_sql_fn: Function to convert text to SQL (None = use simple rules)
 99            db_type: Database type ("sqlite", "postgresql", "mysql")
100            content_columns: Columns to use for document content (None = all columns)
101            score_column: Column to use for relevance score (None = use default_score)
102            default_score: Default score when no score_column specified
103        """
104        self.db_connection = db_connection
105        self.schema = schema
106        self.text_to_sql_fn = text_to_sql_fn or self._simple_text_to_sql
107        self.db_type = db_type
108        self.content_columns = content_columns
109        self.score_column = score_column
110        self.default_score = default_score
111    
112    def _simple_text_to_sql(self, query_text: str, schema: SQLSchema) -> SQLQuery:
113        """
114        Simple rule-based text-to-SQL conversion.
115        
116        For production, use LLM-based conversion or specialized libraries.
117        
118        Args:
119            query_text: Natural language query
120            schema: Database schema
121            
122        Returns:
123            SQLQuery object
124        """
125        # Extract keywords for filtering
126        query_lower = query_text.lower()
127        
128        # Build SELECT clause
129        if self.content_columns:
130            columns = ", ".join(self.content_columns)
131        else:
132            columns = "*"
133        
134        # Add score column if specified
135        if self.score_column:
136            columns = f"{columns}, {self.score_column}"
137        
138        # Build basic SELECT
139        sql = f"SELECT {columns} FROM {schema.table_name}"
140        
141        # Simple WHERE clause based on keywords
142        conditions = []
143        
144        # Look for numeric comparisons
145        if "under" in query_lower or "less than" in query_lower or "<" in query_lower:
146            # Try to find price column
147            price_cols = [c["name"] for c in schema.columns if "price" in c["name"].lower() or "cost" in c["name"].lower()]
148            if price_cols:
149                # Extract number
150                import re
151                numbers = re.findall(r'\d+', query_text)
152                if numbers:
153                    conditions.append(f"{price_cols[0]} < {numbers[0]}")
154        
155        if "over" in query_lower or "more than" in query_lower or ">" in query_lower:
156            price_cols = [c["name"] for c in schema.columns if "price" in c["name"].lower() or "cost" in c["name"].lower()]
157            if price_cols:
158                import re
159                numbers = re.findall(r'\d+', query_text)
160                if numbers:
161                    conditions.append(f"{price_cols[0]} > {numbers[0]}")
162        
163        # Look for text columns to search
164        text_cols = [c["name"] for c in schema.columns if c["type"].upper() in ["TEXT", "VARCHAR", "STRING"]]
165        
166        # Build LIKE conditions for text search
167        search_terms = []
168        # Remove common words
169        stop_words = {"in", "the", "a", "an", "under", "over", "less", "more", "than", "products", "items"}
170        words = [w for w in query_lower.split() if w not in stop_words and not w.isdigit()]
171        
172        for word in words:
173            if len(word) > 2:  # Skip very short words
174                search_terms.append(word)
175        
176        if search_terms and text_cols:
177            like_conditions = []
178            for col in text_cols:
179                for term in search_terms:
180                    like_conditions.append(f"{col} LIKE '%{term}%'")
181            
182            if like_conditions:
183                conditions.append(f"({' OR '.join(like_conditions)})")
184        
185        # Add WHERE clause if conditions exist
186        if conditions:
187            sql += " WHERE " + " AND ".join(conditions)
188        
189        # Add LIMIT
190        sql += " LIMIT 100"  # Safety limit
191        
192        return SQLQuery(
193            sql=sql,
194            explanation=f"Simple keyword-based SQL generation"
195        )
196    
197    def _execute_sql(self, sql_query: SQLQuery) -> List[Dict[str, Any]]:
198        """
199        Execute SQL query and return results.
200        
201        Args:
202            sql_query: SQL query to execute
203            
204        Returns:
205            List of result rows as dictionaries
206        """
207        cursor = self.db_connection.cursor()
208        
209        try:
210            logger.info(f"Executing SQL: {sql_query.sql}")
211            
212            if sql_query.parameters:
213                cursor.execute(sql_query.sql, sql_query.parameters)
214            else:
215                cursor.execute(sql_query.sql)
216            
217            # Get column names
218            if cursor.description:
219                columns = [desc[0] for desc in cursor.description]
220                
221                # Fetch results
222                rows = cursor.fetchall()
223                
224                # Convert to dictionaries
225                results = []
226                for row in rows:
227                    result_dict = dict(zip(columns, row))
228                    results.append(result_dict)
229                
230                logger.info(f"SQL returned {len(results)} rows")
231                return results
232            else:
233                return []
234                
235        except Exception as e:
236            logger.error(f"SQL execution failed: {e}")
237            raise
238        finally:
239            cursor.close()
240    
241    def _row_to_document(self, row: Dict[str, Any], index: int) -> Document:
242        """
243        Convert database row to Document object.
244        
245        Args:
246            row: Database row as dictionary
247            index: Row index for ID generation
248            
249        Returns:
250            Document object
251        """
252        # Extract score if available
253        score = self.default_score
254        if self.score_column and self.score_column in row:
255            score = float(row[self.score_column])
256        
257        # Build content from specified columns or all columns
258        if self.content_columns:
259            content_parts = []
260            for col in self.content_columns:
261                if col in row:
262                    content_parts.append(f"{col}: {row[col]}")
263            content = "\n".join(content_parts)
264        else:
265            # Use JSON representation
266            content = json.dumps(row, indent=2, default=str)
267        
268        # Use primary key if available, otherwise use index
269        doc_id = str(row.get(self.schema.primary_key, f"sql_result_{index}"))
270        
271        # Store all row data in metadata
272        metadata = {
273            "source": "sql_database",
274            "table": self.schema.table_name,
275            "row_data": row,
276            "sql_score": score
277        }
278        
279        return Document(
280            id=doc_id,
281            content=content,
282            metadata=metadata
283        )
284    
285    def retrieve(
286        self,
287        query: RetrievalQuery,
288        **kwargs
289    ) -> List[SearchResult]:
290        """
291        Perform SQL-based retrieval.
292        
293        Args:
294            query: The retrieval query
295            **kwargs: Additional arguments
296            
297        Returns:
298            List of SearchResult objects from database query
299        """
300        if not query.text:
301            raise ValueError("SQLRetriever requires query.text")
302        
303        # Convert text to SQL
304        sql_query = self.text_to_sql_fn(query.text, self.schema)
305        
306        logger.info(f"Generated SQL: {sql_query.sql}")
307        if sql_query.explanation:
308            logger.info(f"Explanation: {sql_query.explanation}")
309        
310        # Execute SQL
311        rows = self._execute_sql(sql_query)
312        
313        if not rows:
314            logger.warning("SQL query returned no results")
315            return []
316        
317        # Convert rows to Documents
318        documents = [
319            self._row_to_document(row, i)
320            for i, row in enumerate(rows)
321        ]
322        
323        # Create SearchResult objects
324        results = [
325            SearchResult(
326                document=doc,
327                score=doc.metadata.get("sql_score", self.default_score),
328                rank=i + 1
329            )
330            for i, doc in enumerate(documents)
331        ]
332        
333        # Limit to top_k
334        results = results[:query.top_k]
335        
336        # Update ranks
337        for i, result in enumerate(results):
338            result.rank = i + 1
339        
340        logger.info(f"SQL retrieval complete: {len(results)} results")
341        
342        return results

Retrieval from structured databases using SQL queries.

This retriever:

  1. Converts natural language queries to SQL (via LLM or rule-based)
  2. Executes SQL against a database
  3. Converts results to Document objects
  4. Returns SearchResult objects with relevance scores

Example:

import sqlite3

# Create database connection
conn = sqlite3.connect("products.db")

# Define schema
schema = SQLSchema(
    table_name="products",
    columns=[
        {"name": "id", "type": "INTEGER", "description": "Product ID"},
        {"name": "name", "type": "TEXT", "description": "Product name"},
        {"name": "price", "type": "REAL", "description": "Price in USD"},
        {"name": "category", "type": "TEXT", "description": "Product category"}
    ],
    primary_key="id",
    description="E-commerce product catalog"
)

# Create retriever
retriever = SQLRetriever(
    db_connection=conn,
    schema=schema,
    text_to_sql_fn=my_text_to_sql_function,
    db_type="sqlite"
)

# Retrieve
query = RetrievalQuery(text="products under $100 in electronics", top_k=10)
results = retriever.retrieve(query)
SQLRetriever( db_connection: Any, schema: SQLSchema, text_to_sql_fn: Optional[Callable[[str, SQLSchema], SQLQuery]] = None, db_type: str = 'sqlite', content_columns: Optional[List[str]] = None, score_column: Optional[str] = None, default_score: float = 1.0)
 82    def __init__(
 83        self,
 84        db_connection: Any,
 85        schema: SQLSchema,
 86        text_to_sql_fn: Optional[Callable[[str, SQLSchema], SQLQuery]] = None,
 87        db_type: str = "sqlite",
 88        content_columns: Optional[List[str]] = None,
 89        score_column: Optional[str] = None,
 90        default_score: float = 1.0
 91    ):
 92        """
 93        Initialize SQL retriever.
 94        
 95        Args:
 96            db_connection: Database connection object (e.g., sqlite3.Connection)
 97            schema: Database schema information
 98            text_to_sql_fn: Function to convert text to SQL (None = use simple rules)
 99            db_type: Database type ("sqlite", "postgresql", "mysql")
100            content_columns: Columns to use for document content (None = all columns)
101            score_column: Column to use for relevance score (None = use default_score)
102            default_score: Default score when no score_column specified
103        """
104        self.db_connection = db_connection
105        self.schema = schema
106        self.text_to_sql_fn = text_to_sql_fn or self._simple_text_to_sql
107        self.db_type = db_type
108        self.content_columns = content_columns
109        self.score_column = score_column
110        self.default_score = default_score

Initialize SQL retriever.

Args: db_connection: Database connection object (e.g., sqlite3.Connection) schema: Database schema information text_to_sql_fn: Function to convert text to SQL (None = use simple rules) db_type: Database type ("sqlite", "postgresql", "mysql") content_columns: Columns to use for document content (None = all columns) score_column: Column to use for relevance score (None = use default_score) default_score: Default score when no score_column specified

db_connection
schema
text_to_sql_fn
db_type
content_columns
score_column
default_score
def retrieve( self, query: RetrievalQuery, **kwargs) -> List[gmf_forge_ai_data.SearchResult]:
285    def retrieve(
286        self,
287        query: RetrievalQuery,
288        **kwargs
289    ) -> List[SearchResult]:
290        """
291        Perform SQL-based retrieval.
292        
293        Args:
294            query: The retrieval query
295            **kwargs: Additional arguments
296            
297        Returns:
298            List of SearchResult objects from database query
299        """
300        if not query.text:
301            raise ValueError("SQLRetriever requires query.text")
302        
303        # Convert text to SQL
304        sql_query = self.text_to_sql_fn(query.text, self.schema)
305        
306        logger.info(f"Generated SQL: {sql_query.sql}")
307        if sql_query.explanation:
308            logger.info(f"Explanation: {sql_query.explanation}")
309        
310        # Execute SQL
311        rows = self._execute_sql(sql_query)
312        
313        if not rows:
314            logger.warning("SQL query returned no results")
315            return []
316        
317        # Convert rows to Documents
318        documents = [
319            self._row_to_document(row, i)
320            for i, row in enumerate(rows)
321        ]
322        
323        # Create SearchResult objects
324        results = [
325            SearchResult(
326                document=doc,
327                score=doc.metadata.get("sql_score", self.default_score),
328                rank=i + 1
329            )
330            for i, doc in enumerate(documents)
331        ]
332        
333        # Limit to top_k
334        results = results[:query.top_k]
335        
336        # Update ranks
337        for i, result in enumerate(results):
338            result.rank = i + 1
339        
340        logger.info(f"SQL retrieval complete: {len(results)} results")
341        
342        return results

Perform SQL-based retrieval.

Args: query: The retrieval query **kwargs: Additional arguments

Returns: List of SearchResult objects from database query

class MultiIndexRetriever(gmf_forge_ai_data.retrieval.BaseRetriever):
 31class MultiIndexRetriever(BaseRetriever):
 32    """
 33    Retrieve from multiple indices/sources and merge results.
 34    
 35    This retriever:
 36    1. Queries multiple independent retrievers in parallel (or sequentially)
 37    2. Merges results from all sources
 38    3. Applies source-specific weights and boosts
 39    4. Re-ranks using RRF or weighted scoring
 40    
 41    Use cases:
 42    - Multi-domain search (HR + Finance + IT)
 43    - Federated search across multiple databases
 44    - Combining different data sources (documents + SQL + graph)
 45    
 46    Example:
 47        ```python
 48        # Create retrievers for different sources
 49        hr_retriever = VectorRetriever(hr_store, embedder)
 50        finance_retriever = VectorRetriever(finance_store, embedder)
 51        it_retriever = VectorRetriever(it_store, embedder)
 52        
 53        # Create multi-index retriever
 54        retriever = MultiIndexRetriever(
 55            sources=[
 56                SourceConfig("HR", hr_retriever, weight=1.0),
 57                SourceConfig("Finance", finance_retriever, weight=1.5),
 58                SourceConfig("IT", it_retriever, weight=1.0)
 59            ],
 60            fusion_strategy="rrf"
 61        )
 62        
 63        # Retrieve across all sources
 64        query = RetrievalQuery(text="employee benefits policy", top_k=10)
 65        results = retriever.retrieve(query)
 66        ```
 67    """
 68    
 69    def __init__(
 70        self,
 71        sources: List[SourceConfig],
 72        fusion_strategy: str = "rrf",
 73        rrf_k: int = 60,
 74        normalize_scores: bool = True
 75    ):
 76        """
 77        Initialize multi-index retriever.
 78        
 79        Args:
 80            sources: List of SourceConfig objects defining retrievers
 81            fusion_strategy: Strategy for merging results:
 82                - "rrf": Reciprocal Rank Fusion
 83                - "weighted_average": Weighted average of normalized scores
 84                - "max_score": Take maximum score across sources
 85            rrf_k: Constant for RRF formula (default 60)
 86            normalize_scores: Normalize scores to [0,1] before fusion
 87        """
 88        self.sources = sources
 89        self.fusion_strategy = fusion_strategy
 90        self.rrf_k = rrf_k
 91        self.normalize_scores = normalize_scores
 92        
 93        # Validate sources
 94        if not sources:
 95            raise ValueError("At least one source must be provided")
 96        
 97        # Normalize weights
 98        total_weight = sum(s.weight for s in sources if s.enabled)
 99        if total_weight > 0:
100            for source in sources:
101                source.weight = source.weight / total_weight
102    
103    def _normalize_scores_for_source(
104        self,
105        results: List[SearchResult]
106    ) -> List[SearchResult]:
107        """
108        Normalize scores to [0, 1] range.
109        
110        Args:
111            results: Search results
112            
113        Returns:
114            Results with normalized scores
115        """
116        if not results:
117            return results
118        
119        scores = [r.score for r in results]
120        min_score = min(scores)
121        max_score = max(scores)
122        
123        if max_score - min_score == 0:
124            # All scores are the same
125            for result in results:
126                result.score = 1.0
127        else:
128            for result in results:
129                result.score = (result.score - min_score) / (max_score - min_score)
130        
131        return results
132    
133    def _rrf_fusion(
134        self,
135        source_results: Dict[str, List[SearchResult]]
136    ) -> List[SearchResult]:
137        """
138        Merge results using Reciprocal Rank Fusion.
139        
140        Args:
141            source_results: Dict mapping source names to their results
142            
143        Returns:
144            Merged and re-ranked results
145        """
146        # Track all unique documents and their RRF scores
147        doc_rrf_scores: Dict[str, float] = {}
148        doc_objects: Dict[str, SearchResult] = {}
149        
150        # Calculate RRF scores
151        for source_name, results in source_results.items():
152            # Get source config
153            source_config = next((s for s in self.sources if s.name == source_name), None)
154            if not source_config or not source_config.enabled:
155                continue
156            
157            weight = source_config.weight
158            boost = source_config.boost_factor
159            
160            for result in results:
161                doc_id = result.document.id
162                
163                # RRF score: weight / (k + rank)
164                rrf_score = (weight * boost) / (self.rrf_k + result.rank)
165                
166                if doc_id not in doc_rrf_scores:
167                    doc_rrf_scores[doc_id] = rrf_score
168                    doc_objects[doc_id] = result
169                    # Add source metadata
170                    result.document.metadata["retrieval_source"] = source_name
171                else:
172                    doc_rrf_scores[doc_id] += rrf_score
173                    # If document appears in multiple sources, mark it
174                    existing_source = doc_objects[doc_id].document.metadata.get("retrieval_source", "")
175                    if source_name not in existing_source:
176                        doc_objects[doc_id].document.metadata["retrieval_source"] = f"{existing_source}, {source_name}"
177        
178        # Create final results
179        merged_results = []
180        for doc_id, rrf_score in doc_rrf_scores.items():
181            result = doc_objects[doc_id]
182            result.score = rrf_score
183            merged_results.append(result)
184        
185        # Sort by RRF score
186        merged_results.sort(key=lambda x: x.score, reverse=True)
187        
188        return merged_results
189    
190    def _weighted_average_fusion(
191        self,
192        source_results: Dict[str, List[SearchResult]]
193    ) -> List[SearchResult]:
194        """
195        Merge results using weighted average of scores.
196        
197        Args:
198            source_results: Dict mapping source names to their results
199            
200        Returns:
201            Merged and re-ranked results
202        """
203        # Track all unique documents and their weighted scores
204        doc_weighted_scores: Dict[str, float] = {}
205        doc_score_counts: Dict[str, int] = {}
206        doc_objects: Dict[str, SearchResult] = {}
207        
208        # Calculate weighted scores
209        for source_name, results in source_results.items():
210            # Get source config
211            source_config = next((s for s in self.sources if s.name == source_name), None)
212            if not source_config or not source_config.enabled:
213                continue
214            
215            # Normalize scores for this source
216            if self.normalize_scores:
217                results = self._normalize_scores_for_source(results)
218            
219            weight = source_config.weight
220            boost = source_config.boost_factor
221            
222            for result in results:
223                doc_id = result.document.id
224                weighted_score = result.score * weight * boost
225                
226                if doc_id not in doc_weighted_scores:
227                    doc_weighted_scores[doc_id] = weighted_score
228                    doc_score_counts[doc_id] = 1
229                    doc_objects[doc_id] = result
230                    result.document.metadata["retrieval_source"] = source_name
231                else:
232                    doc_weighted_scores[doc_id] += weighted_score
233                    doc_score_counts[doc_id] += 1
234                    existing_source = doc_objects[doc_id].document.metadata.get("retrieval_source", "")
235                    if source_name not in existing_source:
236                        doc_objects[doc_id].document.metadata["retrieval_source"] = f"{existing_source}, {source_name}"
237        
238        # Create final results with averaged scores
239        merged_results = []
240        for doc_id, total_score in doc_weighted_scores.items():
241            result = doc_objects[doc_id]
242            # Average the weighted scores
243            result.score = total_score / doc_score_counts[doc_id]
244            merged_results.append(result)
245        
246        # Sort by score
247        merged_results.sort(key=lambda x: x.score, reverse=True)
248        
249        return merged_results
250    
251    def _max_score_fusion(
252        self,
253        source_results: Dict[str, List[SearchResult]]
254    ) -> List[SearchResult]:
255        """
256        Merge results using maximum score across sources.
257        
258        Args:
259            source_results: Dict mapping source names to their results
260            
261        Returns:
262            Merged and re-ranked results
263        """
264        # Track all unique documents and their max scores
265        doc_max_scores: Dict[str, float] = {}
266        doc_objects: Dict[str, SearchResult] = {}
267        
268        # Find max scores
269        for source_name, results in source_results.items():
270            # Get source config
271            source_config = next((s for s in self.sources if s.name == source_name), None)
272            if not source_config or not source_config.enabled:
273                continue
274            
275            # Normalize scores for this source
276            if self.normalize_scores:
277                results = self._normalize_scores_for_source(results)
278            
279            boost = source_config.boost_factor
280            
281            for result in results:
282                doc_id = result.document.id
283                boosted_score = result.score * boost
284                
285                if doc_id not in doc_max_scores or boosted_score > doc_max_scores[doc_id]:
286                    doc_max_scores[doc_id] = boosted_score
287                    doc_objects[doc_id] = result
288                    result.document.metadata["retrieval_source"] = source_name
289                else:
290                    # Update source metadata if from multiple sources
291                    existing_source = doc_objects[doc_id].document.metadata.get("retrieval_source", "")
292                    if source_name not in existing_source:
293                        doc_objects[doc_id].document.metadata["retrieval_source"] = f"{existing_source}, {source_name}"
294        
295        # Create final results
296        merged_results = []
297        for doc_id, max_score in doc_max_scores.items():
298            result = doc_objects[doc_id]
299            result.score = max_score
300            merged_results.append(result)
301        
302        # Sort by score
303        merged_results.sort(key=lambda x: x.score, reverse=True)
304        
305        return merged_results
306    
307    def retrieve(
308        self,
309        query: RetrievalQuery,
310        **kwargs
311    ) -> List[SearchResult]:
312        """
313        Perform multi-source retrieval.
314        
315        Args:
316            query: The retrieval query
317            **kwargs: Additional arguments passed to each retriever
318            
319        Returns:
320            List of SearchResult objects merged from all sources
321        """
322        # Retrieve from each source
323        source_results: Dict[str, List[SearchResult]] = {}
324        
325        for source in self.sources:
326            if not source.enabled:
327                logger.info(f"Skipping disabled source: {source.name}")
328                continue
329            
330            try:
331                logger.info(f"Retrieving from source: {source.name}")
332                results = source.retriever.retrieve(query, **kwargs)
333                source_results[source.name] = results
334                logger.info(f"Source {source.name} returned {len(results)} results")
335                
336            except Exception as e:
337                logger.warning(f"Retrieval from source {source.name} failed: {e}")
338                source_results[source.name] = []
339        
340        # Check if we got any results
341        total_results = sum(len(results) for results in source_results.values())
342        if total_results == 0:
343            logger.warning("No results from any source")
344            return []
345        
346        # Merge results based on fusion strategy
347        logger.info(f"Merging results using {self.fusion_strategy} strategy")
348        
349        if self.fusion_strategy == "rrf":
350            merged_results = self._rrf_fusion(source_results)
351        elif self.fusion_strategy == "weighted_average":
352            merged_results = self._weighted_average_fusion(source_results)
353        elif self.fusion_strategy == "max_score":
354            merged_results = self._max_score_fusion(source_results)
355        else:
356            raise ValueError(f"Unknown fusion strategy: {self.fusion_strategy}")
357        
358        # Limit to top_k
359        final_results = merged_results[:query.top_k]
360        
361        # Update ranks
362        for i, result in enumerate(final_results):
363            result.rank = i + 1
364        
365        logger.info(
366            f"Multi-index retrieval complete: {len(final_results)} results "
367            f"from {len(source_results)} sources"
368        )
369        
370        return final_results

Retrieve from multiple indices/sources and merge results.

This retriever:

  1. Queries multiple independent retrievers in parallel (or sequentially)
  2. Merges results from all sources
  3. Applies source-specific weights and boosts
  4. Re-ranks using RRF or weighted scoring

Use cases:

  • Multi-domain search (HR + Finance + IT)
  • Federated search across multiple databases
  • Combining different data sources (documents + SQL + graph)

Example:

# Create retrievers for different sources
hr_retriever = VectorRetriever(hr_store, embedder)
finance_retriever = VectorRetriever(finance_store, embedder)
it_retriever = VectorRetriever(it_store, embedder)

# Create multi-index retriever
retriever = MultiIndexRetriever(
    sources=[
        SourceConfig("HR", hr_retriever, weight=1.0),
        SourceConfig("Finance", finance_retriever, weight=1.5),
        SourceConfig("IT", it_retriever, weight=1.0)
    ],
    fusion_strategy="rrf"
)

# Retrieve across all sources
query = RetrievalQuery(text="employee benefits policy", top_k=10)
results = retriever.retrieve(query)
MultiIndexRetriever( sources: List[SourceConfig], fusion_strategy: str = 'rrf', rrf_k: int = 60, normalize_scores: bool = True)
 69    def __init__(
 70        self,
 71        sources: List[SourceConfig],
 72        fusion_strategy: str = "rrf",
 73        rrf_k: int = 60,
 74        normalize_scores: bool = True
 75    ):
 76        """
 77        Initialize multi-index retriever.
 78        
 79        Args:
 80            sources: List of SourceConfig objects defining retrievers
 81            fusion_strategy: Strategy for merging results:
 82                - "rrf": Reciprocal Rank Fusion
 83                - "weighted_average": Weighted average of normalized scores
 84                - "max_score": Take maximum score across sources
 85            rrf_k: Constant for RRF formula (default 60)
 86            normalize_scores: Normalize scores to [0,1] before fusion
 87        """
 88        self.sources = sources
 89        self.fusion_strategy = fusion_strategy
 90        self.rrf_k = rrf_k
 91        self.normalize_scores = normalize_scores
 92        
 93        # Validate sources
 94        if not sources:
 95            raise ValueError("At least one source must be provided")
 96        
 97        # Normalize weights
 98        total_weight = sum(s.weight for s in sources if s.enabled)
 99        if total_weight > 0:
100            for source in sources:
101                source.weight = source.weight / total_weight

Initialize multi-index retriever.

Args: sources: List of SourceConfig objects defining retrievers fusion_strategy: Strategy for merging results: - "rrf": Reciprocal Rank Fusion - "weighted_average": Weighted average of normalized scores - "max_score": Take maximum score across sources rrf_k: Constant for RRF formula (default 60) normalize_scores: Normalize scores to [0,1] before fusion

sources
fusion_strategy
rrf_k
normalize_scores
def retrieve( self, query: RetrievalQuery, **kwargs) -> List[gmf_forge_ai_data.SearchResult]:
307    def retrieve(
308        self,
309        query: RetrievalQuery,
310        **kwargs
311    ) -> List[SearchResult]:
312        """
313        Perform multi-source retrieval.
314        
315        Args:
316            query: The retrieval query
317            **kwargs: Additional arguments passed to each retriever
318            
319        Returns:
320            List of SearchResult objects merged from all sources
321        """
322        # Retrieve from each source
323        source_results: Dict[str, List[SearchResult]] = {}
324        
325        for source in self.sources:
326            if not source.enabled:
327                logger.info(f"Skipping disabled source: {source.name}")
328                continue
329            
330            try:
331                logger.info(f"Retrieving from source: {source.name}")
332                results = source.retriever.retrieve(query, **kwargs)
333                source_results[source.name] = results
334                logger.info(f"Source {source.name} returned {len(results)} results")
335                
336            except Exception as e:
337                logger.warning(f"Retrieval from source {source.name} failed: {e}")
338                source_results[source.name] = []
339        
340        # Check if we got any results
341        total_results = sum(len(results) for results in source_results.values())
342        if total_results == 0:
343            logger.warning("No results from any source")
344            return []
345        
346        # Merge results based on fusion strategy
347        logger.info(f"Merging results using {self.fusion_strategy} strategy")
348        
349        if self.fusion_strategy == "rrf":
350            merged_results = self._rrf_fusion(source_results)
351        elif self.fusion_strategy == "weighted_average":
352            merged_results = self._weighted_average_fusion(source_results)
353        elif self.fusion_strategy == "max_score":
354            merged_results = self._max_score_fusion(source_results)
355        else:
356            raise ValueError(f"Unknown fusion strategy: {self.fusion_strategy}")
357        
358        # Limit to top_k
359        final_results = merged_results[:query.top_k]
360        
361        # Update ranks
362        for i, result in enumerate(final_results):
363            result.rank = i + 1
364        
365        logger.info(
366            f"Multi-index retrieval complete: {len(final_results)} results "
367            f"from {len(source_results)} sources"
368        )
369        
370        return final_results

Perform multi-source retrieval.

Args: query: The retrieval query **kwargs: Additional arguments passed to each retriever

Returns: List of SearchResult objects merged from all sources

@dataclass
class SQLSchema:
21@dataclass
22class SQLSchema:
23    """Represents a database schema."""
24    table_name: str
25    columns: List[Dict[str, str]]  # [{"name": "col1", "type": "int", "description": "..."}]
26    primary_key: Optional[str] = None
27    description: Optional[str] = None

Represents a database schema.

SQLSchema( table_name: str, columns: List[Dict[str, str]], primary_key: Optional[str] = None, description: Optional[str] = None)
table_name: str
columns: List[Dict[str, str]]
primary_key: Optional[str] = None
description: Optional[str] = None
@dataclass
class SQLQuery:
30@dataclass
31class SQLQuery:
32    """Represents a generated SQL query."""
33    sql: str
34    parameters: Dict[str, Any] = field(default_factory=dict)
35    explanation: Optional[str] = None

Represents a generated SQL query.

SQLQuery( sql: str, parameters: Dict[str, Any] = <factory>, explanation: Optional[str] = None)
sql: str
parameters: Dict[str, Any]
explanation: Optional[str] = None
@dataclass
class SourceConfig:
21@dataclass
22class SourceConfig:
23    """Configuration for a retrieval source."""
24    name: str
25    retriever: BaseRetriever
26    weight: float = 1.0
27    boost_factor: float = 1.0
28    enabled: bool = True

Configuration for a retrieval source.

SourceConfig( name: str, retriever: BaseRetriever, weight: float = 1.0, boost_factor: float = 1.0, enabled: bool = True)
name: str
retriever: BaseRetriever
weight: float = 1.0
boost_factor: float = 1.0
enabled: bool = True