gmf_forge_ai_data.embeddings
Embeddings module for the data-layer package.
This module provides embedding generation capabilities for RAG applications, with support for Azure OpenAI and efficient batch processing.
Example: Basic usage with Azure OpenAI:
>>> from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings
>>> embedder = AzureOpenAIEmbeddings(
... endpoint="https://your-resource.openai.azure.com",
... api_key="your-api-key",
... deployment_name="text-embedding-3-large"
... )
>>> vector = embedder.embed_text("Hello world")
Batch processing for large collections:
>>> from gmf_forge_ai_data.embeddings import BatchEmbeddings
>>> batch_embedder = BatchEmbeddings(provider=embedder, batch_size=100)
>>> vectors = batch_embedder.embed_batch(large_text_list)
1""" 2Embeddings module for the data-layer package. 3 4This module provides embedding generation capabilities for RAG applications, 5with support for Azure OpenAI and efficient batch processing. 6 7Example: 8 Basic usage with Azure OpenAI: 9 10 >>> from gmf_forge_ai_data.embeddings import AzureOpenAIEmbeddings 11 >>> embedder = AzureOpenAIEmbeddings( 12 ... endpoint="https://your-resource.openai.azure.com", 13 ... api_key="your-api-key", 14 ... deployment_name="text-embedding-3-large" 15 ... ) 16 >>> vector = embedder.embed_text("Hello world") 17 18 Batch processing for large collections: 19 20 >>> from gmf_forge_ai_data.embeddings import BatchEmbeddings 21 >>> batch_embedder = BatchEmbeddings(provider=embedder, batch_size=100) 22 >>> vectors = batch_embedder.embed_batch(large_text_list) 23""" 24 25from .base_embeddings import EmbeddingProvider 26from .azure_openai_embeddings import AzureOpenAIEmbeddings 27from .batch_embeddings import BatchEmbeddings 28 29__all__ = [ 30 "EmbeddingProvider", 31 "AzureOpenAIEmbeddings", 32 "BatchEmbeddings", 33]
13class EmbeddingProvider(ABC): 14 """ 15 Abstract base class for embedding providers. 16 17 All embedding implementations (Azure OpenAI, OpenAI, Cohere, etc.) should 18 inherit from this class and implement its abstract methods. 19 """ 20 21 @abstractmethod 22 def embed_text(self, text: str) -> List[float]: 23 """ 24 Generate embeddings for a single text string. 25 26 Args: 27 text: The input text to embed 28 29 Returns: 30 A list of floats representing the embedding vector 31 32 Raises: 33 ValueError: If text is empty or None 34 Exception: Provider-specific errors (rate limits, auth, etc.) 35 """ 36 pass 37 38 @abstractmethod 39 def embed_batch(self, texts: List[str]) -> List[List[float]]: 40 """ 41 Generate embeddings for multiple texts in a batch. 42 43 This method should be more efficient than calling embed_text() repeatedly 44 as it can leverage the provider's batch API capabilities. 45 46 Args: 47 texts: List of text strings to embed 48 49 Returns: 50 List of embedding vectors, one for each input text 51 52 Raises: 53 ValueError: If texts is empty or contains invalid entries 54 Exception: Provider-specific errors (rate limits, auth, etc.) 55 """ 56 pass 57 58 @abstractmethod 59 def get_embedding_dimension(self) -> int: 60 """ 61 Get the dimensionality of the embedding vectors. 62 63 Returns: 64 Integer dimension of the embedding space 65 (e.g., 1536 for text-embedding-3-large) 66 """ 67 pass 68 69 @abstractmethod 70 def get_model_name(self) -> str: 71 """ 72 Get the name/identifier of the embedding model being used. 73 74 Returns: 75 String identifier for the model (e.g., "text-embedding-3-large") 76 """ 77 pass 78 79 def validate_text(self, text: str) -> None: 80 """ 81 Validate input text before embedding. 82 83 Args: 84 text: The text to validate 85 86 Raises: 87 ValueError: If text is None, empty, or not a string 88 """ 89 if text is None: 90 raise ValueError("Text cannot be None") 91 if not isinstance(text, str): 92 raise ValueError(f"Text must be a string, got {type(text)}") 93 if not text.strip(): 94 raise ValueError("Text cannot be empty or whitespace only") 95 96 def validate_batch(self, texts: List[str]) -> None: 97 """ 98 Validate a batch of texts before embedding. 99 100 Args: 101 texts: List of texts to validate 102 103 Raises: 104 ValueError: If texts is invalid or contains invalid entries 105 """ 106 if texts is None: 107 raise ValueError("Texts list cannot be None") 108 if not isinstance(texts, list): 109 raise ValueError(f"Texts must be a list, got {type(texts)}") 110 if len(texts) == 0: 111 raise ValueError("Texts list cannot be empty") 112 113 for i, text in enumerate(texts): 114 try: 115 self.validate_text(text) 116 except ValueError as e: 117 raise ValueError(f"Invalid text at index {i}: {str(e)}")
Abstract base class for embedding providers.
All embedding implementations (Azure OpenAI, OpenAI, Cohere, etc.) should inherit from this class and implement its abstract methods.
21 @abstractmethod 22 def embed_text(self, text: str) -> List[float]: 23 """ 24 Generate embeddings for a single text string. 25 26 Args: 27 text: The input text to embed 28 29 Returns: 30 A list of floats representing the embedding vector 31 32 Raises: 33 ValueError: If text is empty or None 34 Exception: Provider-specific errors (rate limits, auth, etc.) 35 """ 36 pass
Generate embeddings for a single text string.
Args: text: The input text to embed
Returns: A list of floats representing the embedding vector
Raises: ValueError: If text is empty or None Exception: Provider-specific errors (rate limits, auth, etc.)
38 @abstractmethod 39 def embed_batch(self, texts: List[str]) -> List[List[float]]: 40 """ 41 Generate embeddings for multiple texts in a batch. 42 43 This method should be more efficient than calling embed_text() repeatedly 44 as it can leverage the provider's batch API capabilities. 45 46 Args: 47 texts: List of text strings to embed 48 49 Returns: 50 List of embedding vectors, one for each input text 51 52 Raises: 53 ValueError: If texts is empty or contains invalid entries 54 Exception: Provider-specific errors (rate limits, auth, etc.) 55 """ 56 pass
Generate embeddings for multiple texts in a batch.
This method should be more efficient than calling embed_text() repeatedly as it can leverage the provider's batch API capabilities.
Args: texts: List of text strings to embed
Returns: List of embedding vectors, one for each input text
Raises: ValueError: If texts is empty or contains invalid entries Exception: Provider-specific errors (rate limits, auth, etc.)
58 @abstractmethod 59 def get_embedding_dimension(self) -> int: 60 """ 61 Get the dimensionality of the embedding vectors. 62 63 Returns: 64 Integer dimension of the embedding space 65 (e.g., 1536 for text-embedding-3-large) 66 """ 67 pass
Get the dimensionality of the embedding vectors.
Returns: Integer dimension of the embedding space (e.g., 1536 for text-embedding-3-large)
69 @abstractmethod 70 def get_model_name(self) -> str: 71 """ 72 Get the name/identifier of the embedding model being used. 73 74 Returns: 75 String identifier for the model (e.g., "text-embedding-3-large") 76 """ 77 pass
Get the name/identifier of the embedding model being used.
Returns: String identifier for the model (e.g., "text-embedding-3-large")
79 def validate_text(self, text: str) -> None: 80 """ 81 Validate input text before embedding. 82 83 Args: 84 text: The text to validate 85 86 Raises: 87 ValueError: If text is None, empty, or not a string 88 """ 89 if text is None: 90 raise ValueError("Text cannot be None") 91 if not isinstance(text, str): 92 raise ValueError(f"Text must be a string, got {type(text)}") 93 if not text.strip(): 94 raise ValueError("Text cannot be empty or whitespace only")
Validate input text before embedding.
Args: text: The text to validate
Raises: ValueError: If text is None, empty, or not a string
96 def validate_batch(self, texts: List[str]) -> None: 97 """ 98 Validate a batch of texts before embedding. 99 100 Args: 101 texts: List of texts to validate 102 103 Raises: 104 ValueError: If texts is invalid or contains invalid entries 105 """ 106 if texts is None: 107 raise ValueError("Texts list cannot be None") 108 if not isinstance(texts, list): 109 raise ValueError(f"Texts must be a list, got {type(texts)}") 110 if len(texts) == 0: 111 raise ValueError("Texts list cannot be empty") 112 113 for i, text in enumerate(texts): 114 try: 115 self.validate_text(text) 116 except ValueError as e: 117 raise ValueError(f"Invalid text at index {i}: {str(e)}")
Validate a batch of texts before embedding.
Args: texts: List of texts to validate
Raises: ValueError: If texts is invalid or contains invalid entries
32class AzureOpenAIEmbeddings(EmbeddingProvider): 33 """ 34 Azure OpenAI embedding provider using text-embedding-3-large. 35 36 This implementation provides: 37 - SSL certificate handling for internal Azure endpoints 38 - Automatic retry logic with exponential backoff 39 - Optional logging via shared-core BasicLogger 40 - Simple developer-friendly API 41 42 Example (API key): 43 >>> embedder = AzureOpenAIEmbeddings( 44 ... endpoint="https://your-resource.openai.azure.com", 45 ... api_key="your-api-key", 46 ... deployment_name="text-embedding-3-large" 47 ... ) 48 >>> vector = embedder.embed_text("Hello world") 49 >>> vectors = embedder.embed_batch(["Text 1", "Text 2"]) 50 51 Example (managed identity / token provider): 52 >>> from azure.identity import DefaultAzureCredential, get_bearer_token_provider 53 >>> token_provider = get_bearer_token_provider( 54 ... DefaultAzureCredential(), 55 ... "https://cognitiveservices.azure.com/.default" # Cognitive Services scope 56 ... ) 57 >>> embedder = AzureOpenAIEmbeddings( 58 ... endpoint="https://your-resource.openai.azure.com", 59 ... deployment_name="text-embedding-3-large", 60 ... token_provider=token_provider 61 ... ) 62 """ 63 64 def __init__( 65 self, 66 endpoint: str, 67 deployment_name: str, 68 api_key: Optional[str] = None, 69 token_provider: Optional[Callable[[], str]] = None, 70 model: str = "text-embedding-3-large", 71 api_version: str = "2024-02-01", 72 max_retries: int = 3, 73 retry_delay: float = 1.0, 74 ssl_cert_path: Optional[str] = None, 75 logger: Optional['BasicLogger'] = None, 76 ): 77 """ 78 Initialize Azure OpenAI embeddings provider. 79 80 Exactly one of ``api_key`` or ``token_provider`` must be supplied. 81 82 Args: 83 endpoint: Azure OpenAI endpoint URL 84 deployment_name: Name of the deployment in Azure 85 api_key: Azure OpenAI API key. Use for local development or 86 when managed identity is not available. 87 token_provider: Zero-argument callable that returns a bearer token 88 string. Use for managed identity / workload identity scenarios. 89 The callable must request the **Cognitive Services** scope:: 90 91 from azure.identity import DefaultAzureCredential, get_bearer_token_provider 92 token_provider = get_bearer_token_provider( 93 DefaultAzureCredential(), 94 "https://cognitiveservices.azure.com/.default" 95 ) 96 97 Note: this scope is different from Azure AI Search 98 (``https://search.azure.com/.default``) — each service 99 requires its own token_provider. 100 model: Model identifier (default: text-embedding-3-large) 101 api_version: Azure OpenAI API version (default: 2024-02-01) 102 max_retries: Maximum number of retry attempts (default: 3) 103 retry_delay: Initial delay between retries in seconds (default: 1.0) 104 ssl_cert_path: Optional path to SSL certificate bundle for corporate networks. 105 If None, uses default SSL verification. 106 logger: Optional BasicLogger instance for observability 107 """ 108 if not api_key and not token_provider: 109 raise ValueError( 110 "Either api_key or token_provider must be supplied to AzureOpenAIEmbeddings." 111 ) 112 self.endpoint = endpoint 113 self.deployment_name = deployment_name 114 self.model = model 115 self.max_retries = max_retries 116 self.retry_delay = retry_delay 117 self.logger = logger 118 self.ssl_cert_path = ssl_cert_path 119 120 # Log initialization details 121 if self.logger: 122 self.logger.info( 123 "Initializing AzureOpenAIEmbeddings", 124 endpoint=endpoint, 125 deployment=deployment_name, 126 model=model, 127 ssl_cert_provided=ssl_cert_path is not None 128 ) 129 130 # Handle SSL certificate if provided (matches shared-core pattern) 131 http_client = None 132 if ssl_cert_path and os.path.exists(ssl_cert_path): 133 try: 134 # Use httpx Client with verify parameter (same as shared-core AzureOpenAIProvider) 135 http_client = httpx.Client(verify=str(ssl_cert_path), timeout=30.0) 136 137 if self.logger: 138 self.logger.info(f"✓ Using SSL certificate: {ssl_cert_path}") 139 except Exception as e: 140 if self.logger: 141 self.logger.warning(f"Failed to configure SSL certificate: {e}") 142 elif ssl_cert_path: 143 if self.logger: 144 self.logger.warning(f"SSL certificate path provided but file not found: {ssl_cert_path}") 145 else: 146 if self.logger: 147 self.logger.debug("No SSL certificate provided, using default SSL verification") 148 149 # Initialize Azure OpenAI client 150 if token_provider: 151 self.client = AzureOpenAI( 152 azure_ad_token_provider=token_provider, 153 api_version=api_version, 154 azure_endpoint=endpoint, 155 http_client=http_client, 156 ) 157 else: 158 self.client = AzureOpenAI( 159 api_key=api_key, 160 api_version=api_version, 161 azure_endpoint=endpoint, 162 http_client=http_client, 163 ) 164 165 # Dimension is detected lazily from the first API response 166 self._dimension: Optional[int] = None 167 168 if self.logger: 169 self.logger.info( 170 f"✓ Initialized AzureOpenAIEmbeddings with model={model}, " 171 f"deployment={deployment_name}" 172 ) 173 174 def _call_with_retry(self, func, *args, **kwargs): 175 """ 176 Execute a function with exponential backoff retry logic. 177 178 Args: 179 func: Function to call 180 *args: Positional arguments for the function 181 **kwargs: Keyword arguments for the function 182 183 Returns: 184 Result of the function call 185 186 Raises: 187 Exception: If all retries are exhausted 188 """ 189 last_exception = None 190 delay = self.retry_delay 191 192 for attempt in range(self.max_retries): 193 try: 194 return func(*args, **kwargs) 195 except Exception as e: 196 last_exception = e 197 error_type = type(e).__name__ 198 error_msg = str(e) 199 200 if attempt < self.max_retries - 1: 201 if self.logger: 202 self.logger.warning( 203 f"Attempt {attempt + 1} failed", 204 error_type=error_type, 205 error_message=error_msg, 206 retry_in=f"{delay}s" 207 ) 208 time.sleep(delay) 209 delay *= 2 # Exponential backoff 210 else: 211 if self.logger: 212 self.logger.error( 213 f"All {self.max_retries} attempts failed", 214 error_type=error_type, 215 error_message=error_msg 216 ) 217 218 raise last_exception 219 220 def embed_text(self, text: str) -> List[float]: 221 """ 222 Generate embeddings for a single text string. 223 224 Args: 225 text: The input text to embed 226 227 Returns: 228 A list of floats representing the embedding vector 229 230 Raises: 231 ValueError: If text is invalid 232 Exception: Azure OpenAI API errors 233 """ 234 self.validate_text(text) 235 236 if self.logger: 237 self.logger.debug(f"Embedding single text (length={len(text)})") 238 239 def _embed(): 240 response: CreateEmbeddingResponse = self.client.embeddings.create( 241 input=text, 242 model=self.deployment_name, 243 ) 244 return response.data[0].embedding 245 246 embedding = self._call_with_retry(_embed) 247 248 if self._dimension is None: 249 self._dimension = len(embedding) 250 251 if self.logger: 252 self.logger.debug(f"Generated embedding with dimension={len(embedding)}") 253 254 return embedding 255 256 def embed_batch(self, texts: List[str]) -> List[List[float]]: 257 """ 258 Generate embeddings for multiple texts in a batch. 259 260 Azure OpenAI supports batch requests which are more efficient than 261 individual calls. This method automatically handles the batch API. 262 263 Args: 264 texts: List of text strings to embed 265 266 Returns: 267 List of embedding vectors, one for each input text 268 269 Raises: 270 ValueError: If texts is invalid 271 Exception: Azure OpenAI API errors 272 """ 273 self.validate_batch(texts) 274 275 if self.logger: 276 self.logger.info(f"Embedding batch of {len(texts)} texts") 277 278 def _embed(): 279 response: CreateEmbeddingResponse = self.client.embeddings.create( 280 input=texts, 281 model=self.deployment_name, 282 ) 283 # Ensure results are in the correct order 284 sorted_data = sorted(response.data, key=lambda x: x.index) 285 return [item.embedding for item in sorted_data] 286 287 embeddings = self._call_with_retry(_embed) 288 289 if self._dimension is None and embeddings: 290 self._dimension = len(embeddings[0]) 291 292 if self.logger: 293 self.logger.info( 294 f"Generated {len(embeddings)} embeddings successfully" 295 ) 296 297 return embeddings 298 299 def get_embedding_dimension(self) -> Optional[int]: 300 """ 301 Get the dimensionality of the embedding vectors. 302 303 Returns the actual dimension observed from the first API response. 304 Returns None if no embedding has been generated yet. 305 """ 306 return self._dimension 307 308 def get_model_name(self) -> str: 309 """ 310 Get the name of the embedding model. 311 312 Returns: 313 The model identifier 314 """ 315 return self.model
Azure OpenAI embedding provider using text-embedding-3-large.
This implementation provides:
- SSL certificate handling for internal Azure endpoints
- Automatic retry logic with exponential backoff
- Optional logging via shared-core BasicLogger
- Simple developer-friendly API
Example (API key):
embedder = AzureOpenAIEmbeddings( ... endpoint="https://your-resource.openai.azure.com", ... api_key="your-api-key", ... deployment_name="text-embedding-3-large" ... ) vector = embedder.embed_text("Hello world") vectors = embedder.embed_batch(["Text 1", "Text 2"])
Example (managed identity / token provider):
from azure.identity import DefaultAzureCredential, get_bearer_token_provider token_provider = get_bearer_token_provider( ... DefaultAzureCredential(), ... "https://cognitiveservices.azure.com/.default" # Cognitive Services scope ... ) embedder = AzureOpenAIEmbeddings( ... endpoint="https://your-resource.openai.azure.com", ... deployment_name="text-embedding-3-large", ... token_provider=token_provider ... )
64 def __init__( 65 self, 66 endpoint: str, 67 deployment_name: str, 68 api_key: Optional[str] = None, 69 token_provider: Optional[Callable[[], str]] = None, 70 model: str = "text-embedding-3-large", 71 api_version: str = "2024-02-01", 72 max_retries: int = 3, 73 retry_delay: float = 1.0, 74 ssl_cert_path: Optional[str] = None, 75 logger: Optional['BasicLogger'] = None, 76 ): 77 """ 78 Initialize Azure OpenAI embeddings provider. 79 80 Exactly one of ``api_key`` or ``token_provider`` must be supplied. 81 82 Args: 83 endpoint: Azure OpenAI endpoint URL 84 deployment_name: Name of the deployment in Azure 85 api_key: Azure OpenAI API key. Use for local development or 86 when managed identity is not available. 87 token_provider: Zero-argument callable that returns a bearer token 88 string. Use for managed identity / workload identity scenarios. 89 The callable must request the **Cognitive Services** scope:: 90 91 from azure.identity import DefaultAzureCredential, get_bearer_token_provider 92 token_provider = get_bearer_token_provider( 93 DefaultAzureCredential(), 94 "https://cognitiveservices.azure.com/.default" 95 ) 96 97 Note: this scope is different from Azure AI Search 98 (``https://search.azure.com/.default``) — each service 99 requires its own token_provider. 100 model: Model identifier (default: text-embedding-3-large) 101 api_version: Azure OpenAI API version (default: 2024-02-01) 102 max_retries: Maximum number of retry attempts (default: 3) 103 retry_delay: Initial delay between retries in seconds (default: 1.0) 104 ssl_cert_path: Optional path to SSL certificate bundle for corporate networks. 105 If None, uses default SSL verification. 106 logger: Optional BasicLogger instance for observability 107 """ 108 if not api_key and not token_provider: 109 raise ValueError( 110 "Either api_key or token_provider must be supplied to AzureOpenAIEmbeddings." 111 ) 112 self.endpoint = endpoint 113 self.deployment_name = deployment_name 114 self.model = model 115 self.max_retries = max_retries 116 self.retry_delay = retry_delay 117 self.logger = logger 118 self.ssl_cert_path = ssl_cert_path 119 120 # Log initialization details 121 if self.logger: 122 self.logger.info( 123 "Initializing AzureOpenAIEmbeddings", 124 endpoint=endpoint, 125 deployment=deployment_name, 126 model=model, 127 ssl_cert_provided=ssl_cert_path is not None 128 ) 129 130 # Handle SSL certificate if provided (matches shared-core pattern) 131 http_client = None 132 if ssl_cert_path and os.path.exists(ssl_cert_path): 133 try: 134 # Use httpx Client with verify parameter (same as shared-core AzureOpenAIProvider) 135 http_client = httpx.Client(verify=str(ssl_cert_path), timeout=30.0) 136 137 if self.logger: 138 self.logger.info(f"✓ Using SSL certificate: {ssl_cert_path}") 139 except Exception as e: 140 if self.logger: 141 self.logger.warning(f"Failed to configure SSL certificate: {e}") 142 elif ssl_cert_path: 143 if self.logger: 144 self.logger.warning(f"SSL certificate path provided but file not found: {ssl_cert_path}") 145 else: 146 if self.logger: 147 self.logger.debug("No SSL certificate provided, using default SSL verification") 148 149 # Initialize Azure OpenAI client 150 if token_provider: 151 self.client = AzureOpenAI( 152 azure_ad_token_provider=token_provider, 153 api_version=api_version, 154 azure_endpoint=endpoint, 155 http_client=http_client, 156 ) 157 else: 158 self.client = AzureOpenAI( 159 api_key=api_key, 160 api_version=api_version, 161 azure_endpoint=endpoint, 162 http_client=http_client, 163 ) 164 165 # Dimension is detected lazily from the first API response 166 self._dimension: Optional[int] = None 167 168 if self.logger: 169 self.logger.info( 170 f"✓ Initialized AzureOpenAIEmbeddings with model={model}, " 171 f"deployment={deployment_name}" 172 )
Initialize Azure OpenAI embeddings provider.
Exactly one of api_key or token_provider must be supplied.
Args: endpoint: Azure OpenAI endpoint URL deployment_name: Name of the deployment in Azure api_key: Azure OpenAI API key. Use for local development or when managed identity is not available. token_provider: Zero-argument callable that returns a bearer token string. Use for managed identity / workload identity scenarios. The callable must request the Cognitive Services scope::
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
token_provider = get_bearer_token_provider(
DefaultAzureCredential(),
"https://cognitiveservices.azure.com/.default"
)
Note: this scope is different from Azure AI Search
(``https://search.azure.com/.default``) — each service
requires its own token_provider.
model: Model identifier (default: text-embedding-3-large)
api_version: Azure OpenAI API version (default: 2024-02-01)
max_retries: Maximum number of retry attempts (default: 3)
retry_delay: Initial delay between retries in seconds (default: 1.0)
ssl_cert_path: Optional path to SSL certificate bundle for corporate networks.
If None, uses default SSL verification.
logger: Optional BasicLogger instance for observability
220 def embed_text(self, text: str) -> List[float]: 221 """ 222 Generate embeddings for a single text string. 223 224 Args: 225 text: The input text to embed 226 227 Returns: 228 A list of floats representing the embedding vector 229 230 Raises: 231 ValueError: If text is invalid 232 Exception: Azure OpenAI API errors 233 """ 234 self.validate_text(text) 235 236 if self.logger: 237 self.logger.debug(f"Embedding single text (length={len(text)})") 238 239 def _embed(): 240 response: CreateEmbeddingResponse = self.client.embeddings.create( 241 input=text, 242 model=self.deployment_name, 243 ) 244 return response.data[0].embedding 245 246 embedding = self._call_with_retry(_embed) 247 248 if self._dimension is None: 249 self._dimension = len(embedding) 250 251 if self.logger: 252 self.logger.debug(f"Generated embedding with dimension={len(embedding)}") 253 254 return embedding
Generate embeddings for a single text string.
Args: text: The input text to embed
Returns: A list of floats representing the embedding vector
Raises: ValueError: If text is invalid Exception: Azure OpenAI API errors
256 def embed_batch(self, texts: List[str]) -> List[List[float]]: 257 """ 258 Generate embeddings for multiple texts in a batch. 259 260 Azure OpenAI supports batch requests which are more efficient than 261 individual calls. This method automatically handles the batch API. 262 263 Args: 264 texts: List of text strings to embed 265 266 Returns: 267 List of embedding vectors, one for each input text 268 269 Raises: 270 ValueError: If texts is invalid 271 Exception: Azure OpenAI API errors 272 """ 273 self.validate_batch(texts) 274 275 if self.logger: 276 self.logger.info(f"Embedding batch of {len(texts)} texts") 277 278 def _embed(): 279 response: CreateEmbeddingResponse = self.client.embeddings.create( 280 input=texts, 281 model=self.deployment_name, 282 ) 283 # Ensure results are in the correct order 284 sorted_data = sorted(response.data, key=lambda x: x.index) 285 return [item.embedding for item in sorted_data] 286 287 embeddings = self._call_with_retry(_embed) 288 289 if self._dimension is None and embeddings: 290 self._dimension = len(embeddings[0]) 291 292 if self.logger: 293 self.logger.info( 294 f"Generated {len(embeddings)} embeddings successfully" 295 ) 296 297 return embeddings
Generate embeddings for multiple texts in a batch.
Azure OpenAI supports batch requests which are more efficient than individual calls. This method automatically handles the batch API.
Args: texts: List of text strings to embed
Returns: List of embedding vectors, one for each input text
Raises: ValueError: If texts is invalid Exception: Azure OpenAI API errors
299 def get_embedding_dimension(self) -> Optional[int]: 300 """ 301 Get the dimensionality of the embedding vectors. 302 303 Returns the actual dimension observed from the first API response. 304 Returns None if no embedding has been generated yet. 305 """ 306 return self._dimension
Get the dimensionality of the embedding vectors.
Returns the actual dimension observed from the first API response. Returns None if no embedding has been generated yet.
25class BatchEmbeddings(EmbeddingProvider): 26 """ 27 Wrapper for efficient batch processing of embeddings. 28 29 This class wraps any EmbeddingProvider and adds: 30 - Automatic batching with configurable batch size 31 - Progress tracking for large document collections 32 - Retry logic per batch 33 - Memory-efficient processing 34 35 Example: 36 >>> base_embedder = AzureOpenAIEmbeddings(...) 37 >>> batch_embedder = BatchEmbeddings( 38 ... provider=base_embedder, 39 ... batch_size=100, 40 ... show_progress=True 41 ... ) 42 >>> # Process 10,000 documents efficiently 43 >>> embeddings = batch_embedder.embed_batch(large_text_list) 44 """ 45 46 def __init__( 47 self, 48 provider: EmbeddingProvider, 49 batch_size: int = 100, 50 show_progress: bool = True, 51 progress_callback: Optional[Callable[[int, int], None]] = None, 52 logger: Optional['BasicLogger'] = None, 53 ): 54 """ 55 Initialize batch embedding wrapper. 56 57 Args: 58 provider: The underlying EmbeddingProvider to wrap 59 batch_size: Number of texts to process per batch (default: 100) 60 show_progress: Whether to print progress messages (default: True) 61 progress_callback: Optional callback function(current, total) for 62 custom progress tracking 63 logger: Optional BasicLogger instance for observability 64 """ 65 self.provider = provider 66 self.batch_size = batch_size 67 self.show_progress = show_progress 68 self.progress_callback = progress_callback 69 self.logger = logger 70 71 if self.logger: 72 self.logger.info( 73 f"Initialized BatchEmbeddings wrapper with batch_size={batch_size}" 74 ) 75 76 def embed_text(self, text: str) -> List[float]: 77 """ 78 Generate embeddings for a single text (delegates to wrapped provider). 79 80 Args: 81 text: The input text to embed 82 83 Returns: 84 A list of floats representing the embedding vector 85 """ 86 return self.provider.embed_text(text) 87 88 def embed_batch(self, texts: List[str]) -> List[List[float]]: 89 """ 90 Generate embeddings for multiple texts with automatic batching. 91 92 This method splits large text lists into smaller batches and processes 93 them sequentially to avoid hitting provider limits and manage memory. 94 95 Args: 96 texts: List of text strings to embed 97 98 Returns: 99 List of embedding vectors, one for each input text 100 101 Raises: 102 ValueError: If texts is invalid 103 Exception: Provider-specific errors 104 """ 105 self.validate_batch(texts) 106 107 total_texts = len(texts) 108 109 if self.logger: 110 self.logger.info( 111 f"Processing {total_texts} texts in batches of {self.batch_size}" 112 ) 113 114 # If batch size is larger than input, just use the provider directly 115 if total_texts <= self.batch_size: 116 if self.logger: 117 self.logger.debug("Batch size larger than input, processing in one call") 118 return self.provider.embed_batch(texts) 119 120 # Process in batches 121 all_embeddings = [] 122 num_batches = (total_texts + self.batch_size - 1) // self.batch_size 123 124 for batch_idx in range(num_batches): 125 start_idx = batch_idx * self.batch_size 126 end_idx = min(start_idx + self.batch_size, total_texts) 127 batch_texts = texts[start_idx:end_idx] 128 129 if self.show_progress and self.logger: 130 progress = (batch_idx + 1) / num_batches * 100 131 self.logger.info( 132 f"Processing batch {batch_idx + 1}/{num_batches}", 133 progress_pct=round(progress, 1), 134 texts_done=end_idx, 135 texts_total=total_texts, 136 ) 137 138 if self.progress_callback: 139 self.progress_callback(end_idx, total_texts) 140 141 if self.logger: 142 self.logger.debug( 143 f"Processing batch {batch_idx + 1}/{num_batches}: " 144 f"texts[{start_idx}:{end_idx}]" 145 ) 146 147 # Embed the batch 148 batch_embeddings = self.provider.embed_batch(batch_texts) 149 all_embeddings.extend(batch_embeddings) 150 151 if self.show_progress and self.logger: 152 self.logger.info(f"Completed embedding {total_texts} texts") 153 154 if self.logger: 155 self.logger.info( 156 f"Successfully processed all {total_texts} texts in {num_batches} batches" 157 ) 158 159 return all_embeddings 160 161 def get_embedding_dimension(self) -> int: 162 """ 163 Get the dimensionality of the embedding vectors. 164 165 Returns: 166 The dimension from the wrapped provider 167 """ 168 return self.provider.get_embedding_dimension() 169 170 def get_model_name(self) -> str: 171 """ 172 Get the name of the embedding model. 173 174 Returns: 175 The model name from the wrapped provider 176 """ 177 return self.provider.get_model_name() 178 179 def set_batch_size(self, batch_size: int) -> None: 180 """ 181 Update the batch size for processing. 182 183 Args: 184 batch_size: New batch size to use 185 186 Raises: 187 ValueError: If batch_size is not positive 188 """ 189 if batch_size <= 0: 190 raise ValueError(f"Batch size must be positive, got {batch_size}") 191 192 self.batch_size = batch_size 193 194 if self.logger: 195 self.logger.info(f"Updated batch size to {batch_size}") 196 197 def get_batch_size(self) -> int: 198 """ 199 Get the current batch size. 200 201 Returns: 202 Current batch size setting 203 """ 204 return self.batch_size
Wrapper for efficient batch processing of embeddings.
This class wraps any EmbeddingProvider and adds:
- Automatic batching with configurable batch size
- Progress tracking for large document collections
- Retry logic per batch
- Memory-efficient processing
Example:
base_embedder = AzureOpenAIEmbeddings(...) batch_embedder = BatchEmbeddings( ... provider=base_embedder, ... batch_size=100, ... show_progress=True ... )
Process 10,000 documents efficiently
embeddings = batch_embedder.embed_batch(large_text_list)
46 def __init__( 47 self, 48 provider: EmbeddingProvider, 49 batch_size: int = 100, 50 show_progress: bool = True, 51 progress_callback: Optional[Callable[[int, int], None]] = None, 52 logger: Optional['BasicLogger'] = None, 53 ): 54 """ 55 Initialize batch embedding wrapper. 56 57 Args: 58 provider: The underlying EmbeddingProvider to wrap 59 batch_size: Number of texts to process per batch (default: 100) 60 show_progress: Whether to print progress messages (default: True) 61 progress_callback: Optional callback function(current, total) for 62 custom progress tracking 63 logger: Optional BasicLogger instance for observability 64 """ 65 self.provider = provider 66 self.batch_size = batch_size 67 self.show_progress = show_progress 68 self.progress_callback = progress_callback 69 self.logger = logger 70 71 if self.logger: 72 self.logger.info( 73 f"Initialized BatchEmbeddings wrapper with batch_size={batch_size}" 74 )
Initialize batch embedding wrapper.
Args: provider: The underlying EmbeddingProvider to wrap batch_size: Number of texts to process per batch (default: 100) show_progress: Whether to print progress messages (default: True) progress_callback: Optional callback function(current, total) for custom progress tracking logger: Optional BasicLogger instance for observability
76 def embed_text(self, text: str) -> List[float]: 77 """ 78 Generate embeddings for a single text (delegates to wrapped provider). 79 80 Args: 81 text: The input text to embed 82 83 Returns: 84 A list of floats representing the embedding vector 85 """ 86 return self.provider.embed_text(text)
Generate embeddings for a single text (delegates to wrapped provider).
Args: text: The input text to embed
Returns: A list of floats representing the embedding vector
88 def embed_batch(self, texts: List[str]) -> List[List[float]]: 89 """ 90 Generate embeddings for multiple texts with automatic batching. 91 92 This method splits large text lists into smaller batches and processes 93 them sequentially to avoid hitting provider limits and manage memory. 94 95 Args: 96 texts: List of text strings to embed 97 98 Returns: 99 List of embedding vectors, one for each input text 100 101 Raises: 102 ValueError: If texts is invalid 103 Exception: Provider-specific errors 104 """ 105 self.validate_batch(texts) 106 107 total_texts = len(texts) 108 109 if self.logger: 110 self.logger.info( 111 f"Processing {total_texts} texts in batches of {self.batch_size}" 112 ) 113 114 # If batch size is larger than input, just use the provider directly 115 if total_texts <= self.batch_size: 116 if self.logger: 117 self.logger.debug("Batch size larger than input, processing in one call") 118 return self.provider.embed_batch(texts) 119 120 # Process in batches 121 all_embeddings = [] 122 num_batches = (total_texts + self.batch_size - 1) // self.batch_size 123 124 for batch_idx in range(num_batches): 125 start_idx = batch_idx * self.batch_size 126 end_idx = min(start_idx + self.batch_size, total_texts) 127 batch_texts = texts[start_idx:end_idx] 128 129 if self.show_progress and self.logger: 130 progress = (batch_idx + 1) / num_batches * 100 131 self.logger.info( 132 f"Processing batch {batch_idx + 1}/{num_batches}", 133 progress_pct=round(progress, 1), 134 texts_done=end_idx, 135 texts_total=total_texts, 136 ) 137 138 if self.progress_callback: 139 self.progress_callback(end_idx, total_texts) 140 141 if self.logger: 142 self.logger.debug( 143 f"Processing batch {batch_idx + 1}/{num_batches}: " 144 f"texts[{start_idx}:{end_idx}]" 145 ) 146 147 # Embed the batch 148 batch_embeddings = self.provider.embed_batch(batch_texts) 149 all_embeddings.extend(batch_embeddings) 150 151 if self.show_progress and self.logger: 152 self.logger.info(f"Completed embedding {total_texts} texts") 153 154 if self.logger: 155 self.logger.info( 156 f"Successfully processed all {total_texts} texts in {num_batches} batches" 157 ) 158 159 return all_embeddings
Generate embeddings for multiple texts with automatic batching.
This method splits large text lists into smaller batches and processes them sequentially to avoid hitting provider limits and manage memory.
Args: texts: List of text strings to embed
Returns: List of embedding vectors, one for each input text
Raises: ValueError: If texts is invalid Exception: Provider-specific errors
161 def get_embedding_dimension(self) -> int: 162 """ 163 Get the dimensionality of the embedding vectors. 164 165 Returns: 166 The dimension from the wrapped provider 167 """ 168 return self.provider.get_embedding_dimension()
Get the dimensionality of the embedding vectors.
Returns: The dimension from the wrapped provider
170 def get_model_name(self) -> str: 171 """ 172 Get the name of the embedding model. 173 174 Returns: 175 The model name from the wrapped provider 176 """ 177 return self.provider.get_model_name()
Get the name of the embedding model.
Returns: The model name from the wrapped provider
179 def set_batch_size(self, batch_size: int) -> None: 180 """ 181 Update the batch size for processing. 182 183 Args: 184 batch_size: New batch size to use 185 186 Raises: 187 ValueError: If batch_size is not positive 188 """ 189 if batch_size <= 0: 190 raise ValueError(f"Batch size must be positive, got {batch_size}") 191 192 self.batch_size = batch_size 193 194 if self.logger: 195 self.logger.info(f"Updated batch size to {batch_size}")
Update the batch size for processing.
Args: batch_size: New batch size to use
Raises: ValueError: If batch_size is not positive