azure_ai_search_vector_store_example.py
"""
Example: Using AzureAISearchVectorStore with custom document schemas.
This example demonstrates how to:
1. Define domain-specific documents by extending Document
2. Pass document_type to AzureAISearchVectorStore
3. Automatically get all fields indexed (universal + custom)
4. Filter on any field efficiently
All fields from your document class are automatically indexed in Azure Search.
Required .env variables:
- AZURE_AI_SEARCH_ENDPOINT: Your Azure Search service endpoint
- AZURE_AI_SEARCH_API_KEY: Your Azure Search API key
- SEARCH_SERVICE_NAME: Your Azure Search service name
Optional .env variables:
- AZURE_AI_SEARCH_EMBEDDING_DIMENSION: Embedding dimension (default: 1536)
"""
import os
import time
from dataclasses import dataclass
from datetime import datetime
from typing import Optional
from pathlib import Path
from dotenv import load_dotenv
from gmf_forge_ai_data.vector_stores import Document, AzureAISearchVectorStore
from gmf_forge_ai_data.indexing import AzureAISearchIndexBuilder
from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings
from gmf_forge_ai_data.chunkers import RecursiveChunker
from gmf_forge_ai_shared_core.observability import BasicLogger
from gmf_forge_ai_shared_core.observability.tracing import get_tracer
# Load environment variables
env_path = Path(__file__).parent / '.env'
load_dotenv(env_path)
logger = BasicLogger(__name__)
# Corporate SSL certificate path (for corporate networks with SSL inspection)
WORKSPACE_ROOT = Path(__file__).parent.parent.parent.parent
CORPORATE_CERT = WORKSPACE_ROOT / "certs" / "gmf_and_public_cas.pem"
# ============================================================================
# Configuration from .env
# ============================================================================
def get_azure_search_config():
"""Load Azure Search configuration from environment variables.
Authentication is selected by the ``AZURE_USE_MANAGED_IDENTITY`` env var:
* **API key** (default): set ``AZURE_AI_SEARCH_API_KEY``.
* **Managed identity**: set ``AZURE_USE_MANAGED_IDENTITY=true``.
The token provider must request the **Azure AI Search** scope::
https://search.azure.com/.default
Note: Azure OpenAI / Cognitive Services uses a *different* scope
(``https://cognitiveservices.azure.com/.default``) — each service
requires its own token provider.
"""
endpoint = os.getenv('AZURE_AI_SEARCH_ENDPOINT')
api_key = os.getenv('AZURE_AI_SEARCH_API_KEY')
service_name = os.getenv('SEARCH_SERVICE_NAME')
embedding_dim = int(os.getenv('AZURE_AI_SEARCH_EMBEDDING_DIMENSION', '1536'))
use_managed_identity = os.getenv('AZURE_USE_MANAGED_IDENTITY', '').lower() in ('1', 'true', 'yes')
if not endpoint or (not api_key and not use_managed_identity):
logger.warning(
"Azure Search credentials not found in .env file",
missing=["AZURE_AI_SEARCH_ENDPOINT",
"AZURE_AI_SEARCH_API_KEY (or set AZURE_USE_MANAGED_IDENTITY=true)"],
note="Using placeholder values for demonstration",
)
return {
'endpoint': 'https://your-search.search.windows.net',
'api_key': 'your-api-key',
'token_provider': None,
'service_name': 'your-search-service',
'embedding_dimension': embedding_dim,
'configured': False
}
token_provider = None
if use_managed_identity:
# Managed identity — scope: https://search.azure.com/.default
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
token_provider = get_bearer_token_provider(
DefaultAzureCredential(),
'https://search.azure.com/.default',
)
api_key = None
return {
'endpoint': endpoint,
'api_key': api_key,
'token_provider': token_provider,
'service_name': service_name or 'default-search',
'embedding_dimension': embedding_dim,
'configured': True
}
def get_ssl_cert_path() -> str:
"""Get SSL certificate path if it exists."""
if CORPORATE_CERT.exists():
return str(CORPORATE_CERT)
return None
def _ensure_index(config: dict, index_name: str, document_type) -> None:
"""Provision the Azure AI Search index using AzureAISearchIndexBuilder."""
AzureAISearchIndexBuilder(
endpoint=config['endpoint'],
api_key=config.get('api_key'),
token_provider=config.get('token_provider'),
index_name=index_name,
embedding_dimension=config['embedding_dimension'],
document_type=document_type,
ssl_cert_path=get_ssl_cert_path(),
).create_index()
def get_embedder():
"""Initialize Azure OpenAI embeddings from environment variables.
Authentication is selected by the ``AZURE_USE_MANAGED_IDENTITY`` env var:
* **API key** (default): set ``AZURE_OPENAI_API_KEY``.
* **Managed identity**: set ``AZURE_USE_MANAGED_IDENTITY=true``.
The token provider must request the **Cognitive Services** scope::
https://cognitiveservices.azure.com/.default
Note: Azure AI Search uses a *different* scope
(``https://search.azure.com/.default``) — each service requires
its own token provider.
"""
endpoint = os.getenv('AZURE_OPENAI_ENDPOINT')
api_key = os.getenv('AZURE_OPENAI_API_KEY')
deployment = os.getenv('AZURE_OPENAI_EMBEDDING_MODEL')
api_version = os.getenv('AZURE_OPENAI_EMBEDDING_MODEL_VERSION', '2024-02-01')
use_managed_identity = os.getenv('AZURE_USE_MANAGED_IDENTITY', '').lower() in ('1', 'true', 'yes')
if not endpoint or not deployment or (not api_key and not use_managed_identity):
logger.warning("Azure OpenAI credentials not found in .env file",
missing=["AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_EMBEDDING_MODEL",
"AZURE_OPENAI_API_KEY (or set AZURE_USE_MANAGED_IDENTITY=true)"])
return None
ssl_cert = get_ssl_cert_path()
if ssl_cert:
logger.info("Using SSL certificate", cert=CORPORATE_CERT.name)
try:
if use_managed_identity:
# Managed identity — scope: https://cognitiveservices.azure.com/.default
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
token_provider = get_bearer_token_provider(
DefaultAzureCredential(),
'https://cognitiveservices.azure.com/.default',
)
embedder = AzureOpenAIEmbeddings(
endpoint=endpoint,
deployment_name=deployment,
token_provider=token_provider,
api_version=api_version,
ssl_cert_path=ssl_cert,
)
else:
embedder = AzureOpenAIEmbeddings(
endpoint=endpoint,
deployment_name=deployment,
api_key=api_key,
api_version=api_version,
ssl_cert_path=ssl_cert,
)
logger.info("Initialized embeddings", deployment=deployment)
return embedder
except Exception as e:
logger.warning("Failed to initialize embedder", error=str(e))
return None
# ============================================================================
# Example 1: Legal Document with Custom Fields
# ============================================================================
@dataclass
class LegalDocument(Document):
"""Legal document with case-specific fields."""
case_number: str = ""
court: str = ""
jurisdiction: str = ""
decision_date: Optional[datetime] = None
case_type: str = ""
source: str = ""
page_number: Optional[int] = None
def example_legal_documents():
"""Legal document vector store with automatic field indexing."""
logger.info("Starting example", example="Legal Documents")
# Load config from .env
config = get_azure_search_config()
if not config['configured']:
logger.info("Skipping Azure Search operations - no credentials", note="Showing schema and usage examples only")
return
# Initialize embedder
embedder = get_embedder()
if not embedder:
logger.info("Skipping document indexing - no embedder available")
return
# Create store with custom document type
# ALL fields from LegalDocument are automatically indexed
store = AzureAISearchVectorStore(
endpoint=config['endpoint'],
index_name="legal_documents",
api_key=config.get('api_key'),
token_provider=config.get('token_provider'),
embedding_dimension=config['embedding_dimension'],
document_type=LegalDocument, # <-- Pass your document type
)
_ensure_index(config, "legal_documents", LegalDocument)
logger.info("Azure Search index created",
universal_fields=["id", "content", "embedding", "timestamp"],
custom_fields=["case_number", "court", "jurisdiction", "decision_date", "case_type", "source", "page_number"],
note="All fields are indexed and filterable",
)
# Create sample legal documents
docs = [
LegalDocument(
id="case_001",
content="In the matter of Smith v. Jones, the court finds...",
case_number="2024-CV-12345",
court="Supreme Court",
jurisdiction="Federal",
decision_date=datetime(2024, 1, 15),
case_type="civil",
source="court_records.pdf",
page_number=42,
timestamp=datetime.now()
),
LegalDocument(
id="case_002",
content="United States v. Johnson, regarding antitrust violations...",
case_number="2024-CR-67890",
court="District Court",
jurisdiction="Federal",
decision_date=datetime(2024, 2, 20),
case_type="criminal",
source="court_records.pdf",
page_number=105,
timestamp=datetime.now()
),
LegalDocument(
id="case_003",
content="State of California v. Tech Corp, patent infringement...",
case_number="2024-CV-11111",
court="State Supreme Court",
jurisdiction="State",
decision_date=datetime(2024, 3, 10),
case_type="civil",
source="state_records.pdf",
page_number=1,
timestamp=datetime.now()
)
]
# Generate embeddings for all documents
logger.info("Generating embeddings")
for doc in docs:
doc.embedding = embedder.embed_text(doc.content)
logger.info("Generated embeddings", count=len(docs))
# Add documents — embeddings already generated above
store.add_documents(docs)
logger.info("Added documents to index", count=len(docs))
# ── Vector search with filter ─────────────────────────────────────────
query = "patent infringement technology"
logger.info("Vector search with filter", query=query, filter="jurisdiction=Federal")
query_embedding = embedder.embed_text(query)
results = store.search(
query_embedding=query_embedding,
top_k=3,
filters={"jurisdiction": "Federal"},
search_type="vector",
)
for r in results:
doc: LegalDocument = r.document
logger.info("Result", rank=r.rank, id=doc.id, score=round(r.score, 4),
case_number=doc.case_number, court=doc.court, jurisdiction=doc.jurisdiction)
# ── get_document ──────────────────────────────────────────────────────
fetched = store.get_document("case_001")
if fetched:
logger.info("get_document", id=fetched.id, case_number=fetched.case_number,
court=fetched.court, case_type=fetched.case_type)
logger.info("Supported filter syntax",
equality="{'jurisdiction': 'Federal'}",
date_range="{'decision_date': ('>=', datetime(2024, 1, 1))}",
multiple="{'jurisdiction': 'Federal', 'case_type': 'civil'}")
# ============================================================================
# Example 2: Product Catalog
# ============================================================================
@dataclass
class ProductDocument(Document):
"""Product document for e-commerce."""
sku: str = ""
category: str = ""
price: float = 0.0
in_stock: bool = True
brand: str = ""
rating: float = 0.0
review_count: int = 0
def example_product_catalog():
"""Product catalog with automatic field indexing."""
logger.info("Starting example", example="Product Catalog")
# Load config from .env
config = get_azure_search_config()
if not config['configured']:
return
# Initialize embedder
embedder = get_embedder()
if not embedder:
return
# Create store - all ProductDocument fields automatically indexed
store = AzureAISearchVectorStore(
endpoint=config['endpoint'],
index_name="products",
api_key=config.get('api_key'),
token_provider=config.get('token_provider'),
embedding_dimension=config['embedding_dimension'],
document_type=ProductDocument, # <-- Your custom type
)
_ensure_index(config, "products", ProductDocument)
logger.info("Azure Search index created",
universal_fields=["id", "content", "embedding", "timestamp"],
custom_fields=["sku", "category", "price", "in_stock", "brand", "rating", "review_count"],
)
# Sample products
products = [
ProductDocument(
id="prod_001",
content="Premium wireless headphones with noise cancellation...",
sku="WH-1000XM5",
category="Electronics",
price=399.99,
in_stock=True,
brand="Sony",
rating=4.8,
review_count=1250,
timestamp=datetime.now()
),
ProductDocument(
id="prod_002",
content="Professional DSLR camera with 24MP sensor...",
sku="EOS-R5",
category="Electronics",
price=3899.00,
in_stock=False,
brand="Canon",
rating=4.9,
review_count=890,
timestamp=datetime.now()
)
]
# Generate embeddings for products
logger.info("Generating embeddings")
for product in products:
product.embedding = embedder.embed_text(product.content)
logger.info("Generated embeddings", count=len(products))
# Add documents — embeddings already generated above
store.add_documents(products)
logger.info("Added products to index", count=len(products))
# ── Keyword search ────────────────────────────────────────────────────
query = "wireless headphones noise cancellation"
logger.info("Keyword search", query=query, search_type="keyword")
results = store.search(query=query, top_k=3, search_type="keyword")
for r in results:
prod: ProductDocument = r.document
logger.info("Result", rank=r.rank, id=prod.id, score=round(r.score, 4),
sku=prod.sku, brand=prod.brand, price=prod.price)
# ── count ─────────────────────────────────────────────────────────────
logger.info("Document count", total=store.count())
logger.info("Supported filter syntax",
numeric_range="{'price': ('<=', 500)}",
boolean="{'in_stock': True}",
combined="{'category': 'Electronics', 'in_stock': True, 'rating': ('>=', 4.5)}")
# ============================================================================
# Example 3: Financial Documents
# ============================================================================
@dataclass
class FinancialDocument(Document):
"""Financial report document."""
document_type: str = "" # 10-K, 10-Q, 8-K, etc.
fiscal_year: int = 0
quarter: str = "" # Q1, Q2, Q3, Q4, FY
company_ticker: str = ""
sector: str = ""
report_date: Optional[datetime] = None
source: str = ""
page_number: Optional[int] = None
def example_financial_documents():
"""Financial documents with automatic field indexing."""
logger.info("Starting example", example="Financial Documents")
# Load config from .env
config = get_azure_search_config()
if not config['configured']:
return
# Initialize embedder
embedder = get_embedder()
if not embedder:
return
store = AzureAISearchVectorStore(
endpoint=config['endpoint'],
index_name="financial_reports",
api_key=config.get('api_key'),
token_provider=config.get('token_provider'),
embedding_dimension=config['embedding_dimension'],
document_type=FinancialDocument,
)
_ensure_index(config, "financial_reports", FinancialDocument)
logger.info("Azure Search index created",
universal_fields=["id", "content", "embedding", "timestamp"],
custom_fields=["document_type", "fiscal_year", "quarter", "company_ticker", "sector", "report_date", "source", "page_number"],
note="All fields are indexed and filterable",
)
# Create sample financial documents
fin_docs = [
FinancialDocument(
id="fin_001",
content="Microsoft Q1 2024 earnings report shows strong cloud growth...",
document_type="10-Q",
fiscal_year=2024,
quarter="Q1",
company_ticker="MSFT",
sector="Technology",
report_date=datetime(2024, 1, 30),
source="sec_filings.pdf",
page_number=1,
timestamp=datetime.now()
),
FinancialDocument(
id="fin_002",
content="Apple annual 10-K filing for fiscal year 2024...",
document_type="10-K",
fiscal_year=2024,
quarter="FY",
company_ticker="AAPL",
sector="Technology",
report_date=datetime(2024, 10, 28),
source="sec_filings.pdf",
page_number=15,
timestamp=datetime.now()
)
]
# Generate embeddings for financial documents
logger.info("Generating embeddings")
for doc in fin_docs:
doc.embedding = embedder.embed_text(doc.content)
logger.info("Generated embeddings", count=len(fin_docs))
# Add documents — embeddings already generated above
store.add_documents(fin_docs)
logger.info("Added financial documents to index", count=len(fin_docs))
time.sleep(2) # allow Azure Search to index before searching
# ── Hybrid search with filter ─────────────────────────────────────────
query = "quarterly earnings cloud growth"
logger.info("Hybrid search with filter", query=query, filter="sector=Technology",
search_type="hybrid")
query_embedding = embedder.embed_text(query)
results = store.search(
query=query,
query_embedding=query_embedding,
top_k=3,
filters={"sector": "Technology"},
search_type="hybrid",
)
for r in results:
doc: FinancialDocument = r.document
logger.info("Result", rank=r.rank, id=doc.id, score=round(r.score, 4),
document_type=doc.document_type, company_ticker=doc.company_ticker)
# ── CRUD operations ───────────────────────────────────────────────────
logger.info("CRUD demonstration")
fetched = store.get_document("fin_001")
if fetched:
logger.info("get_document", id=fetched.id, company_ticker=fetched.company_ticker,
quarter=fetched.quarter, sector=fetched.sector)
updated = store.update_document("fin_001", metadata={"reviewed": True, "reviewer": "compliance_team"})
logger.info("update_document", id="fin_001", success=updated)
logger.info("count", total=store.count())
deleted = store.delete_documents(["fin_002"])
time.sleep(2) # Azure Search index updates are eventually consistent
logger.info("delete_documents", deleted=deleted, remaining=store.count())
# ============================================================================
# Document type for AI/ML Knowledge Base
# ============================================================================
@dataclass
class AIMLDocument(Document):
"""AI/ML knowledge base document."""
topic: str = "" # Specific topic like "deep_learning", "CNNs", "transformers"
category: str = "" # Broader category like "ML", "NLP", "CV"
# ============================================================================
# Example 4: AI/ML Knowledge Base (for retrieval examples)
# ============================================================================
def example_ai_ml_knowledge_base():
"""Create AI/ML knowledge base for retrieval strategy demonstrations."""
logger.info("Starting example", example="AI/ML Knowledge Base")
# Load config from .env
config = get_azure_search_config()
if not config['configured']:
return
# Initialize embedder
embedder = get_embedder()
if not embedder:
return
store = AzureAISearchVectorStore(
endpoint=config['endpoint'],
index_name="ai_ml_knowledge",
api_key=config.get('api_key'),
token_provider=config.get('token_provider'),
embedding_dimension=config['embedding_dimension'],
document_type=AIMLDocument,
)
_ensure_index(config, "ai_ml_knowledge", AIMLDocument)
logger.info("Azure Search index created",
universal_fields=["id", "content", "embedding", "timestamp"],
custom_fields=["topic", "category"],
)
# Create AI/ML knowledge base documents.
# Each document is a multi-sentence paragraph so that context processing
# examples (reranking, compression, windowing) have realistic content to
# work with. Sentences within a chunk intentionally vary in relevance to
# the shared query "How do neural networks learn and what role does
# backpropagation play?" — this lets ContextCompressor demonstrate
# meaningful sentence-level extraction and token reduction.
documents_data = [
{
"id": "doc_001",
"content": (
"Machine learning is a subset of artificial intelligence that enables systems "
"to learn and improve automatically from experience without being explicitly programmed. "
"Rather than relying on hand-crafted rules, ML models generalise patterns from training "
"data to make predictions on new, unseen inputs. "
"Supervised learning trains models on labelled examples by minimising a loss function "
"that measures the gap between predictions and ground truth. "
"This optimisation process relies on computing gradients — the direction and magnitude "
"in which to adjust model parameters — making calculus a foundational tool in machine learning. "
"Once training is complete, the model's parameters remain fixed and predictions are made "
"through a single forward pass."
),
"category": "ML",
"topic": "introduction"
},
{
"id": "doc_002",
"content": (
"Deep learning uses neural networks with multiple stacked layers to learn hierarchical "
"representations of data directly from raw inputs. "
"A neural network learns by adjusting its weights through backpropagation, an algorithm "
"that computes the gradient of the loss with respect to every weight using the chain rule "
"of calculus. "
"During forward propagation, each layer transforms its input through a weighted sum "
"followed by a non-linear activation function, producing the final prediction. "
"Backward propagation reverses this process: the loss gradient flows from the output layer "
"back through each layer, enabling the network to attribute responsibility for the "
"prediction error to each individual weight. "
"Gradient descent then uses these gradients to nudge every weight in the direction that "
"reduces the overall loss, and this cycle repeats over many mini-batches until convergence."
),
"category": "ML",
"topic": "deep_learning"
},
{
"id": "doc_003",
"content": (
"Natural language processing (NLP) allows computers to understand, interpret, and generate "
"human language in a meaningful way. "
"It combines rule-based approaches with statistical and neural techniques to tackle tasks "
"such as text classification, machine translation, and question answering. "
"Words are typically represented as dense vectors — word embeddings — that encode semantic "
"relationships so that similar words appear close to each other in the vector space. "
"Neural language models are pre-trained on massive text corpora and learn to predict masked "
"or next tokens, developing rich internal representations of grammar, facts, and reasoning. "
"Fine-tuning adapts these pre-trained weights to specific downstream tasks using small "
"task-specific labelled datasets."
),
"category": "NLP",
"topic": "introduction"
},
{
"id": "doc_004",
"content": (
"The transformer architecture replaced recurrent networks as the dominant model for NLP "
"by processing entire sequences in parallel using self-attention. "
"Self-attention computes a weighted combination of all token representations simultaneously, "
"allowing each position to directly attend to any other without the vanishing-gradient "
"bottleneck of sequential processing. "
"Unlike recurrent networks that propagate gradients through time steps, transformers propagate "
"gradients through attention and feed-forward sub-layers, making training of very deep models "
"more stable and allowing better gradient flow during backpropagation. "
"Pre-trained models such as BERT, GPT, and T5 distil knowledge from billions of text tokens "
"into their weights through iterative gradient-based updates. "
"Positional encodings inject sequence-order information into the attention computation, "
"since the self-attention mechanism itself is inherently order-invariant."
),
"category": "NLP",
"topic": "transformers"
},
{
"id": "doc_005",
"content": (
"Optimising a machine learning model means finding the weight values that minimise a chosen "
"loss function on the training data. "
"Gradient descent is the workhorse algorithm: it iteratively moves the parameters in the "
"direction of the negative gradient, taking small steps controlled by a learning rate. "
"In neural networks, these gradients are efficiently calculated end-to-end by backpropagation, "
"which applies the chain rule to propagate error signals layer by layer from the output "
"back to the earliest weights. "
"Adaptive optimisers such as Adam and RMSProp maintain per-parameter learning rates that "
"adjust based on the history of gradient magnitudes, often converging faster than vanilla SGD. "
"Learning rate schedules including warm-up phases and cosine annealing further stabilise "
"training of large models."
),
"category": "ML",
"topic": "optimisation"
},
{
"id": "doc_006",
"content": (
"Computer vision is the field of AI that enables machines to extract meaning from images, "
"videos, and other visual inputs. "
"Convolutional neural networks (CNNs) became the standard for vision tasks after "
"demonstrating breakthrough performance on the ImageNet benchmark in 2012. "
"A CNN stacks convolutional layers — which detect local spatial patterns using learnable "
"filters — with pooling layers that progressively reduce spatial resolution. "
"CNNs are trained end-to-end with backpropagation: gradients flow from the classification "
"loss back through pooling and convolutional layers, updating the filter weights so the "
"network learns to recognise task-relevant visual features. "
"Modern architectures like ResNet use residual connections to allow gradients to bypass "
"individual layers, enabling stable training of networks hundreds of layers deep."
),
"category": "CV",
"topic": "CNNs"
},
{
"id": "doc_007",
"content": (
"Activation functions introduce the non-linearities that allow neural networks to model "
"complex, high-dimensional relationships beyond what a linear model can capture. "
"Early networks used sigmoid and tanh activations, but both suffer from vanishing gradients: "
"when inputs fall in the saturated tails the derivative approaches zero and backpropagation "
"cannot effectively update the earlier layers of the network. "
"The Rectified Linear Unit (ReLU) addresses this by having a constant gradient of 1 for "
"positive inputs, greatly accelerating training in deep architectures. "
"Variants such as Leaky ReLU, GELU, and Swish make subtle adjustments to improve gradient "
"flow in specific model families. "
"Choosing an appropriate activation function is therefore a practical consideration for "
"ensuring that backpropagation can reliably train all layers of a deep network."
),
"category": "ML",
"topic": "activation_functions"
},
{
"id": "doc_008",
"content": (
"Reinforcement learning (RL) trains an agent to take actions in an environment by "
"maximising a cumulative reward signal over time, rather than learning from labelled "
"input–output pairs. "
"The agent interacts episodically with the environment, observing states, selecting actions "
"according to its current policy, and receiving scalar reward or penalty feedback. "
"Deep reinforcement learning uses neural networks as function approximators for the policy "
"or value function; these networks are updated through gradient-based methods derived from "
"the Bellman equation, with backpropagation computing the necessary weight gradients. "
"Stabilisation techniques such as experience replay and target networks are critical because "
"RL training signals are non-stationary and temporally correlated, making naive gradient "
"updates unstable. "
"Prominent algorithms including DQN, PPO, and SAC have achieved superhuman performance on "
"games and continuous control benchmarks using these deep RL principles."
),
"category": "ML",
"topic": "reinforcement"
},
]
logger.info("Creating AI/ML knowledge base documents")
documents = []
for doc_data in documents_data:
doc = AIMLDocument(
id=doc_data["id"],
content=doc_data["content"],
category=doc_data["category"],
topic=doc_data["topic"],
timestamp=datetime.now()
)
documents.append(doc)
# Generate embeddings
logger.info("Generating embeddings")
for doc in documents:
doc.embedding = embedder.embed_text(doc.content)
logger.info("Generated embeddings", count=len(documents))
# Add documents — embeddings already generated above
store.add_documents(documents)
logger.info("Added AI/ML documents to index", count=len(documents))
logger.info("Knowledge base created",
ml_docs=5, ml_topics=["introduction", "deep_learning", "optimisation", "activation_functions", "reinforcement"],
nlp_docs=2, nlp_topics=["introduction", "transformers"],
cv_docs=1, cv_topics=["CNNs"],
note="Each document is a multi-sentence paragraph with mixed relevance to neural network learning and backpropagation",
)
logger.info("Index used by downstream examples",
retrieval_example="Vector, Keyword, Hybrid, MMR, Ensemble, Parent Document, Hierarchical, Graph retrievers",
query_example="QueryExpansion, HyDE, SubQuestion, StepBack",
context_example="RelevanceFilter, ContextDeduplicator, ContextReranker, ContextCompressor, ContextWindowManager",
)
# ============================================================================
# Document type for the RAG Pipeline example
# ============================================================================
@dataclass
class KnowledgeBaseArticle(Document):
"""Knowledge base article with domain classification and versioning fields."""
domain: str = "" # e.g. "ai", "security", "infrastructure"
language: str = "en" # ISO 639-1 language code
version: str = "1.0" # document version for change tracking
# ============================================================================
# Example 5: End-to-End RAG Pipeline (Chunk → Embed → Index → Retrieve)
# ============================================================================
def example_rag_pipeline():
"""Demonstrate a complete RAG ingestion pipeline: chunk → embed → index → retrieve."""
logger.info("Starting example", example="End-to-End RAG Pipeline")
logger.info("Scenario",
description="Ingest multiple knowledge base articles through the full RAG pipeline",
pipeline="articles → RecursiveChunker → embed_batch → add_documents(KnowledgeBaseArticle) → vector_search",
note="Each article carries its own domain, language, and version — enabling filtered retrieval",
)
config = get_azure_search_config()
if not config['configured']:
return
embedder = get_embedder()
if not embedder:
return
store = AzureAISearchVectorStore(
endpoint=config['endpoint'],
index_name="rag_pipeline_demo",
api_key=config.get('api_key'),
token_provider=config.get('token_provider'),
embedding_dimension=config['embedding_dimension'],
document_type=KnowledgeBaseArticle,
)
_ensure_index(config, "rag_pipeline_demo", KnowledgeBaseArticle)
logger.info("Azure Search index schema",
document_type="KnowledgeBaseArticle",
custom_fields=["domain", "language", "version"],
note="Each chunk carries these fields — enables filtering by domain, language, or version")
# ── Source articles (each with its own domain, language, and version) ───────
articles = [
{
"domain": "ai",
"language": "en",
"version": "2.0",
"text": (
"Retrieval-augmented generation (RAG) is a technique that enhances large language "
"model responses by grounding them in factual, up-to-date information retrieved "
"from an external knowledge base. Rather than relying solely on the information "
"encoded in model weights during training, a RAG system retrieves relevant passages "
"at inference time and supplies them as context to the model, dramatically reducing "
"hallucinations and enabling answers to reflect recent events. The ingestion pipeline "
"is the foundation of every RAG system: raw documents are loaded by a connector, "
"split into overlapping chunks by a chunker, converted into dense embedding vectors, "
"and stored in a vector database alongside the original text and metadata."
),
},
{
"domain": "security",
"language": "en",
"version": "1.3",
"text": (
"Zero-trust architecture is a security model that assumes no implicit trust for any "
"user, device, or network segment. Every access request is verified, validated, and "
"encrypted before being granted, regardless of whether the request originates from "
"inside or outside the corporate network. Micro-segmentation divides the network into "
"small zones, each requiring separate authentication. Continuous monitoring of user "
"behaviour and device posture helps detect anomalies and respond to potential breaches "
"in real time. Implementing zero-trust requires identity-aware proxies, multi-factor "
"authentication, and robust logging across all access points."
),
},
{
"domain": "infrastructure",
"language": "en",
"version": "1.0",
"text": (
"Infrastructure as Code (IaC) lets teams manage cloud resources through declarative "
"configuration files instead of manual provisioning. Tools such as Terraform, Bicep, "
"and Pulumi enable version-controlled, repeatable deployments across development, staging, "
"and production environments. Drift detection compares the desired state in code against "
"the actual state in the cloud provider, flagging any manual changes that may have been "
"introduced. Modular IaC patterns encourage reusable modules that encapsulate best "
"practices for networking, compute, and storage, reducing duplication and enforcing "
"organisational standards at scale."
),
},
]
logger.info("Source articles", count=len(articles),
domains=[a["domain"] for a in articles])
query = "How does RAG reduce hallucinations?"
tracer = get_tracer()
with tracer.trace("rag_pipeline", input=query) as trace:
with trace.span("chunking") as span:
# ── Step 1: Chunk ──────────────────────────────────────────────────────
logger.info("Step 1 - Chunking",
strategy="RecursiveChunker", chunk_size=400, chunk_overlap=50)
chunker = RecursiveChunker(chunk_size=400, chunk_overlap=50)
all_chunks_per_article: list = []
for article in articles:
chunks = chunker.chunk(
article["text"],
metadata={"source": f"kb_{article['domain']}", "domain": article["domain"]},
)
logger.info("Chunked article",
domain=article["domain"], num_chunks=len(chunks),
avg_chars=sum(len(c.text) for c in chunks) // max(len(chunks), 1))
all_chunks_per_article.append((article, chunks))
span.set_output({"documents": len(articles), "total_chunks": sum(len(c) for _, c in all_chunks_per_article)})
# ── Step 2: Embed ──────────────────────────────────────────────────────────
with trace.span("embedding") as span:
all_docs: list = []
for article, chunks in all_chunks_per_article:
texts = [c.text for c in chunks]
logger.info("Step 2 - Embedding", domain=article["domain"], chunks=len(texts))
embeddings = embedder.embed_batch(texts)
# ── Build typed documents ──────────────────────────────────────────
for i, (chunk, emb) in enumerate(zip(chunks, embeddings)):
all_docs.append(
KnowledgeBaseArticle(
id=f"{article['domain']}_{chunk.chunk_id}",
content=chunk.text,
embedding=emb,
domain=article["domain"],
language=article["language"],
version=article["version"],
metadata={**chunk.metadata, "chunk_index": i,
"char_start": chunk.start_pos, "char_end": chunk.end_pos},
)
)
span.set_output({"chunks": len(all_docs)})
# ── Step 3: Index ──────────────────────────────────────────────────────────
with trace.span("indexing") as span:
logger.info("Step 3 - Indexing", index="rag_pipeline_demo",
document_type="KnowledgeBaseArticle", total_chunks=len(all_docs))
store.add_documents(all_docs, generate_embeddings=False)
logger.info("Chunks indexed", count=len(all_docs))
span.set_output({"chunks_indexed": len(all_docs)})
# ── Step 4: Retrieve ───────────────────────────────────────────────────────
with trace.span("retrieval", input=query) as span:
logger.info("Step 4 - Retrieval", query=query, top_k=3)
query_embedding = embedder.embed_text(query)
results = store.search(query_embedding=query_embedding, top_k=3)
logger.info("Retrieved chunks", returned=len(results))
for r in results:
article: KnowledgeBaseArticle = r.document
snippet = article.content[:100].replace("\n", " ")
logger.info("Retrieved chunk",
rank=r.rank, id=article.id, score=round(r.score, 4),
domain=article.domain, language=article.language, version=article.version,
snippet=snippet)
span.set_output({"results": len(results), "top_score": round(results[0].score, 4) if results else 0})
trace.set_output({"chunks_indexed": len(all_docs), "query_results": len(results)})
logger.info("RAG pipeline complete",
chunks_indexed=len(all_docs),
query_results=len(results),
tip="Filter by domain='security' or version='2.0' to narrow results at query time")
# ============================================================================
# Key Benefits Summary
# ============================================================================
def print_benefits():
"""Log key benefits of this approach."""
logger.info("Key Benefits: Simple API",
detail="Pass document_type parameter - no need to manually define search fields")
logger.info("Key Benefits: Automatic Indexing",
detail="All document fields automatically indexed with appropriate data types")
logger.info("Key Benefits: Complete Filtering",
detail="Filter on ANY field - string equality, numeric ranges, date ranges, boolean, multiple conditions")
logger.info("Key Benefits: Type Safety",
detail="Documents returned as custom type with IntelliSense support")
logger.info("Key Benefits: Production Ready",
detail="Azure Search scalability, HNSW vector search, hybrid search (vector + BM25)")
# ============================================================================
# Main
# ============================================================================
if __name__ == "__main__":
logger.info("Custom Azure Vector Store Examples",
note="Using AzureAISearchVectorStore with custom document schemas - all fields automatically indexed")
example_legal_documents()
logger.info("")
example_product_catalog()
logger.info("")
example_financial_documents()
logger.info("")
example_ai_ml_knowledge_base()
logger.info("")
example_rag_pipeline()
logger.info("")
print_benefits()
logger.info("Usage in production",
step_1="Define your domain-specific Document subclass",
step_2="Pass document_type to AzureAISearchVectorStore",
step_3="All fields automatically indexed",
step_4="Filter on any field efficiently",
)
azure_devops_wiki_example.py
"""
Example: Azure DevOps Wiki Connector + WikiPageChunker
Demonstrates loading wiki pages from an Azure DevOps project wiki and chunking
them with WikiPageChunker.chunk_markdown(), followed by the full ingest pipeline:
load → chunk → embed → index → retrieve
Examples:
1. AzureDevOpsWikiConnector — Load pages from an ADO project wiki
2. WikiPageChunker — Chunk pages at Markdown heading boundaries
3. Full Pipeline — load → chunk → embed → index → retrieve
How this connects to existing examples:
- Uses the same embedder, Azure Search config, and helper patterns
as connectors_example.py, retrieval_example.py, etc.
- Introduces AzureDevOpsWikiConnector (PAT auth, REST API, Markdown content)
and WikiPageChunker.chunk_markdown() (heading-aware split for # headings)
Flow demonstrated:
AzureDevOpsWikiConnector.load() (fetches live pages from ADO Wiki)
→ WikiPageChunker.chunk_markdown() (splits at # / ## / ### headings)
→ AzureOpenAIEmbeddings (embed each chunk)
→ AzureAISearchVectorStore (index for retrieval)
→ VectorRetriever.retrieve()
Prerequisites:
- ADO credentials in .env:
AZURE_DEVOPS_ORG, AZURE_DEVOPS_PROJECT, AZURE_DEVOPS_PAT
- For the full pipeline, Azure credentials:
AZURE_OPENAI_ENDPOINT, AZURE_OPENAI_API_KEY, AZURE_OPENAI_EMBEDDING_MODEL,
AZURE_AI_SEARCH_ENDPOINT, AZURE_AI_SEARCH_API_KEY
HOW TO CREATE A PERSONAL ACCESS TOKEN (PAT)
───────────────────────────────────────────
1. Sign in to https://dev.azure.com/{org}
2. Click your profile avatar (top right) → Personal access tokens
3. Click "New token"
4. Scopes: select "Wiki" → Read
5. Copy the token and set it as AZURE_DEVOPS_PAT in .env
"""
import os
import sys
from pathlib import Path
from typing import List, Optional
from dotenv import load_dotenv
from gmf_forge_ai_shared_core.observability import BasicLogger
from gmf_forge_ai_shared_core.observability.tracing import get_tracer
from gmf_forge_ai_data.connectors import AzureDevOpsWikiConnector
from gmf_forge_ai_data.chunkers import WikiPageChunker
from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings
from gmf_forge_ai_data.vector_stores import AzureAISearchVectorStore, Document, SearchResult
from gmf_forge_ai_data.indexing import AzureAISearchIndexBuilder
from gmf_forge_ai_data.retrieval import RetrievalQuery, VectorRetriever
# ── Environment ───────────────────────────────────────────────────────────────
env_path = Path(__file__).parent / ".env"
load_dotenv(env_path)
WORKSPACE_ROOT = Path(__file__).parent.parent.parent.parent
CORPORATE_CERT = WORKSPACE_ROOT / "certs" / "gmf_and_public_cas.pem"
logger = BasicLogger(__name__)
# ── Config helpers (same pattern as all other examples) ──────────────────────
def get_ssl_cert_path() -> Optional[str]:
return str(CORPORATE_CERT) if CORPORATE_CERT.exists() else None
def get_embedder() -> Optional[AzureOpenAIEmbeddings]:
endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
api_key = os.getenv("AZURE_OPENAI_API_KEY")
deployment = os.getenv("AZURE_OPENAI_EMBEDDING_MODEL")
api_version = os.getenv("AZURE_OPENAI_EMBEDDING_MODEL_VERSION", "2024-02-01")
if not endpoint or not api_key or not deployment:
logger.warning("Azure OpenAI embedding credentials not found in .env")
return None
ssl_cert = get_ssl_cert_path()
if ssl_cert:
logger.info("Using SSL certificate", cert=CORPORATE_CERT.name)
embedder = AzureOpenAIEmbeddings(
endpoint=endpoint,
api_key=api_key,
deployment_name=deployment,
api_version=api_version,
ssl_cert_path=ssl_cert,
)
logger.info("Initialized embeddings", deployment=deployment)
return embedder
def get_vector_store(index_name: str) -> Optional[AzureAISearchVectorStore]:
endpoint = os.getenv("AZURE_AI_SEARCH_ENDPOINT")
api_key = os.getenv("AZURE_AI_SEARCH_API_KEY")
emb_dim = int(os.getenv("AZURE_AI_SEARCH_EMBEDDING_DIMENSION", "1536"))
if not endpoint or not api_key:
logger.warning("Azure Search credentials not found in .env")
return None
store = AzureAISearchVectorStore(
endpoint=endpoint,
index_name=index_name,
api_key=api_key,
embedding_dimension=emb_dim,
document_type=Document,
)
logger.info("Connected to Azure Search index", index=index_name)
return store
def log_docs(docs: List[Document], label: str = "", max_content: int = 90):
"""Log a list of Documents via BasicLogger."""
if label:
logger.info(label, count=len(docs))
for doc in docs:
page_path = doc.metadata.get("page_path", doc.id)
wiki_name = doc.metadata.get("wiki_name", "")
preview = " ".join(doc.content[:max_content].split())
if len(doc.content) > max_content:
preview += "..."
logger.info(
"document",
path=page_path,
wiki=wiki_name,
chars=len(doc.content),
preview=preview,
)
def log_results(results: List[SearchResult], label: str = ""):
if label:
logger.info(label, count=len(results))
for r in results:
page_path = r.document.metadata.get("page_path", r.document.id)
heading = r.document.metadata.get("section_heading", "")
preview = " ".join(r.document.content[:90].split())
if len(r.document.content) > 90:
preview += "..."
logger.info(
"result",
rank=r.rank,
score=round(r.score, 4),
path=page_path,
section=heading,
preview=preview,
)
# ══════════════════════════════════════════════════════════════════════════════
# Example 1: Azure DevOps Wiki Connector — load pages
# ══════════════════════════════════════════════════════════════════════════════
# Target wiki and section for this demo
# https://dev.azure.com/GMFinancial/GMFinancial/_wiki/wikis/GMFinancial.wiki/162/Team-Wikis
# NOTE: ADO REST API paths use the actual page title (spaces), not the URL slug (hyphens).
_DEMO_WIKI_ID = "GMFinancial.wiki"
# Pages used in Examples 1–3. page_paths_recursive=True means each entry is
# treated as a sub-tree root: the page itself + all its child pages are loaded.
# To load only the exact pages listed (no children), set page_paths_recursive=False.
_DEMO_PAGE_PATHS = [
"/Team Wikis/Innovation and GenAI Lab",
"/Team Wikis/GMF Bank",
]
_DEMO_PAGE_PATHS_RECURSIVE = True
def example_ado_wiki_connector() -> Optional[List[Document]]:
"""
Load specific sections and their children from GMFinancial.wiki.
Targets:
Wiki: GMFinancial.wiki
Pages: _DEMO_PAGE_PATHS (recursive — each entry + all descendants)
URL: https://dev.azure.com/GMFinancial/GMFinancial/_wiki/wikis/GMFinancial.wiki
page_paths_recursive=True walks each listed path as a sub-tree root,
fetching the page itself and all its child pages without loading the
entire wiki.
"""
logger.info("Example 1: Azure DevOps Wiki Connector")
logger.info(
"Scenario: Load specific sections from GMFinancial.wiki. "
"AzureDevOpsWikiConnector authenticates with a PAT and recursively "
"fetches each entry in _DEMO_PAGE_PATHS and its child pages as Markdown Documents."
)
org = os.getenv("AZURE_DEVOPS_ORG")
project = os.getenv("AZURE_DEVOPS_PROJECT")
pat = os.getenv("AZURE_DEVOPS_PAT")
if not all([org, project, pat]):
logger.warning(
"Azure DevOps credentials not found in .env — skipping live demo",
required_vars=[
"AZURE_DEVOPS_ORG (e.g. GMFinancial)",
"AZURE_DEVOPS_PROJECT (e.g. GMFinancial)",
"AZURE_DEVOPS_PAT (PAT with Wiki Read scope)",
],
how_to_create_pat=(
"dev.azure.com/{org} → Profile → Personal access tokens → "
"New token → Scopes: Wiki → Read"
),
)
return None
ssl_cert = get_ssl_cert_path()
connector = AzureDevOpsWikiConnector(
organization=org,
project=project,
pat=pat,
wiki_ids=[_DEMO_WIKI_ID],
page_paths=_DEMO_PAGE_PATHS,
page_paths_recursive=_DEMO_PAGE_PATHS_RECURSIVE,
ssl_cert_path=ssl_cert,
)
logger.info(
"Loading from Azure DevOps Wiki",
org=org,
project=project,
wiki=_DEMO_WIKI_ID,
pages=_DEMO_PAGE_PATHS,
source_url=f"https://dev.azure.com/{org}/{project}/_wiki/wikis/{_DEMO_WIKI_ID}",
)
try:
docs = connector.load()
except Exception as exc:
logger.error(
"Azure DevOps Wiki load failed — skipping live demo",
error=str(exc),
common_causes=[
"401 Unauthorized: AZURE_DEVOPS_PAT is invalid or expired",
"403 Forbidden: PAT lacks Wiki Read scope",
"404 Not Found: AZURE_DEVOPS_ORG, AZURE_DEVOPS_PROJECT, or wiki ID is wrong",
"SSL error: set ssl_cert_path to your corporate CA bundle",
],
)
return None
log_docs(docs, "Loaded pages")
if len(docs) < 1:
logger.warning("No pages loaded — check PAT permissions and wiki paths", pages=_DEMO_PAGE_PATHS)
else:
logger.info(
"Successfully loaded pages",
total_pages=len(docs),
pages=[d.metadata.get("page_path") for d in docs],
)
# Show full metadata for the first page
if docs:
logger.info("First page metadata", **docs[0].metadata)
logger.info("AzureDevOpsWikiConnector complete — Documents are Markdown, ready for WikiPageChunker.")
return docs
# ══════════════════════════════════════════════════════════════════════════════
# Example 2: WikiPageChunker — chunk Markdown pages at heading boundaries
# ══════════════════════════════════════════════════════════════════════════════
def example_wiki_chunker(docs: List[Document]) -> List[Document]:
"""
Chunk Azure DevOps Wiki pages at Markdown heading (#, ##, ###) boundaries.
Uses WikiPageChunker.chunk_markdown() which parses ATX headings directly,
giving accurate section boundaries without heuristic detection.
Each chunk carries:
- section_heading: the heading text that opens the chunk
- heading_level: 1–6 matching the # depth
- chunking_strategy: "wiki"
"""
logger.info("Example 2: WikiPageChunker (chunk_markdown)")
logger.info(
"Scenario: ADO Wiki pages are Markdown. "
"WikiPageChunker.chunk_markdown() splits at # heading boundaries "
"for semantically coherent, section-aligned chunks."
)
chunker = WikiPageChunker(max_chunk_size=1500, min_chunk_size=200)
all_chunk_docs: List[Document] = []
for doc in docs:
chunks = chunker.chunk_markdown(doc.content, metadata=doc.metadata)
for i, chunk in enumerate(chunks):
chunk_doc = Document(
id=f"{doc.id}_chunk_{i}",
content=chunk.text,
metadata={
**chunk.metadata,
"chunk_index": i,
"chunk_count": len(chunks),
},
)
all_chunk_docs.append(chunk_doc)
logger.info(
"WikiPageChunker complete",
pages=len(docs),
chunks=len(all_chunk_docs),
avg_chunks_per_page=round(len(all_chunk_docs) / len(docs), 1) if docs else 0,
)
# Show the first 3 chunks to illustrate heading-aware splits
logger.info("Sample chunks (first 3):")
for chunk_doc in all_chunk_docs[:3]:
heading = chunk_doc.metadata.get("section_heading", "(no heading)")
level = chunk_doc.metadata.get("heading_level", 0)
preview = " ".join(chunk_doc.content[:120].split())
if len(chunk_doc.content) > 120:
preview += "..."
logger.info(
"chunk",
id=chunk_doc.id,
heading=heading,
heading_level=level,
chars=len(chunk_doc.content),
preview=preview,
)
return all_chunk_docs
# ══════════════════════════════════════════════════════════════════════════════
# Example 3: Full Pipeline — load → chunk → embed → index → retrieve
# ══════════════════════════════════════════════════════════════════════════════
def example_full_pipeline(
embedder: AzureOpenAIEmbeddings,
store: AzureAISearchVectorStore,
):
"""
End-to-end ingest pipeline for Azure DevOps Wiki pages:
AzureDevOpsWikiConnector.load() → raw Markdown Documents
→ WikiPageChunker.chunk_markdown() → heading-split chunks
→ AzureOpenAIEmbeddings.embed → embeddings on each chunk
→ AzureAISearchVectorStore → indexed and searchable
→ VectorRetriever.retrieve() → query results
"""
logger.info("Example 3: Full Pipeline — Load → Chunk → Embed → Index → Retrieve")
logger.info("Scenario: End-to-end ADO Wiki ingest ready for RAG.")
org = os.getenv("AZURE_DEVOPS_ORG")
project = os.getenv("AZURE_DEVOPS_PROJECT")
pat = os.getenv("AZURE_DEVOPS_PAT")
if not all([org, project, pat]):
logger.error(
"ADO credentials required for full pipeline",
required_env=["AZURE_DEVOPS_ORG", "AZURE_DEVOPS_PROJECT", "AZURE_DEVOPS_PAT"],
)
return
query = "How do I get started and what are the main components of the copilot lite?"
tracer = get_tracer()
with tracer.trace("ado_wiki_pipeline", input=query) as trace:
# ── Step 1: Load ─────────────────────────────────────────────────────
with trace.span("loading") as span:
logger.info(
"Step 1 - Load: AzureDevOpsWikiConnector fetching Innovation and GenAI Lab pages",
wiki=_DEMO_WIKI_ID,
pages=_DEMO_PAGE_PATHS,
)
connector = AzureDevOpsWikiConnector(
organization=org,
project=project,
pat=pat,
wiki_ids=[_DEMO_WIKI_ID],
page_paths=_DEMO_PAGE_PATHS,
page_paths_recursive=_DEMO_PAGE_PATHS_RECURSIVE,
ssl_cert_path=get_ssl_cert_path(),
)
raw_docs = connector.load()
logger.info("Step 1 complete", pages_loaded=len(raw_docs))
span.set_output({"pages_loaded": len(raw_docs)})
if not raw_docs:
logger.warning("No pages loaded — check ADO credentials and wiki IDs.")
trace.set_output({"pages_loaded": 0, "skipped": True})
return
# ── Step 2: Chunk ─────────────────────────────────────────────────────
with trace.span("chunking") as span:
logger.info("Step 2 - Chunk: WikiPageChunker splitting at Markdown headings")
chunker = WikiPageChunker(max_chunk_size=1500, min_chunk_size=200)
all_chunks: List[Document] = []
for doc in raw_docs:
chunks = chunker.chunk_markdown(doc.content, metadata=doc.metadata)
for i, chunk in enumerate(chunks):
all_chunks.append(Document(
id=f"{doc.id}_chunk_{i}",
content=chunk.text,
metadata={
**chunk.metadata,
"chunk_index": i,
"chunk_count": len(chunks),
},
))
logger.info("Step 2 complete", pages=len(raw_docs), chunks=len(all_chunks))
span.set_output({"pages": len(raw_docs), "chunks": len(all_chunks)})
# ── Step 3: Embed ─────────────────────────────────────────────────────
with trace.span("embedding") as span:
logger.info("Step 3 - Embed: generating embeddings for all chunks")
for chunk_doc in all_chunks:
chunk_doc.embedding = embedder.embed_text(chunk_doc.content)
logger.info("Step 3 complete", embeddings_generated=len(all_chunks))
span.set_output({"embeddings": len(all_chunks)})
# ── Step 4: Index ─────────────────────────────────────────────────────
with trace.span("indexing") as span:
logger.info("Step 4 - Index: adding chunks to Azure Search", index=store.index_name)
store.add_documents(all_chunks, generate_embeddings=False)
logger.info("Step 4 complete", chunks_indexed=len(all_chunks), index=store.index_name)
span.set_output({"chunks_indexed": len(all_chunks)})
# ── Step 5: Retrieve ──────────────────────────────────────────────────
with trace.span("retrieval", input=query) as span:
logger.info("Step 5 - Retrieve", query=query)
query_embedding = embedder.embed_text(query)
retriever = VectorRetriever(store)
results = retriever.retrieve(RetrievalQuery(embedding=query_embedding, top_k=5))
log_results(results, "Top 5 results")
span.set_output({
"results": len(results),
"top_score": round(results[0].score, 4) if results else 0,
})
trace.set_output({
"pages_loaded": len(raw_docs),
"chunks_indexed": len(all_chunks),
"query_results": len(results),
})
logger.info(
"Pipeline complete",
pages=len(raw_docs),
chunks=len(all_chunks),
results=len(results),
note="Wiki chunks are ready to feed into a RAG LLM prompt",
)
# ══════════════════════════════════════════════════════════════════════════════
# Main
# ══════════════════════════════════════════════════════════════════════════════
def main():
logger.info("Azure DevOps Wiki Examples — load, chunk, and index wiki pages")
# Example 1: Load pages from Azure DevOps Wiki
logger.info("********************")
docs = example_ado_wiki_connector()
logger.info("********************")
# Example 2: Chunk pages with WikiPageChunker.chunk_markdown()
if docs:
chunk_docs = example_wiki_chunker(docs)
else:
logger.warning(
"Skipping WikiPageChunker example — no pages loaded",
hint="Set AZURE_DEVOPS_ORG, AZURE_DEVOPS_PROJECT, AZURE_DEVOPS_PAT in .env",
)
chunk_docs = []
logger.info("********************")
# Example 3: Full pipeline (requires Azure + ADO credentials)
logger.info("Setting up full pipeline (requires Azure + ADO credentials)")
embedder = get_embedder()
if not embedder:
logger.error(
"Embedder required for full pipeline",
required_env=["AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_API_KEY", "AZURE_OPENAI_EMBEDDING_MODEL"],
)
sys.exit(1)
store = get_vector_store("ado_wiki_docs")
if not store:
logger.error(
"Azure Search index required for full pipeline",
required_env=["AZURE_AI_SEARCH_ENDPOINT", "AZURE_AI_SEARCH_API_KEY"],
)
sys.exit(1)
AzureAISearchIndexBuilder(
endpoint=os.getenv("AZURE_AI_SEARCH_ENDPOINT"),
api_key=os.getenv("AZURE_AI_SEARCH_API_KEY"),
index_name="ado_wiki_docs",
embedding_dimension=int(os.getenv("AZURE_AI_SEARCH_EMBEDDING_DIMENSION", "1536")),
ssl_cert_path=get_ssl_cert_path(),
).create_index()
logger.info("********************")
example_full_pipeline(embedder, store)
logger.info("********************")
if __name__ == "__main__":
main()
chunking_example.py
"""
Text Chunking Examples for RAG Applications
Demonstrates 7 chunking strategies:
1. Fixed-size (token-based with tiktoken)
2. Semantic (sentence boundaries)
3. Recursive (hierarchical splitting)
4. Sentence (sentence grouping)
5. Markdown (header-aware)
6. Code (function/class boundaries)
7. Combined (multi-strategy pipeline)
Requirements:
- tiktoken (for token-based chunking)
- nltk (for sentence-based chunking) - auto-downloaded on first run
"""
from pathlib import Path
from gmf_forge_ai_data.chunkers import (
FixedSizeChunker,
SemanticChunker,
RecursiveChunker,
SentenceChunker,
MarkdownChunker,
CodeChunker,
)
from gmf_forge_ai_shared_core.observability import BasicLogger
logger = BasicLogger("chunking-example")
# Check for NLTK - REQUIRED for sentence-based chunkers
logger.info("Checking for NLTK punkt tokenizer")
NLTK_AVAILABLE = False
try:
import nltk
# Use data.find() — pure local filesystem check, no network call
try:
nltk.data.find('tokenizers/punkt_tab')
except LookupError:
nltk.data.find('tokenizers/punkt')
NLTK_AVAILABLE = True
logger.info("NLTK punkt tokenizer found")
except ImportError:
logger.warning("NLTK not installed", hint="pip install nltk")
except LookupError:
# NLTK installed but punkt/punkt_tab missing - download it
try:
logger.info("NLTK found but punkt tokenizer missing — downloading (~1.8 MB one-time)")
try:
nltk.download('punkt_tab', quiet=False)
except Exception:
nltk.download('punkt', quiet=False)
NLTK_AVAILABLE = True
logger.info("NLTK punkt tokenizer downloaded successfully")
except Exception as e:
logger.error("Failed to download punkt tokenizer", error=str(e), hint='python -c \'import nltk; nltk.download("punkt_tab")\'')
except Exception as e:
logger.error("NLTK check failed", error=str(e))
# Helper functions
def log_section(title: str):
"""Log a section header."""
logger.info(f"=== {title} ===")
def check_nltk_required() -> bool:
"""Check if NLTK is available and log skip message if not."""
if not NLTK_AVAILABLE:
logger.warning("SKIPPED: Requires NLTK punkt tokenizer", hint="pip install nltk")
return False
return True
def display_chunks(chunks, preview_length=100, show_details=True):
"""Log chunk information in a consistent format."""
logger.info("Chunks created", count=len(chunks))
if not show_details:
return
for i, chunk in enumerate(chunks, 1):
meta = {"chunk": i, "length": len(chunk.text)}
if "token_count" in chunk.metadata:
meta["tokens"] = chunk.metadata["token_count"]
if "sentence_count" in chunk.metadata:
meta["sentences"] = chunk.metadata["sentence_count"]
if "code_blocks" in chunk.metadata:
meta["code_blocks"] = chunk.metadata["code_blocks"]
if "headers" in chunk.metadata:
meta["headers"] = [h["text"] for h in chunk.metadata["headers"]]
preview = chunk.text[:preview_length].replace("\n", " ")
logger.info("chunk", **meta, preview=preview)
# Sample text for examples 1-4 (general AI text used by multiple strategies)
SAMPLE_TEXT = """
Artificial intelligence (AI) is intelligence demonstrated by machines, as opposed to natural
intelligence displayed by animals including humans. AI research has been defined as the field
of study of intelligent agents, which refers to any system that perceives its environment and
takes actions that maximize its chance of achieving its goals.
The term artificial intelligence is applied when a machine mimics cognitive functions that humans
associate with the human mind, such as learning and problem solving. As machines become increasingly
capable, mental facilities once thought to require intelligence are removed from the definition.
For example, optical character recognition is no longer perceived as an exemplar of artificial
intelligence having become a routine technology.
Modern machine learning algorithms are trained on large datasets to recognize patterns and make
predictions. Deep learning, a subset of machine learning, uses neural networks with multiple
layers to process complex data. These systems have achieved remarkable success in areas such as
image recognition, natural language processing, and game playing.
"""
def example_fixed_size_chunking():
"""Example 1: Fixed-size token-based chunking."""
log_section("Example 1: Fixed-Size Token-Based Chunking")
chunker = FixedSizeChunker(
chunk_size=150,
chunk_overlap=20,
encoding_name="cl100k_base",
logger=logger
)
chunks = chunker.chunk(SAMPLE_TEXT, metadata={"source": "ai_article"})
display_chunks(chunks)
def example_semantic_chunking():
"""Example 2: Semantic chunking based on sentence boundaries."""
log_section("Example 2: Semantic Sentence-Based Chunking")
if not check_nltk_required():
return
import nltk
chunker = SemanticChunker(
sentence_tokenizer=nltk.sent_tokenize,
max_chunk_size=300,
min_chunk_size=50,
logger=logger
)
chunks = chunker.chunk(SAMPLE_TEXT, metadata={"source": "ai_article"})
display_chunks(chunks, preview_length=150)
def example_recursive_chunking():
"""Example 3: Recursive hierarchical chunking."""
log_section("Example 3: Recursive Hierarchical Chunking")
chunker = RecursiveChunker(
chunk_size=400,
chunk_overlap=50,
logger=logger
)
chunks = chunker.chunk(SAMPLE_TEXT, metadata={"source": "ai_article"})
logger.info("Splitting hierarchy: paragraph -> sentence -> word -> character")
display_chunks(chunks, preview_length=120)
def example_sentence_chunking():
"""Example 4: Sentence-based chunking with grouping."""
log_section("Example 4: Sentence Grouping Chunking")
if not check_nltk_required():
return
import nltk
# A. Size-based grouping
logger.info("A. Size-based grouping (max 300 chars)")
chunker = SentenceChunker(
sentence_tokenizer=nltk.sent_tokenize,
max_chunk_size=300,
logger=logger
)
chunks = chunker.chunk(SAMPLE_TEXT, metadata={"source": "ai_article"})
display_chunks(chunks, show_details=False)
# B. Fixed sentence count
logger.info("B. Fixed count (2 sentences per chunk)")
chunker = SentenceChunker(
sentence_tokenizer=nltk.sent_tokenize,
max_chunk_size=1000,
sentences_per_chunk=2,
logger=logger
)
chunks = chunker.chunk(SAMPLE_TEXT, metadata={"source": "ai_article"})
display_chunks(chunks, show_details=False)
# Sample markdown document for example 5
MARKDOWN_TEXT = """
# Introduction to Machine Learning
Machine learning is a fascinating field that combines statistics, computer science, and domain expertise.
## What is Machine Learning?
Machine learning is a method of data analysis that automates analytical model building. It is a branch
of artificial intelligence based on the idea that systems can learn from data, identify patterns, and
make decisions with minimal human intervention.
### Types of Machine Learning
There are three main types:
1. **Supervised Learning**: The algorithm learns from labeled training data
2. **Unsupervised Learning**: The algorithm finds patterns in unlabeled data
3. **Reinforcement Learning**: The algorithm learns through trial and error
## Deep Learning
Deep learning is a subset of machine learning that uses neural networks with multiple layers. These
networks can automatically discover complex patterns in large datasets.
### Applications
- Computer vision
- Natural language processing
- Speech recognition
- Autonomous vehicles
"""
def example_markdown_chunking():
"""Example 5: Markdown-aware chunking."""
log_section("Example 5: Markdown Header-Aware Chunking")
chunker = MarkdownChunker(
max_chunk_size=500,
combine_headers=True,
min_header_level=2,
logger=logger
)
chunks = chunker.chunk(MARKDOWN_TEXT, metadata={"source": "ml_guide"})
display_chunks(chunks)
# Sample Python code for example 6
PYTHON_CODE = """
import os
from typing import List, Dict
class DataProcessor:
def __init__(self, config: Dict):
self.config = config
self.results = []
def process_data(self, data: List[str]) -> List[str]:
'''Process a list of data elements.'''
processed = []
for item in data:
result = self._transform(item)
processed.append(result)
return processed
def _transform(self, item: str) -> str:
'''Transform a single data item.'''
return item.upper().strip()
def main():
'''Main entry point.'''
processor = DataProcessor({"mode": "production"})
data = ["hello", "world", "test"]
results = processor.process_data(data)
print(results)
if __name__ == "__main__":
main()
"""
def example_code_chunking():
"""Example 6: Code-aware chunking."""
log_section("Example 6: Code Function/Class Chunking")
chunker = CodeChunker(
max_chunk_size=500,
language="python",
include_imports=True,
logger=logger
)
chunks = chunker.chunk(PYTHON_CODE, metadata={"source": "data_processor.py"})
display_chunks(chunks, preview_length=150)
def example_combined_pipeline():
"""Example 7: Combined chunking strategies."""
log_section("Example 7: Combined Chunking Pipeline")
# Process different content types
md_chunker = MarkdownChunker(max_chunk_size=400, logger=logger)
md_chunks = md_chunker.chunk(MARKDOWN_TEXT, metadata={"type": "documentation"})
code_chunker = CodeChunker(max_chunk_size=500, language="python", logger=logger)
code_chunks = code_chunker.chunk(PYTHON_CODE, metadata={"type": "code"})
text_chunker = FixedSizeChunker(chunk_size=200, chunk_overlap=30, logger=logger)
text_chunks = text_chunker.chunk(SAMPLE_TEXT, metadata={"type": "article"})
# Combine and log results
all_chunks = md_chunks + code_chunks + text_chunks
logger.info(
"Combined pipeline complete",
total=len(all_chunks),
documentation=len(md_chunks),
code=len(code_chunks),
articles=len(text_chunks),
)
if __name__ == "__main__":
log_section("Text Chunking Examples")
logger.info("Demonstrating 7 chunking strategies for RAG applications")
try:
example_fixed_size_chunking()
example_semantic_chunking()
example_recursive_chunking()
example_sentence_chunking()
example_markdown_chunking()
example_code_chunking()
example_combined_pipeline()
logger.info("All examples completed successfully")
except Exception as e:
logger.error("Example failed", error=str(e))
import traceback
traceback.print_exc()
connectors_example.py
"""
Example: Data Connectors (Module 6)
Demonstrates all 3 connectors plus the full ingest-to-retrieval pipeline:
1. FilesystemConnector — Load .md files from the monorepo docs
2. SharePointConnector — Load from SharePoint via Microsoft Graph API
3. BlobStorageConnector — Load from Azure Blob Storage
4. Full Pipeline — load → chunk → embed → index → retrieve
How this connects to existing examples:
- Uses the same embedder, Azure Search config, and helper patterns
as all previous examples (retrieval_example.py, query_example.py, etc.)
- Introduces the RAG ingest path: content sources → Document objects
ready to feed the rest of the pipeline
Flow demonstrated:
connector.load() (FilesystemConnector / SharePoint / Blob)
→ RecursiveChunker (split into indexable chunks)
→ AzureOpenAIEmbeddings (embed each chunk)
→ AzureAISearchVectorStore.add_documents()
→ VectorRetriever.retrieve()
→ results printed
Prerequisites:
- Azure AI Search + Azure OpenAI credentials in .env
(AZURE_OPENAI_ENDPOINT, AZURE_OPENAI_API_KEY, AZURE_OPENAI_DEPLOYMENT,
AZURE_OPENAI_EMBEDDING_MODEL, AZURE_AI_SEARCH_ENDPOINT, AZURE_AI_SEARCH_API_KEY)
Optional (for SharePoint and Blob examples):
- SHAREPOINT_TENANT_ID, SHAREPOINT_CLIENT_ID, SHAREPOINT_CLIENT_SECRET,
SHAREPOINT_SITE_ID, SHAREPOINT_FOLDER_PATH
- AZURE_STORAGE_ACCOUNT_URL, AZURE_STORAGE_CREDENTIAL, AZURE_STORAGE_CONTAINER
"""
import os
import sys
from pathlib import Path
from typing import List, Optional
from dotenv import load_dotenv
from gmf_forge_ai_shared_core.observability import BasicLogger
from gmf_forge_ai_shared_core.observability.tracing import get_tracer
from gmf_forge_ai_data.connectors import (
FilesystemConnector,
SharePointConnector,
BlobStorageConnector,
)
from gmf_forge_ai_data.chunkers import RecursiveChunker
from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings
from gmf_forge_ai_data.vector_stores import AzureAISearchVectorStore, Document, SearchResult
from gmf_forge_ai_data.indexing import AzureAISearchIndexBuilder
from gmf_forge_ai_data.retrieval import RetrievalQuery, VectorRetriever
# ── Environment ───────────────────────────────────────────────────────────────
env_path = Path(__file__).parent / ".env"
load_dotenv(env_path, override=True)
WORKSPACE_ROOT = Path(__file__).parent.parent.parent.parent
CORPORATE_CERT = WORKSPACE_ROOT / "certs" / "gmf_and_public_cas.pem"
logger = BasicLogger(__name__)
# ── Config helpers (same pattern as all other examples) ──────────────────────
def get_ssl_cert_path() -> Optional[str]:
return str(CORPORATE_CERT) if CORPORATE_CERT.exists() else None
def get_embedder() -> Optional[AzureOpenAIEmbeddings]:
endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
api_key = os.getenv("AZURE_OPENAI_API_KEY")
deployment = os.getenv("AZURE_OPENAI_EMBEDDING_MODEL")
api_version = os.getenv("AZURE_OPENAI_EMBEDDING_MODEL_VERSION", "2024-02-01")
if not endpoint or not api_key or not deployment:
logger.warning("Azure OpenAI embedding credentials not found in .env")
return None
ssl_cert = get_ssl_cert_path()
if ssl_cert:
logger.info("Using SSL certificate", cert=CORPORATE_CERT.name)
embedder = AzureOpenAIEmbeddings(
endpoint=endpoint,
api_key=api_key,
deployment_name=deployment,
api_version=api_version,
ssl_cert_path=ssl_cert,
)
logger.info("Initialized embeddings", deployment=deployment)
return embedder
def get_vector_store(index_name: str) -> Optional[AzureAISearchVectorStore]:
endpoint = os.getenv("AZURE_AI_SEARCH_ENDPOINT")
api_key = os.getenv("AZURE_AI_SEARCH_API_KEY")
emb_dim = int(os.getenv("AZURE_AI_SEARCH_EMBEDDING_DIMENSION", "1536"))
if not endpoint or not api_key:
logger.warning("Azure Search credentials not found in .env")
return None
store = AzureAISearchVectorStore(
endpoint=endpoint,
index_name=index_name,
api_key=api_key,
embedding_dimension=emb_dim,
document_type=Document,
)
logger.info("Connected to Azure Search index", index=index_name)
return store
def log_docs(docs: List[Document], label: str = "", max_content: int = 90):
"""Log a list of Documents via BasicLogger."""
if label:
logger.info(label, count=len(docs))
for doc in docs:
source = doc.metadata.get("file_name", doc.id)
ext = doc.metadata.get("extension", "")
size = doc.metadata.get("size_bytes", 0)
preview = " ".join(doc.content[:max_content].split())
if len(doc.content) > max_content:
preview += "..."
logger.info("document", file=source, extension=ext, size_bytes=size, preview=preview)
def log_results(results: List[SearchResult], label: str = ""):
if label:
logger.info(label, count=len(results))
for r in results:
source = r.document.metadata.get("file_name", r.document.id)
preview = " ".join(r.document.content[:90].split())
if len(r.document.content) > 90:
preview += "..."
logger.info("result", rank=r.rank, score=round(r.score, 4), file=source, preview=preview)
# ══════════════════════════════════════════════════════════════════════════════
# Example 1: Filesystem Connector
# ══════════════════════════════════════════════════════════════════════════════
def example_filesystem_connector() -> List[Document]:
"""Load Markdown documentation files from the monorepo root."""
logger.info("Example 1: Filesystem Connector")
logger.info(
"Scenario: Knowledge base on disk — local docs, wikis, source code, text files. "
"FilesystemConnector scans the tree and yields Documents."
)
# Use the monorepo's own .md documentation files as a real source
docs_root = WORKSPACE_ROOT
connector = FilesystemConnector(
root_path=docs_root,
extensions=[".md"],
recursive=False, # top-level .md files only
)
logger.info("Scanning filesystem", root_path=str(docs_root), extensions=[".md"], recursive=False)
docs = connector.load()[:3]
log_docs(docs, "Loaded documents")
# Show metadata detail for the first document
if docs:
first = docs[0]
logger.info("First document metadata", **first.metadata)
logger.info("FilesystemConnector complete — Documents have stable IDs and metadata. Pass to a chunker next.")
return docs
# ══════════════════════════════════════════════════════════════════════════════
# Example 2: SharePoint Connector
#
# HOW TO GET YOUR SHAREPOINT_SITE_ID
# ──────────────────────────────────
# Option A — Graph Explorer (browser, no code)
# 1. Go to https://developer.microsoft.com/graph/graph-explorer
# 2. Sign in with your work account.
# 3. Run this GET request to search for your site by keyword:
#
# GET https://graph.microsoft.com/v1.0/sites?search=InnovationLab
#
# Or to list ALL sites the app has access to:
#
# GET https://graph.microsoft.com/v1.0/sites?search=*
#
# 4. The response JSON contains:
# "id": "gmfinancial.sharepoint.com,<guid1>,<guid2>"
# That full string is your SHAREPOINT_SITE_ID.
#
# Option B — PowerShell (one-liner)
# $token = (Invoke-RestMethod -Uri "https://login.microsoftonline.com/<tenant>/oauth2/v2.0/token" `
# -Method Post -Body @{ grant_type="client_credentials"; client_id="<id>"; `
# client_secret="<secret>"; scope="https://graph.microsoft.com/.default" }).access_token
# Invoke-RestMethod -Uri "https://graph.microsoft.com/v1.0/sites?search=InnovationLab" `
# -Headers @{ Authorization="Bearer $token" } | Select-Object id
#
# FOLDER PATH NOTE
# ────────────────
# SHAREPOINT_FOLDER_PATH is a subfolder *inside* the default document library.
# Use "/" to load everything, or "/Policies" for a specific folder.
# Do NOT use "Shared Documents" — that is the library name, not a subfolder.
# ══════════════════════════════════════════════════════════════════════════════
def example_sharepoint_connector() -> Optional[List[Document]]:
"""Load documents from a SharePoint document library via Graph API."""
logger.info("Example 2: SharePoint Connector")
logger.info(
"Scenario: Knowledge in SharePoint — policies, reports, wiki pages. "
"SharePointConnector authenticates via OAuth2 client credentials and downloads via Microsoft Graph API."
)
tenant_id = os.getenv("SHAREPOINT_TENANT_ID")
client_id = os.getenv("SHAREPOINT_CLIENT_ID")
client_secret = os.getenv("SHAREPOINT_CLIENT_SECRET")
site_id = os.getenv("SHAREPOINT_SITE_ID")
# Default to a specific shallow folder so the example stays fast.
# Set SHAREPOINT_FOLDER_PATH in .env to override.
folder_path = os.getenv("SHAREPOINT_FOLDER_PATH", "/General/Document Library/Copilot Lite Agent")
if not all([tenant_id, client_id, client_secret, site_id]):
logger.warning(
"SharePoint credentials not found in .env - skipping live demo",
required_vars=[
"SHAREPOINT_TENANT_ID",
"SHAREPOINT_CLIENT_ID",
"SHAREPOINT_CLIENT_SECRET",
"SHAREPOINT_SITE_ID",
],
graph_permissions=["Sites.Read.All", "Files.Read.All"],
)
return None
ssl_cert = get_ssl_cert_path()
connector = SharePointConnector(
tenant_id=tenant_id,
client_id=client_id,
client_secret=client_secret,
site_id=site_id,
folder_path=folder_path,
ssl_cert_path=ssl_cert,
)
logger.info("Loading from SharePoint", site_id=site_id, folder_path=folder_path)
try:
docs = connector.load()[:3]
except Exception as exc:
logger.error(
"SharePoint load failed - skipping live demo",
error=str(exc),
common_causes=[
"401 Unauthorized: app registration lacks Sites.Read.All / Files.Read.All or admin consent not granted",
"404 Not Found: SHAREPOINT_SITE_ID or SHAREPOINT_FOLDER_PATH is wrong",
"SSL error: set ssl_cert_path to your corporate CA bundle",
],
)
return None
log_docs(docs, "Loaded documents")
logger.info("SharePointConnector complete — files downloaded and decoded via Graph API", count=len(docs))
return docs
# ══════════════════════════════════════════════════════════════════════════════
# Example 3: Blob Storage Connector
# ══════════════════════════════════════════════════════════════════════════════
def example_blob_connector() -> Optional[List[Document]]:
"""Load documents from an Azure Blob Storage container."""
logger.info("Example 3: Blob Storage Connector")
logger.info(
"Scenario: Azure Blob Storage as landing zone for reports, PDFs, or text files. "
"BlobStorageConnector lists and downloads blobs via the azure-storage-blob SDK."
)
account_name = os.getenv("AZURE_STORAGE_ACCOUNT_NAME")
access_key = os.getenv("AZURE_STORAGE_ACCESS_KEY")
container_name = os.getenv("AZURE_BLOB_CONTAINER_NAME")
# A prefix is required to keep the demo fast — without one, load() lists and
# downloads every blob in the container before [:3] can limit the result.
# Set AZURE_STORAGE_PREFIX in .env to use a different prefix.
prefix = os.getenv("AZURE_STORAGE_PREFIX", "Anti")
if not all([account_name, access_key, container_name]):
logger.warning(
"Azure Blob Storage credentials not found in .env - skipping live demo",
required_vars=[
"AZURE_STORAGE_ACCOUNT_NAME",
"AZURE_STORAGE_ACCESS_KEY",
"AZURE_BLOB_CONTAINER_NAME",
],
optional_vars=["AZURE_STORAGE_PREFIX (folder prefix)"],
hint="Install SDK if needed: pip install azure-storage-blob",
)
return None
ssl_cert = get_ssl_cert_path()
connector = BlobStorageConnector(
account_name=account_name,
access_key=access_key,
container_name=container_name,
prefix=prefix,
ssl_cert_path=ssl_cert,
)
display_name = account_name or "(connection string)"
logger.info("Loading from Blob Storage", account=display_name, container=container_name, prefix=prefix or "(entire container)")
try:
docs = connector.load()[:3]
except Exception as exc:
logger.error(
"Blob Storage load failed - skipping live demo",
error=str(exc),
common_causes=[
"AuthenticationFailed: AZURE_STORAGE_ACCESS_KEY is invalid or expired",
"ResourceNotFound: AZURE_BLOB_CONTAINER_NAME does not exist",
"SSL error: set ssl_cert_path to your corporate CA bundle",
],
)
return None
log_docs(docs, "Loaded documents")
logger.info("BlobStorageConnector complete — blobs streamed and text extracted", count=len(docs))
return docs
# ══════════════════════════════════════════════════════════════════════════════
# Full Pipeline: load → chunk → embed → index → retrieve
# ══════════════════════════════════════════════════════════════════════════════
def example_full_pipeline(
embedder: AzureOpenAIEmbeddings,
store: AzureAISearchVectorStore,
pre_loaded_docs: Optional[List[Document]] = None,
):
"""
End-to-end ingest pipeline:
FilesystemConnector.load() → raw Documents
→ RecursiveChunker.chunk() → smaller indexable chunks
→ AzureOpenAIEmbeddings.embed → embeddings on each chunk
→ AzureAISearchVectorStore → indexed and searchable
→ VectorRetriever.retrieve() → query results
"""
logger.info("Full Pipeline: Load -> Chunk -> Embed -> Index -> Retrieve")
logger.info("Scenario: End-to-end ingest path — from raw source files to a searchable vector index.")
query = "How is this repository structured and what packages does it contain?"
tracer = get_tracer()
with tracer.trace("connector_pipeline", input=query) as trace:
# ── Step 1: Load ────────────────────────────────────────────────────────────────────
with trace.span("loading") as span:
if pre_loaded_docs is not None:
raw_docs = pre_loaded_docs[:3]
source = "sharepoint" if any(d.metadata.get("source", "").startswith("https://") for d in raw_docs) else "filesystem"
logger.info("Step 1 - Load: using pre-loaded sample docs", source=source, docs_loaded=len(raw_docs))
else:
logger.info("Step 1 - Load: FilesystemConnector scanning monorepo .md files")
connector = FilesystemConnector(
root_path=WORKSPACE_ROOT,
extensions=[".md"],
recursive=False,
)
raw_docs = connector.load()[:3]
logger.info("Step 1 complete", docs_loaded=len(raw_docs), source=WORKSPACE_ROOT.name)
span.set_output({"docs_loaded": len(raw_docs)})
if not raw_docs:
logger.warning("No .md files found - skipping pipeline.")
trace.set_output({"docs_loaded": 0, "skipped": True})
return
# ── Step 2: Chunk ────────────────────────────────────────────────────────────────────
with trace.span("chunking") as span:
logger.info("Step 2 - Chunk: RecursiveChunker", chunk_size=300, overlap=50)
chunker = RecursiveChunker(chunk_size=300, chunk_overlap=50)
all_chunks: List[Document] = []
for doc in raw_docs:
chunks = chunker.chunk(doc.content)
for i, chunk in enumerate(chunks):
chunk_doc = Document(
id=f"{doc.id}_chunk_{i}",
content=chunk.text,
metadata={
**doc.metadata,
"chunk_index": i,
"chunk_count": len(chunks),
},
)
all_chunks.append(chunk_doc)
logger.info("Step 2 complete", docs=len(raw_docs), chunks=len(all_chunks))
span.set_output({"docs": len(raw_docs), "chunks": len(all_chunks)})
# ── Step 3: Embed ────────────────────────────────────────────────────────────────────
with trace.span("embedding") as span:
logger.info("Step 3 - Embed: generating embeddings for all chunks")
for chunk_doc in all_chunks:
chunk_doc.embedding = embedder.embed_text(chunk_doc.content)
logger.info("Step 3 complete", embeddings_generated=len(all_chunks))
span.set_output({"embeddings": len(all_chunks)})
# ── Step 4: Index ────────────────────────────────────────────────────────────────────
with trace.span("indexing") as span:
logger.info("Step 4 - Index: adding chunks to Azure Search", index=store.index_name)
store.add_documents(all_chunks, generate_embeddings=False)
logger.info("Step 4 complete", chunks_indexed=len(all_chunks), index=store.index_name)
span.set_output({"chunks_indexed": len(all_chunks)})
# ── Step 5: Retrieve ──────────────────────────────────────────────────────────────────
with trace.span("retrieval", input=query) as span:
logger.info("Step 5 - Retrieve", query=query)
query_embedding = embedder.embed_text(query)
retriever = VectorRetriever(store)
results = retriever.retrieve(RetrievalQuery(embedding=query_embedding, top_k=5))
log_results(results, "Top results")
span.set_output({"results": len(results), "top_score": round(results[0].score, 4) if results else 0})
trace.set_output({"docs_loaded": len(raw_docs), "chunks_indexed": len(all_chunks), "query_results": len(results)})
logger.info(
"Pipeline complete",
files=len(raw_docs),
chunks=len(all_chunks),
results=len(results),
note="Passages are ready to feed into a RAG LLM prompt",
)
# ══════════════════════════════════════════════════════════════════════════════
# Main
# ══════════════════════════════════════════════════════════════════════════════
def main():
logger.info("Connector Examples (Module 6) — 3 data source connectors + full ingest pipeline")
# Run individual connector examples
logger.info("********************")
fs_docs = example_filesystem_connector()
logger.info("********************")
sp_docs = example_sharepoint_connector()
logger.info("********************")
example_blob_connector()
logger.info("********************")
# Full pipeline requires embedder + vector store
# Use up to 3 SharePoint docs; fall back to filesystem docs if SharePoint
# credentials are not configured or no docs were returned.
pipeline_docs = (sp_docs or [])[:3] or (fs_docs or [])[:3]
logger.info(
"Setting up full pipeline (requires Azure credentials)",
source="sharepoint" if (sp_docs) else "filesystem",
sample_docs=len(pipeline_docs),
)
embedder = get_embedder()
if not embedder:
logger.error(
"Embedder required for full pipeline",
required_env=["AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_API_KEY", "AZURE_OPENAI_EMBEDDING_MODEL"],
)
sys.exit(1)
store = get_vector_store("connector_docs")
if not store:
logger.error(
"Azure Search index required for full pipeline",
required_env=["AZURE_AI_SEARCH_ENDPOINT", "AZURE_AI_SEARCH_API_KEY"],
)
sys.exit(1)
AzureAISearchIndexBuilder(
endpoint=os.getenv("AZURE_AI_SEARCH_ENDPOINT"),
api_key=os.getenv("AZURE_AI_SEARCH_API_KEY"),
index_name="connector_docs",
embedding_dimension=int(os.getenv("AZURE_AI_SEARCH_EMBEDDING_DIMENSION", "1536")),
ssl_cert_path=get_ssl_cert_path(),
).create_index()
example_full_pipeline(embedder, store, pre_loaded_docs=pipeline_docs)
logger.info("---")
logger.info(
"Summary",
connectors=[
"FilesystemConnector: scans local directory trees for text/PDF/DOCX files",
"SharePointConnector: fetches files via Microsoft Graph API (OAuth2 app-only)",
"BlobStorageConnector: lists and downloads blobs from Azure Blob Storage",
],
recommended_pipeline="connector.load() -> chunker.chunk() -> embedder.embed() -> store.add_documents()",
)
if __name__ == "__main__":
main()
context_example.py
"""
Example: Context Processing Strategies (Module 5)
Demonstrates all 5 post-retrieval context management techniques:
1. RelevanceFilter — Drop results below a score threshold
2. ContextDeduplicator — Remove near-duplicate passages
3. ContextReranker — Reorder by LLM-assessed relevance
4. ContextCompressor — Extract only query-relevant sentences per chunk
5. ContextWindowManager — Fit everything into a token budget
How this connects to existing examples:
- Retrieves from the same 'ai_ml_knowledge' index used in retrieval_example.py
and query_example.py (run azure_ai_search_vector_store_example.py first)
- Uses the same embedder, Azure Search config, and AIMLDocument schema
- Designed to slot directly after retrieval in a RAG pipeline
Flow demonstrated:
retrieve (VectorRetriever)
→ 1. RelevanceFilter
→ 2. ContextDeduplicator
→ 3. ContextReranker (LLM)
→ 4. ContextCompressor (LLM)
→ 5. ContextWindowManager
→ ready for LLM prompt assembly
Prerequisites:
- Run azure_ai_search_vector_store_example.py first to populate 'ai_ml_knowledge'
- Azure AI Search + Azure OpenAI credentials in .env
(AZURE_OPENAI_ENDPOINT, AZURE_OPENAI_API_KEY, AZURE_OPENAI_DEPLOYMENT,
AZURE_OPENAI_EMBEDDING_MODEL)
"""
import asyncio
import os
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional
from dotenv import load_dotenv
from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings
from gmf_forge_ai_data.vector_stores import AzureAISearchVectorStore, Document, SearchResult
from gmf_forge_ai_data.retrieval import RetrievalQuery, VectorRetriever
from gmf_forge_ai_data.context import (
RelevanceFilter,
ContextDeduplicator,
ContextReranker,
ContextCompressor,
ContextWindowManager,
WindowedContext,
)
from gmf_forge_ai_shared_core.llm_gateway import UnifiedLLMGateway
from gmf_forge_ai_shared_core.llm_gateway.providers import AzureOpenAIProvider
from gmf_forge_ai_shared_core.observability import BasicLogger, BasicMetricsCollector
from gmf_forge_ai_shared_core.observability.tracing import get_tracer
# ── Environment ─────────────────────────────────────────────────────────────────────────────
env_path = Path(__file__).parent / ".env"
load_dotenv(env_path)
WORKSPACE_ROOT = Path(__file__).parent.parent.parent.parent
CORPORATE_CERT = WORKSPACE_ROOT / "certs" / "gmf_and_public_cas.pem"
logger = BasicLogger(__name__)
metrics = BasicMetricsCollector()
# ── Document schema (same as retrieval_example.py / query_example.py) ────────
@dataclass
class AIMLDocument(Document):
topic: str = ""
category: str = ""
# ── Config helpers (same pattern as all other examples) ──────────────────────
def get_ssl_cert_path() -> Optional[str]:
return str(CORPORATE_CERT) if CORPORATE_CERT.exists() else None
def get_embedder() -> Optional[AzureOpenAIEmbeddings]:
endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
api_key = os.getenv("AZURE_OPENAI_API_KEY")
deployment = os.getenv("AZURE_OPENAI_EMBEDDING_MODEL")
api_version = os.getenv("AZURE_OPENAI_EMBEDDING_MODEL_VERSION", "2024-02-01")
if not endpoint or not api_key or not deployment:
logger.warning("Azure OpenAI embedding credentials not found in .env")
return None
ssl_cert = get_ssl_cert_path()
if ssl_cert:
logger.info("Using SSL certificate", cert=CORPORATE_CERT.name)
embedder = AzureOpenAIEmbeddings(
endpoint=endpoint,
api_key=api_key,
deployment_name=deployment,
api_version=api_version,
ssl_cert_path=ssl_cert,
)
logger.info("Initialized embeddings", deployment=deployment)
return embedder
def get_llm_gateway() -> Optional[UnifiedLLMGateway]:
endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
api_key = os.getenv("AZURE_OPENAI_API_KEY")
deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
api_version = os.getenv("AZURE_OPENAI_API_VERSION", "2025-01-01-preview")
if not endpoint or not api_key or not deployment:
return None
ssl_cert = get_ssl_cert_path()
provider = AzureOpenAIProvider(
endpoint=endpoint,
api_key=api_key,
deployment_name=deployment,
api_version=api_version,
ssl_cert_path=ssl_cert,
)
gateway = UnifiedLLMGateway(default_provider=provider)
logger.info("Initialized LLM gateway", deployment=deployment)
return gateway
def get_vector_store(
index_name: str = "ai_ml_knowledge",
) -> Optional[AzureAISearchVectorStore]:
endpoint = os.getenv("AZURE_AI_SEARCH_ENDPOINT")
api_key = os.getenv("AZURE_AI_SEARCH_API_KEY")
emb_dim = int(os.getenv("AZURE_AI_SEARCH_EMBEDDING_DIMENSION", "1536"))
if not endpoint or not api_key:
logger.warning("Azure Search credentials not found in .env")
return None
try:
store = AzureAISearchVectorStore(
endpoint=endpoint,
index_name=index_name,
api_key=api_key,
embedding_dimension=emb_dim,
document_type=AIMLDocument,
)
logger.info("Connected to Azure Search index", index=index_name, documents=store.count())
return store
except Exception as e:
logger.warning("Could not connect to index", index=index_name, error=str(e))
return None
def retrieve(
query: str,
embedder: AzureOpenAIEmbeddings,
store: AzureAISearchVectorStore,
top_k: int = 8,
) -> List[SearchResult]:
"""Run a vector retrieval query and return raw results."""
embedding = embedder.embed_text(query)
retriever = VectorRetriever(store)
return retriever.retrieve(RetrievalQuery(embedding=embedding, top_k=top_k))
def log_results(results: List[SearchResult], label: str = "", max_content: int = 100):
"""Log a result list via BasicLogger."""
if label:
logger.info(label, count=len(results))
for r in results:
topic = getattr(r.document, "topic", "?")
content_preview = r.document.content[:max_content].replace("\n", " ")
if len(r.document.content) > max_content:
content_preview += "..."
logger.info("result", rank=r.rank, score=round(r.score, 4), topic=topic, preview=content_preview)
# ══════════════════════════════════════════════════════════════════════════════
# Example 1: Relevance Filter
# ══════════════════════════════════════════════════════════════════════════════
def example_relevance_filter(results: List[SearchResult]):
"""Drop results below a similarity score threshold."""
logger.info("Example 1: Relevance Filter")
logger.info(
"Scenario: Vector search top-k results may include weakly related docs. "
"Drop anything below a score floor before passing to LLM."
)
log_results(results, "Before filtering")
# Two thresholds to illustrate the effect
for threshold in [0.82, 0.88]:
filter_ = RelevanceFilter(min_score=threshold)
kept = filter_.filter(results)
dropped = len(results) - len(kept)
logger.info("RelevanceFilter applied", min_score=threshold, kept=len(kept), dropped=dropped)
for r in kept:
topic = getattr(r.document, "topic", "?")
logger.info("kept result", rank=r.rank, score=round(r.score, 4), topic=topic)
logger.info("RelevanceFilter complete — removes weak matches with zero LLM calls")
# ══════════════════════════════════════════════════════════════════════════════
# Example 2: Context Deduplicator
# ══════════════════════════════════════════════════════════════════════════════
def example_deduplicator(results: List[SearchResult]):
"""Remove near-duplicate passages using character trigram Jaccard."""
logger.info("Example 2: Context Deduplicator")
logger.info(
"Scenario: Overlapping chunks can return near-identical passages wasting context tokens. "
"Remove duplicates by n-gram similarity before the LLM."
)
# Inject a synthetic near-duplicate of the first result.
original = results[0]
duplicate = SearchResult(
document=type(original.document)(
id=original.document.id + "_dup",
content=original.document.content,
topic=getattr(original.document, "topic", ""),
category=getattr(original.document, "category", ""),
),
score=original.score - 0.01,
rank=len(results),
)
results_with_dup = results + [duplicate]
# Compute and log the Jaccard score so the threshold is transparent
deduper_check = ContextDeduplicator(similarity_threshold=0.85)
fp_orig = deduper_check._fingerprint(original.document.content)
fp_dup = deduper_check._fingerprint(duplicate.document.content)
jaccard = deduper_check._jaccard(fp_orig, fp_dup)
logger.info(
"Duplicate injected",
topic=getattr(original.document, "topic", "?"),
jaccard=round(jaccard, 4),
threshold=0.85,
input_count=len(results_with_dup),
)
deduper = ContextDeduplicator(similarity_threshold=0.85)
unique = deduper.deduplicate(results_with_dup)
removed = len(results_with_dup) - len(unique)
logger.info("ContextDeduplicator applied", threshold=0.85, kept=len(unique), removed=removed)
log_results(unique, "After deduplication")
logger.info("ContextDeduplicator complete — removes redundant passages with zero LLM calls")
# ══════════════════════════════════════════════════════════════════════════════
# Example 3: Context Reranker
# ══════════════════════════════════════════════════════════════════════════════
async def example_reranker(
query: str,
results: List[SearchResult],
gateway: UnifiedLLMGateway,
):
"""Reorder results by LLM relevance — goes beyond vector similarity."""
logger.info("Example 3: Context Reranker")
logger.info(
"Scenario: Vector similarity scores are tightly clustered and don't show true relevance. "
"Ask the LLM to rank retrieved docs by true relevance."
)
logger.info("Query", query=query)
log_results(results[:5], "Before reranking (vector order)")
reranker = ContextReranker(gateway, temperature=0.0)
reranked = await reranker.rerank(query, results[:5], top_k=5)
log_results(reranked, "After reranking (LLM order)")
# Highlight any order changes
original_ids = [r.document.id for r in results[:5]]
reranked_ids = [r.document.id for r in reranked]
if original_ids != reranked_ids:
for new_pos, result in enumerate(reranked):
old_pos = original_ids.index(result.document.id)
topic = getattr(result.document, "topic", "?")
if new_pos != old_pos:
logger.info("LLM reranking changed position", topic=topic, old_pos=old_pos, new_pos=new_pos)
else:
logger.info("LLM agreed with vector similarity ranking for this query")
logger.info("ContextReranker complete — improves result ordering using language understanding")
return reranked
# ══════════════════════════════════════════════════════════════════════════════
# Example 4: Context Compressor
# ══════════════════════════════════════════════════════════════════════════════
async def example_compressor(
query: str,
results: List[SearchResult],
gateway: UnifiedLLMGateway,
):
"""Extract only query-relevant sentences from each retrieved chunk."""
logger.info("Example 4: Context Compressor")
logger.info(
"Scenario: Chunks contain background sentences that don't help answer "
"the query, wasting context window and diluting model attention. "
"Per-chunk LLM extraction keeps only relevant sentences."
)
logger.info("Query", query=query)
# Show compression on first 3 results
sample = results[:3]
compressor = ContextCompressor(gateway, temperature=0.0)
compressed = await compressor.compress(query, sample)
for orig, comp in zip(sample, compressed):
topic = getattr(orig.document, "topic", "?")
orig_tokens = max(1, round(len(orig.document.content) / 4))
comp_tokens = max(1, round(len(comp.document.content) / 4))
reduction = round((1 - comp_tokens / max(orig_tokens, 1)) * 100)
logger.info(
"Compressed chunk",
topic=topic,
tokens_before=orig_tokens,
tokens_after=comp_tokens,
reduction_pct=reduction,
compressed=comp.document.content,
)
logger.info("ContextCompressor complete - reduces token usage and focuses LLM on evidence")
return compressed
# ══════════════════════════════════════════════════════════════════════════════
# Example 5: Context Window Manager
# ══════════════════════════════════════════════════════════════════════════════
def example_window_manager(results: List[SearchResult]):
"""Fit results into a token budget — the final gate before prompt assembly."""
logger.info("Example 5: Context Window Manager")
logger.info(
"Scenario: Even after filtering and compression, total context may exceed the token budget. "
"Keep as many full docs as fit; optionally truncate the last."
)
total_tokens = sum(max(1, round(len(r.document.content) / 4)) for r in results)
def _round25(n: float) -> int:
return max(25, round(n / 25) * 25)
budgets = [
_round25(total_tokens * 0.25),
_round25(total_tokens * 0.50),
_round25(total_tokens * 2.00),
]
logger.info("Token budget test", total_tokens=total_tokens, budgets=budgets)
for budget in budgets:
manager = ContextWindowManager(max_tokens=budget, allow_truncation=True)
window = manager.fit(results)
truncated_rank = window.results[-1].rank if window.truncated else None
fitted = []
for r in window.results:
topic = getattr(r.document, "topic", "?")
tokens = max(1, round(len(r.document.content) / 4))
fitted.append({"rank": r.rank, "topic": topic, "tokens": tokens, "truncated": window.truncated and r.rank == truncated_rank})
logger.info(
"ContextWindowManager applied",
max_tokens=budget,
included=len(window.results),
dropped=window.dropped,
truncated=window.truncated,
tokens_used=window.total_tokens,
budget=window.budget,
utilization_pct=round(window.total_tokens / window.budget * 100),
docs=fitted,
)
logger.info("ContextWindowManager complete — prevents context overflow before prompt assembly")
# ══════════════════════════════════════════════════════════════════════════════
# Pipeline Demo: Full post-retrieval context pipeline
# ══════════════════════════════════════════════════════════════════════════════
async def example_full_pipeline(
query: str,
embedder: AzureOpenAIEmbeddings,
store: AzureAISearchVectorStore,
gateway: UnifiedLLMGateway,
):
"""
Show the complete post-retrieval context pipeline.
Pipeline:
retrieve (top 8)
→ RelevanceFilter (drop score < 0.82)
→ ContextDeduplicator (remove near-duplicates, threshold 0.85)
→ ContextReranker (LLM-reorder top results)
→ ContextCompressor (per-doc sentence extraction)
→ ContextWindowManager (fit into 1500-token budget)
→ context string ready for LLM prompt
"""
logger.info("Pipeline Demo: Full Post-Retrieval Context Pipeline")
logger.info("Query", query=query)
tracer = get_tracer()
with tracer.trace("context_pipeline", input=query) as trace:
# Step 1: Retrieve
with trace.span("retrieval", input=query) as span:
raw = retrieve(query, embedder, store, top_k=8)
span.set_output({"doc_count": len(raw)})
logger.info("Step 1 - Retrieve", results=len(raw), index="ai_ml_knowledge")
metrics.increment("pipeline.steps_run")
# Step 2: Relevance filter
with trace.span("relevance_filter") as span:
filter_ = RelevanceFilter(min_score=0.82)
filtered = filter_.filter(raw)
span.set_output({"kept": len(filtered), "dropped": len(raw) - len(filtered)})
logger.info("Step 2 - Filter", kept=len(filtered), dropped=len(raw) - len(filtered), min_score=0.82)
metrics.increment("pipeline.steps_run")
# Step 3: Deduplicate
with trace.span("deduplication") as span:
deduper = ContextDeduplicator(similarity_threshold=0.85)
unique = deduper.deduplicate(filtered)
span.set_output({"kept": len(unique), "removed": len(filtered) - len(unique)})
logger.info("Step 3 - Deduplicate", kept=len(unique), removed=len(filtered) - len(unique))
metrics.increment("pipeline.steps_run")
# Step 4: Rerank
with trace.span("reranker") as span:
reranker = ContextReranker(gateway, temperature=0.0)
reranked = await reranker.rerank(query, unique, top_k=5)
span.set_output({"top_k": len(reranked)})
logger.info("Step 4 - Rerank", top_k=len(reranked))
metrics.increment("pipeline.steps_run")
# Step 5: Compress
with trace.span("compressor") as span:
compressor = ContextCompressor(gateway, temperature=0.0)
compressed = await compressor.compress(query, reranked)
total_before = sum(max(1, round(len(r.document.content) / 4)) for r in reranked)
total_after = sum(max(1, round(len(r.document.content) / 4)) for r in compressed)
span.set_output({"tokens_before": total_before, "tokens_after": total_after})
logger.info("Step 5 - Compress", tokens_before=total_before, tokens_after=total_after, reduction_pct=round((1 - total_after / max(total_before, 1)) * 100))
metrics.increment("pipeline.steps_run")
# Step 6: Window
with trace.span("window_manager") as span:
manager = ContextWindowManager(max_tokens=1500, allow_truncation=True)
window = manager.fit(compressed)
span.set_output({"docs": len(window.results), "tokens_used": window.total_tokens})
logger.info("Step 6 - Window", docs=len(window.results), tokens_used=window.total_tokens, budget=window.budget, truncated=window.truncated)
metrics.increment("pipeline.steps_run")
metrics.gauge("pipeline.final_passages", len(window.results))
metrics.gauge("pipeline.tokens_used", window.total_tokens)
trace.set_output({"passages": len(window.results), "tokens": window.total_tokens})
for i, r in enumerate(window.results, 1):
topic = getattr(r.document, "topic", "?")
logger.info("final context passage", index=i, topic=topic, score=round(r.score, 4), preview=r.document.content[:150])
logger.info(
"Context pipeline complete",
retrieved=len(raw),
focused_passages=len(window.results),
tokens=window.total_tokens,
note="Ready to insert into your LLM prompt template",
)
# ══════════════════════════════════════════════════════════════════════════════
# Main
# ══════════════════════════════════════════════════════════════════════════════
def main():
asyncio.run(_main())
async def _main():
logger.info("Context Processing Examples (Module 5) — 5 post-retrieval context management strategies")
# LLM required for reranker and compressor
gateway = get_llm_gateway()
if not gateway:
logger.error(
"LLM gateway required (Examples 3, 4, and Pipeline)",
required_env=["AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_API_KEY", "AZURE_OPENAI_DEPLOYMENT"],
)
sys.exit(1)
embedder = get_embedder()
if not embedder:
logger.error("Embedding provider required for retrieval")
sys.exit(1)
store = get_vector_store("ai_ml_knowledge")
if not store:
logger.error("Could not connect to Azure Search index", hint="Run azure_ai_search_vector_store_example.py first")
sys.exit(1)
# Shared query and a base retrieval used across all examples
query = "How do neural networks learn and what role does backpropagation play?"
logger.info("Shared query for all examples", query=query)
raw_results = retrieve(query, embedder, store, top_k=8)
metrics.gauge("pipeline.raw_result_count", len(raw_results))
log_results(raw_results, "Raw retrieval results")
# Run individual examples
example_relevance_filter(raw_results)
example_deduplicator(raw_results)
reranked = await example_reranker(query, raw_results, gateway)
await example_compressor(query, reranked, gateway)
example_window_manager(raw_results)
await example_full_pipeline(query, embedder, store, gateway)
logger.info(
"Summary",
components=[
"RelevanceFilter: drops weak results by score threshold (no LLM)",
"ContextDeduplicator: removes near-duplicates by n-gram Jaccard (no LLM)",
"ContextReranker: reorders by LLM relevance scoring (LLM, temp=0.0)",
"ContextCompressor: extracts query-relevant sentences (LLM, temp=0.0)",
"ContextWindowManager: fits results into a token budget (no LLM)",
],
recommended_pipeline="retrieve -> filter -> deduplicate -> rerank -> compress -> window -> LLM prompt",
)
perf = metrics.get_metrics()
logger.info("Metrics Summary", counters=perf["counters"], gauges=perf["gauges"])
if __name__ == "__main__":
main()
cosmosdb_vector_store_example.py
"""
Example: Azure Cosmos DB NoSQL API Vector Store (Module 7a)
Demonstrates how to use AzureCosmosDBVectorStore for vector search in a RAG pipeline:
1. Connect to a Cosmos DB NoSQL API account
2. Index documents with real Azure OpenAI embeddings
3. Run vector, keyword, and hybrid searches
4. Use metadata filters to narrow results
5. Fetch, update, and delete documents
Prerequisites:
- Azure Cosmos DB account (NoSQL API)
- Azure OpenAI embeddings credentials in .env
Required .env variables:
AZURE_COSMOS_ENDPOINT https://your-account.documents.azure.com:443/
AZURE_COSMOS_KEY your Cosmos DB account key
AZURE_COSMOS_DATABASE e.g. rag_db (default: rag_db)
AZURE_COSMOS_CONTAINER e.g. tech_articles (default: tech_articles)
AZURE_OPENAI_ENDPOINT
AZURE_OPENAI_API_KEY
AZURE_OPENAI_EMBEDDING_MODEL deployment name for your embedding model
AZURE_OPENAI_EMBEDDING_MODEL_VERSION (default: 2024-02-01)
Optional:
AZURE_AI_SEARCH_EMBEDDING_DIMENSION (default: 1536)
"""
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
from dotenv import load_dotenv
from gmf_forge_ai_shared_core.observability import BasicLogger
from gmf_forge_ai_shared_core.observability.tracing import get_tracer
import nltk
from gmf_forge_ai_data.chunkers import SentenceChunker
from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings
from gmf_forge_ai_data.vector_stores import Document, AzureCosmosDBVectorStore
from gmf_forge_ai_data.indexing import CosmosDBIndexBuilder
# ── Environment ────────────────────────────────────────────────────────────────
env_path = Path(__file__).parent / ".env"
load_dotenv(env_path)
logger = BasicLogger(__name__)
WORKSPACE_ROOT = Path(__file__).parent.parent.parent.parent
CORPORATE_CERT = WORKSPACE_ROOT / "certs" / "gmf_and_public_cas.pem"
# ── Document schema ────────────────────────────────────────────────────────────
@dataclass
class TechArticle(Document):
"""Technology article with domain-specific fields."""
category: str = ""
author: str = ""
published_year: int = 0
@dataclass
class SupportTicket(Document):
"""Customer support ticket with triage and routing fields."""
severity: str = "" # "critical", "high", "medium", "low"
product: str = "" # affected product or service name
status: str = "" # "open", "in_progress", "resolved"
# ── Config helpers ─────────────────────────────────────────────────────────────
def get_ssl_cert_path() -> Optional[str]:
return str(CORPORATE_CERT) if CORPORATE_CERT.exists() else None
def get_embedder() -> Optional[AzureOpenAIEmbeddings]:
endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
api_key = os.getenv("AZURE_OPENAI_API_KEY")
deployment = os.getenv("AZURE_OPENAI_EMBEDDING_MODEL")
api_version = os.getenv("AZURE_OPENAI_EMBEDDING_MODEL_VERSION", "2024-02-01")
if not endpoint or not api_key or not deployment:
logger.error(
"Azure OpenAI embedding credentials missing",
missing_vars=["AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_API_KEY", "AZURE_OPENAI_EMBEDDING_MODEL"],
)
return None
ssl_cert = get_ssl_cert_path()
if ssl_cert:
logger.info("Using SSL certificate", cert=CORPORATE_CERT.name)
embedder = AzureOpenAIEmbeddings(
endpoint=endpoint,
api_key=api_key,
deployment_name=deployment,
api_version=api_version,
ssl_cert_path=ssl_cert,
)
logger.info("Initialized embeddings", deployment=deployment)
return embedder
# ── Sample documents ───────────────────────────────────────────────────────────
def build_sample_documents(embedder: AzureOpenAIEmbeddings) -> list:
"""Build 8 TechArticle documents and embed their content via Azure OpenAI."""
raw = [
TechArticle(
id="neural_networks_intro",
content=(
"Neural networks are computing systems loosely inspired by the biological "
"neural networks that constitute animal brains. They consist of layers of "
"interconnected nodes, or neurons, which process data and learn patterns. "
"Deep neural networks have many hidden layers, enabling them to learn "
"complex representations from raw input data. Modern architectures have "
"achieved state-of-the-art results in vision, language, and speech tasks."
),
category="deep_learning",
author="AI Research Team",
published_year=2023,
metadata={"difficulty": "beginner", "tags": ["neural_networks", "AI"]},
),
TechArticle(
id="backpropagation_explained",
content=(
"Backpropagation is the primary algorithm used to train neural networks. "
"It computes the gradient of the loss function with respect to every "
"weight by applying the chain rule of calculus. The gradients are then "
"used by an optimisation algorithm — such as stochastic gradient descent "
"or Adam — to update the weights. Efficient GPU implementations make "
"backpropagation tractable even for networks with billions of parameters."
),
category="deep_learning",
author="ML Theory Group",
published_year=2023,
metadata={"difficulty": "intermediate", "tags": ["backpropagation", "training"]},
),
TechArticle(
id="transformers_architecture",
content=(
"The Transformer architecture, introduced by Vaswani et al. in 2017, "
"relies entirely on self-attention mechanisms instead of recurrent layers. "
"Multi-head attention allows the model to attend to information from "
"different representation subspaces simultaneously. Positional encodings "
"inject sequence-order information. Transformers power large language "
"models such as GPT-4, BERT, and T5."
),
category="nlp",
author="NLP Research Team",
published_year=2022,
metadata={"difficulty": "intermediate", "tags": ["transformers", "attention"]},
),
TechArticle(
id="vector_databases",
content=(
"Vector databases are purpose-built systems for storing and querying "
"high-dimensional embedding vectors. They use approximate nearest "
"neighbour (ANN) algorithms — such as HNSW, IVF, or PQ — to return "
"semantically similar documents in milliseconds. Popular options include "
"Pinecone, Weaviate, Qdrant, Azure AI Search, Cosmos DB, and MongoDB "
"Atlas. They are essential infrastructure for RAG applications."
),
category="infrastructure",
author="Data Engineering Team",
published_year=2024,
metadata={"difficulty": "beginner", "tags": ["vector_db", "RAG"]},
),
TechArticle(
id="retrieval_augmented_generation",
content=(
"Retrieval Augmented Generation (RAG) combines a retrieval component with "
"a generative language model. When a user poses a question, a retriever "
"fetches relevant documents from a knowledge base; the generator then "
"conditions its response on both the question and the retrieved context. "
"RAG reduces hallucinations, supports domain-specific knowledge, and "
"enables the model to cite sources."
),
category="nlp",
author="AI Research Team",
published_year=2024,
metadata={"difficulty": "intermediate", "tags": ["RAG", "LLM"]},
),
TechArticle(
id="convolutional_networks",
content=(
"Convolutional Neural Networks (CNNs) use learnable convolutional filters "
"to extract local spatial features from grid-structured data like images. "
"Pooling layers reduce spatial dimensions while retaining dominant features. "
"Architectures such as ResNet, VGG, and EfficientNet have set benchmarks "
"on ImageNet. CNNs are also applied to 1-D signals in speech processing "
"and time-series analysis."
),
category="computer_vision",
author="Vision Research Team",
published_year=2022,
metadata={"difficulty": "intermediate", "tags": ["CNN", "computer_vision"]},
),
TechArticle(
id="reinforcement_learning",
content=(
"Reinforcement learning (RL) trains agents to maximise cumulative reward "
"through interaction with an environment. An agent selects actions, "
"observes state transitions, and receives reward signals. Deep RL combines "
"neural networks with RL algorithms — exemplified by DQN, PPO, and SAC. "
"AlphaGo and AlphaFold are landmark applications. RL has also been used "
"to fine-tune language models through RLHF."
),
category="deep_learning",
author="RL Research Team",
published_year=2023,
metadata={"difficulty": "advanced", "tags": ["reinforcement_learning", "agents"]},
),
TechArticle(
id="cosmos_db_vector_search",
content=(
"Azure Cosmos DB for NoSQL API supports native vector search using the "
"VectorDistance() SQL function and a container-level vector embedding policy. "
"Developers define the embedding path, distance function (cosine, euclidean, "
"or dot product), and index type (flat, quantizedFlat, or diskANN) at "
"container creation time. Vector queries are expressed in standard NoSQL SQL "
"and can be combined with scalar filters, making it easy to add semantic "
"search to existing Cosmos DB workloads without a separate vector store."
),
category="infrastructure",
author="Data Engineering Team",
published_year=2024,
metadata={"difficulty": "beginner", "tags": ["cosmos_db", "vector_db", "RAG"]},
),
]
logger.info("Generating embeddings", count=len(raw))
texts = [d.content for d in raw]
embeddings = embedder.embed_batch(texts)
for doc, emb in zip(raw, embeddings):
doc.embedding = emb
logger.info("Embeddings generated", dimension=len(embeddings[0]))
return raw
# ══════════════════════════════════════════════════════════════════════════════
# Example 1: Index documents
# ══════════════════════════════════════════════════════════════════════════════
def example_index_documents(store: AzureCosmosDBVectorStore, docs: list) -> None:
logger.info("Starting example", example="1: Index Documents")
logger.info("Scenario",
problem="Documents must be embedded and stored before they can be retrieved",
solution="embed_batch() all content, then add_documents() to the store")
added = store.add_documents(docs)
logger.info("Indexed documents", added=added, total=store.count())
# ══════════════════════════════════════════════════════════════════════════════
# Example 2: Vector search
# ══════════════════════════════════════════════════════════════════════════════
def example_vector_search(store: AzureCosmosDBVectorStore, embedder: AzureOpenAIEmbeddings) -> None:
logger.info("Starting example", example="2: Vector Search")
logger.info("Scenario",
problem="Find semantically similar documents to a natural-language query",
solution="Embed the query, then search with search_type='vector'",
score_note="similarity = 1 − VectorDistance (cosine); rank=0 is always the closest match")
query = "backpropagation training gradients"
logger.info("Query", text=query)
query_emb = embedder.embed_text(query)
results = store.search(query_embedding=query_emb, top_k=3, search_type="vector")
for r in results:
art: TechArticle = r.document
logger.info("Vector result", rank=r.rank, id=art.id, similarity=round(r.score, 4),
category=art.category, year=art.published_year)
# ══════════════════════════════════════════════════════════════════════════════
# Example 3: Keyword search
# ══════════════════════════════════════════════════════════════════════════════
def example_keyword_search(store: AzureCosmosDBVectorStore) -> None:
logger.info("Starting example", example="3: Keyword Search")
logger.info("Scenario",
problem="Find documents containing specific terms (exact lexical match)",
solution="Pass query text only with search_type='keyword'")
query = "RAG retrieval language model"
logger.info("Query", text=query)
results = store.search(query=query, top_k=3, search_type="keyword")
for r in results:
art: TechArticle = r.document
logger.info("Keyword result", rank=r.rank, id=art.id, score=round(r.score, 4),
category=art.category)
# ══════════════════════════════════════════════════════════════════════════════
# Example 4: Hybrid search
# ══════════════════════════════════════════════════════════════════════════════
def example_hybrid_search(store: AzureCosmosDBVectorStore, embedder: AzureOpenAIEmbeddings) -> None:
logger.info("Starting example", example="4: Hybrid Search")
logger.info("Scenario",
problem="Balance semantic relevance with precise keyword matching",
solution="Provide both query text and query_embedding with search_type='hybrid'")
query = "vector database ANN HNSW"
logger.info("Query", text=query)
query_emb = embedder.embed_text(query)
results = store.search(
query=query,
query_embedding=query_emb,
top_k=3,
search_type="hybrid",
)
for r in results:
art: TechArticle = r.document
logger.info("Hybrid result", rank=r.rank, id=art.id, score=round(r.score, 4),
category=art.category)
# ══════════════════════════════════════════════════════════════════════════════
# Example 5: Filtered vector search
# ══════════════════════════════════════════════════════════════════════════════
def example_filtered_search(store: AzureCosmosDBVectorStore, embedder: AzureOpenAIEmbeddings) -> None:
logger.info("Starting example", example="5: Filtered Vector Search")
logger.info("Scenario",
problem="Restrict results to a specific document subset before scoring",
solution="Pass a filters dict alongside query_embedding",
score_note="similarity = 1 − VectorDistance (cosine); rank=0 is always the closest match")
query = "neural network deep learning"
logger.info("Query", text=query, filter="category=deep_learning")
query_emb = embedder.embed_text(query)
results = store.search(
query_embedding=query_emb,
top_k=5,
filters={"category": "deep_learning"},
search_type="vector",
)
for r in results:
art: TechArticle = r.document
logger.info("Filtered result", rank=r.rank, id=art.id, category=art.category,
similarity=round(r.score, 4))
# ══════════════════════════════════════════════════════════════════════════════
# Example 6: CRUD operations
# ══════════════════════════════════════════════════════════════════════════════
def example_crud(store: AzureCosmosDBVectorStore) -> None:
logger.info("Starting example", example="6: CRUD — get / update / delete")
# get
doc_id = "backpropagation_explained"
fetched: TechArticle = store.get_document(doc_id)
if fetched:
logger.info("Fetched document", id=fetched.id, author=fetched.author,
category=fetched.category)
# update metadata
store.update_document(doc_id, metadata={"reviewed": True})
updated: TechArticle = store.get_document(doc_id)
logger.info("After metadata update", id=updated.id, metadata=updated.metadata)
# delete
deleted = store.delete_documents(["reinforcement_learning"])
logger.info("Deleted documents", deleted=deleted, remaining=store.count())
# ══════════════════════════════════════════════════════════════════════════════
# Example 7: End-to-End RAG Pipeline (Chunk → Embed → Index → Retrieve)
# ══════════════════════════════════════════════════════════════════════════════
def example_rag_pipeline(embedder: AzureOpenAIEmbeddings) -> None:
"""Demonstrate a complete RAG ingestion pipeline using SentenceChunker and SupportTicket schema."""
logger.info("Starting example", example="7: End-to-End RAG Pipeline")
logger.info("Scenario",
description="Ingest multiple support tickets through the full RAG pipeline",
pipeline="tickets → SentenceChunker → embed_batch → add_documents(SupportTicket) → vector_search",
why="SentenceChunker preserves sentence boundaries so each chunk is self-contained and readable",
note="SupportTicket index schema differs from TechArticle — demonstrates separate container per domain")
# ── Dedicated store for support ticket container ─────────────────────────────
endpoint = os.getenv("AZURE_COSMOS_ENDPOINT")
key = os.getenv("AZURE_COSMOS_KEY")
database_name = os.getenv("AZURE_COSMOS_DATABASE", "rag_db")
embedding_dim = int(os.getenv("AZURE_AI_SEARCH_EMBEDDING_DIMENSION", "1536"))
if not endpoint or not key:
logger.error("AZURE_COSMOS_ENDPOINT / AZURE_COSMOS_KEY not set — skipping RAG pipeline example")
return
CosmosDBIndexBuilder(
endpoint=endpoint,
api_key=key,
database_name=database_name,
container_name="support_tickets",
embedding_dimension=embedding_dim,
ssl_cert_path=get_ssl_cert_path(),
).create_index()
pipeline_store = AzureCosmosDBVectorStore(
endpoint=endpoint,
key=key,
database_name=database_name,
container_name="support_tickets",
embedding_dimension=embedding_dim,
document_type=SupportTicket,
ssl_cert_path=get_ssl_cert_path(),
)
logger.info("Cosmos DB pipeline store",
container="support_tickets", document_type="SupportTicket",
custom_fields=["severity", "product", "status"],
note="Separate container from 'tech_articles' — each domain gets its own schema")
# ── Source tickets (each with its own severity, product, and status) ────────
tickets = [
{
"id": "TICKET-4501",
"severity": "critical",
"product": "azure_openai",
"status": "open",
"text": (
"Our production RAG pipeline stopped returning results after the latest Azure OpenAI "
"deployment update. The embedding endpoint returns HTTP 429 throttling errors under normal "
"load that was well within our quota yesterday. Rolling back to the previous deployment "
"version has no effect. We suspect the rate-limit configuration was reset during the update. "
"This is blocking all customer-facing search functionality and needs urgent attention."
),
},
{
"id": "TICKET-4502",
"severity": "high",
"product": "cosmos_db",
"status": "in_progress",
"text": (
"Vector search queries on our Cosmos DB for MongoDB vCore cluster are returning results "
"with unexpectedly low similarity scores since we increased the embedding dimension from "
"1536 to 3072. The HNSW index was recreated with the new dimension, but the similarity "
"scores dropped by roughly 40 percent. We verified that the embeddings stored in the "
"collection match the new dimension. Investigating whether the index efConstruction or M "
"parameters need tuning for the higher-dimensional vectors."
),
},
{
"id": "TICKET-4503",
"severity": "medium",
"product": "azure_ai_search",
"status": "open",
"text": (
"Hybrid search in Azure AI Search is returning duplicate results when the same document "
"scores highly in both the vector and keyword components. Our current workaround is "
"client-side deduplication, but we would prefer a server-side solution. We are using the "
"2024-07-01 API version with semantic ranker enabled. The duplicates appear only when "
"the query text closely matches the indexed content verbatim."
),
},
{
"id": "TICKET-4504",
"severity": "low",
"product": "cosmos_db",
"status": "resolved",
"text": (
"Request to increase the default connection pool size for the Cosmos DB Python SDK from "
"100 to 250 connections. Under sustained batch-indexing workloads the pool is exhausted "
"and new requests queue behind idle connections. We have confirmed that the vCore cluster "
"can handle the additional connections without hitting its own limits. This is a performance "
"optimisation rather than a functional issue."
),
},
]
logger.info("Source tickets", count=len(tickets),
severities=[t["severity"] for t in tickets],
products=[t["product"] for t in tickets])
query = "Cosmos DB vector search low similarity scores after dimension change"
tracer = get_tracer()
with tracer.trace("rag_pipeline", input=query) as trace:
with trace.span("chunking") as span:
# ── Step 1: Chunk ──────────────────────────────────────────────────────
logger.info("Step 1 - Chunking",
strategy="SentenceChunker", max_chunk_size=400, tokenizer="nltk.sent_tokenize")
nltk.download("punkt_tab", quiet=True)
chunker = SentenceChunker(
sentence_tokenizer=nltk.sent_tokenize,
max_chunk_size=400,
)
all_chunks_per_ticket: list = []
for ticket in tickets:
chunks = chunker.chunk(
ticket["text"],
metadata={"source": ticket["id"], "product": ticket["product"]},
)
logger.info("Chunked ticket",
ticket_id=ticket["id"], severity=ticket["severity"],
num_chunks=len(chunks),
avg_chars=sum(len(c.text) for c in chunks) // max(len(chunks), 1))
all_chunks_per_ticket.append((ticket, chunks))
span.set_output({"documents": len(tickets), "total_chunks": sum(len(c) for _, c in all_chunks_per_ticket)})
# ── Step 2: Embed ──────────────────────────────────────────────────────────
with trace.span("embedding") as span:
all_docs: list = []
for ticket, chunks in all_chunks_per_ticket:
texts = [c.text for c in chunks]
logger.info("Step 2 - Embedding", ticket_id=ticket["id"], chunks=len(texts))
embeddings = embedder.embed_batch(texts)
# ── Build typed documents ──────────────────────────────────────────
for i, (chunk, emb) in enumerate(zip(chunks, embeddings)):
all_docs.append(
SupportTicket(
id=f"{ticket['id']}_chunk_{i}",
content=chunk.text,
embedding=emb,
severity=ticket["severity"],
product=ticket["product"],
status=ticket["status"],
metadata={**chunk.metadata, "chunk_index": i,
"char_start": chunk.start_pos, "char_end": chunk.end_pos},
)
)
span.set_output({"chunks": len(all_docs)})
# ── Step 3: Index ──────────────────────────────────────────────────────────
with trace.span("indexing") as span:
logger.info("Step 3 - Indexing", document_type="SupportTicket", total_chunks=len(all_docs))
pipeline_store.add_documents(all_docs, generate_embeddings=False)
logger.info("Chunks indexed", count=len(all_docs))
span.set_output({"chunks_indexed": len(all_docs)})
# ── Step 4: Retrieve ───────────────────────────────────────────────────────
with trace.span("retrieval", input=query) as span:
logger.info("Step 4 - Retrieval", query=query, top_k=3)
query_embedding = embedder.embed_text(query)
results = pipeline_store.search(query_embedding=query_embedding, top_k=3)
logger.info("Retrieved chunks", returned=len(results))
for r in results:
ticket: SupportTicket = r.document
snippet = ticket.content[:100].replace("\n", " ")
logger.info("Retrieved chunk",
rank=r.rank, id=ticket.id, similarity=round(r.score, 4),
severity=ticket.severity, product=ticket.product, status=ticket.status,
snippet=snippet)
span.set_output({"results": len(results), "top_score": round(results[0].score, 4) if results else 0})
trace.set_output({"chunks_indexed": len(all_docs), "query_results": len(results)})
pipeline_store.close()
logger.info("RAG pipeline complete",
chunks_indexed=len(all_docs),
query_results=len(results),
tip="Filter by severity='critical' or product='cosmos_db' to narrow support ticket search results")
# ══════════════════════════════════════════════════════════════════════════════
# Main
# ══════════════════════════════════════════════════════════════════════════════
def main() -> None:
logger.info("COSMOS DB NOSQL API VECTOR STORE EXAMPLES",
note="Demonstrating vector, keyword, hybrid search and CRUD with AzureCosmosDBVectorStore")
# ── Embedder (required) ──────────────────────────────────────────────────
embedder = get_embedder()
if not embedder:
logger.error("Cannot run examples without embedder — check .env")
return
# ── Cosmos DB connection ─────────────────────────────────────────────────
endpoint = os.getenv("AZURE_COSMOS_ENDPOINT")
key = os.getenv("AZURE_COSMOS_KEY")
database_name = os.getenv("AZURE_COSMOS_DATABASE", "rag_db")
container_name = os.getenv("AZURE_COSMOS_CONTAINER", "tech_articles")
if not endpoint or not key:
logger.error(
"AZURE_COSMOS_ENDPOINT / AZURE_COSMOS_KEY not set in .env",
required_vars=["AZURE_COSMOS_ENDPOINT", "AZURE_COSMOS_KEY"],
optional_vars=["AZURE_COSMOS_DATABASE", "AZURE_COSMOS_CONTAINER"],
)
return
embedding_dim = int(os.getenv("AZURE_AI_SEARCH_EMBEDDING_DIMENSION", "1536"))
logger.info("Connecting to Cosmos DB",
database=database_name, container=container_name)
try:
CosmosDBIndexBuilder(
endpoint=endpoint,
api_key=key,
database_name=database_name,
container_name=container_name,
embedding_dimension=embedding_dim,
ssl_cert_path=get_ssl_cert_path(),
).create_index()
store = AzureCosmosDBVectorStore(
endpoint=endpoint,
key=key,
database_name=database_name,
container_name=container_name,
embedding_dimension=embedding_dim,
document_type=TechArticle,
ssl_cert_path=get_ssl_cert_path(),
)
except RuntimeError as exc:
logger.error("Failed to connect to Cosmos DB", error=str(exc))
return
except Exception as exc:
logger.error("Unexpected error connecting to Cosmos DB", error=str(exc))
return
logger.info("Connected", documents=store.count())
# ── Build documents ───────────────────────────────────────────────────────
docs = build_sample_documents(embedder)
# ── Run examples ──────────────────────────────────────────────────────────
print()
example_index_documents(store, docs)
print()
example_vector_search(store, embedder)
print()
example_keyword_search(store)
print()
example_hybrid_search(store, embedder)
print()
example_filtered_search(store, embedder)
print()
example_crud(store)
print()
example_rag_pipeline(embedder)
print()
store.close()
logger.info("SUMMARY")
logger.info("AzureCosmosDBVectorStore",
description="Vector search on Azure Cosmos DB NoSQL API",
search_types=["vector", "keyword", "hybrid"],
filtering="Pass filters dict to any search call",
crud="get_document / update_document / delete_documents",
rag_pipeline="Chunk → embed_batch → add_documents → vector_search")
logger.info("Production Tips",
tip_1="Use CosmosDBIndexBuilder to provision the container before first use — keep schema management separate from document operations",
tip_2="Use hybrid search as the default — best balance of semantic + lexical recall",
tip_3="Keep embedding_dimension consistent with your Azure OpenAI deployment",
tip_4="Reuse the embedder instance across examples to avoid re-initialising the HTTP client",
tip_5="SentenceChunker(nltk.sent_tokenize, max_chunk_size=400) keeps chunks at clean sentence boundaries")
if __name__ == "__main__":
main()
document_intelligence_layout_example.py
"""
Example: Document Intelligence Layout (Module 8)
Demonstrates how to use the DocumentIntelligenceLayout class to analyse
documents and feed the markdown output into the RAG pipeline:
1. analyze_file — Analyse a local PDF/DOCX/image and get markdown back
2. analyze_bytes — Analyse raw bytes (e.g. from BlobStorageConnector)
3. analyze_url — Analyse a publicly accessible document URL
4. Full Pipeline — analyze → MarkdownChunker → embed → index → retrieve
How this connects to existing examples:
- The LayoutResult.markdown is passed directly to MarkdownChunker, which
splits on headers produced by Document Intelligence.
- Chunks are embedded and stored using the same AzureOpenAIEmbeddings and
AzureAISearchVectorStore patterns used in all other examples.
- Uses the same .env credentials as connectors_example.py and
retrieval_example.py.
Cost awareness:
- Azure Document Intelligence (prebuilt-layout) is billed per page.
- This example tracks total pages analysed via BasicMetricsCollector and
reports an estimated cost at completion. Update _COST_PER_PAGE if your
pricing tier differs (see https://azure.microsoft.com/pricing/details/ai-document-intelligence/).
Prerequisites:
- Azure Document Intelligence endpoint in .env
(AZURE_DOCUMENT_INTELLIGENCE_ENDPOINT)
Authentication is via DefaultAzureCredential (managed identity).
Locally: run az login before executing this example.
In Azure: assign the "Cognitive Services User" role to the workload's
managed identity on the Document Intelligence resource.
- Azure OpenAI + Azure AI Search credentials in .env for the full pipeline
(AZURE_OPENAI_ENDPOINT, AZURE_OPENAI_API_KEY, AZURE_OPENAI_EMBEDDING_MODEL,
AZURE_AI_SEARCH_ENDPOINT, AZURE_AI_SEARCH_API_KEY)
Optional env vars:
DOCUMENT_INTELLIGENCE_SAMPLE_FILE — path to local document used by
Examples 1, 2, and 4 (PDF/DOCX/PPTX/
XLSX or image). Defaults to
sample_document.pdf in this directory.
DOCUMENT_INTELLIGENCE_SAMPLE_URL — URL used by Example 3. Defaults to
a public Azure sample PDF.
"""
import os
import sys
import time
from datetime import datetime
from pathlib import Path
from typing import Optional
from dotenv import load_dotenv
from gmf_forge_ai_data.layout import DocumentIntelligenceLayout, LayoutResult
from gmf_forge_ai_data.chunkers import MarkdownChunker
from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings
from gmf_forge_ai_data.vector_stores import AzureAISearchVectorStore, Document
from gmf_forge_ai_data.indexing import AzureAISearchIndexBuilder
from gmf_forge_ai_data.retrieval import RetrievalQuery, VectorRetriever
from gmf_forge_ai_shared_core.observability import BasicLogger, BasicMetricsCollector
from gmf_forge_ai_shared_core.observability.tracing import get_tracer
# ── Environment ───────────────────────────────────────────────────────────────
env_path = Path(__file__).parent / ".env"
load_dotenv(env_path)
WORKSPACE_ROOT = Path(__file__).parent.parent.parent.parent
CORPORATE_CERT = WORKSPACE_ROOT / "certs" / "gmf_and_public_cas.pem"
logger = BasicLogger(__name__)
metrics = BasicMetricsCollector()
# Azure DI prebuilt-layout pricing (S0 tier, pay-as-you-go).
# Update this if you are on a commitment tier or if pricing has changed.
# https://azure.microsoft.com/pricing/details/ai-document-intelligence/
_COST_PER_PAGE_USD = 0.001
# ── Config helpers ────────────────────────────────────────────────────────────
def get_ssl_cert_path() -> Optional[str]:
return str(CORPORATE_CERT) if CORPORATE_CERT.exists() else None
def get_sample_path() -> Path:
"""Return the sample document path.
Reads ``DOCUMENT_INTELLIGENCE_SAMPLE_FILE`` from the environment, which
should be an absolute or relative path to any supported document
(PDF, DOCX, PPTX, XLSX, or image). Falls back to
``<examples_dir>/sample_document.pdf`` when the variable is not set.
"""
file_env = os.getenv("DOCUMENT_INTELLIGENCE_SAMPLE_FILE")
if file_env:
p = Path(file_env)
if not p.is_absolute():
p = Path(__file__).parent / p
return p
return Path(__file__).parent / "sample_document.pdf"
def get_sample_url() -> str:
"""Return the sample document URL.
Reads ``DOCUMENT_INTELLIGENCE_SAMPLE_URL`` from the environment.
Falls back to a public Azure sample PDF when the variable is not set.
Replace the default with an internal Blob SAS URL or SharePoint
direct-download link for real use.
"""
return os.getenv(
"DOCUMENT_INTELLIGENCE_SAMPLE_URL",
"https://raw.githubusercontent.com/Azure-Samples/cognitive-services-REST-api-samples/master/curl/form-recognizer/sample-layout.pdf",
)
def get_layout_client() -> Optional[DocumentIntelligenceLayout]:
endpoint = os.getenv("AZURE_DOCUMENT_INTELLIGENCE_ENDPOINT")
if not endpoint:
logger.warning(
"AZURE_DOCUMENT_INTELLIGENCE_ENDPOINT not set in .env — skipping",
)
return None
# If AZURE_DOCUMENT_INTELLIGENCE_KEY is set, use API key auth.
# Otherwise fall back to DefaultAzureCredential (managed identity /
# az login). Cognitive Services Multiservices accounts require the
# managed identity path; standalone DI resources can use either.
api_key = os.getenv("AZURE_DOCUMENT_INTELLIGENCE_KEY")
try:
client = DocumentIntelligenceLayout(
endpoint=endpoint,
api_key=api_key or None,
logger=logger,
)
logger.info(
"Document Intelligence client ready",
auth="api_key" if api_key else "DefaultAzureCredential (managed identity)",
)
return client
except ImportError as exc:
logger.warning(
"Required package not installed — skipping all Document Intelligence examples",
error=str(exc),
)
return None
def get_embedder() -> Optional[AzureOpenAIEmbeddings]:
endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
api_key = os.getenv("AZURE_OPENAI_API_KEY")
deployment = os.getenv("AZURE_OPENAI_EMBEDDING_MODEL")
api_version = os.getenv("AZURE_OPENAI_EMBEDDING_MODEL_VERSION", "2024-02-01")
if not endpoint or not api_key or not deployment:
logger.warning("Azure OpenAI embedding credentials not found in .env — skipping")
return None
return AzureOpenAIEmbeddings(
endpoint=endpoint,
api_key=api_key,
deployment_name=deployment,
api_version=api_version,
ssl_cert_path=get_ssl_cert_path(),
logger=logger,
)
def get_vector_store(index_name: str) -> Optional[AzureAISearchVectorStore]:
endpoint = os.getenv("AZURE_AI_SEARCH_ENDPOINT")
api_key = os.getenv("AZURE_AI_SEARCH_API_KEY")
emb_dim = int(os.getenv("AZURE_AI_SEARCH_EMBEDDING_DIMENSION", "1536"))
if not endpoint or not api_key:
logger.warning("Azure AI Search credentials not found in .env — skipping")
return None
return AzureAISearchVectorStore(
endpoint=endpoint,
index_name=index_name,
api_key=api_key,
embedding_dimension=emb_dim,
)
def log_result(result: LayoutResult, latency_ms: float = 0.0) -> None:
"""Log a LayoutResult and record per-call metrics."""
metrics.increment("layout.analyze.calls")
metrics.increment("layout.analyze.pages", value=result.page_count)
if latency_ms > 0:
metrics.histogram("layout.analyze.latency_ms", latency_ms)
estimated_cost = result.page_count * _COST_PER_PAGE_USD
metrics.increment("layout.analyze.estimated_cost_usd", value=estimated_cost)
logger.info(
"Layout analysis complete",
page_count=result.page_count,
markdown_length=len(result.markdown),
source=result.metadata.get("source", ""),
model_id=result.metadata.get("model_id", ""),
analyzed_at=result.metadata.get("analyzed_at", ""),
latency_ms=round(latency_ms, 1),
estimated_cost_usd=round(estimated_cost, 4),
)
preview = result.markdown[:400].replace("\n", " ↵ ")
logger.info("Markdown preview", preview=preview)
# ── Example 1: Analyse a local file ──────────────────────────────────────────
def example_analyze_file():
"""Analyse a local PDF or DOCX and print the markdown output."""
logger.info("=== Example 1: analyze_file ===")
layout = get_layout_client()
if not layout:
return
# Replace with the path to any local PDF, DOCX, PPTX, XLSX, or image file,
# or set DOCUMENT_INTELLIGENCE_SAMPLE_FILE in .env
sample_path = get_sample_path()
if not sample_path.exists():
logger.warning(
"Sample file not found — skipping Example 1",
path=str(sample_path),
hint="Set DOCUMENT_INTELLIGENCE_SAMPLE_FILE in .env or place a PDF at the path above",
)
return
tracer = get_tracer()
with tracer.trace("layout.analyze_file", input=str(sample_path)) as trace:
t0 = time.perf_counter()
try:
result = layout.analyze_file(sample_path)
latency_ms = (time.perf_counter() - t0) * 1000
metrics.increment("layout.analyze.success")
trace.set_output({
"pages": result.page_count,
"markdown_chars": len(result.markdown),
"source": result.metadata.get("source", ""),
})
except Exception as exc:
metrics.increment("layout.analyze.errors")
logger.warning("analyze_file failed — skipping", error=str(exc))
return
log_result(result, latency_ms)
# ── Example 2: Analyse raw bytes ─────────────────────────────────────────────
def example_analyze_bytes():
"""
Analyse raw document bytes — mirrors how BlobStorageConnector returns content.
In a real pipeline the bytes would come from BlobStorageConnector or
SharePointConnector. Here we read a local file to simulate that.
"""
logger.info("=== Example 2: analyze_bytes ===")
layout = get_layout_client()
if not layout:
return
# Reads DOCUMENT_INTELLIGENCE_SAMPLE_FILE from .env, or falls back to
# sample_document.pdf in the examples directory.
sample_path = get_sample_path()
if not sample_path.exists():
logger.warning(
"Sample file not found — skipping Example 2",
path=str(sample_path),
hint="Set DOCUMENT_INTELLIGENCE_SAMPLE_FILE in .env or place a PDF at the path above",
)
return
with open(sample_path, "rb") as fh:
content = fh.read()
tracer = get_tracer()
with tracer.trace("layout.analyze_bytes", input=sample_path.name) as trace:
t0 = time.perf_counter()
try:
result = layout.analyze_bytes(content, filename=sample_path.name)
latency_ms = (time.perf_counter() - t0) * 1000
metrics.increment("layout.analyze.success")
trace.set_output({
"pages": result.page_count,
"bytes_in": len(content),
"markdown_chars": len(result.markdown),
})
except Exception as exc:
metrics.increment("layout.analyze.errors")
logger.warning("analyze_bytes failed — skipping", error=str(exc))
return
log_result(result, latency_ms)
# ── Example 3: Analyse a URL ─────────────────────────────────────────────────
def example_analyze_url():
"""Analyse a publicly accessible document URL."""
logger.info("=== Example 3: analyze_url ===")
layout = get_layout_client()
if not layout:
return
# URL read from DOCUMENT_INTELLIGENCE_SAMPLE_URL in .env, or falls back
# to a public Azure sample PDF. Replace with an internal Blob SAS URL
# or SharePoint direct-download link for real use.
url = get_sample_url()
tracer = get_tracer()
with tracer.trace("layout.analyze_url", input=url) as trace:
t0 = time.perf_counter()
try:
result = layout.analyze_url(url)
latency_ms = (time.perf_counter() - t0) * 1000
metrics.increment("layout.analyze.success")
trace.set_output({
"pages": result.page_count,
"markdown_chars": len(result.markdown),
"url": url,
})
except Exception as exc:
metrics.increment("layout.analyze.errors")
logger.warning("analyze_url failed — skipping", error=str(exc))
return
log_result(result, latency_ms)
# ── Example 4: Full pipeline ─────────────────────────────────────────────────
def example_full_pipeline():
"""
Full RAG ingest pipeline:
analyze_file → MarkdownChunker → embed → index (Azure AI Search) → retrieve
This is the recommended pattern for document ingestion with Document Intelligence.
The markdown produced by Azure DI preserves headers, tables, and page structure,
which MarkdownChunker uses to create semantically meaningful chunks.
"""
logger.info("=== Example 4: Full pipeline — analyze → chunk → embed → index → retrieve ===")
layout = get_layout_client()
embedder = get_embedder()
if not layout or not embedder:
logger.warning("Skipping full pipeline — missing credentials")
return
sample_path = get_sample_path()
if not sample_path.exists():
logger.warning(
"Sample file not found — skipping Example 4",
path=str(sample_path),
hint="Set DOCUMENT_INTELLIGENCE_SAMPLE_FILE in .env or place a PDF at the path above",
)
return
tracer = get_tracer()
try:
with tracer.trace("layout.full_pipeline", input=str(sample_path)) as trace:
# ── Step 1: Analyse document → markdown ───────────────────────────
with trace.span("document_intelligence_analyze", input=str(sample_path)) as span:
logger.info("Step 1: Analysing document with Document Intelligence")
t0 = time.perf_counter()
result = layout.analyze_file(sample_path)
latency_ms = (time.perf_counter() - t0) * 1000
metrics.increment("layout.analyze.calls")
metrics.increment("layout.analyze.pages", value=result.page_count)
metrics.increment("layout.analyze.success")
metrics.histogram("layout.analyze.latency_ms", latency_ms)
metrics.increment(
"layout.analyze.estimated_cost_usd",
value=result.page_count * _COST_PER_PAGE_USD,
)
span.set_output({
"pages": result.page_count,
"markdown_chars": len(result.markdown),
"latency_ms": round(latency_ms, 1),
"estimated_cost_usd": round(result.page_count * _COST_PER_PAGE_USD, 4),
})
logger.info(
"Step 1 complete",
pages=result.page_count,
markdown_chars=len(result.markdown),
latency_ms=round(latency_ms, 1),
estimated_cost_usd=round(result.page_count * _COST_PER_PAGE_USD, 4),
)
# ── Step 2: Chunk the markdown at header boundaries ───────────────
with trace.span("markdown_chunker") as span:
logger.info("Step 2: Chunking markdown with MarkdownChunker")
# Azure DI output has well-formed headers — split at level-2 headings
# to keep sections together, with a 1500-char size cap per chunk.
chunker = MarkdownChunker(max_chunk_size=1500, min_header_level=2)
chunks = chunker.chunk(result.markdown, metadata=result.metadata)
metrics.gauge("pipeline.chunks_created", len(chunks))
span.set_output({"chunks": len(chunks)})
logger.info("Step 2 complete", chunks=len(chunks))
for i, chunk in enumerate(chunks[:3]):
preview = chunk.text[:120].replace("\n", " ↵ ")
logger.info(f" chunk[{i}]", chars=len(chunk.text), preview=preview)
# ── Step 3: Embed each chunk ───────────────────────────────────────
with trace.span("embedding") as span:
logger.info("Step 3: Embedding chunks")
chunk_texts = [c.text for c in chunks]
embeddings = embedder.embed_batch(chunk_texts)
metrics.gauge("pipeline.embeddings_created", len(embeddings))
span.set_output({
"count": len(embeddings),
"dimension": len(embeddings[0]),
})
logger.info(
"Step 3 complete",
count=len(embeddings),
dimension=len(embeddings[0]),
)
# ── Step 4: Index into Azure AI Search ────────────────────────────
store = get_vector_store("doc-intelligence-layout")
if not store:
logger.warning("Skipping index/retrieve steps — no vector store configured")
return
search_endpoint = os.getenv("AZURE_AI_SEARCH_ENDPOINT")
search_api_key = os.getenv("AZURE_AI_SEARCH_API_KEY")
emb_dim = int(os.getenv("AZURE_AI_SEARCH_EMBEDDING_DIMENSION", "1536"))
if search_endpoint and search_api_key:
builder = AzureAISearchIndexBuilder(
endpoint=search_endpoint,
api_key=search_api_key,
index_name="doc-intelligence-layout",
embedding_dimension=emb_dim,
ssl_cert_path=get_ssl_cert_path(),
)
builder.create_index()
with trace.span("indexing") as span:
logger.info("Step 4: Indexing documents")
documents = [
Document(
id=f"{sample_path.stem}_chunk_{i}",
content=chunk.text,
embedding=embedding,
timestamp=datetime.now(),
metadata={**chunk.metadata, **result.metadata},
)
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings))
]
store.add_documents(documents)
metrics.increment("pipeline.documents_indexed", value=len(documents))
span.set_output({"indexed": len(documents), "index": "doc-intelligence-layout"})
logger.info(
"Step 4 complete",
documents_indexed=len(documents),
index="doc-intelligence-layout",
)
# ── Step 5: Retrieve ───────────────────────────────────────────────
with trace.span("retrieval") as span:
logger.info("Step 5: Retrieving with a test query")
query_text = "What are the key policies described in this document?"
query_embedding = embedder.embed_text(query_text)
retriever = VectorRetriever(vector_store=store)
results = retriever.retrieve(
RetrievalQuery(text=query_text, embedding=query_embedding, top_k=3)
)
metrics.gauge("pipeline.retrieval_results", len(results))
span.set_output({
"results": len(results),
"top_score": round(results[0].score, 4) if results else 0,
})
logger.info("Step 5 complete", results=len(results))
for res in results:
preview = res.document.content[:120].replace("\n", " ↵ ")
logger.info(
f" result[rank={res.rank}]",
score=round(res.score, 4),
preview=preview,
)
trace.set_output({
"pages_analysed": result.page_count,
"chunks": len(chunks),
"indexed": len(documents),
"retrieved": len(results),
"estimated_cost_usd": round(result.page_count * _COST_PER_PAGE_USD, 4),
})
except Exception as exc:
metrics.increment("layout.analyze.errors")
logger.warning("Full pipeline failed — skipping", error=str(exc))
# ── Entry point ───────────────────────────────────────────────────────────────
def main():
logger.info(
"Document Intelligence Layout Examples",
description="analyze_file / analyze_bytes / analyze_url / full pipeline",
)
example_analyze_file()
example_analyze_bytes()
example_analyze_url()
example_full_pipeline()
# ── Metrics summary ───────────────────────────────────────────────────
perf = metrics.get_metrics()
counters = perf.get("counters", {})
gauges = perf.get("gauges", {})
histograms = perf.get("histograms", {})
total_calls = counters.get("layout.analyze.calls", 0)
total_pages = counters.get("layout.analyze.pages", 0)
total_errors = counters.get("layout.analyze.errors", 0)
total_cost = total_pages * _COST_PER_PAGE_USD
logger.info(
"Metrics Summary",
di_calls=total_calls,
di_pages_analysed=total_pages,
di_errors=total_errors,
estimated_total_cost_usd=round(total_cost, 4),
cost_per_page_usd=_COST_PER_PAGE_USD,
)
if histograms.get("layout.analyze.latency_ms"):
latencies = histograms["layout.analyze.latency_ms"]
avg_ms = sum(latencies) / len(latencies)
logger.info(
"Latency Summary",
calls=len(latencies),
avg_ms=round(avg_ms, 1),
min_ms=round(min(latencies), 1),
max_ms=round(max(latencies), 1),
)
if counters:
logger.info("All Counters", **{k: v for k, v in sorted(counters.items())})
if gauges:
logger.info("All Gauges", **{k: round(v, 4) for k, v in sorted(gauges.items())})
logger.info("All examples complete")
if __name__ == "__main__":
main()
embeddings_example.py
"""
Example: Basic usage of Azure OpenAI embeddings.
This example demonstrates how to use the embeddings module to:
1. Generate embeddings for single texts
2. Process batches of texts efficiently
3. Use optional logging with shared-core
"""
import os
from pathlib import Path
from dotenv import load_dotenv
from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings, BatchEmbeddings
# Load environment variables from .env file
env_path = Path(__file__).parent / ".env"
load_dotenv(dotenv_path=env_path)
# Corporate SSL certificate path (for corporate networks with SSL inspection)
WORKSPACE_ROOT = Path(__file__).parent.parent.parent.parent
CORPORATE_CERT = WORKSPACE_ROOT / "certs" / "gmf_and_public_cas.pem"
from gmf_forge_ai_shared_core.observability import BasicLogger
logger = BasicLogger("embeddings-example")
logger.info("Loaded .env", path=str(env_path))
def get_ssl_cert_path() -> str:
"""Get SSL certificate path if it exists."""
if CORPORATE_CERT.exists():
return str(CORPORATE_CERT)
return None
def get_embedder(logger=None) -> AzureOpenAIEmbeddings:
"""
Build an ``AzureOpenAIEmbeddings`` instance from environment variables.
Authentication is selected by the ``AZURE_USE_MANAGED_IDENTITY`` env var:
* **API key** (default): set ``AZURE_OPENAI_API_KEY``.
* **Managed identity**: set ``AZURE_USE_MANAGED_IDENTITY=true``.
The token provider must request the **Cognitive Services** scope::
https://cognitiveservices.azure.com/.default
Note: Azure AI Search uses a *different* scope
(``https://search.azure.com/.default``) — each service
requires its own token provider.
"""
ssl_cert = get_ssl_cert_path()
endpoint = os.getenv("AZURE_OPENAI_ENDPOINT", "https://your-resource.openai.azure.com")
deployment = os.getenv("AZURE_OPENAI_EMBEDDING_MODEL", "text-embedding-3-large")
api_version = os.getenv("AZURE_OPENAI_EMBEDDING_MODEL_VERSION", "2024-02-01")
if os.getenv("AZURE_USE_MANAGED_IDENTITY", "").lower() in ("1", "true", "yes"):
# Managed identity — scope: https://cognitiveservices.azure.com/.default
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
token_provider = get_bearer_token_provider(
DefaultAzureCredential(),
"https://cognitiveservices.azure.com/.default",
)
return AzureOpenAIEmbeddings(
endpoint=endpoint,
deployment_name=deployment,
token_provider=token_provider,
api_version=api_version,
ssl_cert_path=ssl_cert,
logger=logger,
)
else:
return AzureOpenAIEmbeddings(
endpoint=endpoint,
deployment_name=deployment,
api_key=os.getenv("AZURE_OPENAI_API_KEY", "your-api-key"),
api_version=api_version,
ssl_cert_path=ssl_cert,
logger=logger,
)
def example_single_embedding():
"""Example: Generate embedding for a single text."""
logger.info("Starting example", example="1: Single Text Embedding")
embedder = get_embedder(logger=logger)
# Generate embedding for a single text
text = "What is the capital of France?"
embedding = embedder.embed_text(text)
logger.info("Embedding result", text=text, dimension=len(embedding), first_5_values=embedding[:5], model=embedder.get_model_name())
def example_batch_embedding():
"""Example: Generate embeddings for multiple texts efficiently."""
logger.info("Starting example", example="2: Batch Embedding")
embedder = get_embedder(logger=logger)
# Sample texts to embed
texts = [
"What is machine learning?",
"How does neural network training work?",
"What are transformers in NLP?",
"Explain retrieval-augmented generation.",
"What is semantic search?",
]
# Method 1: Direct batch call
logger.info("Method 1: Direct batch embedding")
embeddings = embedder.embed_batch(texts)
logger.info("Batch embedding result", count=len(embeddings), dimension=len(embeddings[0]))
def example_large_batch_processing():
"""Example: Process large collections with BatchEmbeddings wrapper."""
logger.info("Starting example", example="3: Large-Scale Batch Processing")
base_embedder = get_embedder(logger=logger)
# Wrap with BatchEmbeddings for automatic batching and progress tracking
batch_embedder = BatchEmbeddings(
provider=base_embedder,
batch_size=100, # Process 100 texts per API call
show_progress=True, # Show progress messages
logger=logger,
)
# Simulate large document collection
large_text_collection = [
f"This is document number {i} with some sample content."
for i in range(250) # 250 documents
]
logger.info("Processing documents", count=len(large_text_collection))
embeddings = batch_embedder.embed_batch(large_text_collection)
logger.info("Successfully generated embeddings", count=len(embeddings), dimension=batch_embedder.get_embedding_dimension())
def example_with_custom_progress():
"""Example: Custom progress tracking callback."""
logger.info("Starting example", example="4: Custom Progress Tracking")
base_embedder = get_embedder()
# Custom progress callback
def progress_callback(current: int, total: int):
percentage = (current / total) * 100
logger.info("Progress", current=current, total=total, pct=round(percentage, 1))
batch_embedder = BatchEmbeddings(
provider=base_embedder,
batch_size=50,
show_progress=False, # Disable default progress
progress_callback=progress_callback, # Use custom callback
)
texts = [f"Sample text {i}" for i in range(150)]
embeddings = batch_embedder.embed_batch(texts)
logger.info("Completed", count=len(embeddings))
def example_error_handling():
"""Example: Error handling and validation."""
logger.info("Starting example", example="5: Error Handling")
embedder = get_embedder()
# Test 1: Empty text
try:
embedder.embed_text("")
except ValueError as e:
logger.info("Caught expected error", test=1, error=str(e))
# Test 2: None text
try:
embedder.embed_text(None)
except ValueError as e:
logger.info("Caught expected error", test=2, error=str(e))
# Test 3: Empty batch
try:
embedder.embed_batch([])
except ValueError as e:
logger.info("Caught expected error", test=3, error=str(e))
# Test 4: Batch with invalid entry
try:
embedder.embed_batch(["valid text", "", "another valid text"])
except ValueError as e:
logger.info("Caught expected error", test=4, error=str(e))
if __name__ == "__main__":
logger.info("Azure OpenAI Embeddings Examples")
# Check for SSL certificate
ssl_cert = get_ssl_cert_path()
if ssl_cert:
logger.info("SSL Certificate found", path=ssl_cert)
else:
logger.warning("SSL Certificate not found - using default SSL verification")
logger.info("If you get SSL errors in corporate networks, add certificate to certs/")
# Debug: Show loaded environment variables
logger.info("Environment Configuration",
endpoint=os.getenv("AZURE_OPENAI_ENDPOINT", "NOT SET"),
api_key_set=bool(os.getenv("AZURE_OPENAI_API_KEY")),
embedding_model=os.getenv("AZURE_OPENAI_EMBEDDING_MODEL", "NOT SET"),
model_version=os.getenv("AZURE_OPENAI_EMBEDDING_MODEL_VERSION", "NOT SET"),
)
# Check if credentials are set (API key or managed identity)
use_managed_identity = os.getenv("AZURE_USE_MANAGED_IDENTITY", "").lower() in ("1", "true", "yes")
if not all([
os.getenv("AZURE_OPENAI_ENDPOINT"),
os.getenv("AZURE_OPENAI_EMBEDDING_MODEL"),
os.getenv("AZURE_OPENAI_API_KEY") or use_managed_identity,
]):
logger.warning(
"Required environment variables not set",
missing_vars=["AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_EMBEDDING_MODEL",
"AZURE_OPENAI_API_KEY (or set AZURE_USE_MANAGED_IDENTITY=true)"],
)
else:
logger.info("Configuration loaded successfully")
logger.info("Running Example 1: Single Text Embedding")
try:
example_single_embedding()
logger.info("Example 1 passed")
example_batch_embedding()
logger.info("Example 2 passed")
example_large_batch_processing()
logger.info("Example 3 passed")
example_with_custom_progress()
logger.info("Example 4 passed")
example_error_handling()
logger.info("Example 5 passed")
except Exception as e:
logger.error("Example failed", error=str(e), model=os.getenv("AZURE_OPENAI_EMBEDDING_MODEL"))
logger.info("Example script completed")
index_builder_example.py
"""
Example: Indexing module — schema provisioning for vector stores.
This example demonstrates the separation of *infrastructure* concerns
(index / container creation) from *application* concerns (document CRUD
and search). The workflow is:
1. Use an ``IndexBuilder`` to create the backend schema once (or whenever
the schema changes).
2. Use the matching ``VectorStore`` for all runtime document operations.
The three builders covered here:
- ``AzureAISearchIndexBuilder`` — Azure AI Search HNSW index
- ``CosmosDBIndexBuilder`` — Cosmos DB NoSQL vector container
- ``MongoDBIndexBuilder`` — MongoDB Atlas vector + text indexes
Required .env variables
-----------------------
Azure AI Search:
AZURE_AI_SEARCH_ENDPOINT — e.g. https://my-search.search.windows.net
AZURE_AI_SEARCH_API_KEY
Azure Cosmos DB:
AZURE_COSMOS_ENDPOINT — e.g. https://my-account.documents.azure.com:443/
AZURE_COSMOS_KEY
MongoDB Atlas:
MONGODB_CONNECTION_STRING — e.g. mongodb+srv://user:pass@cluster.mongodb.net/
MONGODB_DATABASE — target database name (default: rag_db)
MONGODB_COLLECTION — target collection name (default: documents)
Optional:
AZURE_AI_SEARCH_EMBEDDING_DIMENSION — default 1536
"""
import os
import time
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Optional
from dotenv import load_dotenv
from gmf_forge_ai_data.indexing import (
AzureAISearchIndexBuilder,
CosmosDBIndexBuilder,
MongoDBIndexBuilder,
)
from gmf_forge_ai_shared_core.observability import BasicLogger
# ── Environment & globals ────────────────────────────────────────────────────
env_path = Path(__file__).parent / ".env"
load_dotenv(env_path)
logger = BasicLogger(__name__)
WORKSPACE_ROOT = Path(__file__).parent.parent.parent.parent
CORPORATE_CERT = WORKSPACE_ROOT / "certs" / "gmf_and_public_cas.pem"
def _ssl() -> Optional[str]:
"""Return the corporate SSL cert path if present, otherwise None."""
return str(CORPORATE_CERT) if CORPORATE_CERT.exists() else None
# ── Custom document type used across all examples ────────────────────────────
@dataclass
class PolicyDocument:
"""
Domain document extended with policy-management fields.
All extra fields (beyond id / content / embedding / timestamp / metadata)
will be registered as filterable fields in the respective index.
"""
id: str = ""
content: str = ""
embedding: list = None # type: ignore[assignment]
timestamp: Optional[datetime] = None
metadata: dict = None # type: ignore[assignment]
# --- custom fields ---
department: str = ""
policy_id: str = ""
effective_date: Optional[datetime] = None
version: str = ""
# ============================================================================
# Example 1: Azure AI Search index builder
# ============================================================================
def example_azure_ai_search_builder():
"""
Provision an Azure AI Search HNSW index with full developer control
over vector parameters and field mapping, then demonstrate schema
migration via ``create_or_replace_index()``.
Demonstrates
------------
- ``create_index()`` — idempotent creation
- ``index_exists()`` — existence check
- ``list_indexes()`` — list all indexes on the service
- ``create_or_replace_index()`` — destructive schema replacement
- ``delete_index()`` — clean-up
"""
logger.info("Starting example", example="AzureAISearchIndexBuilder")
endpoint = os.getenv("AZURE_AI_SEARCH_ENDPOINT")
api_key = os.getenv("AZURE_AI_SEARCH_API_KEY")
embedding_dim = int(os.getenv("AZURE_AI_SEARCH_EMBEDDING_DIMENSION", "1536"))
if not endpoint or not api_key:
logger.warning(
"Azure AI Search credentials not configured",
missing=["AZURE_AI_SEARCH_ENDPOINT", "AZURE_AI_SEARCH_API_KEY"],
action="Showing builder API without executing Azure calls",
)
_demo_azure_builder_api(embedding_dim)
return
index_name = "indexing-example-policy-docs"
# ------------------------------------------------------------------ #
# Step 1 — Infrastructure: provision the index #
# ------------------------------------------------------------------ #
builder = AzureAISearchIndexBuilder(
endpoint=endpoint,
api_key=api_key,
index_name=index_name,
embedding_dimension=embedding_dim,
document_type=PolicyDocument,
# HNSW parameters — tune for recall vs. memory trade-off
hnsw_m=4, # graph edges per node (default 4)
hnsw_ef_construction=400, # candidates during build (default 400)
hnsw_ef_search=500, # candidates at query time (default 500)
metric="cosine",
ssl_cert_path=_ssl(),
)
logger.info("Creating Azure AI Search index", index_name=index_name)
builder.create_index()
logger.info("Index created (or already existed)", index_name=index_name)
# ------------------------------------------------------------------ #
# Step 2 — Verify existence and list #
# ------------------------------------------------------------------ #
exists = builder.index_exists()
logger.info("Index existence check", index_name=index_name, exists=exists)
all_indexes = builder.list_indexes()
logger.info("All indexes on service", count=len(all_indexes), indexes=all_indexes)
# ------------------------------------------------------------------ #
# Step 3 — Schema migration: replace the index #
# ------------------------------------------------------------------ #
logger.info(
"Replacing index (schema migration scenario)",
note="create_or_replace_index() drops + recreates the index",
)
builder.create_or_replace_index()
logger.info("Index replaced", index_name=index_name)
# ------------------------------------------------------------------ #
# Step 4 — Clean up #
# ------------------------------------------------------------------ #
builder.delete_index()
logger.info("Index deleted (example clean-up)", index_name=index_name)
# ------------------------------------------------------------------ #
# Step 5 — Application step (not executed — credentials vary) #
# ------------------------------------------------------------------ #
logger.info(
"Application step pattern",
note="After create_index(), use AzureAISearchVectorStore for document ops",
code="store = AzureAISearchVectorStore(endpoint=..., index_name=..., api_key=...)",
)
def _demo_azure_builder_api(embedding_dim: int) -> None:
"""Print the builder construction call without executing Azure requests."""
logger.info(
"AzureAISearchIndexBuilder usage",
constructor="AzureAISearchIndexBuilder(\n"
" endpoint='https://my-search.search.windows.net',\n"
" api_key='...',\n"
" index_name='policy_docs',\n"
f" embedding_dimension={embedding_dim},\n"
" document_type=PolicyDocument,\n"
" hnsw_m=4,\n"
" hnsw_ef_construction=400,\n"
" hnsw_ef_search=500,\n"
" metric='cosine',\n"
")",
methods=[
"builder.create_index() # idempotent",
"builder.create_or_replace_index() # destructive replace",
"builder.delete_index() # raises if not exists",
"builder.index_exists() -> bool",
"builder.list_indexes() -> List[str]",
],
)
# ============================================================================
# Example 2: Cosmos DB index builder
# ============================================================================
def example_cosmos_db_builder():
"""
Provision a Cosmos DB NoSQL container with a vector embedding policy
and configurable indexing policy.
Demonstrates
------------
- ``create_index()`` — creates database + container (idempotent)
- ``index_exists()`` — container existence check
- ``list_indexes()`` — list all containers in the database
- Distance function and vector index type selection
"""
logger.info("Starting example", example="CosmosDBIndexBuilder")
endpoint = os.getenv("AZURE_COSMOS_ENDPOINT")
api_key = os.getenv("AZURE_COSMOS_KEY")
if not endpoint or not api_key:
logger.warning(
"Cosmos DB credentials not configured",
missing=["AZURE_COSMOS_ENDPOINT", "AZURE_COSMOS_KEY"],
action="Showing builder API without executing Cosmos DB calls",
)
_demo_cosmos_builder_api()
return
database_name = "rag_db_example"
container_name = "indexing_example_policy_docs"
# ------------------------------------------------------------------ #
# Step 1 — Infrastructure: provision the container #
# ------------------------------------------------------------------ #
builder = CosmosDBIndexBuilder(
endpoint=endpoint,
api_key=api_key,
database_name=database_name,
container_name=container_name,
embedding_dimension=1536,
distance_function="cosine",
# quantizedFlat → lower memory; diskANN → higher recall on large datasets
vector_index_type="quantizedFlat",
partition_key="/id",
# throughput=400, # omit for serverless accounts; set for provisioned RU/s accounts
ssl_cert_path=_ssl(),
)
logger.info(
"Creating Cosmos DB database and container",
database=database_name,
container=container_name,
)
builder.create_index()
logger.info("Container created (or already existed)", container=container_name)
# ------------------------------------------------------------------ #
# Step 2 — Existence check and list #
# ------------------------------------------------------------------ #
exists = builder.index_exists()
logger.info("Container existence check", container=container_name, exists=exists)
all_containers = builder.list_indexes()
logger.info("All containers in database", count=len(all_containers), containers=all_containers)
# ------------------------------------------------------------------ #
# Step 3 — Application step note #
# ------------------------------------------------------------------ #
logger.info(
"Application step pattern",
note="After create_index(), use AzureCosmosDBVectorStore for document ops",
code=(
"store = AzureCosmosDBVectorStore(\n"
f" endpoint=..., api_key=...,\n"
f" database_name='{database_name}',\n"
f" container_name='{container_name}',\n"
" embedding_dimension=1536,\n"
")"
),
)
# ------------------------------------------------------------------ #
# Step 4 — Clean up (delete container only, keep database) #
# ------------------------------------------------------------------ #
builder.delete_index()
logger.info("Container deleted (example clean-up)", container=container_name)
def _demo_cosmos_builder_api() -> None:
"""Print Cosmos builder usage without executing requests."""
logger.info(
"CosmosDBIndexBuilder usage",
constructor=(
"CosmosDBIndexBuilder(\n"
" endpoint='https://my-account.documents.azure.com:443/',\n"
" api_key='...',\n"
" database_name='rag_db',\n"
" container_name='policy_docs',\n"
" embedding_dimension=1536,\n"
" distance_function='cosine', # cosine | euclidean | dotproduct\n"
" vector_index_type='quantizedFlat', # quantizedFlat | diskANN\n"
" partition_key='/id',\n"
" # throughput=400, # provisioned accounts only; omit for serverless\n"
")"
),
prerequisite=(
"Enable NoSQL Vector Search on the account first:\n"
" az cosmosdb update --resource-group <RG> "
"--name <ACCOUNT> --capabilities EnableNoSQLVectorSearch"
),
)
# ============================================================================
# Example 3: MongoDB Atlas index builder
# ============================================================================
def example_mongodb_builder():
"""
Provision an Atlas Vector Search index and a MongoDB ``$text`` index on
a collection, then demonstrate updating the vector index via
``create_or_replace_index()``.
Demonstrates
------------
- ``create_index()`` — creates vector index + text index
- ``index_exists()`` — vector index existence check
- ``list_indexes()`` — list all Atlas search indexes on collection
- ``create_or_replace_index()`` — replaces vector index (text index unchanged)
- ``delete_index()`` — removes the vector index
- ``list_text_indexes()`` — list standard MongoDB indexes
- Custom filter fields from ``PolicyDocument`` dataclass
"""
logger.info("Starting example", example="MongoDBIndexBuilder")
connection_string = os.getenv("MONGODB_CONNECTION_STRING")
database_name = os.getenv("MONGODB_DATABASE", "rag_db")
collection_name = os.getenv("MONGODB_COLLECTION", "documents")
if not connection_string:
logger.warning(
"MongoDB credentials not configured",
missing=["MONGODB_CONNECTION_STRING"],
action="Showing builder API without executing MongoDB calls",
)
_demo_mongodb_builder_api(database_name, collection_name)
return
vector_index_name = "indexing_example_vector_index"
# ------------------------------------------------------------------ #
# Step 1 — Infrastructure: provision the indexes #
# ------------------------------------------------------------------ #
builder = MongoDBIndexBuilder(
connection_string=connection_string,
database_name=database_name,
collection_name=collection_name,
embedding_dimension=1536,
document_type=PolicyDocument, # extra fields → Atlas filter fields
vector_index_name=vector_index_name,
similarity="cosine",
# Extra filter paths beyond dataclass inference (optional)
extra_filter_paths=["metadata.source", "metadata.category"],
ssl_cert_path=_ssl(),
)
logger.info(
"Creating Atlas Vector Search index + text index",
database=database_name,
collection=collection_name,
vector_index=vector_index_name,
)
builder.create_index()
logger.info("Indexes created (or already existed)")
# ------------------------------------------------------------------ #
# Step 2 — Existence check and list #
# ------------------------------------------------------------------ #
exists = builder.index_exists()
logger.info("Vector index existence check", vector_index=vector_index_name, exists=exists)
all_vector_indexes = builder.list_indexes()
logger.info("All Atlas vector search indexes", indexes=all_vector_indexes)
text_indexes = builder.list_text_indexes()
logger.info("All standard MongoDB indexes", indexes=text_indexes)
# ------------------------------------------------------------------ #
# Step 3 — Schema migration: replace the vector index #
# ------------------------------------------------------------------ #
logger.info(
"Replacing vector index (schema migration scenario)",
note="Text index is preserved — only the vector index is replaced",
)
builder.create_or_replace_index()
logger.info("Vector index replaced", vector_index=vector_index_name)
# ------------------------------------------------------------------ #
# Step 4 — Application step note #
# ------------------------------------------------------------------ #
logger.info(
"Application step pattern",
note="After create_index(), use MongoDBVectorStore for document ops",
code=(
"store = MongoDBVectorStore(\n"
f" connection_string=...,\n"
f" database_name='{database_name}',\n"
f" collection_name='{collection_name}',\n"
f" vector_index_name='{vector_index_name}',\n"
" embedding_dimension=1536,\n"
")"
),
note2=(
"Atlas indexes take a short time to become ready. "
"Poll list_indexes() or check the Atlas UI before running searches."
),
)
# ------------------------------------------------------------------ #
# Step 5 — Clean up #
# ------------------------------------------------------------------ #
builder.delete_index()
logger.info("Vector index deleted (example clean-up)", vector_index=vector_index_name)
def _demo_mongodb_builder_api(database_name: str, collection_name: str) -> None:
"""Print MongoDB builder usage without executing requests."""
logger.info(
"MongoDBIndexBuilder usage",
constructor=(
"MongoDBIndexBuilder(\n"
" connection_string='mongodb+srv://user:pass@cluster.mongodb.net/',\n"
f" database_name='{database_name}',\n"
f" collection_name='{collection_name}',\n"
" embedding_dimension=1536,\n"
" document_type=PolicyDocument, # extra fields → Atlas filter fields\n"
" vector_index_name='vector_index',\n"
" similarity='cosine', # cosine | euclidean | dotProduct\n"
" extra_filter_paths=['metadata.source'],\n"
")"
),
atlas_note=(
"Indexes created via create_search_index() are provisioned asynchronously. "
"Allow a short wait before running $vectorSearch queries."
),
)
# ============================================================================
# Example 4: Builder vs. vector store — side-by-side pattern
# ============================================================================
def example_separation_pattern():
"""
Contrast the builder pattern vs. the historical pattern where index
creation was bundled inside the vector store.
This example is documentation-only (no network calls).
"""
logger.info("Starting example", example="Separation pattern explanation")
logger.info(
"OLD pattern — index creation inside vector store",
code=(
"store = AzureAISearchVectorStore(...)\n"
"store.create_index('policy_docs') # ← coupled: schema + data ops in one class"
),
problem=(
"- Index parameters (HNSW m, ef_construction, metric) are not exposed\n"
"- Every application instance could inadvertently mutate the schema\n"
"- DevOps pipeline cannot provision schema separately from app deployment"
),
)
logger.info(
"NEW pattern — builder owns provisioning, store owns document ops",
step_1=(
"# Infrastructure pipeline (CI/CD / one-time setup)\n"
"builder = AzureAISearchIndexBuilder(\n"
" endpoint=..., api_key=..., index_name='policy_docs',\n"
" hnsw_m=4, hnsw_ef_construction=400, metric='cosine',\n"
")\n"
"builder.create_index()"
),
step_2=(
"# Application runtime\n"
"store = AzureAISearchVectorStore(\n"
" endpoint=..., api_key=..., index_name='policy_docs',\n"
")\n"
"store.add_documents(chunks)\n"
"results = store.search(query_embedding=embedding, top_k=5)"
),
benefits=[
"Explicit HNSW tuning (m, ef_construction, ef_search, metric)",
"Schema changes tracked separately in version control",
"Zero risk of accidental index mutation at runtime",
"Same builder interface across Azure AI Search, Cosmos DB, MongoDB",
],
)
# ============================================================================
# main
# ============================================================================
def main():
"""Run all indexing examples."""
start = time.time()
logger.info("Indexing module examples starting")
example_azure_ai_search_builder()
example_cosmos_db_builder()
example_mongodb_builder()
example_separation_pattern()
elapsed_ms = (time.time() - start) * 1000
logger.info(
"All indexing examples completed",
duration_ms=round(elapsed_ms, 1),
builders_demonstrated=3,
pattern="builder (schema) + vector_store (document ops)",
)
if __name__ == "__main__":
main()
mongodb_vector_store_example.py
"""
Example: MongoDB Atlas Vector Search Vector Store (Module 7b)
Demonstrates how to use MongoDBVectorStore for vector search in a RAG pipeline:
1. Connect to a MongoDB Atlas cluster with Vector Search enabled
2. Index documents with real Azure OpenAI embeddings
3. Run vector, keyword, and hybrid searches
4. Use metadata filters to narrow results
5. Fetch, update, and delete documents
Index creation is handled by MongoDBIndexBuilder::
builder = MongoDBIndexBuilder(connection_string, database_name, collection_name, ...)
builder.create_index() # provisions vector search index + text index
Required .env variables:
MONGODB_CONNECTION_STRING mongodb+srv://user:pass@cluster.mongodb.net/
MONGODB_DATABASE e.g. rag_db (default: rag_db)
MONGODB_COLLECTION e.g. tech_articles (default: tech_articles)
MONGODB_VECTOR_INDEX e.g. vector_index (default: vector_index)
AZURE_OPENAI_ENDPOINT
AZURE_OPENAI_API_KEY
AZURE_OPENAI_EMBEDDING_MODEL deployment name for your embedding model
AZURE_OPENAI_EMBEDDING_MODEL_VERSION (default: 2024-02-01)
Optional:
AZURE_AI_SEARCH_EMBEDDING_DIMENSION (default: 1536)
"""
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
from dotenv import load_dotenv
from gmf_forge_ai_shared_core.observability import BasicLogger
from gmf_forge_ai_shared_core.observability.tracing import get_tracer
import nltk
from gmf_forge_ai_data.chunkers import SemanticChunker
from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings
from gmf_forge_ai_data.vector_stores import Document, MongoDBVectorStore
from gmf_forge_ai_data.indexing import MongoDBIndexBuilder
# ── Environment ────────────────────────────────────────────────────────────────
env_path = Path(__file__).parent / ".env"
load_dotenv(env_path)
logger = BasicLogger(__name__)
WORKSPACE_ROOT = Path(__file__).parent.parent.parent.parent
CORPORATE_CERT = WORKSPACE_ROOT / "certs" / "gmf_and_public_cas.pem"
# ── Document schema ────────────────────────────────────────────────────────────
@dataclass
class TechArticle(Document):
"""Technology article with domain-specific fields."""
category: str = ""
author: str = ""
published_year: int = 0
@dataclass
class ResearchPaper(Document):
"""Academic research paper with citation and classification fields."""
field: str = "" # research field e.g. "nlp", "computer_vision", "rag"
institution: str = "" # publishing institution or university
year: int = 0 # publication year
# ── Config helpers ─────────────────────────────────────────────────────────────
def get_ssl_cert_path() -> Optional[str]:
return str(CORPORATE_CERT) if CORPORATE_CERT.exists() else None
def get_embedder() -> Optional[AzureOpenAIEmbeddings]:
endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
api_key = os.getenv("AZURE_OPENAI_API_KEY")
deployment = os.getenv("AZURE_OPENAI_EMBEDDING_MODEL")
api_version = os.getenv("AZURE_OPENAI_EMBEDDING_MODEL_VERSION", "2024-02-01")
if not endpoint or not api_key or not deployment:
logger.error(
"Azure OpenAI embedding credentials missing",
missing_vars=["AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_API_KEY", "AZURE_OPENAI_EMBEDDING_MODEL"],
)
return None
ssl_cert = get_ssl_cert_path()
if ssl_cert:
logger.info("Using SSL certificate", cert=CORPORATE_CERT.name)
embedder = AzureOpenAIEmbeddings(
endpoint=endpoint,
api_key=api_key,
deployment_name=deployment,
api_version=api_version,
ssl_cert_path=ssl_cert,
)
logger.info("Initialized embeddings", deployment=deployment)
return embedder
# ── Sample documents ───────────────────────────────────────────────────────────
def build_sample_documents(embedder: AzureOpenAIEmbeddings) -> list:
"""Build 8 TechArticle documents and embed their content via Azure OpenAI."""
raw = [
TechArticle(
id="neural_networks_intro",
content=(
"Neural networks are computing systems loosely inspired by the biological "
"neural networks that constitute animal brains. They consist of layers of "
"interconnected nodes, or neurons, which process data and learn patterns. "
"Deep neural networks have many hidden layers, enabling them to learn "
"complex representations from raw input data. Modern architectures have "
"achieved state-of-the-art results in vision, language, and speech tasks."
),
category="deep_learning",
author="AI Research Team",
published_year=2023,
metadata={"difficulty": "beginner", "tags": ["neural_networks", "AI"]},
),
TechArticle(
id="backpropagation_explained",
content=(
"Backpropagation is the primary algorithm used to train neural networks. "
"It computes the gradient of the loss function with respect to every "
"weight by applying the chain rule of calculus. The gradients are then "
"used by an optimisation algorithm — such as stochastic gradient descent "
"or Adam — to update the weights. Efficient GPU implementations make "
"backpropagation tractable even for networks with billions of parameters."
),
category="deep_learning",
author="ML Theory Group",
published_year=2023,
metadata={"difficulty": "intermediate", "tags": ["backpropagation", "training"]},
),
TechArticle(
id="transformers_architecture",
content=(
"The Transformer architecture, introduced by Vaswani et al. in 2017, "
"relies entirely on self-attention mechanisms instead of recurrent layers. "
"Multi-head attention allows the model to attend to information from "
"different representation subspaces simultaneously. Positional encodings "
"inject sequence-order information. Transformers power large language "
"models such as GPT-4, BERT, and T5."
),
category="nlp",
author="NLP Research Team",
published_year=2022,
metadata={"difficulty": "intermediate", "tags": ["transformers", "attention"]},
),
TechArticle(
id="vector_databases",
content=(
"Vector databases are purpose-built systems for storing and querying "
"high-dimensional embedding vectors. They use approximate nearest "
"neighbour (ANN) algorithms — such as HNSW, IVF, or PQ — to return "
"semantically similar documents in milliseconds. Popular options include "
"Pinecone, Weaviate, Qdrant, Azure AI Search, Cosmos DB, and MongoDB "
"Atlas. They are essential infrastructure for RAG applications."
),
category="infrastructure",
author="Data Engineering Team",
published_year=2024,
metadata={"difficulty": "beginner", "tags": ["vector_db", "RAG"]},
),
TechArticle(
id="retrieval_augmented_generation",
content=(
"Retrieval Augmented Generation (RAG) combines a retrieval component with "
"a generative language model. When a user poses a question, a retriever "
"fetches relevant documents from a knowledge base; the generator then "
"conditions its response on both the question and the retrieved context. "
"RAG reduces hallucinations, supports domain-specific knowledge, and "
"enables the model to cite sources."
),
category="nlp",
author="AI Research Team",
published_year=2024,
metadata={"difficulty": "intermediate", "tags": ["RAG", "LLM"]},
),
TechArticle(
id="convolutional_networks",
content=(
"Convolutional Neural Networks (CNNs) use learnable convolutional filters "
"to extract local spatial features from grid-structured data like images. "
"Pooling layers reduce spatial dimensions while retaining dominant features. "
"Architectures such as ResNet, VGG, and EfficientNet have set benchmarks "
"on ImageNet. CNNs are also applied to 1-D signals in speech processing "
"and time-series analysis."
),
category="computer_vision",
author="Vision Research Team",
published_year=2022,
metadata={"difficulty": "intermediate", "tags": ["CNN", "computer_vision"]},
),
TechArticle(
id="reinforcement_learning",
content=(
"Reinforcement learning (RL) trains agents to maximise cumulative reward "
"through interaction with an environment. An agent selects actions, "
"observes state transitions, and receives reward signals. Deep RL combines "
"neural networks with RL algorithms — exemplified by DQN, PPO, and SAC. "
"AlphaGo and AlphaFold are landmark applications. RL has also been used "
"to fine-tune language models through RLHF."
),
category="deep_learning",
author="RL Research Team",
published_year=2023,
metadata={"difficulty": "advanced", "tags": ["reinforcement_learning", "agents"]},
),
TechArticle(
id="mongodb_atlas_vector_search",
content=(
"MongoDB Atlas Vector Search extends the Atlas platform with ANN vector "
"search support using the HNSW algorithm. Developers define a vector "
"index on a collection field (e.g. 'embedding') through the Atlas UI or "
"Atlas CLI and then query it with the $vectorSearch aggregation stage. "
"Metadata pre-filtering, hybrid search, and score projection are "
"available out-of-the-box, making it straightforward to build RAG "
"pipelines on existing MongoDB data."
),
category="infrastructure",
author="Data Engineering Team",
published_year=2024,
metadata={"difficulty": "beginner", "tags": ["mongodb", "vector_db", "RAG"]},
),
]
logger.info("Generating embeddings", count=len(raw))
texts = [d.content for d in raw]
embeddings = embedder.embed_batch(texts)
for doc, emb in zip(raw, embeddings):
doc.embedding = emb
logger.info("Embeddings generated", dimension=len(embeddings[0]))
return raw
# ══════════════════════════════════════════════════════════════════════════════
# Example 1: Index documents
# ══════════════════════════════════════════════════════════════════════════════
def example_index_documents(store: MongoDBVectorStore, docs: list) -> None:
logger.info("Starting example", example="1: Index Documents")
logger.info("Scenario",
problem="Documents must be embedded and stored before they can be retrieved",
solution="embed_batch() all content, then add_documents() to the store")
added = store.add_documents(docs)
logger.info("Indexed documents", added=added, total=store.count())
# ══════════════════════════════════════════════════════════════════════════════
# Example 2: Vector search
# ══════════════════════════════════════════════════════════════════════════════
def example_vector_search(store: MongoDBVectorStore, embedder: AzureOpenAIEmbeddings) -> None:
logger.info("Starting example", example="2: Vector Search")
logger.info("Scenario",
problem="Find semantically similar documents to a natural-language query",
solution="Embed the query, then search with search_type='vector'")
query = "neural network deep learning backpropagation"
logger.info("Query", text=query)
query_emb = embedder.embed_text(query)
results = store.search(query_embedding=query_emb, top_k=3, search_type="vector")
for r in results:
art: TechArticle = r.document
logger.info("Vector result", rank=r.rank, id=art.id, score=round(r.score, 4),
category=art.category, year=art.published_year)
# ══════════════════════════════════════════════════════════════════════════════
# Example 3: Keyword search
# ══════════════════════════════════════════════════════════════════════════════
def example_keyword_search(store: MongoDBVectorStore) -> None:
logger.info("Starting example", example="3: Keyword Search")
logger.info("Scenario",
problem="Find documents containing specific terms (exact lexical match)",
solution="Pass query text only with search_type='keyword'")
query = "transformer attention BERT GPT"
logger.info("Query", text=query)
results = store.search(query=query, top_k=3, search_type="keyword")
for r in results:
art: TechArticle = r.document
logger.info("Keyword result", rank=r.rank, id=art.id, score=round(r.score, 4),
category=art.category)
# ══════════════════════════════════════════════════════════════════════════════
# Example 4: Hybrid search
# ══════════════════════════════════════════════════════════════════════════════
def example_hybrid_search(store: MongoDBVectorStore, embedder: AzureOpenAIEmbeddings) -> None:
logger.info("Starting example", example="4: Hybrid Search")
logger.info("Scenario",
problem="Balance semantic relevance with precise keyword matching",
solution="Provide both query text and query_embedding with search_type='hybrid'")
query = "RAG vector database infrastructure"
logger.info("Query", text=query)
query_emb = embedder.embed_text(query)
results = store.search(
query=query,
query_embedding=query_emb,
top_k=3,
search_type="hybrid",
)
for r in results:
art: TechArticle = r.document
logger.info("Hybrid result", rank=r.rank, id=art.id, score=round(r.score, 4),
category=art.category)
# ══════════════════════════════════════════════════════════════════════════════
# Example 5: Filtered vector search
# ══════════════════════════════════════════════════════════════════════════════
def example_filtered_search(store: MongoDBVectorStore, embedder: AzureOpenAIEmbeddings) -> None:
logger.info("Starting example", example="5: Filtered Vector Search")
logger.info("Scenario",
problem="Restrict results to documents matching a metadata condition",
solution="Pass a filters dict alongside query_embedding")
query = "machine learning algorithms"
logger.info("Query", text=query, filter="published_year >= 2024")
query_emb = embedder.embed_text(query)
results = store.search(
query_embedding=query_emb,
top_k=5,
filters={"published_year": (">=", 2024)},
search_type="vector",
)
for r in results:
art: TechArticle = r.document
logger.info("Filtered result", rank=r.rank, id=art.id, year=art.published_year,
score=round(r.score, 4))
# ══════════════════════════════════════════════════════════════════════════════
# Example 6: CRUD operations
# ══════════════════════════════════════════════════════════════════════════════
def example_crud(store: MongoDBVectorStore) -> None:
logger.info("Starting example", example="6: CRUD — get / update / delete")
# get
doc_id = "transformers_architecture"
fetched: TechArticle = store.get_document(doc_id)
if fetched:
logger.info("Fetched document", id=fetched.id, author=fetched.author,
category=fetched.category)
# update metadata
store.update_document(doc_id, metadata={"highlighted": True})
updated: TechArticle = store.get_document(doc_id)
logger.info("After metadata update", id=updated.id, metadata=updated.metadata)
# delete
deleted = store.delete_documents(["convolutional_networks"])
logger.info("Deleted documents", deleted=deleted, remaining=store.count())
# ══════════════════════════════════════════════════════════════════════════════
# Example 7: End-to-End RAG Pipeline (Chunk → Embed → Index → Retrieve)
# ══════════════════════════════════════════════════════════════════════════════
def example_rag_pipeline(embedder: AzureOpenAIEmbeddings) -> None:
"""Demonstrate a complete RAG ingestion pipeline using SemanticChunker and ResearchPaper schema."""
logger.info("Starting example", example="7: End-to-End RAG Pipeline")
logger.info("Scenario",
description="Ingest multiple research papers through the full RAG pipeline",
pipeline="papers → SemanticChunker → embed_batch → add_documents(ResearchPaper) → vector_search",
why="SemanticChunker groups sentences at natural topic boundaries, maximising chunk coherence",
note="ResearchPaper index schema differs from TechArticle — demonstrates separate collection per domain")
# ── Dedicated store for research papers index ────────────────────────────────
connection_string = os.getenv("MONGODB_CONNECTION_STRING")
database_name = os.getenv("MONGODB_DATABASE", "rag_db")
vector_index_name = os.getenv("MONGODB_VECTOR_INDEX", "vector_index")
embedding_dim = int(os.getenv("AZURE_AI_SEARCH_EMBEDDING_DIMENSION", "1536"))
if not connection_string:
logger.error("MONGODB_CONNECTION_STRING not set — skipping RAG pipeline example")
return
MongoDBIndexBuilder(
connection_string=connection_string,
database_name=database_name,
collection_name="research_papers",
embedding_dimension=embedding_dim,
vector_index_name=vector_index_name,
ssl_cert_path=get_ssl_cert_path(),
).create_index()
pipeline_store = MongoDBVectorStore(
connection_string=connection_string,
database_name=database_name,
collection_name="research_papers",
vector_index_name=vector_index_name,
embedding_dimension=embedding_dim,
document_type=ResearchPaper,
ssl_cert_path=get_ssl_cert_path(),
)
logger.info("MongoDB pipeline store",
collection="research_papers", document_type="ResearchPaper",
custom_fields=["field", "institution", "year"],
note="Separate collection from 'tech_articles' — each domain gets its own index schema")
# ── Source papers (each with its own field, institution, and year) ──────────
papers = [
{
"id": "paper_rag_survey",
"field": "nlp",
"institution": "stanford_nlp",
"year": 2024,
"text": (
"Retrieval-augmented generation combines parametric knowledge stored in large language "
"model weights with non-parametric knowledge retrieved from an external corpus at "
"inference time. This survey examines 87 RAG systems published between 2020 and 2024, "
"categorising them along three axes: retriever architecture, integration strategy, and "
"evaluation methodology. Dense passage retrieval using bi-encoder models remains the most "
"popular retriever, though recent work on late-interaction models such as ColBERT shows "
"promising improvements in recall. Integration strategies range from simple prompt "
"concatenation to cross-attention fusion between retriever and generator representations. "
"The survey identifies chunk granularity and overlap as under-explored hyper-parameters "
"that significantly affect downstream answer quality."
),
},
{
"id": "paper_vector_indexing",
"field": "information_retrieval",
"institution": "microsoft_research",
"year": 2025,
"text": (
"Approximate nearest-neighbour search is the computational backbone of vector databases. "
"This paper benchmarks four index families — HNSW, IVF-PQ, ScaNN, and DiskANN — on "
"a 100-million-vector dataset drawn from production embedding workloads. HNSW achieves "
"the highest recall at low latency but consumes the most memory due to its graph structure. "
"IVF-PQ compresses vectors via product quantisation, trading recall for a 10x reduction "
"in memory. DiskANN extends HNSW to disk-resident data, making billion-scale search "
"feasible on commodity hardware. The authors propose a hybrid index that dynamically "
"routes queries to HNSW for high-precision needs and IVF-PQ for cost-sensitive workloads."
),
},
{
"id": "paper_chunking_strategies",
"field": "rag",
"institution": "platform_ai_lab",
"year": 2025,
"text": (
"Chunking strategy is a critical but often overlooked component of RAG pipelines. "
"This paper evaluates five chunking approaches — fixed-size, recursive, sentence-based, "
"semantic, and document-structure-aware — across four knowledge-intensive QA benchmarks. "
"Semantic chunking, which groups sentences by topic similarity, achieves the best answer "
"accuracy on long-form documents such as technical manuals and legal filings. "
"Sentence-based chunking performs competitively on short, well-structured documents like "
"FAQ pages and knowledge-base articles. Recursive chunking offers a robust general-purpose "
"baseline with the lowest variance across domains. The study recommends tuning chunk size "
"between 300 and 500 characters with 10-15 percent overlap for most production deployments."
),
},
]
logger.info("Source papers", count=len(papers),
fields=[p["field"] for p in papers],
institutions=[p["institution"] for p in papers])
query = "What chunking strategy works best for RAG pipelines?"
tracer = get_tracer()
with tracer.trace("rag_pipeline", input=query) as trace:
with trace.span("chunking") as span:
# ── Step 1: Chunk ──────────────────────────────────────────────────────
logger.info("Step 1 - Chunking",
strategy="SemanticChunker", max_chunk_size=500, min_chunk_size=80,
tokenizer="nltk.sent_tokenize")
nltk.download("punkt_tab", quiet=True)
chunker = SemanticChunker(
sentence_tokenizer=nltk.sent_tokenize,
max_chunk_size=500,
min_chunk_size=80,
)
all_chunks_per_paper: list = []
for paper in papers:
chunks = chunker.chunk(
paper["text"],
metadata={"source": paper["id"], "field": paper["field"]},
)
logger.info("Chunked paper",
paper_id=paper["id"], field=paper["field"], institution=paper["institution"],
num_chunks=len(chunks),
avg_chars=sum(len(c.text) for c in chunks) // max(len(chunks), 1))
all_chunks_per_paper.append((paper, chunks))
span.set_output({"documents": len(papers), "total_chunks": sum(len(c) for _, c in all_chunks_per_paper)})
# ── Step 2: Embed ──────────────────────────────────────────────────────────
with trace.span("embedding") as span:
all_docs: list = []
for paper, chunks in all_chunks_per_paper:
texts = [c.text for c in chunks]
logger.info("Step 2 - Embedding", paper_id=paper["id"], chunks=len(texts))
embeddings = embedder.embed_batch(texts)
# ── Build typed documents ──────────────────────────────────────────
for i, (chunk, emb) in enumerate(zip(chunks, embeddings)):
all_docs.append(
ResearchPaper(
id=f"{paper['id']}_chunk_{i}",
content=chunk.text,
embedding=emb,
field=paper["field"],
institution=paper["institution"],
year=paper["year"],
metadata={**chunk.metadata, "chunk_index": i,
"char_start": chunk.start_pos, "char_end": chunk.end_pos},
)
)
span.set_output({"chunks": len(all_docs)})
# ── Step 3: Index ──────────────────────────────────────────────────────────
with trace.span("indexing") as span:
logger.info("Step 3 - Indexing", document_type="ResearchPaper", total_chunks=len(all_docs))
pipeline_store.add_documents(all_docs, generate_embeddings=False)
logger.info("Chunks indexed", count=len(all_docs))
span.set_output({"chunks_indexed": len(all_docs)})
# ── Step 4: Retrieve ───────────────────────────────────────────────────────
with trace.span("retrieval", input=query) as span:
logger.info("Step 4 - Retrieval", query=query, top_k=3)
query_embedding = embedder.embed_text(query)
results = pipeline_store.search(query_embedding=query_embedding, top_k=3)
logger.info("Retrieved chunks", returned=len(results))
for r in results:
paper: ResearchPaper = r.document
snippet = paper.content[:100].replace("\n", " ")
logger.info("Retrieved chunk",
rank=r.rank, id=paper.id, score=round(r.score, 4),
field=paper.field, institution=paper.institution, year=paper.year,
snippet=snippet)
span.set_output({"results": len(results), "top_score": round(results[0].score, 4) if results else 0})
trace.set_output({"chunks_indexed": len(all_docs), "query_results": len(results)})
pipeline_store.close()
logger.info("RAG pipeline complete",
chunks_indexed=len(all_docs),
query_results=len(results),
tip="Filter by field='rag' or year=2025 to narrow research paper search results")
# ══════════════════════════════════════════════════════════════════════════════
# Main
# ══════════════════════════════════════════════════════════════════════════════
def main() -> None:
logger.info("MONGODB ATLAS VECTOR SEARCH EXAMPLES",
note="Demonstrating vector, keyword, hybrid search and CRUD with MongoDBVectorStore")
# ── Embedder (required) ──────────────────────────────────────────────────
embedder = get_embedder()
if not embedder:
logger.error("Cannot run examples without embedder — check .env")
return
# ── MongoDB connection ───────────────────────────────────────────────────
connection_string = os.getenv("MONGODB_CONNECTION_STRING")
database_name = os.getenv("MONGODB_DATABASE", "rag_db")
collection_name = os.getenv("MONGODB_COLLECTION", "tech_articles")
vector_index_name = os.getenv("MONGODB_VECTOR_INDEX", "vector_index")
if not connection_string:
logger.error(
"MONGODB_CONNECTION_STRING not set in .env",
required_vars=["MONGODB_CONNECTION_STRING"],
optional_vars=["MONGODB_DATABASE", "MONGODB_COLLECTION", "MONGODB_VECTOR_INDEX"],
)
return
embedding_dim = int(os.getenv("AZURE_AI_SEARCH_EMBEDDING_DIMENSION", "1536"))
logger.info("Connecting to MongoDB Atlas",
database=database_name, collection=collection_name,
vector_index=vector_index_name)
MongoDBIndexBuilder(
connection_string=connection_string,
database_name=database_name,
collection_name=collection_name,
embedding_dimension=embedding_dim,
vector_index_name=vector_index_name,
ssl_cert_path=get_ssl_cert_path(),
).create_index()
store = MongoDBVectorStore(
connection_string=connection_string,
database_name=database_name,
collection_name=collection_name,
vector_index_name=vector_index_name,
embedding_dimension=embedding_dim,
document_type=TechArticle,
ssl_cert_path=get_ssl_cert_path(),
)
logger.info("Connected", documents=store.count())
# ── Build documents ───────────────────────────────────────────────────────
docs = build_sample_documents(embedder)
# ── Run examples ──────────────────────────────────────────────────────────
examples = [
("1: Index Documents", lambda: example_index_documents(store, docs)),
("2: Vector Search", lambda: example_vector_search(store, embedder)),
("3: Keyword Search", lambda: example_keyword_search(store)),
("4: Hybrid Search", lambda: example_hybrid_search(store, embedder)),
("5: Filtered Search", lambda: example_filtered_search(store, embedder)),
("6: CRUD", lambda: example_crud(store)),
("7: RAG Pipeline", lambda: example_rag_pipeline(embedder)),
]
for name, fn in examples:
logger.info("")
try:
fn()
except RuntimeError as exc:
logger.error("Example failed", example=name, error=str(exc))
except Exception as exc:
logger.error("Example failed with unexpected error", example=name, error=str(exc))
logger.info("")
store.close()
logger.info("SUMMARY")
logger.info("MongoDBVectorStore",
description="Vector search on MongoDB Atlas using $vectorSearch aggregation",
search_types=["vector", "keyword", "hybrid"],
filtering="Pass filters dict to any search call",
crud="get_document / update_document / delete_documents",
rag_pipeline="Chunk → embed_batch → add_documents → vector_search")
logger.info("Production Tips",
tip_1="Use MongoDBIndexBuilder to provision vector search and text indexes before first use — keeps schema management separate from document operations",
tip_2="Index name in .env (MONGODB_VECTOR_INDEX) must match the vector_index_name passed to MongoDBIndexBuilder",
tip_3="Use hybrid search as the default — best balance of semantic + lexical recall",
tip_4="Keep embedding_dimension consistent with your Azure OpenAI deployment",
tip_5="Reuse the embedder instance across examples to avoid re-initialising the HTTP client",
tip_6="SemanticChunker(nltk.sent_tokenize, max_chunk_size=500) produces coherent topic-aligned chunks")
if __name__ == "__main__":
main()
query_example.py
"""
Example: Query Processing Strategies (Module 4)
Demonstrates all 5 query optimization techniques before retrieval:
1. Query Decomposer — Break complex queries into focused sub-queries
2. Query Router — Route queries to the right index automatically
3. Query Expander — Generate rephrasing variations for better recall
4. Query Rewriter — Clean and clarify queries before retrieval
5. HyDE Generator — Hypothetical Document Embeddings for better vector search
How this connects to existing examples:
- Uses the same 4 Azure AI Search indexes from azure_ai_search_vector_store_example.py
(legal_documents, products, financial_reports, ai_ml_knowledge)
- HyDE and Expander examples retrieve from ai_ml_knowledge
- Query Router example routes against all 4 indexes
Prerequisites:
- Run azure_ai_search_vector_store_example.py first to populate all 4 indexes
- Azure AI Search + Azure OpenAI credentials in .env
(AZURE_OPENAI_ENDPOINT, AZURE_OPENAI_API_KEY, AZURE_OPENAI_DEPLOYMENT)
"""
import asyncio
import os
import sys
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Optional
from dotenv import load_dotenv
from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings
from gmf_forge_ai_data.vector_stores import AzureAISearchVectorStore, Document
from gmf_forge_ai_data.retrieval import RetrievalQuery, VectorRetriever
from gmf_forge_ai_data.query import (
QueryDecomposer,
QueryRouter,
QueryExpander,
QueryRewriter,
HyDEGenerator,
)
from gmf_forge_ai_shared_core.llm_gateway import UnifiedLLMGateway
from gmf_forge_ai_shared_core.llm_gateway.providers import AzureOpenAIProvider
from gmf_forge_ai_shared_core.observability import BasicLogger
from gmf_forge_ai_shared_core.observability.tracing import get_tracer
# ── Environment ──────────────────────────────────────────────────────────────
env_path = Path(__file__).parent / ".env"
load_dotenv(env_path)
logger = BasicLogger(__name__)
WORKSPACE_ROOT = Path(__file__).parent.parent.parent.parent
CORPORATE_CERT = WORKSPACE_ROOT / "certs" / "gmf_and_public_cas.pem"
# ── Document schemas (same as azure_ai_search_vector_store_example.py / retrieval_example.py) ──
@dataclass
class LegalDocument(Document):
case_number: str = ""
court: str = ""
jurisdiction: str = ""
decision_date: Optional[datetime] = None
case_type: str = ""
source: str = ""
page_number: Optional[int] = None
@dataclass
class ProductDocument(Document):
sku: str = ""
category: str = ""
price: float = 0.0
in_stock: bool = True
brand: str = ""
rating: float = 0.0
review_count: int = 0
@dataclass
class FinancialDocument(Document):
document_type: str = ""
fiscal_year: int = 0
quarter: str = ""
company_ticker: str = ""
sector: str = ""
report_date: Optional[datetime] = None
source: str = ""
page_number: Optional[int] = None
@dataclass
class AIMLDocument(Document):
topic: str = ""
category: str = ""
# ── Config helpers ───────────────────────────────────────────────────────────
def get_ssl_cert_path() -> Optional[str]:
return str(CORPORATE_CERT) if CORPORATE_CERT.exists() else None
def get_azure_search_config() -> Optional[dict]:
endpoint = os.getenv("AZURE_AI_SEARCH_ENDPOINT")
api_key = os.getenv("AZURE_AI_SEARCH_API_KEY")
embedding_dim = int(os.getenv("AZURE_AI_SEARCH_EMBEDDING_DIMENSION", "1536"))
if not endpoint or not api_key:
logger.warning("Azure Search credentials not found in .env file",
missing=["AZURE_AI_SEARCH_ENDPOINT", "AZURE_AI_SEARCH_API_KEY"])
return None
return {"endpoint": endpoint, "api_key": api_key, "embedding_dimension": embedding_dim}
def get_embedder() -> Optional[AzureOpenAIEmbeddings]:
endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
api_key = os.getenv("AZURE_OPENAI_API_KEY")
deployment = os.getenv("AZURE_OPENAI_EMBEDDING_MODEL")
api_version = os.getenv("AZURE_OPENAI_EMBEDDING_MODEL_VERSION", "2024-02-01")
if not endpoint or not api_key or not deployment:
logger.warning("Azure OpenAI embedding credentials not found in .env file")
return None
ssl_cert = get_ssl_cert_path()
if ssl_cert:
logger.info("Using SSL certificate", cert=CORPORATE_CERT.name)
embedder = AzureOpenAIEmbeddings(
endpoint=endpoint,
api_key=api_key,
deployment_name=deployment,
api_version=api_version,
ssl_cert_path=ssl_cert,
)
logger.info("Initialized embeddings", deployment=deployment)
return embedder
def get_llm_gateway() -> Optional[UnifiedLLMGateway]:
"""Initialize UnifiedLLMGateway backed by Azure OpenAI. Returns None if not configured."""
endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
api_key = os.getenv("AZURE_OPENAI_API_KEY")
deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
api_version = os.getenv("AZURE_OPENAI_API_VERSION", "2025-01-01-preview")
if not endpoint or not api_key or not deployment:
return None
ssl_cert = get_ssl_cert_path()
provider = AzureOpenAIProvider(
endpoint=endpoint,
api_key=api_key,
deployment_name=deployment,
api_version=api_version,
ssl_cert_path=ssl_cert,
)
gateway = UnifiedLLMGateway(default_provider=provider)
logger.info("Initialized LLM gateway", deployment=deployment)
return gateway
def get_vector_store(config: dict, index_name: str, doc_type) -> Optional[AzureAISearchVectorStore]:
try:
store = AzureAISearchVectorStore(
endpoint=config["endpoint"],
index_name=index_name,
api_key=config["api_key"],
embedding_dimension=config["embedding_dimension"],
document_type=doc_type,
)
count = store.count()
logger.info("Connected to index", index=index_name, documents=count)
return store
except Exception as e:
logger.warning("Could not connect to index", index=index_name, error=str(e))
return None
# ── Route definitions ────────────────────────────────────────────────────────
INDEX_ROUTES = {
"legal_documents": "Legal cases, court decisions, jurisdiction, antitrust, patent infringement, civil",
"products": "Products, prices, inventory, electronics, camera, headphones, furniture, brand",
"financial_reports": "Earnings, revenue, fiscal year, company financials, SEC filings, quarter, ticker",
"ai_ml_knowledge": "Machine learning, AI, neural networks, deep learning, NLP, transformers, computer vision",
}
# ══════════════════════════════════════════════════════════════════════════════
# Example 1: Query Decomposer
# ══════════════════════════════════════════════════════════════════════════════
async def example_query_decomposer(gateway: UnifiedLLMGateway):
"""Demonstrate query decomposition — complex query → focused sub-queries."""
logger.info("Starting example", example="1: Query Decomposer")
logger.info("Scenario",
problem="A single broad query returns unfocused, mixed results",
solution="Decompose into sub-queries, retrieve for each separately")
complex_queries = [
"What are the antitrust violations and what patents were infringed in 2024?",
"Tell me about machine learning algorithms and how deep learning works and what transformers do",
"What is the revenue of Microsoft and how did Apple perform in fiscal year 2024?",
]
# temperature=0.0 → deterministic splits (default); raise to ~0.2 for variety
decomposer = QueryDecomposer(gateway, temperature=0.0)
for query in complex_queries:
result = await decomposer.decompose(query, max_sub_queries=3)
logger.info("Decomposed query", original=query, sub_queries=result.sub_queries)
logger.info("Decomposed queries can now be run in parallel for precise results")
# ══════════════════════════════════════════════════════════════════════════════
# Example 2: Query Router
# ══════════════════════════════════════════════════════════════════════════════
async def example_query_router(gateway: UnifiedLLMGateway):
"""Demonstrate query routing — automatically select the right index."""
logger.info("Starting example", example="2: Query Router")
logger.info("Scenario",
problem="Searching all indexes for every query is slow and noisy",
solution="Route each query to the single most relevant index")
logger.info("Configured routes", routes={name: desc[:55] for name, desc in INDEX_ROUTES.items()})
test_queries = [
("What antitrust cases were filed in 2024?", "legal_documents"),
("Show me Canon DSLR cameras under $4000", "products"),
("What were Microsoft's Q1 2024 earnings?", "financial_reports"),
("How do convolutional neural networks work?", "ai_ml_knowledge"),
("Sony WH-1000XM5 price and availability", "products"),
]
# temperature=0.0 → deterministic routing (keep low — consistency matters here)
router = QueryRouter(routes=INDEX_ROUTES, llm_gateway=gateway, temperature=0.0)
for query, expected in test_queries:
decision = await router.route(query)
matched = decision.target == expected
logger.info("Route decision", query=query[:47], target=decision.target,
confidence=round(decision.confidence, 3), expected=expected, matched=matched)
logger.info("Query router directs queries to the right index without scanning all")
# ══════════════════════════════════════════════════════════════════════════════
# Example 3: Query Expander
# ══════════════════════════════════════════════════════════════════════════════
async def example_query_expander(
gateway: UnifiedLLMGateway,
vector_store: Optional[AzureAISearchVectorStore],
embedder: Optional[AzureOpenAIEmbeddings],
):
"""Demonstrate query expansion — more variations → better recall."""
logger.info("Starting example", example="3: Query Expander")
logger.info("Scenario",
problem="A narrow query misses semantically related documents",
solution="Expand to synonyms/rephrasing and merge results")
queries_to_expand = [
"machine learning techniques",
"antitrust violations",
"neural network models",
]
# temperature=0.3 → creative variation (default); raise to ~0.7 for more diversity
expander = QueryExpander(gateway, temperature=0.3)
for query in queries_to_expand:
result = await expander.expand(query, num_expansions=3)
logger.info("Expanded query", original=query, expansions=result.expansions or [])
if vector_store and embedder:
logger.info("Live retrieval comparison", mode="original vs expanded query")
query = "machine learning techniques"
expanded = await expander.expand(query, num_expansions=2)
retriever = VectorRetriever(vector_store)
all_content = set()
query_emb = embedder.embed_text(query)
orig_results = retriever.retrieve(RetrievalQuery(embedding=query_emb, top_k=3))
logger.info("Original results", query=query, count=len(orig_results))
for r in orig_results:
topic = getattr(r.document, "topic", "?")
logger.info("Original result", rank=r.rank, topic=topic, content_preview=r.document.content[:60])
all_content.add(r.document.id)
new_docs = 0
for exp_query in expanded.expansions:
exp_emb = embedder.embed_text(exp_query)
exp_results = retriever.retrieve(RetrievalQuery(embedding=exp_emb, top_k=3))
for r in exp_results:
if r.document.id not in all_content:
new_docs += 1
all_content.add(r.document.id)
logger.info("Expansion result",
expansion=expanded.expansions[0] if expanded.expansions else query,
additional_unique_docs=new_docs)
logger.info("Query expansion increases recall by covering synonym/rephrasing space")
# ══════════════════════════════════════════════════════════════════════════════
# Example 4: Query Rewriter
# ══════════════════════════════════════════════════════════════════════════════
async def example_query_rewriter(gateway: UnifiedLLMGateway):
"""Demonstrate query rewriting — fix bad queries before they hit the retriever."""
logger.info("Starting example", example="4: Query Rewriter")
logger.info("Scenario",
problem="'tell me about the apple thing' retrieves noise",
solution="Rewrite to precise retrieval language before search")
test_cases = [
("tell me about the antitrust stuff", "legal documents"),
("show me machine learning docs please", "AI/ML knowledge base"),
("find me microsoft earnings info", "financial reports"),
("can you give me camera products", "product catalog"),
("What are neural networks?", None),
]
# temperature=0.0 → deterministic rewrites (default); keep low for reproducibility
rewriter = QueryRewriter(gateway, temperature=0.0)
for query, context in test_cases:
result = await rewriter.rewrite(query, context=context)
changed = result.rewritten != result.original
logger.info("Rewrite result", original=result.original[:44], rewritten=result.rewritten, changed=changed)
logger.info("Changes detail", query=test_cases[-2][0])
last_query, last_context = test_cases[-2]
detail = await rewriter.rewrite(last_query, context=last_context)
for change in detail.changes:
logger.info("Change", detail=change)
logger.info("Rewriting improves retrieval precision on conversational queries")
# ══════════════════════════════════════════════════════════════════════════════
# Example 5: HyDE Generator
# ══════════════════════════════════════════════════════════════════════════════
async def example_hyde_generator(
gateway: UnifiedLLMGateway,
vector_store: Optional[AzureAISearchVectorStore],
embedder: Optional[AzureOpenAIEmbeddings],
):
"""Demonstrate HyDE — hypothetical document for better vector retrieval."""
logger.info("Starting example", example="5: HyDE Generator (Hypothetical Document Embeddings)")
logger.info("Scenario",
problem="Short query embeds poorly vs. answer-length documents",
solution="Generate a hypothetical answer, embed that instead")
hyde = HyDEGenerator(llm_gateway=gateway, embedder=embedder)
queries = [
("How do transformers work in NLP?", "AI/ML knowledge base"),
("What penalties apply for antitrust violations?", "legal documents"),
]
for query, domain in queries:
logger.info("HyDE query", query=query, domain=domain)
hypo = await hyde.generate(query, domain=domain)
preview = hypo.hypothetical_doc[:200]
if len(hypo.hypothetical_doc) > 200:
preview += "..."
logger.info("Hypothetical document", word_count=len(hypo.hypothetical_doc.split()), preview=preview)
if vector_store and embedder and "AI/ML" in domain:
retriever = VectorRetriever(vector_store)
query_emb = embedder.embed_text(query)
std_results = retriever.retrieve(RetrievalQuery(embedding=query_emb, top_k=3))
hypo_full = await hyde.generate_and_embed(query, domain=domain)
hyde_results = retriever.retrieve(
RetrievalQuery(embedding=hypo_full.embedding, top_k=3)
)
logger.info("Standard retrieval (query embedded directly)")
for r in std_results:
topic = getattr(r.document, "topic", "?")
logger.info("Standard result", rank=r.rank, score=round(r.score, 4), topic=topic, content_preview=r.document.content[:50])
logger.info("HyDE retrieval (hypothetical doc embedded)")
for r in hyde_results:
topic = getattr(r.document, "topic", "?")
logger.info("HyDE result", rank=r.rank, score=round(r.score, 4), topic=topic, content_preview=r.document.content[:50])
std_top = std_results[0].document.id if std_results else None
hyde_top = hyde_results[0].document.id if hyde_results else None
if std_top != hyde_top:
logger.info("HyDE retrieved a different top document than standard search")
else:
logger.info("Same top document - both strategies agree on this query")
logger.info("HyDE embeds a hypothetical answer instead of the raw query",
tip="Best for short/ambiguous queries against a large document corpus")
# ══════════════════════════════════════════════════════════════════════════════
# Pipeline Demo: Combining all 5 strategies
# ══════════════════════════════════════════════════════════════════════════════
async def example_full_pipeline(
gateway: UnifiedLLMGateway,
vector_store: Optional[AzureAISearchVectorStore],
embedder: Optional[AzureOpenAIEmbeddings],
):
"""
Show the full query processing pipeline feeding into retrieval.
Pipeline: raw query
→ Rewriter (clean)
→ Router (pick index)
→ Expander (variations)
→ Retriever (per variation, merge results)
"""
logger.info("Starting example", example="Pipeline Demo: Full Query Processing to Retrieval")
raw_query = "tell me about deep learning stuff for images"
logger.info("Raw user query", query=raw_query)
rewriter = QueryRewriter(gateway, temperature=0.0)
router = QueryRouter(routes=INDEX_ROUTES, llm_gateway=gateway, temperature=0.0)
expander = QueryExpander(gateway, temperature=0.3)
tracer = get_tracer()
with tracer.trace("query_pipeline", input=raw_query) as trace:
# ── Step 1: Rewrite ───────────────────────────────────────────────────
with trace.span("rewrite", input=raw_query) as span:
rewritten = await rewriter.rewrite(raw_query, context="AI/ML knowledge base")
logger.info("Step 1 - Rewriter", original=raw_query, result=rewritten.rewritten)
span.set_output({"rewritten": rewritten.rewritten})
# ── Step 2: Route ─────────────────────────────────────────────────────
with trace.span("route", input=rewritten.rewritten) as span:
decision = await router.route(rewritten.rewritten)
logger.info("Step 2 - Router", target=decision.target, confidence=round(decision.confidence, 3))
span.set_output({"target": decision.target, "confidence": round(decision.confidence, 3)})
# ── Step 3: Expand ────────────────────────────────────────────────────
with trace.span("expand", input=rewritten.rewritten) as span:
expanded = await expander.expand(rewritten.rewritten, num_expansions=2)
all_queries = [rewritten.rewritten] + expanded.expansions
logger.info("Step 3 - Expander", total_queries=len(all_queries), queries=all_queries)
span.set_output({"total_queries": len(all_queries), "expansions": expanded.expansions})
# ── Step 4: Retrieve ──────────────────────────────────────────────────
with trace.span("retrieval", input=all_queries) as span:
if vector_store and embedder:
logger.info("Step 4 - Retrieve", index=decision.target, query_variations=len(all_queries))
retriever = VectorRetriever(vector_store)
seen_ids = set()
all_results = []
for q in all_queries:
emb = embedder.embed_text(q)
results = retriever.retrieve(RetrievalQuery(embedding=emb, top_k=3))
for r in results:
if r.document.id not in seen_ids:
seen_ids.add(r.document.id)
all_results.append(r)
all_results.sort(key=lambda r: r.score, reverse=True)
logger.info("Merged results", unique_docs=len(all_results))
for i, r in enumerate(all_results[:5], 1):
topic = getattr(r.document, "topic", "?")
logger.info("Merged result", rank=i, score=round(r.score, 4), topic=topic, content_preview=r.document.content[:55])
span.set_output({"unique_docs": len(all_results), "top_score": round(all_results[0].score, 4) if all_results else 0})
else:
logger.info("Step 4 - Retrieve: skipped (no vector store available)")
all_results = []
span.set_output({"skipped": True})
trace.set_output({
"rewritten": rewritten.rewritten,
"routed_to": decision.target,
"query_variations": len(all_queries),
"unique_results": len(all_results),
})
logger.info("Full pipeline complete",
pipeline="noisy query -> cleaned -> routed -> expanded -> retrieved")
# ══════════════════════════════════════════════════════════════════════════════
# Main
# ══════════════════════════════════════════════════════════════════════════════
def main():
asyncio.run(_main())
async def _main():
logger.info("QUERY PROCESSING EXAMPLES (Module 4)",
note="Demonstrating 5 query optimization strategies for RAG pipelines")
gateway = get_llm_gateway()
if not gateway:
logger.error("LLM gateway is required for all query examples",
missing_vars=["AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_API_KEY", "AZURE_OPENAI_DEPLOYMENT"])
sys.exit(1)
config = get_azure_search_config()
embedder = get_embedder() if config else None
vector_store = None
if config and embedder:
vector_store = get_vector_store(config, "ai_ml_knowledge", AIMLDocument)
logger.info("")
await example_query_decomposer(gateway)
logger.info("")
await example_query_router(gateway)
logger.info("")
await example_query_expander(gateway, vector_store, embedder)
logger.info("")
await example_query_rewriter(gateway)
logger.info("")
await example_hyde_generator(gateway, vector_store, embedder)
logger.info("")
await example_full_pipeline(gateway, vector_store, embedder)
logger.info("")
logger.info("SUMMARY")
logger.info("QueryDecomposer",
description="Splits multi-part queries into focused sub-queries",
best_for="complex research questions, multi-concept queries")
logger.info("QueryRouter",
description="Selects the single best index/retriever per query",
best_for="multi-domain systems (Legal, Finance, Products, AI)")
logger.info("QueryExpander",
description="Generates synonym / rephrasing variations",
best_for="narrow queries that miss related documents")
logger.info("QueryRewriter",
description="Cleans conversational or ambiguous queries",
best_for="user-facing chatbots, natural language inputs")
logger.info("HyDEGenerator",
description="Generates hypothetical answer doc, embeds that instead of query",
best_for="short/vague queries, zero-shot retrieval")
logger.info("Production Tips",
tip_1="Chain Rewriter -> Router -> Expander before every retrieval call",
tip_2="Use Decomposer for complex multi-part research questions",
tip_3="Use HyDE for high-precision zero-shot retrieval",
tip_4="Pair Expander results with EnsembleRetriever for RRF fusion")
if __name__ == "__main__":
main()
retrieval_example.py
"""
Example: Using various retrieval strategies.
This example demonstrates:
1. Basic retrievers (Vector, Keyword, Hybrid) across multiple document types
2. MMR retriever for diverse results
3. Parent Document retriever for better context
4. Ensemble retriever combining multiple strategies
5. Hierarchical retriever for two-stage search
6. Graph retriever leveraging entity relationships
7. SQL retriever for structured data
8. Multi-Index retriever for federated search
Prerequisites:
- Run azure_ai_search_vector_store_example.py first to create all indexes:
* legal_documents (LegalDocument)
* products (ProductDocument)
* financial_reports (FinancialDocument)
* ai_ml_knowledge (AIMLDocument)
- Azure AI Search credentials in .env
- Azure OpenAI embeddings credentials in .env
"""
import os
import sqlite3
import tempfile
import time
from dataclasses import dataclass
from pathlib import Path
from datetime import datetime
from typing import Optional
from dotenv import load_dotenv
from gmf_forge_ai_shared_core.observability import BasicLogger, BasicMetricsCollector
from gmf_forge_ai_shared_core.observability.tracing import get_tracer
from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings
from gmf_forge_ai_data.vector_stores import AzureAISearchVectorStore, InMemoryVectorStore, Document
from gmf_forge_ai_data.retrieval import (
RetrievalQuery,
VectorRetriever,
KeywordRetriever,
HybridRetriever,
MMRRetriever,
ParentDocumentRetriever,
EnsembleRetriever,
HierarchicalRetriever,
GraphRetriever,
SQLRetriever,
SQLSchema,
MultiIndexRetriever,
SourceConfig,
)
# Load environment variables
env_path = Path(__file__).parent / ".env"
load_dotenv(env_path)
logger = BasicLogger(__name__)
metrics = BasicMetricsCollector()
# Corporate SSL certificate path
WORKSPACE_ROOT = Path(__file__).parent.parent.parent.parent
CORPORATE_CERT = WORKSPACE_ROOT / "certs" / "gmf_and_public_cas.pem"
# Custom document schemas for different domain types
@dataclass
class LegalDocument(Document):
"""Legal document with case-specific fields."""
case_number: str = ""
court: str = ""
jurisdiction: str = ""
decision_date: Optional[datetime] = None
case_type: str = ""
source: str = ""
page_number: Optional[int] = None
@dataclass
class ProductDocument(Document):
"""Product document for e-commerce."""
sku: str = ""
category: str = ""
price: float = 0.0
in_stock: bool = True
brand: str = ""
rating: float = 0.0
review_count: int = 0
@dataclass
class FinancialDocument(Document):
"""Financial report document."""
document_type: str = "" # 10-K, 10-Q, 8-K, etc.
fiscal_year: int = 0
quarter: str = "" # Q1, Q2, Q3, Q4, FY
company_ticker: str = ""
sector: str = ""
report_date: Optional[datetime] = None
source: str = ""
page_number: Optional[int] = None
@dataclass
class AIMLDocument(Document):
"""AI/ML knowledge base document with indexed custom fields."""
topic: str = "" # Specific topic like "deep_learning", "CNNs", "transformers"
category: str = "" # Broader category like "ML", "NLP", "CV"
# Returns corporate SSL certificate path if available, otherwise None.
def get_ssl_cert_path() -> str:
"""Get SSL certificate path if it exists."""
if CORPORATE_CERT.exists():
return str(CORPORATE_CERT)
return None
# Loads Azure AI Search credentials and configuration from environment variables.
def get_azure_search_config():
"""Load Azure Search configuration from environment variables."""
endpoint = os.getenv('AZURE_AI_SEARCH_ENDPOINT')
api_key = os.getenv('AZURE_AI_SEARCH_API_KEY')
embedding_dim = int(os.getenv('AZURE_AI_SEARCH_EMBEDDING_DIMENSION', '1536'))
if not endpoint or not api_key:
logger.warning("Azure Search credentials not found in .env file",
missing=["AZURE_AI_SEARCH_ENDPOINT", "AZURE_AI_SEARCH_API_KEY"])
return None
return {
'endpoint': endpoint,
'api_key': api_key,
'embedding_dimension': embedding_dim
}
# Initializes Azure OpenAI embeddings client with SSL support.
def get_embedder():
"""Initialize Azure OpenAI embeddings from environment variables."""
endpoint = os.getenv('AZURE_OPENAI_ENDPOINT')
api_key = os.getenv('AZURE_OPENAI_API_KEY')
deployment = os.getenv('AZURE_OPENAI_EMBEDDING_MODEL')
api_version = os.getenv('AZURE_OPENAI_EMBEDDING_MODEL_VERSION', '2024-02-01')
if not endpoint or not api_key or not deployment:
logger.warning("Azure OpenAI credentials not found in .env file")
return None
ssl_cert = get_ssl_cert_path()
if ssl_cert:
logger.info("Using SSL certificate", cert=CORPORATE_CERT.name)
try:
embedder = AzureOpenAIEmbeddings(
endpoint=endpoint,
api_key=api_key,
deployment_name=deployment,
api_version=api_version,
ssl_cert_path=ssl_cert
)
logger.info("Initialized embeddings", deployment=deployment)
return embedder
except Exception as e:
logger.warning("Failed to initialize embedder", error=str(e))
return None
# Connects to an existing Azure AI Search index and returns vector store instance.
def get_vector_store(config, index_name, document_type=None, verbose=True):
"""Connect to Azure Search vector store with existing documents.
Args:
config: Azure Search configuration dict
index_name: Name of the index to connect to
document_type: Optional document type class (e.g., AIMLDocument, LegalDocument)
verbose: Whether to print connection status
Returns:
AzureAISearchVectorStore instance or None if connection fails
"""
if verbose:
logger.info("Connecting to Azure AI Search index", index=index_name)
store = AzureAISearchVectorStore(
endpoint=config['endpoint'],
index_name=index_name,
api_key=config['api_key'],
embedding_dimension=config['embedding_dimension'],
document_type=document_type,
)
# Check if index has documents
doc_count = store.count()
if verbose:
logger.info("Connected to index", index=index_name, documents=doc_count)
if doc_count == 0:
logger.warning("Index is empty - please run azure_ai_search_vector_store_example.py first",
index=index_name)
return None
return store
# NOTE: In production, choose the right basic retriever for your use case:
# 1. VectorRetriever: Best for semantic/conceptual queries ("machine learning algorithms")
# 2. KeywordRetriever: Best for exact term matching (product SKUs, case numbers, technical terms)
# 3. HybridRetriever: Recommended default - combines semantic + lexical for balanced results
# Scoring: Vector (0-1 cosine), Keyword (0-∞ BM25), Hybrid (~0-0.05 Azure Search fusion)
# Works across any document schema with indexed fields.
def example_basic_retrievers(config, embedder):
"""Demonstrate basic retrieval strategies across different document types."""
logger.info("Starting example", example="1: Basic Retrievers (Multiple Document Types)")
logger.info("Searching Legal Documents")
legal_store = get_vector_store(config, "legal_documents", LegalDocument, verbose=False)
if not legal_store:
logger.info("Skipping legal documents example")
else:
query_text = "antitrust violations"
logger.info("Legal query", query=query_text)
query_embedding = embedder.embed_text(query_text)
vector_retriever = VectorRetriever(legal_store)
results = vector_retriever.retrieve_embedding(embedding=query_embedding, top_k=2)
for r in results:
legal_doc = r.document
logger.info("Legal result", rank=r.rank+1, score=round(r.score, 4), case_number=legal_doc.case_number, content_preview=r.document.content[:70])
logger.info("Searching Product Catalog")
product_store = get_vector_store(config, "products", ProductDocument, verbose=False)
if not product_store:
logger.info("Skipping products example")
else:
query_text = "professional camera"
logger.info("Product query", query=query_text)
query_embedding = embedder.embed_text(query_text)
keyword_retriever = KeywordRetriever(product_store)
results = keyword_retriever.retrieve_text(text=query_text, top_k=2)
for r in results:
product_doc = r.document
logger.info("Product result", rank=r.rank+1, score=round(r.score, 4),
brand=product_doc.brand, sku=product_doc.sku, price=product_doc.price,
content_preview=r.document.content[:70])
logger.info("Searching Financial Reports")
financial_store = get_vector_store(config, "financial_reports", FinancialDocument, verbose=False)
if not financial_store:
logger.info("Skipping financial reports example")
else:
query_text = "technology sector revenue"
logger.info("Financial query", query=query_text)
query_embedding = embedder.embed_text(query_text)
hybrid_retriever = HybridRetriever(financial_store)
query = RetrievalQuery(text=query_text, embedding=query_embedding, top_k=2)
results = hybrid_retriever.retrieve(query)
for r in results:
fin_doc = r.document
logger.info("Financial result", rank=r.rank+1, score=round(r.score, 4),
ticker=fin_doc.company_ticker, doc_type=fin_doc.document_type,
quarter=fin_doc.quarter, fiscal_year=fin_doc.fiscal_year,
content_preview=r.document.content[:70])
logger.info("Basic retrievers work across all document types")
# NOTE: In production, MMRRetriever addresses result redundancy:
# 1. Problem: Standard vector search returns similar/redundant documents
# 2. Solution: MMR balances relevance (similarity to query) with diversity (dissimilarity to selected docs)
# 3. Lambda parameter: λ=1.0 (pure relevance), λ=0.0 (pure diversity), λ=0.5 (balanced)
# 4. fetch_k parameter: Retrieve N candidates (20-50), then apply MMR to select top_k diverse results
# Use when: Search results look too similar or users need broader coverage of a topic.
def example_mmr_retriever(vector_store, embedder):
"""Demonstrate MMR for diverse results."""
logger.info("Starting example", example="2: MMR Retriever (Diverse Results)")
query_text = "machine learning techniques"
logger.info("Query", query=query_text)
query_embedding = embedder.embed_text(query_text)
logger.info("Standard Vector Search")
vector_retriever = VectorRetriever(vector_store)
results = vector_retriever.retrieve_embedding(
embedding=query_embedding,
top_k=5
)
for r in results:
topic = getattr(r.document, 'topic', 'unknown')
logger.info("Vector result", rank=r.rank+1, topic=topic, content_preview=r.document.content[:60])
logger.info("MMR Search (balanced diversity, lambda=0.5)")
mmr_retriever = MMRRetriever(
vector_store=vector_store,
lambda_param=0.5,
fetch_k=20
)
results = mmr_retriever.retrieve_embedding(
embedding=query_embedding,
top_k=5
)
for r in results:
topic = getattr(r.document, 'topic', 'unknown')
logger.info("MMR result", rank=r.rank+1, topic=topic, content_preview=r.document.content[:60])
logger.info("MMR produces more diverse results covering different topics")
# NOTE: In production, ParentDocumentRetriever architecture typically involves:
# 1. Storage: Separate child index (chunks with embeddings) + parent store (full docs, no vectors)
# 2. Chunking: Use RecursiveChunker or SemanticChunker to split parents into precise children
# 3. Ingestion: Store parent_id in child metadata, batch embed children only
# 4. Scale considerations: Fetch_k multiplier (3x), deduplication ratio monitoring
# This example uses in-memory stores for demonstration.
def example_parent_document_retriever(embedder):
"""Demonstrate parent document retrieval."""
logger.info("Starting example", example="3: Parent Document Retriever")
parent1_content = """Artificial Intelligence Overview:
AI encompasses machine learning, deep learning, natural language processing,
and computer vision. These technologies enable machines to perform tasks
that typically require human intelligence."""
parent2_content = """Deep Learning Technologies:
Deep learning includes convolutional neural networks for images,
recurrent neural networks for sequences, and transformers for
natural language understanding."""
parent1 = Document(
id="parent_1",
content=parent1_content,
embedding=embedder.embed_text(parent1_content),
timestamp=datetime.now()
)
parent2 = Document(
id="parent_2",
content=parent2_content,
embedding=embedder.embed_text(parent2_content),
timestamp=datetime.now()
)
parent_store = InMemoryVectorStore()
parent_store.add_documents([parent1, parent2])
children_data = [
("parent_1", "AI encompasses machine learning and deep learning."),
("parent_1", "Natural language processing enables text understanding."),
("parent_1", "Computer vision allows image interpretation."),
("parent_2", "Convolutional neural networks process images."),
("parent_2", "Transformers revolutionized NLP with attention."),
]
children = []
for parent_id, chunk_content in children_data:
child = Document(
id=f"{parent_id}_chunk_{len(children)}",
content=chunk_content,
embedding=embedder.embed_text(chunk_content),
timestamp=datetime.now(),
metadata={"parent_id": parent_id}
)
children.append(child)
child_store = InMemoryVectorStore()
child_store.add_documents(children)
logger.info("Setup", parent_docs=2, child_chunks=len(children))
query_text = "transformers in NLP"
logger.info("Query", query=query_text)
query_embedding = embedder.embed_text(query_text)
retriever = ParentDocumentRetriever(
child_store=child_store,
parent_store=parent_store,
parent_id_key="parent_id"
)
results = retriever.retrieve_embedding(
embedding=query_embedding,
top_k=2
)
logger.info("Results (full parent documents)")
for r in results:
logger.info("Parent result", rank=r.rank+1, parent_id=r.document.id, score=round(r.score, 4), content_preview=r.document.content[:150])
logger.info("Retrieved full parents via precise child chunk matching")
# NOTE: In production, EnsembleRetriever is useful for:
# 1. Robustness: Combining vector (semantic) + keyword (lexical) + MMR (diversity)
# 2. Fusion strategies: RRF (rank-based), weighted average, or max score
# 3. Weights tuning: Start with equal weights, adjust based on A/B testing results
# 4. Performance: Run retrievers in parallel, merge results efficiently
# Recommended: Hybrid (vector+keyword) ensemble for most applications.
def example_ensemble_retriever(vector_store, embedder):
"""Demonstrate ensemble retrieval."""
logger.info("Starting example", example="4: Ensemble Retriever (Fusion)")
query_text = "learning algorithms"
logger.info("Query", query=query_text)
query_embedding = embedder.embed_text(query_text)
# Create individual retrievers
vector_retriever = VectorRetriever(vector_store)
keyword_retriever = KeywordRetriever(vector_store)
mmr_retriever = MMRRetriever(vector_store, lambda_param=0.7, fetch_k=15)
# Create ensemble with RRF fusion
ensemble = EnsembleRetriever(
retrievers=[vector_retriever, keyword_retriever, mmr_retriever],
weights=[0.5, 0.3, 0.2], # Favor vector, then keyword, then MMR
fusion_strategy="rrf" # Reciprocal Rank Fusion
)
query = RetrievalQuery(
text=query_text,
embedding=query_embedding,
top_k=5
)
results = ensemble.retrieve(query)
logger.info("Ensemble Results (RRF fusion of 3 retrievers)")
for r in results:
topic = getattr(r.document, 'topic', 'unknown')
logger.info("Ensemble result", rank=r.rank+1, score=round(r.score, 4), topic=topic, content_preview=r.document.content[:70])
logger.info("Ensemble combines strengths of multiple retrieval strategies")
# NOTE: In production, HierarchicalRetriever requires:
# 1. Storage: Summary index (document-level embeddings) + chunk index (detailed chunks)
# 2. document_id_field must be stored in chunk metadata pointing to its parent summary doc id
# 3. Use case: Large collections (>10K docs) where scanning all chunks is expensive
# 4. Tuning: stage1_top_k (broad recall), stage2_top_k (precision), score weights
# Alternative: Use retriever-reranker pattern for smaller collections.
def example_hierarchical_retriever(embedder):
"""Demonstrate hierarchical two-stage retrieval."""
logger.info("Starting example", example="5: Hierarchical Retriever (Two-Stage)")
logger.info("Scenario",
stage_1="summary_store: Find relevant topic-level summaries",
stage_2="chunk_store: Retrieve detailed chunks from matched topics")
# --- Summary store: one document per topic group ---
summaries = [
("ml_summary", "Machine learning encompasses supervised learning, unsupervised learning, and reinforcement learning algorithms."),
("dl_summary", "Deep learning uses multi-layer neural networks including CNNs, RNNs, and Transformers for complex pattern recognition."),
("nlp_summary", "Natural language processing enables machines to understand, generate, and translate human language using statistical and neural models."),
]
summary_docs = [
Document(
id=sid,
content=content,
embedding=embedder.embed_text(content),
timestamp=datetime.now()
)
for sid, content in summaries
]
summary_store = InMemoryVectorStore()
summary_store.add_documents(summary_docs)
# --- Chunk store: multiple fine-grained chunks per topic, linked via metadata ---
chunks_data = [
("ml_chunk_1", "Supervised learning uses labelled training data to learn input-output mappings.", "ml_summary"),
("ml_chunk_2", "Unsupervised learning discovers hidden structure in unlabelled data via clustering and dimensionality reduction.", "ml_summary"),
("ml_chunk_3", "Reinforcement learning trains agents to maximise cumulative reward through environment interaction.", "ml_summary"),
("dl_chunk_1", "Convolutional neural networks (CNNs) apply shared filters to extract spatial features from images.", "dl_summary"),
("dl_chunk_2", "Transformers use self-attention mechanisms to model long-range dependencies in sequences.", "dl_summary"),
("dl_chunk_3", "Recurrent neural networks (RNNs) and LSTMs process sequential data by maintaining hidden state.", "dl_summary"),
("nlp_chunk_1", "Tokenisation splits raw text into sub-word units used as model inputs.", "nlp_summary"),
("nlp_chunk_2", "Sentiment analysis classifies the emotional tone of text as positive, negative, or neutral.", "nlp_summary"),
]
chunk_docs = [
Document(
id=cid,
content=content,
embedding=embedder.embed_text(content),
timestamp=datetime.now(),
metadata={"document_id": parent_id} # links chunk → summary
)
for cid, content, parent_id in chunks_data
]
chunk_store = InMemoryVectorStore()
chunk_store.add_documents(chunk_docs)
logger.info("Setup", summary_docs=len(summary_docs), chunk_docs=len(chunk_docs))
summary_retriever = VectorRetriever(summary_store)
chunk_retriever = VectorRetriever(chunk_store)
hierarchical = HierarchicalRetriever(
summary_retriever=summary_retriever,
chunk_retriever=chunk_retriever,
stage1_top_k=2, # identify top 2 topic groups
stage2_top_k=2, # return up to 2 chunks per group
document_id_field="document_id", # metadata key in chunks
combine_scores=True
)
query_text = "machine learning algorithms"
logger.info("Query", query=query_text)
query_embedding = embedder.embed_text(query_text)
query = RetrievalQuery(text=query_text, embedding=query_embedding, top_k=5)
results = hierarchical.retrieve(query)
logger.info("Hierarchical Results (two-stage search)")
for r in results:
parent = r.document.metadata.get('document_id', 'unknown')
logger.info("Hierarchical result", rank=r.rank, score=round(r.score, 4), group=parent, content_preview=r.document.content[:80])
logger.info("Hierarchical retrieval efficient for large document collections")
# NOTE: In production, GraphRetriever depends on knowledge graph construction:
# 1. Automatic extraction: Use NER + relation extraction (spaCy, Hugging Face Transformers)
# 2. Graph databases: Neo4j, Amazon Neptune, Azure Cosmos DB (Gremlin API)
# 3. Domain ontologies: Import pre-built knowledge graphs (medical, scientific)
# 4. Knowledge graph APIs: Wikidata, DBpedia for general knowledge
# Combines graph traversal (BFS with decay) + vector similarity for rich context.
def example_graph_retriever(vector_store, embedder):
"""Demonstrate graph-based retrieval."""
logger.info("Starting example", example="6: Graph Retriever (Entity Relationships)")
logger.info("Scenario", note="Knowledge graph with entity relationships")
try:
import networkx as nx
graph = nx.DiGraph()
graph.add_edge("Python", "Machine Learning", relation="used_in", weight=0.9)
graph.add_edge("Machine Learning", "Deep Learning", relation="includes", weight=0.9)
graph.add_edge("Deep Learning", "Neural Networks", relation="implements", weight=0.95)
graph.add_edge("Neural Networks", "CNNs", relation="type_of", weight=0.85)
graph.add_edge("Neural Networks", "Transformers", relation="type_of", weight=0.90)
graph.add_edge("Machine Learning", "Reinforcement Learning", relation="includes", weight=0.80)
graph.add_edge("NLP", "Transformers", relation="uses", weight=0.95)
entity_docs = {}
all_docs_query = RetrievalQuery(text="machine learning", embedding=embedder.embed_text("machine learning"), top_k=10)
all_results = vector_store.search(query_embedding=all_docs_query.embedding, top_k=10)
for result in all_results:
topic = getattr(result.document, 'topic', '').lower()
if 'introduction' in topic or 'machine' in topic:
entity_docs["Machine Learning"] = entity_docs.get("Machine Learning", []) + [result.document.id]
if 'deep_learning' in topic or 'neural' in topic:
entity_docs["Deep Learning"] = entity_docs.get("Deep Learning", []) + [result.document.id]
entity_docs["Neural Networks"] = entity_docs.get("Neural Networks", []) + [result.document.id]
if 'cnn' in topic or 'convolutional' in topic:
entity_docs["CNNs"] = entity_docs.get("CNNs", []) + [result.document.id]
if 'transformer' in topic:
entity_docs["Transformers"] = entity_docs.get("Transformers", []) + [result.document.id]
entity_docs["NLP"] = entity_docs.get("NLP", []) + [result.document.id]
if 'reinforcement' in topic:
entity_docs["Reinforcement Learning"] = entity_docs.get("Reinforcement Learning", []) + [result.document.id]
graph_retriever = GraphRetriever(
vector_store=vector_store,
knowledge_graph=graph,
entity_document_mapping=entity_docs,
embedder=embedder,
max_hops=2,
combine_vector_scores=True,
graph_weight=0.6,
vector_weight=0.4
)
query_text = "neural networks applications"
logger.info("Graph query", query=query_text)
logger.info("Graph traversal", path="Neural Networks -> CNNs, Transformers -> related docs")
query = RetrievalQuery(text=query_text, top_k=5)
results = graph_retriever.retrieve(query)
logger.info("Graph-based Results (entity relationships)")
for r in results:
topic = getattr(r.document, 'topic', 'unknown')
logger.info("Graph result", rank=r.rank, score=round(r.score, 4), topic=topic, content_preview=r.document.content[:65])
logger.info("Graph retrieval leverages entity relationships for better context")
except ImportError:
logger.warning("NetworkX not installed - skipping graph retriever example",
install_cmd="pip install networkx")
# NOTE: In production, SQLRetriever is best for:
# 1. Structured/tabular data: Product catalogs, customer records, transactions, inventory
# 2. Text-to-SQL: Use LLM-based query generation (OpenAI GPT-4) instead of rule-based
# 3. Schema management: Keep SQLSchema definitions updated, include column descriptions
# 4. Security: Parameterized queries, read-only connections, query validation
# Consider: Combine with vector retrieval for hybrid structured+unstructured search.
def example_sql_retriever():
"""Demonstrate SQL-based structured data retrieval."""
logger.info("Starting example", example="7: SQL Retriever (Structured Data)")
logger.info("Scenario", note="Querying structured product database")
try:
import sqlite3
import tempfile
# Create temporary database
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.db') as f:
db_path = f.name
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# Create products table
cursor.execute("""
CREATE TABLE products (
id INTEGER PRIMARY KEY,
name TEXT,
category TEXT,
price REAL,
description TEXT
)
""")
# Insert sample data
products = [
(1, "Laptop Pro", "electronics", 1299.99, "High-performance laptop for professionals"),
(2, "Wireless Mouse", "electronics", 29.99, "Ergonomic wireless mouse"),
(3, "Office Chair", "furniture", 249.99, "Comfortable ergonomic office chair"),
(4, "Desk Lamp", "furniture", 45.99, "LED desk lamp with adjustable brightness"),
(5, "Smartphone", "electronics", 899.99, "Latest smartphone with advanced features"),
(6, "Bookshelf", "furniture", 159.99, "Wooden bookshelf with 5 shelves"),
]
cursor.executemany("INSERT INTO products VALUES (?, ?, ?, ?, ?)", products)
conn.commit()
# Define schema
schema = SQLSchema(
table_name="products",
columns=[
{"name": "id", "type": "INTEGER", "description": "Product ID"},
{"name": "name", "type": "TEXT", "description": "Product name"},
{"name": "category", "type": "TEXT", "description": "Product category"},
{"name": "price", "type": "REAL", "description": "Price in USD"},
{"name": "description", "type": "TEXT", "description": "Product description"}
],
primary_key="id",
description="E-commerce product catalog"
)
# Create SQL retriever
sql_retriever = SQLRetriever(
db_connection=conn,
schema=schema,
db_type="sqlite",
content_columns=["name", "category", "price", "description"]
)
# Example queries
queries = [
"electronics under 100",
"furniture products",
"smartphone"
]
for query_text in queries:
logger.info("SQL query", query=query_text)
query = RetrievalQuery(text=query_text, top_k=3)
try:
results = sql_retriever.retrieve(query)
logger.info("SQL Results")
for r in results:
row_data = r.document.metadata.get('row_data', {})
logger.info("SQL result", rank=r.rank,
name=row_data.get('name', 'N/A'),
price=row_data.get('price', 0),
category=row_data.get('category', 'N/A'))
except Exception as e:
logger.error("SQL query failed", error=str(e))
logger.info("SQL retrieval enables natural language queries over structured data")
# Cleanup
conn.close()
os.unlink(db_path)
except Exception as e:
logger.warning("SQL example failed", error=str(e))
# Demonstrates federated search across multiple vector stores/indexes with score fusion.
# NOTE: In production, MultiIndexRetriever enables cross-domain/cross-source search:
# 1. Use cases: Search across multiple teams' indexes, multi-tenant applications
# 2. Separate indexes: Different document types, departments, data sources
# 3. Fusion strategy: RRF (Reciprocal Rank Fusion) recommended for balanced results
# 4. Tuning: Adjust source weights and boost_factors based on source quality/relevance
# Tracks retrieval_source in metadata for transparency and debugging.
def example_multi_index_retriever(config, embedder):
"""Demonstrate multi-index federated search."""
logger.info("Starting example", example="8: Multi-Index Retriever (Federated Search)")
logger.info("Scenario", note="Searching across multiple document type indexes")
# Connect to different indexes
legal_store = get_vector_store(config, "legal_documents", LegalDocument, verbose=False)
product_store = get_vector_store(config, "products", ProductDocument, verbose=False)
financial_store = get_vector_store(config, "financial_reports", FinancialDocument, verbose=False)
ai_ml_store = get_vector_store(config, "ai_ml_knowledge", AIMLDocument, verbose=False)
# Create retrievers for each index
retrievers = []
if legal_store:
legal_retriever = HybridRetriever(legal_store)
retrievers.append(SourceConfig("Legal Documents", legal_retriever, weight=1.0, boost_factor=1.0))
if product_store:
product_retriever = HybridRetriever(product_store)
retrievers.append(SourceConfig("Product Catalog", product_retriever, weight=1.0, boost_factor=1.0))
if financial_store:
financial_retriever = HybridRetriever(financial_store)
retrievers.append(SourceConfig("Financial Reports", financial_retriever, weight=1.2, boost_factor=1.1))
if ai_ml_store:
ai_ml_retriever = HybridRetriever(ai_ml_store)
retrievers.append(SourceConfig("AI/ML Knowledge", ai_ml_retriever, weight=1.0, boost_factor=1.0))
if len(retrievers) == 0:
logger.warning("No indexes available for multi-index search")
return
logger.info("Connected to indexes", count=len(retrievers), names=[s.name for s in retrievers])
multi_index = MultiIndexRetriever(
sources=retrievers,
fusion_strategy="rrf"
)
query_text = "technology innovation"
logger.info("Multi-index query", query=query_text, index_count=len(retrievers))
tracer = get_tracer()
with tracer.trace("multi_index_retrieval_pipeline", input=query_text) as trace:
# ── Step 1: Embed query ───────────────────────────────────────────────
with trace.span("embedding", input=query_text) as span:
query_embedding = embedder.embed_text(query_text)
logger.info("Query embedded", dimension=len(query_embedding))
span.set_output({"embedding_dimension": len(query_embedding)})
# ── Step 2: Federated retrieval across all indexes ────────────────────
with trace.span("retrieval", input=query_text) as span:
query = RetrievalQuery(text=query_text, embedding=query_embedding, top_k=6)
results = multi_index.retrieve(query)
logger.info("Multi-Index Results (federated search with RRF fusion)")
for r in results:
source = r.document.metadata.get('retrieval_source', 'unknown')
extra = {}
if isinstance(r.document, LegalDocument):
extra = {"case": r.document.case_number, "court": r.document.court}
elif isinstance(r.document, ProductDocument):
extra = {"brand": r.document.brand, "sku": r.document.sku, "price": r.document.price}
elif isinstance(r.document, FinancialDocument):
extra = {"ticker": r.document.company_ticker, "report": r.document.document_type, "quarter": r.document.quarter}
elif isinstance(r.document, AIMLDocument):
extra = {"topic": r.document.topic, "category": r.document.category}
logger.info("Multi-index result", rank=r.rank, score=round(r.score, 4), source=source,
content_preview=r.document.content[:70], **extra)
span.set_output({
"results": len(results),
"top_score": round(results[0].score, 4) if results else 0,
"sources_searched": len(retrievers),
})
trace.set_output({
"query": query_text,
"indexes_searched": len(retrievers),
"results_returned": len(results),
"fusion_strategy": "rrf",
})
logger.info("Multi-index retrieval searches across multiple domains efficiently")
# Entry point: Initializes configuration and runs all 8 retrieval strategy examples.
def main():
"""Run all retrieval examples."""
logger.info("RETRIEVAL STRATEGIES EXAMPLES",
note="Demonstrating various retrieval patterns for RAG pipelines")
config = get_azure_search_config()
if not config:
logger.error("Cannot run examples without Azure Search configuration")
return
embedder = get_embedder()
if not embedder:
logger.error("Cannot run examples without embedder configuration")
return
vector_store = get_vector_store(config, "ai_ml_knowledge", AIMLDocument)
if not vector_store:
return
logger.info("")
t0 = time.time()
example_basic_retrievers(config, embedder)
metrics.histogram("retrieval.duration_ms", (time.time() - t0) * 1000)
metrics.increment("retrieval.examples_run")
logger.info("")
t0 = time.time()
example_mmr_retriever(vector_store, embedder)
metrics.histogram("retrieval.duration_ms", (time.time() - t0) * 1000)
metrics.increment("retrieval.examples_run")
logger.info("")
t0 = time.time()
example_parent_document_retriever(embedder)
metrics.histogram("retrieval.duration_ms", (time.time() - t0) * 1000)
metrics.increment("retrieval.examples_run")
logger.info("")
t0 = time.time()
example_ensemble_retriever(vector_store, embedder)
metrics.histogram("retrieval.duration_ms", (time.time() - t0) * 1000)
metrics.increment("retrieval.examples_run")
logger.info("")
t0 = time.time()
example_hierarchical_retriever(embedder)
metrics.histogram("retrieval.duration_ms", (time.time() - t0) * 1000)
metrics.increment("retrieval.examples_run")
logger.info("")
t0 = time.time()
example_graph_retriever(vector_store, embedder)
metrics.histogram("retrieval.duration_ms", (time.time() - t0) * 1000)
metrics.increment("retrieval.examples_run")
logger.info("")
t0 = time.time()
example_sql_retriever()
metrics.histogram("retrieval.duration_ms", (time.time() - t0) * 1000)
metrics.increment("retrieval.examples_run")
logger.info("")
t0 = time.time()
example_multi_index_retriever(config, embedder)
metrics.histogram("retrieval.duration_ms", (time.time() - t0) * 1000)
metrics.increment("retrieval.examples_run")
logger.info("")
logger.info("SUMMARY")
logger.info("Basic Retrievers",
description="VectorRetriever (semantic), KeywordRetriever (exact), HybridRetriever (balanced)",
document_types="Legal, Product, Financial")
logger.info("MMRRetriever",
description="Balances relevance with diversity, reduces redundancy",
parameter="lambda (0=diversity, 1=relevance)")
logger.info("ParentDocumentRetriever",
description="Search small chunks for precision, return full parents for context")
logger.info("EnsembleRetriever",
description="Combines multiple strategies with score fusion (RRF, weighted avg, max)")
logger.info("HierarchicalRetriever",
description="Two-stage: coarse summary search then fine chunk retrieval",
use_case="Large collections (>10K docs)")
logger.info("GraphRetriever",
description="Knowledge graph traversal combined with vector similarity")
logger.info("SQLRetriever",
description="Natural language to SQL, structured/tabular data")
logger.info("MultiIndexRetriever",
description="Federated multi-source search with RRF/weighted fusion")
logger.info("Production Tips",
tip_1="Use HybridRetriever as default",
tip_2="Add MMR when results seem redundant",
tip_3="Use ParentDocument for long-form content",
tip_4="Use Ensemble when single strategy isn't sufficient",
tip_5="Use Hierarchical for very large collections",
tip_6="Use Graph when relationships matter",
tip_7="Use SQL for structured/tabular data",
tip_8="Use MultiIndex for cross-domain search")
logger.info("Data Sources",
indexes=["legal_documents (LegalDocument)", "products (ProductDocument)",
"financial_reports (FinancialDocument)", "ai_ml_knowledge (AIMLDocument)"],
created_by="azure_ai_search_vector_store_example.py")
perf = metrics.get_metrics()
durations = perf["histograms"].get("retrieval.duration_ms", [])
logger.info("Performance Metrics",
examples_run=perf["counters"].get("retrieval.examples_run", 0),
total_duration_ms=round(sum(durations), 1),
avg_duration_ms=round(sum(durations) / max(len(durations), 1), 1))
if __name__ == "__main__":
main()
soap_connector_example.py
"""
Example: SOAP / WSDL Connector (SoapConnector)
Demonstrates how a developer subclasses ``SoapConnector`` from the library to
connect to a real SOAP web service (PolicyHub in this case) and produce
Document objects ready for the RAG pipeline.
Library / developer split shown here
--------------------------------------
LIBRARY (gmf_forge_ai_data.connectors.SoapConnector):
- Manages zeep Client lifecycle
- Injects credentials into every call automatically
- Serialises zeep objects → plain dicts
- Provides extract_text() for PDF / DOCX / plain text
- Provides _make_doc_id() for stable document IDs
- Has NO knowledge of PolicyHub field names or folder structure
DEVELOPER (PolicyHubConnector defined below in this file):
- Wraps individual API methods as one-liners
- Implements load() — traversal, field mapping, metadata shape
- Decides which fields go into Document.metadata
- Decides what text to extract and how
Required .env variables:
MITRATECH_API_USERNAME
MITRATECH_API_PASSWORD
MITRATECH_ADMIN_WSDL_URL
MITRATECH_MYLIBRARY_WSDL_URL (optional — for MyLibrary API examples)
Optional:
AZURE_OPENAI_ENDPOINT / AZURE_OPENAI_API_KEY / AZURE_OPENAI_EMBEDDING_MODEL
— When set, the loaded documents are also embedded and stored in an
in-memory vector store so you can run a similarity query at the end.
Pipeline flow demonstrated:
PolicyHubConnector.load()
→ RecursiveChunker
→ AzureOpenAIEmbeddings (if credentials available)
→ InMemoryVectorStore
→ VectorRetriever.retrieve()
Run:
cd packages/data-layer/examples
python soap_connector_example.py
"""
import os
import sys
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Iterator, List, Optional
from dotenv import load_dotenv
# ── Environment ───────────────────────────────────────────────────────────────
env_path = Path(__file__).parent / ".env"
load_dotenv(env_path)
WORKSPACE_ROOT = Path(__file__).parent.parent.parent.parent
CORPORATE_CERT = WORKSPACE_ROOT / "certs" / "gmf_and_public_cas.pem"
# ── Library imports ───────────────────────────────────────────────────────────
from gmf_forge_ai_shared_core.observability import BasicLogger
from gmf_forge_ai_shared_core.observability.tracing import get_tracer
from gmf_forge_ai_data.connectors import SoapConnector
from gmf_forge_ai_data.vector_stores import Document
logger = BasicLogger(__name__)
# =============================================================================
# DEVELOPER CODE — PolicyHub-specific subclass
# =============================================================================
# Everything below this line is developer code.
# The developer knows PolicyHub's field names (FolderId, LibraryPath,
# DocumentId, AuthorName, …) and decides what to put in Document.metadata.
# The library (SoapConnector) knows none of this.
# =============================================================================
#: Fully-qualified name of the Credentials complex type in PolicyHub's WSDL.
#: The library uses this to build a typed object and inject it automatically.
_POLICYHUB_CREDENTIALS_QNAME = (
"{http://schemas.datacontract.org/2004/07/"
"HitecLabs.PolicyHub.Api.Wcf.Security}Credentials"
)
class PolicyHubAdminConnector(SoapConnector):
"""
Developer connector for the PolicyHub Admin API.
Wraps every PolicyHub Admin API operation as a one-line method.
The ``load()`` method traverses the folder tree, retrieves document
content, and returns Document objects. Developers can customise
which fields end up in ``Document.metadata`` by editing ``load()``.
Usage::
connector = PolicyHubAdminConnector(
wsdl_url = os.getenv("MITRATECH_ADMIN_WSDL_URL"),
username = os.getenv("MITRATECH_API_USERNAME"),
password = os.getenv("MITRATECH_API_PASSWORD"),
ssl_verify = False, # set True in production with a valid cert
)
docs = connector.load()
# docs → List[Document] — pass to a chunker, embedder, vector store
"""
def __init__(self, folder_ids: Optional[List[str]] = None, **kwargs):
"""
Args:
folder_ids: Optional list of specific folder IDs to load.
When ``None`` (default) the entire library tree is
traversed starting from the root.
**kwargs: Forwarded to SoapConnector.__init__().
"""
super().__init__(
credentials_type_qname=_POLICYHUB_CREDENTIALS_QNAME,
**kwargs,
)
self.folder_ids = folder_ids # None → load all
# ── Admin API method wrappers (one line each) ─────────────────────────────
def get_library(self) -> Dict[str, Any]:
"""Return the full library tree from the root."""
return self.call("GetLibrary")
def get_documents_by_folder(self, folder_id: str) -> Dict[str, Any]:
"""Return all documents in a specific folder."""
return self.call("GetDocumentsByFolder", folderId=folder_id)
def get_document(self, document_id: str) -> Dict[str, Any]:
"""Return metadata for a single document."""
return self.call("GetDocument", documentId=document_id)
def get_document_data(self, document_id: str) -> Dict[str, Any]:
"""Return the file content (base64-encoded) of a document."""
return self.call("GetDocumentData", documentId=document_id)
def get_document_revisions(self, document_id: str) -> Dict[str, Any]:
"""Return all revisions of a document."""
return self.call("GetDocumentRevisions", documentId=document_id)
def get_document_revision_data(self, document_revision_id: str) -> Dict[str, Any]:
"""Return the file content of a specific document revision."""
return self.call("GetDocumentRevisionData", documentRevisionId=document_revision_id)
def get_compliance_status(self, user_id: str) -> Dict[str, Any]:
"""Return compliance status for a specific user."""
return self.call("GetComplianceStatus", userId=user_id)
def count_documents(self) -> Dict[str, int]:
"""
Traverse the full folder tree and count documents per folder without
fetching any file content. Returns a dict with totals:
{"folders": N, "folders_with_docs": N, "documents": N}
"""
library = self.get_library()
all_folders = list(self._traverse_folders(library))
total_docs = 0
folders_with_docs = 0
for folder in all_folders:
try:
result = self.get_documents_by_folder(folder["FolderId"])
except Exception:
continue
if isinstance(result, list):
count = len(result)
elif isinstance(result, dict):
raw = result.get("PolicyHubDocument")
count = len(raw) if isinstance(raw, list) else (1 if raw else 0)
else:
count = 0
if count:
folders_with_docs += 1
total_docs += count
return {
"folders": len(all_folders),
"folders_with_docs": folders_with_docs,
"documents": total_docs,
}
# ── load() — developer owns this entirely ────────────────────────────────
def load(self, max_docs: Optional[int] = None) -> List[Document]:
"""
Load documents from PolicyHub and return them as Document objects.
Developer decides:
- Which folders to traverse (``self.folder_ids`` or full tree)
- Which fields go into Document.metadata
- What the stable document ID is
- How to name the file for text extraction dispatch
Parameters
----------
max_docs:
Cap the total number of documents returned. Useful for sampling
during development. ``None`` (default) loads everything — use
this only from a production pipeline (e.g. basic-rag-app), not
from an example script.
Returns
-------
List[Document]
One Document per successfully retrieved file. Documents whose
content cannot be fetched or is empty are skipped with a warning.
"""
library = self.get_library()
all_folders = list(self._traverse_folders(library))
# Optionally filter to specific folder IDs
if self.folder_ids:
all_folders = [
f for f in all_folders if f["FolderId"] in self.folder_ids
]
documents: List[Document] = []
for folder in all_folders:
folder_id = folder["FolderId"]
folder_path = folder.get("LibraryPath", folder.get("Name", ""))
try:
result = self.get_documents_by_folder(folder_id)
except Exception as exc:
# The library root node and certain container folders are not
# queryable via GetDocumentsByFolder — skip them silently.
logger.debug(
"Skipping folder (not queryable for documents)",
folder_id=folder_id,
folder_path=folder_path,
error=str(exc),
)
continue
# zeep may serialize the response as a plain list (no wrapper
# element) or as a dict with a "PolicyHubDocument" key.
if isinstance(result, list):
raw_docs = result
elif isinstance(result, dict):
raw_docs = result.get("PolicyHubDocument")
else:
raw_docs = None
if not raw_docs:
continue
if not isinstance(raw_docs, list):
raw_docs = [raw_docs]
for doc in raw_docs:
if max_docs is not None and len(documents) >= max_docs:
break
doc_id = doc.get("DocumentId")
doc_title = doc.get("Name", "Untitled")
mime_type = ""
try:
data = self.get_document_data(doc_id)
raw_bytes = data.get("Data") if data else None
mime_type = (data.get("MimeType") or "") if data else ""
revision_id = str(data.get("RevisionId", "")) if data else ""
version = str(data.get("Version", "")) if data else ""
language = (data.get("Language") or "") if data else ""
locale = (data.get("Locale") or "") if data else ""
except Exception as exc:
logger.warning(
"Could not fetch document data",
document_id=doc_id,
title=doc_title,
error=str(exc),
)
continue
ext = _mime_to_ext(mime_type)
filename = f"{doc_title}{ext}"
content = self.extract_text(raw_bytes or b"", filename)
if not content.strip():
logger.warning(
"Skipping document with no extractable text",
document_id=doc_id,
title=doc_title,
)
continue
# ── Developer decides the metadata shape ──────────────────
# Build a document link from the base service URL
base_url = self.wsdl_url.split("/PolicyHubAPI/")[0]
document_link = f"{base_url}/PolicyHub/Document/{doc_id}"
documents.append(Document(
id = self._make_doc_id(str(doc_id)),
content = content.strip(),
timestamp = _parse_date(doc.get("LastModifiedDateUtc")),
metadata = {
# Required fields
"document_id": str(doc_id),
"document_name": doc_title,
"documentlink": document_link,
"language": language,
"locale": locale,
"revisionid": revision_id,
"source": self.wsdl_url,
"upload_date": str(doc.get("PublishedDateUtc", "")),
"version": version,
# Additional provenance
"folder_id": str(folder_id),
"folder_path": folder_path,
"description": doc.get("Description", ""),
"author": doc.get("AuthorName", ""),
"owner_id": str(doc.get("AuthorUserId", "")),
"reference": str(doc.get("Reference", "")),
"document_type": doc.get("DocumentType", ""),
"file_name": filename,
"mime_type": mime_type,
"extension": ext,
"created_at": str(doc.get("CreationDateUtc", "")),
"modified_at": str(doc.get("LastModifiedDateUtc", "")),
"in_review": str(doc.get("InReview", False)),
"workflow_status":doc.get("WorkflowProcessStatus", ""),
},
))
if max_docs is not None and len(documents) >= max_docs:
break
logger.info("PolicyHub load complete", total_documents=len(documents))
return documents
# ── Private helpers (developer) ───────────────────────────────────────────
def _traverse_folders(self, node: Dict[str, Any]) -> Iterator[Dict[str, Any]]:
"""Recursively yield every folder in the library tree."""
yield node
children = (node.get("Children") or {}).get("PolicyHubFolder", [])
if not isinstance(children, list):
children = [children]
for child in children:
if isinstance(child, dict):
yield from self._traverse_folders(child)
# =============================================================================
# DEVELOPER CODE — PolicyHub MyLibrary connector
# =============================================================================
# Identical structure to PolicyHubAdminConnector but targets the MyLibrary
# WSDL (MyLibraryApi.svc?wsdl), which exposes the full production library
# accessible to end-users — this is the 600+ document corpus.
#
# The MyLibrary API uses the same Credentials type and the same method
# signatures as the Admin API, so all field names are identical.
# =============================================================================
class PolicyHubMyLibraryConnector(SoapConnector):
"""
Developer connector for the PolicyHub MyLibrary API.
Targets ``MyLibraryApi.svc?wsdl`` which exposes the full production
policy library visible to authenticated end-users. Use this connector
(not the Admin connector) to ingest the 600+ real policy documents.
Usage::
connector = PolicyHubMyLibraryConnector(
wsdl_url = os.getenv("MITRATECH_MYLIBRARY_WSDL_URL"),
username = os.getenv("MITRATECH_API_USERNAME"),
password = os.getenv("MITRATECH_API_PASSWORD"),
ssl_verify = False,
)
docs = connector.load(max_docs=5) # remove cap in basic-rag-app
"""
def __init__(self, folder_ids: Optional[List[str]] = None, **kwargs):
super().__init__(
credentials_type_qname=_POLICYHUB_CREDENTIALS_QNAME,
**kwargs,
)
self.folder_ids = folder_ids
# ── MyLibrary API method wrappers (one line each) ─────────────────────────
def list_operations(self) -> List[str]:
"""Return all operation names exposed by this WSDL (useful for debugging)."""
self.connect()
try:
ops = self._client.service._proxy._binding._operations
return sorted(ops.keys())
except Exception:
return []
def get_library(self) -> Dict[str, Any]:
return self.call("GetMyLibrary")
def get_documents_by_folder(self, folder_id: str) -> Dict[str, Any]:
return self.call("GetDocumentsByFolder", folderId=folder_id)
def get_document(self, document_id: str) -> Dict[str, Any]:
return self.call("GetMyDocument", documentId=document_id)
def get_document_data(self, document_id: str) -> Dict[str, Any]:
return self.call("GetMyDocumentData", documentId=document_id)
def get_document_revisions(self, document_id: str) -> Dict[str, Any]:
return self.call("GetMyDocumentRevisions", documentId=document_id)
def get_document_revision_data(self, document_revision_id: str) -> Dict[str, Any]:
return self.call("GetMyDocumentRevisionData", documentRevisionId=document_revision_id)
# ── load() ────────────────────────────────────────────────────────────────
def load(self, max_docs: Optional[int] = None) -> List[Document]:
"""
Load documents from the MyLibrary API.
Parameters
----------
max_docs:
Cap total documents returned. ``None`` loads everything — use
only from a production pipeline (e.g. basic-rag-app).
"""
library = self.get_library()
all_folders = list(self._traverse_folders(library))
if self.folder_ids:
all_folders = [
f for f in all_folders if f["FolderId"] in self.folder_ids
]
documents: List[Document] = []
for folder in all_folders:
folder_id = folder["FolderId"]
folder_path = folder.get("LibraryPath", folder.get("Name", ""))
try:
result = self.get_documents_by_folder(folder_id)
except Exception as exc:
logger.debug(
"Skipping folder (not queryable for documents)",
folder_id=folder_id,
folder_path=folder_path,
error=str(exc),
)
continue
if isinstance(result, list):
raw_docs = result
elif isinstance(result, dict):
raw_docs = result.get("PolicyHubDocument")
else:
raw_docs = None
if not raw_docs:
continue
if not isinstance(raw_docs, list):
raw_docs = [raw_docs]
for doc in raw_docs:
if max_docs is not None and len(documents) >= max_docs:
break
doc_id = doc.get("DocumentId")
doc_title = doc.get("Name", "Untitled")
mime_type = ""
try:
data = self.get_document_data(doc_id)
raw_bytes = data.get("Data") if data else None
mime_type = (data.get("MimeType") or "") if data else ""
except Exception as exc:
logger.warning(
"Could not fetch document data",
document_id=doc_id,
title=doc_title,
error=str(exc),
)
continue
ext = _mime_to_ext(mime_type)
filename = f"{doc_title}{ext}"
content = self.extract_text(raw_bytes or b"", filename)
if not content.strip():
logger.warning(
"Skipping document with no extractable text",
document_id=doc_id,
title=doc_title,
)
continue
documents.append(Document(
id = self._make_doc_id(str(doc_id)),
content = content.strip(),
timestamp = _parse_date(doc.get("LastModifiedDateUtc")),
metadata = {
"source": self.wsdl_url,
"document_id": str(doc_id),
"folder_id": str(folder_id),
"folder_path": folder_path,
"title": doc_title,
"description": doc.get("Description", ""),
"author": doc.get("AuthorName", ""),
"owner_id": str(doc.get("AuthorUserId", "")),
"reference": str(doc.get("Reference", "")),
"document_type": doc.get("DocumentType", ""),
"file_name": filename,
"mime_type": mime_type,
"extension": ext,
"created_at": str(doc.get("CreationDateUtc", "")),
"modified_at": str(doc.get("LastModifiedDateUtc", "")),
"published_at": str(doc.get("PublishedDateUtc", "")),
"in_review": str(doc.get("InReview", False)),
"workflow_status": doc.get("WorkflowProcessStatus", ""),
},
))
if max_docs is not None and len(documents) >= max_docs:
break
logger.info("PolicyHub MyLibrary load complete", total_documents=len(documents))
return documents
def _traverse_folders(self, node: Dict[str, Any]) -> Iterator[Dict[str, Any]]:
yield node
children = (node.get("Children") or {}).get("PolicyHubFolder", [])
if not isinstance(children, list):
children = [children]
for child in children:
if isinstance(child, dict):
yield from self._traverse_folders(child)
# ── Helpers ────────────────────────────────────────────────────────────────────
def _mime_to_ext(mime_type: str) -> str:
"""Map a MIME type string to a file extension for extract_text dispatch."""
mime = (mime_type or "").lower()
if "pdf" in mime:
return ".pdf"
if "word" in mime or "document" in mime or "officedocument" in mime:
return ".docx"
if "text" in mime:
return ".txt"
return ".bin"
def _parse_date(value: Any) -> Optional[datetime]:
"""Parse a date from the SOAP response — may be str, datetime, or None."""
if value is None:
return None
if isinstance(value, datetime):
return value
try:
return datetime.fromisoformat(str(value))
except Exception:
return None
# =============================================================================
# Example runner
# =============================================================================
def main() -> None:
ssl_cert = str(CORPORATE_CERT) if CORPORATE_CERT.exists() else None
admin_wsdl = os.getenv("MITRATECH_ADMIN_WSDL_URL")
username = os.getenv("MITRATECH_API_USERNAME")
password = os.getenv("MITRATECH_API_PASSWORD")
if not all([admin_wsdl, username, password]):
logger.warning(
"SOAP credentials not configured — running dry demo",
missing="MITRATECH_ADMIN_WSDL_URL / MITRATECH_API_USERNAME / MITRATECH_API_PASSWORD",
)
_run_dry_demo()
return
tracer = get_tracer()
# ── EXAMPLE 1: load sample documents ─────────────────────────────────
logger.info("Example 1 — PolicyHub: sample first 5 documents (full ingestion lives in basic-rag-app)")
connector = PolicyHubAdminConnector(
wsdl_url = admin_wsdl,
username = username,
password = password,
ssl_verify = False,
ssl_cert_path = ssl_cert,
)
with tracer.trace("policyhub_soap_connector", input=admin_wsdl) as trace:
with connector:
with trace.span("count_documents", input=admin_wsdl) as span:
counts = connector.count_documents()
span.set_output(counts)
logger.info(
"PolicyHub library stats",
total_folders=counts["folders"],
folders_with_docs=counts["folders_with_docs"],
total_documents=counts["documents"],
)
with trace.span("load_sample", input={"max_docs": 5}) as span:
docs = connector.load(max_docs=5)
span.set_output({"loaded": len(docs)})
for i, doc in enumerate(docs, 1):
logger.info(
"Sampled document",
index=f"{i}/{len(docs)}",
id=doc.id,
document_name=doc.metadata.get("document_name"),
folder_path=doc.metadata.get("folder_path"),
document_id=doc.metadata.get("document_id"),
revisionid=doc.metadata.get("revisionid") or "none",
version=doc.metadata.get("version") or "none",
language=doc.metadata.get("language") or "none",
upload_date=doc.metadata.get("upload_date") or "none",
chars=len(doc.content),
)
# ── EXAMPLE 2: call a single method directly ──────────────────────
logger.info("Example 2 — call GetLibrary directly")
with PolicyHubAdminConnector(
wsdl_url = admin_wsdl,
username = username,
password = password,
ssl_verify = False,
) as conn:
with trace.span("get_library", input=admin_wsdl) as span:
library = conn.get_library()
span.set_output({
"name": library.get("Name"),
"folder_id": library.get("FolderId"),
"path": library.get("LibraryPath"),
})
logger.info(
"Library info",
name=library.get("Name"),
folder_id=library.get("FolderId"),
path=library.get("LibraryPath"),
)
trace.set_output({
"total_documents": counts["documents"],
"sample_loaded": len(docs),
})
# ── EXAMPLE 3: RAG pipeline ───────────────────────────────────────────
_maybe_run_rag_pipeline(docs, ssl_cert)
def _run_dry_demo() -> None:
"""Show the class API without needing a live service."""
logger.info("Dry demo — SoapConnector class API (no live service needed)")
logger.info("Developer subclass pattern",
hint="subclass SoapConnector, add one-line method wrappers, implement load()"
)
def _maybe_run_rag_pipeline(docs: List[Document], ssl_cert: Optional[str]) -> None:
"""If Azure OpenAI is configured, embed + store + retrieve the loaded docs."""
endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
api_key = os.getenv("AZURE_OPENAI_API_KEY")
deployment = os.getenv("AZURE_OPENAI_EMBEDDING_MODEL")
if not all([endpoint, api_key, deployment, docs]):
logger.warning("Azure OpenAI not configured — skipping RAG pipeline demo")
return
logger.info("Example 3 — embed + store + retrieve loaded documents")
from gmf_forge_ai_data.chunkers import RecursiveChunker
from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings, BatchEmbeddings
from gmf_forge_ai_data.vector_stores import InMemoryVectorStore, Document as VDoc
from gmf_forge_ai_data.retrieval import VectorRetriever, RetrievalQuery
tracer = get_tracer()
chunker = RecursiveChunker(chunk_size=512, chunk_overlap=50)
embedder = AzureOpenAIEmbeddings(
endpoint = endpoint,
api_key = api_key,
deployment_name = deployment,
ssl_cert_path = ssl_cert,
)
batch_embedder = BatchEmbeddings(provider=embedder, batch_size=50)
with tracer.trace("policyhub_rag_pipeline", input={"docs": len(docs)}) as trace:
with trace.span("chunking", input={"docs": len(docs)}) as span:
chunks = []
for doc in docs:
for chunk in chunker.chunk(doc.content, metadata=doc.metadata):
chunks.append(chunk)
span.set_output({"chunks": len(chunks)})
logger.info("Chunked documents", docs=len(docs), chunks=len(chunks))
with trace.span("embedding", input={"chunks": len(chunks)}) as span:
texts = [c.text for c in chunks]
embeddings = batch_embedder.embed_batch(texts)
dim = len(embeddings[0]) if embeddings else 1536
span.set_output({"dimension": dim, "embedded": len(embeddings)})
with trace.span("indexing", input={"chunks": len(chunks)}) as span:
store = InMemoryVectorStore(embedding_dimension=dim)
store.add_documents([
VDoc(
id = f"chunk_{i}",
content = chunks[i].text,
embedding = embeddings[i],
metadata = chunks[i].metadata,
)
for i in range(len(chunks))
], generate_embeddings=False)
span.set_output({"stored": store.count()})
logger.info("Indexed chunks", stored=store.count(), store="InMemoryVectorStore")
query = "What is the compliance policy?"
with trace.span("retrieval", input=query) as span:
query_embedding = embedder.embed_text(query)
retriever = VectorRetriever(vector_store=store)
results = retriever.retrieve(RetrievalQuery(text=query, embedding=query_embedding, top_k=3))
span.set_output({"results": len(results), "top_score": results[0].score if results else 0})
trace.set_output({"query": query, "results": len(results)})
logger.info("RAG retrieval complete", query=query, results=len(results))
for r in results:
preview = r.document.content[:120].replace("\n", " ")
logger.info("Result", rank=r.rank, score=round(r.score, 4), preview=preview)
if __name__ == "__main__":
main()
.env.example
# =============================================================================
# Azure OpenAI
# Used by: all examples (embeddings, indexing, retrieval, query, context,
# connectors, mongo_cosmos_vector_store)
# =============================================================================
AZURE_OPENAI_ENDPOINT="https://your-resource.openai.azure.com/"
AZURE_OPENAI_API_KEY="your-azure-openai-api-key"
AZURE_OPENAI_DEPLOYMENT="your-gpt-deployment-name"
AZURE_OPENAI_API_VERSION="2025-01-01-preview"
AZURE_OPENAI_EMBEDDING_MODEL="your-embedding-deployment-name"
AZURE_OPENAI_EMBEDDING_MODEL_VERSION="2023-05-15"
# =============================================================================
# Azure AI Search
# Used by: index_documents_example.py, retrieval_example.py, query_example.py,
# context_example.py, connectors_example.py
# =============================================================================
AZURE_AI_SEARCH_ENDPOINT="https://your-search-service.search.windows.net"
SEARCH_SERVICE_NAME="your-search-service"
AZURE_AI_SEARCH_API_KEY="your-azure-ai-search-api-key"
AZURE_AI_SEARCH_EMBEDDING_DIMENSION=1536
# =============================================================================
# Anthropic (optional — not used by default examples)
# =============================================================================
ANTHROPIC_API_KEY=your-anthropic-api-key
# =============================================================================
# Application Settings
# =============================================================================
LOG_LEVEL=INFO
ENVIRONMENT=development
# =============================================================================
# Managed Identity Authentication (optional)
# Set to true to use DefaultAzureCredential instead of API keys.
# Each service requires its own scope when building a token provider:
# Azure OpenAI / Cognitive Services : https://cognitiveservices.azure.com/.default
# Azure AI Search : https://search.azure.com/.default
# When false (default), API keys above are used.
# =============================================================================
AZURE_USE_MANAGED_IDENTITY=false
# =============================================================================
# SharePoint Connector (optional — connectors_example.py Example 2)
# Requires an Azure AD app registration with Graph API application permissions:
# Sites.Read.All Files.Read.All (admin-consented)
# =============================================================================
SHAREPOINT_TENANT_ID=your-tenant-id
SHAREPOINT_CLIENT_ID=your-sharepoint-app-client-id
SHAREPOINT_CLIENT_SECRET=your-sharepoint-app-client-secret
#SHAREPOINT_SITE_ID="your-tenant.sharepoint.com,site-id-guid,web-id-guid"
SHAREPOINT_SITE_ID="your-tenant.sharepoint.com,site-id-guid,web-id-guid"
SHAREPOINT_FOLDER_PATH=/
# =============================================================================
# Azure Blob Storage Connector (optional — connectors_example.py Example 3)
# =============================================================================
AZURE_STORAGE_ACCOUNT_NAME="your-storage-account-name"
AZURE_STORAGE_ACCESS_KEY="your-storage-account-access-key"
AZURE_BLOB_CONTAINER_NAME="your-blob-container-name"
AZURE_STORAGE_PREFIX=
# =============================================================================
# Azure Cosmos DB for MongoDB vCore (optional — mongo_cosmos_vector_store_example.py)
# =============================================================================
AZURE_COSMOS_ENDPOINT=https://your-cosmos-account.documents.azure.com:443/
AZURE_COSMOS_KEY=your-cosmos-db-key
AZURE_COSMOS_DATABASE=rag_db
# =============================================================================
# MongoDB Atlas Vector Search (optional — mongo_cosmos_vector_store_example.py)
# MONGODB_VECTOR_INDEX must match the Atlas Vector Search index name created
# via Atlas UI / CLI on the collection field "embedding"
# =============================================================================
MONGODB_CONNECTION_STRING=mongodb+srv://your-username:your-password@your-cluster.mongodb.net/
MONGODB_DATABASE=rag_db
MONGODB_COLLECTION=tech_articles
MONGODB_VECTOR_INDEX=vector_index
# =============================================================================
# SOAP / WSDL Connector (optional — soap_connector_example.py)
# =============================================================================
MITRATECH_API_USERNAME="your-api-username"
MITRATECH_API_PASSWORD="your-api-password"
MITRATECH_ADMIN_WSDL_URL=https://your-instance.policyhub.com/PolicyHubAPI/AdminApi.svc?wsdl
MITRATECH_MYLIBRARY_WSDL_URL=https://your-instance.policyhub.com/PolicyHubAPI/MyLibraryApi.svc?wsdl