gmf_forge_ai_shared_core.llm_gateway.providers
LLM Providers.
1"""LLM Providers.""" 2 3from gmf_forge_ai_shared_core.llm_gateway.providers.base_provider import ( 4 BaseProvider, 5 CompletionResponse, 6 ModelInfo, 7) 8from gmf_forge_ai_shared_core.llm_gateway.providers.azure_openai_provider import AzureOpenAIProvider 9from gmf_forge_ai_shared_core.llm_gateway.providers.ollama_provider import OllamaProvider 10 11__all__ = [ 12 "BaseProvider", 13 "CompletionResponse", 14 "ModelInfo", 15 "AzureOpenAIProvider", 16 "OllamaProvider", 17]
31class BaseProvider(ABC): 32 """ 33 Abstract base class for LLM providers. 34 35 All providers (Azure OpenAI, OpenAI, Anthropic, etc.) must implement this interface. 36 37 Note: Model registration is handled by LLMProviderRegistry, not by individual providers. 38 Providers focus on LLM operations (complete, stream, validate). 39 """ 40 41 def __init__(self, name: str): 42 """ 43 Initialize the provider. 44 45 Args: 46 name: Unique identifier for this provider 47 """ 48 self.name = name 49 50 @abstractmethod 51 async def complete( 52 self, 53 prompt: str, 54 model: Optional[str] = None, 55 temperature: float = 0.7, 56 max_tokens: Optional[int] = None, 57 **kwargs: Any 58 ) -> CompletionResponse: 59 """ 60 Generate a completion. 61 62 Args: 63 prompt: The prompt to complete 64 model: Model name 65 temperature: Sampling temperature (0-1) 66 max_tokens: Maximum tokens to generate 67 **kwargs: Provider-specific parameters 68 69 Returns: 70 CompletionResponse object 71 """ 72 pass 73 74 @abstractmethod 75 async def stream_complete( 76 self, 77 prompt: str, 78 model: Optional[str] = None, 79 temperature: float = 0.7, 80 max_tokens: Optional[int] = None, 81 **kwargs: Any 82 ) -> AsyncIterator[str]: 83 """ 84 Stream a completion. 85 86 Args: 87 prompt: The prompt to complete 88 model: Model name 89 temperature: Sampling temperature (0-1) 90 max_tokens: Maximum tokens to generate 91 **kwargs: Provider-specific parameters 92 93 Yields: 94 Chunks of the completion 95 """ 96 pass 97 98 @abstractmethod 99 async def validate_credentials(self) -> bool: 100 """ 101 Validate that the provider credentials are correct. 102 103 Returns: 104 True if credentials are valid, False otherwise 105 """ 106 pass
Abstract base class for LLM providers.
All providers (Azure OpenAI, OpenAI, Anthropic, etc.) must implement this interface.
Note: Model registration is handled by LLMProviderRegistry, not by individual providers. Providers focus on LLM operations (complete, stream, validate).
41 def __init__(self, name: str): 42 """ 43 Initialize the provider. 44 45 Args: 46 name: Unique identifier for this provider 47 """ 48 self.name = name
Initialize the provider.
Args: name: Unique identifier for this provider
50 @abstractmethod 51 async def complete( 52 self, 53 prompt: str, 54 model: Optional[str] = None, 55 temperature: float = 0.7, 56 max_tokens: Optional[int] = None, 57 **kwargs: Any 58 ) -> CompletionResponse: 59 """ 60 Generate a completion. 61 62 Args: 63 prompt: The prompt to complete 64 model: Model name 65 temperature: Sampling temperature (0-1) 66 max_tokens: Maximum tokens to generate 67 **kwargs: Provider-specific parameters 68 69 Returns: 70 CompletionResponse object 71 """ 72 pass
Generate a completion.
Args: prompt: The prompt to complete model: Model name temperature: Sampling temperature (0-1) max_tokens: Maximum tokens to generate **kwargs: Provider-specific parameters
Returns: CompletionResponse object
74 @abstractmethod 75 async def stream_complete( 76 self, 77 prompt: str, 78 model: Optional[str] = None, 79 temperature: float = 0.7, 80 max_tokens: Optional[int] = None, 81 **kwargs: Any 82 ) -> AsyncIterator[str]: 83 """ 84 Stream a completion. 85 86 Args: 87 prompt: The prompt to complete 88 model: Model name 89 temperature: Sampling temperature (0-1) 90 max_tokens: Maximum tokens to generate 91 **kwargs: Provider-specific parameters 92 93 Yields: 94 Chunks of the completion 95 """ 96 pass
Stream a completion.
Args: prompt: The prompt to complete model: Model name temperature: Sampling temperature (0-1) max_tokens: Maximum tokens to generate **kwargs: Provider-specific parameters
Yields: Chunks of the completion
98 @abstractmethod 99 async def validate_credentials(self) -> bool: 100 """ 101 Validate that the provider credentials are correct. 102 103 Returns: 104 True if credentials are valid, False otherwise 105 """ 106 pass
Validate that the provider credentials are correct.
Returns: True if credentials are valid, False otherwise
13@dataclass 14class CompletionResponse: 15 """Standard completion response.""" 16 content: str 17 model: str 18 usage: Dict[str, int] 19 metadata: Dict[str, Any]
Standard completion response.
22@dataclass 23class ModelInfo: 24 """Information about a model available from a provider.""" 25 name: str 26 provider: str 27 capabilities: Dict[str, Any] = field(default_factory=dict) 28 metadata: Dict[str, Any] = field(default_factory=dict)
Information about a model available from a provider.
19class AzureOpenAIProvider(BaseProvider): 20 """ 21 Azure OpenAI provider implementation. 22 23 This is the PRIMARY provider for the platform, offering enterprise-grade 24 security, compliance, and integration with Azure services. 25 """ 26 27 def __init__( 28 self, 29 endpoint: str, 30 deployment_name: str, 31 api_key: Optional[str] = None, 32 api_version: str = "2024-02-15-preview", 33 token_provider: Optional[Callable[[], str]] = None, 34 http_client: Optional[Any] = None, 35 ssl_cert_path: Optional[Union[str, Path]] = None, 36 ): 37 """ 38 Initialize Azure OpenAI provider. 39 40 Exactly one of ``api_key`` or ``token_provider`` must be supplied. 41 42 Args: 43 endpoint: Azure OpenAI endpoint URL 44 deployment_name: Deployment name for the model 45 api_key: Azure OpenAI API key (API key authentication) 46 api_version: API version to use 47 token_provider: A zero-argument callable that returns a bearer token 48 string. Build one with ``get_bearer_token_provider`` from 49 ``azure-identity`` and any credential of your choice, e.g.:: 50 51 from azure.identity import DefaultAzureCredential, get_bearer_token_provider 52 token_provider = get_bearer_token_provider( 53 DefaultAzureCredential(), 54 "https://cognitiveservices.azure.com/.default", 55 ) 56 57 http_client: Optional httpx.AsyncClient for custom SSL/proxy settings 58 ssl_cert_path: Optional path to SSL certificate file (alternative to http_client) 59 """ 60 super().__init__(name="azure-openai") 61 62 if not api_key and not token_provider: 63 raise ValueError( 64 "Either 'api_key' or 'token_provider' must be provided to AzureOpenAIProvider." 65 ) 66 67 self.endpoint = endpoint 68 self.deployment_name = deployment_name 69 70 # Handle SSL certificate if provided 71 if ssl_cert_path and not http_client: 72 try: 73 import httpx 74 http_client = httpx.AsyncClient(verify=str(ssl_cert_path), timeout=30.0) 75 except ImportError: 76 raise ImportError("httpx is required for SSL certificate configuration. Install with: pip install httpx") 77 78 if token_provider: 79 self.client = AsyncAzureOpenAI( 80 azure_endpoint=endpoint, 81 azure_ad_token_provider=token_provider, 82 api_version=api_version, 83 http_client=http_client, 84 ) 85 else: 86 self.client = AsyncAzureOpenAI( 87 azure_endpoint=endpoint, 88 api_key=api_key, 89 api_version=api_version, 90 http_client=http_client, 91 ) 92 93 async def complete( 94 self, 95 prompt: str, 96 model: Optional[str] = None, 97 temperature: float = 0.7, 98 max_tokens: Optional[int] = None, 99 **kwargs: Any 100 ) -> CompletionResponse: 101 """Generate a completion using Azure OpenAI.""" 102 103 # Use deployment name if model not specified 104 deployment = model or self.deployment_name 105 106 response = await self.client.chat.completions.create( 107 model=deployment, 108 messages=[{"role": "user", "content": prompt}], 109 temperature=temperature, 110 max_tokens=max_tokens, 111 **kwargs 112 ) 113 114 return CompletionResponse( 115 content=response.choices[0].message.content, 116 model=deployment, 117 usage={ 118 "prompt_tokens": response.usage.prompt_tokens, 119 "completion_tokens": response.usage.completion_tokens, 120 "total_tokens": response.usage.total_tokens, 121 }, 122 metadata={ 123 "finish_reason": response.choices[0].finish_reason, 124 "id": response.id, 125 } 126 ) 127 128 async def stream_complete( 129 self, 130 prompt: str, 131 model: Optional[str] = None, 132 temperature: float = 0.7, 133 max_tokens: Optional[int] = None, 134 **kwargs: Any 135 ) -> AsyncIterator[str]: 136 """Stream a completion using Azure OpenAI.""" 137 138 deployment = model or self.deployment_name 139 140 stream = await self.client.chat.completions.create( 141 model=deployment, 142 messages=[{"role": "user", "content": prompt}], 143 temperature=temperature, 144 max_tokens=max_tokens, 145 stream=True, 146 **kwargs 147 ) 148 149 async for chunk in stream: 150 if chunk.choices[0].delta.content: 151 yield chunk.choices[0].delta.content 152 153 async def validate_credentials(self) -> bool: 154 """Validate Azure OpenAI credentials.""" 155 try: 156 # Try a minimal completion to validate credentials 157 await self.complete(prompt="test", max_tokens=1) 158 return True 159 except Exception: 160 return False
Azure OpenAI provider implementation.
This is the PRIMARY provider for the platform, offering enterprise-grade security, compliance, and integration with Azure services.
27 def __init__( 28 self, 29 endpoint: str, 30 deployment_name: str, 31 api_key: Optional[str] = None, 32 api_version: str = "2024-02-15-preview", 33 token_provider: Optional[Callable[[], str]] = None, 34 http_client: Optional[Any] = None, 35 ssl_cert_path: Optional[Union[str, Path]] = None, 36 ): 37 """ 38 Initialize Azure OpenAI provider. 39 40 Exactly one of ``api_key`` or ``token_provider`` must be supplied. 41 42 Args: 43 endpoint: Azure OpenAI endpoint URL 44 deployment_name: Deployment name for the model 45 api_key: Azure OpenAI API key (API key authentication) 46 api_version: API version to use 47 token_provider: A zero-argument callable that returns a bearer token 48 string. Build one with ``get_bearer_token_provider`` from 49 ``azure-identity`` and any credential of your choice, e.g.:: 50 51 from azure.identity import DefaultAzureCredential, get_bearer_token_provider 52 token_provider = get_bearer_token_provider( 53 DefaultAzureCredential(), 54 "https://cognitiveservices.azure.com/.default", 55 ) 56 57 http_client: Optional httpx.AsyncClient for custom SSL/proxy settings 58 ssl_cert_path: Optional path to SSL certificate file (alternative to http_client) 59 """ 60 super().__init__(name="azure-openai") 61 62 if not api_key and not token_provider: 63 raise ValueError( 64 "Either 'api_key' or 'token_provider' must be provided to AzureOpenAIProvider." 65 ) 66 67 self.endpoint = endpoint 68 self.deployment_name = deployment_name 69 70 # Handle SSL certificate if provided 71 if ssl_cert_path and not http_client: 72 try: 73 import httpx 74 http_client = httpx.AsyncClient(verify=str(ssl_cert_path), timeout=30.0) 75 except ImportError: 76 raise ImportError("httpx is required for SSL certificate configuration. Install with: pip install httpx") 77 78 if token_provider: 79 self.client = AsyncAzureOpenAI( 80 azure_endpoint=endpoint, 81 azure_ad_token_provider=token_provider, 82 api_version=api_version, 83 http_client=http_client, 84 ) 85 else: 86 self.client = AsyncAzureOpenAI( 87 azure_endpoint=endpoint, 88 api_key=api_key, 89 api_version=api_version, 90 http_client=http_client, 91 )
Initialize Azure OpenAI provider.
Exactly one of api_key or token_provider must be supplied.
Args:
endpoint: Azure OpenAI endpoint URL
deployment_name: Deployment name for the model
api_key: Azure OpenAI API key (API key authentication)
api_version: API version to use
token_provider: A zero-argument callable that returns a bearer token
string. Build one with get_bearer_token_provider from
azure-identity and any credential of your choice, e.g.::
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
token_provider = get_bearer_token_provider(
DefaultAzureCredential(),
"https://cognitiveservices.azure.com/.default",
)
http_client: Optional httpx.AsyncClient for custom SSL/proxy settings
ssl_cert_path: Optional path to SSL certificate file (alternative to http_client)
93 async def complete( 94 self, 95 prompt: str, 96 model: Optional[str] = None, 97 temperature: float = 0.7, 98 max_tokens: Optional[int] = None, 99 **kwargs: Any 100 ) -> CompletionResponse: 101 """Generate a completion using Azure OpenAI.""" 102 103 # Use deployment name if model not specified 104 deployment = model or self.deployment_name 105 106 response = await self.client.chat.completions.create( 107 model=deployment, 108 messages=[{"role": "user", "content": prompt}], 109 temperature=temperature, 110 max_tokens=max_tokens, 111 **kwargs 112 ) 113 114 return CompletionResponse( 115 content=response.choices[0].message.content, 116 model=deployment, 117 usage={ 118 "prompt_tokens": response.usage.prompt_tokens, 119 "completion_tokens": response.usage.completion_tokens, 120 "total_tokens": response.usage.total_tokens, 121 }, 122 metadata={ 123 "finish_reason": response.choices[0].finish_reason, 124 "id": response.id, 125 } 126 )
Generate a completion using Azure OpenAI.
128 async def stream_complete( 129 self, 130 prompt: str, 131 model: Optional[str] = None, 132 temperature: float = 0.7, 133 max_tokens: Optional[int] = None, 134 **kwargs: Any 135 ) -> AsyncIterator[str]: 136 """Stream a completion using Azure OpenAI.""" 137 138 deployment = model or self.deployment_name 139 140 stream = await self.client.chat.completions.create( 141 model=deployment, 142 messages=[{"role": "user", "content": prompt}], 143 temperature=temperature, 144 max_tokens=max_tokens, 145 stream=True, 146 **kwargs 147 ) 148 149 async for chunk in stream: 150 if chunk.choices[0].delta.content: 151 yield chunk.choices[0].delta.content
Stream a completion using Azure OpenAI.
153 async def validate_credentials(self) -> bool: 154 """Validate Azure OpenAI credentials.""" 155 try: 156 # Try a minimal completion to validate credentials 157 await self.complete(prompt="test", max_tokens=1) 158 return True 159 except Exception: 160 return False
Validate Azure OpenAI credentials.
17class OllamaProvider(BaseProvider): 18 """ 19 Ollama provider for local LLM development and testing. 20 21 Ollama runs open-source models locally (llama3, mistral, phi3, etc.). 22 Perfect for development, testing, and environments without cloud access. 23 24 Install Ollama: https://ollama.ai/download 25 26 Example: 27 >>> provider = OllamaProvider( 28 ... base_url="http://localhost:11434", 29 ... model="llama3" 30 ... ) 31 >>> response = await provider.complete("What is RAG?") 32 """ 33 34 def __init__( 35 self, 36 base_url: str = "http://localhost:11434", 37 model: str = "llama3", 38 timeout: float = 30.0 39 ): 40 """ 41 Initialize Ollama provider. 42 43 Args: 44 base_url: Ollama server URL (default: http://localhost:11434) 45 model: Model name (e.g., "llama3", "mistral", "phi3") 46 timeout: Request timeout in seconds 47 """ 48 super().__init__(name="ollama") 49 self.base_url = base_url.rstrip("/") 50 self.model = model 51 self.timeout = timeout 52 self._client = httpx.AsyncClient(timeout=self.timeout) 53 54 async def complete( 55 self, 56 prompt: str, 57 model: Optional[str] = None, 58 temperature: float = 0.7, 59 max_tokens: Optional[int] = None, 60 **kwargs: Any 61 ) -> CompletionResponse: 62 """ 63 Generate a completion using Ollama. 64 65 Args: 66 prompt: The prompt to complete 67 model: Model name (overrides instance default) 68 temperature: Sampling temperature (0-2) 69 max_tokens: Maximum tokens to generate 70 **kwargs: Additional Ollama-specific parameters 71 72 Returns: 73 CompletionResponse object 74 75 Raises: 76 httpx.ConnectError: If cannot connect to Ollama 77 httpx.HTTPStatusError: If Ollama returns an error 78 """ 79 model_name = model or self.model 80 81 # Build request payload 82 payload = { 83 "model": model_name, 84 "prompt": prompt, 85 "stream": False, 86 "options": { 87 "temperature": temperature, 88 } 89 } 90 91 # Add max_tokens if specified (Ollama calls it num_predict) 92 if max_tokens is not None: 93 payload["options"]["num_predict"] = max_tokens 94 95 # Add any additional options 96 if kwargs: 97 payload["options"].update(kwargs) 98 99 try: 100 # Call Ollama API 101 response = await self._client.post( 102 f"{self.base_url}/api/generate", 103 json=payload 104 ) 105 response.raise_for_status() 106 107 result = response.json() 108 109 # Extract response content 110 content = result.get("response", "") 111 112 # Calculate token usage (Ollama provides these) 113 prompt_tokens = result.get("prompt_eval_count", 0) 114 completion_tokens = result.get("eval_count", 0) 115 116 return CompletionResponse( 117 content=content, 118 model=model_name, 119 usage={ 120 "prompt_tokens": prompt_tokens, 121 "completion_tokens": completion_tokens, 122 "total_tokens": prompt_tokens + completion_tokens 123 }, 124 metadata={ 125 "provider": "ollama", 126 "total_duration": result.get("total_duration"), 127 "load_duration": result.get("load_duration"), 128 "eval_duration": result.get("eval_duration") 129 } 130 ) 131 132 except httpx.ConnectError as e: 133 raise ConnectionError( 134 f"Cannot connect to Ollama at {self.base_url}. " 135 f"Is Ollama running? Install and run: ollama serve" 136 ) from e 137 except httpx.HTTPStatusError as e: 138 if e.response.status_code == 404: 139 raise ValueError( 140 f"Model '{model_name}' not found. " 141 f"Pull it first with: ollama pull {model_name}" 142 ) from e 143 raise 144 145 async def stream_complete( 146 self, 147 prompt: str, 148 model: Optional[str] = None, 149 temperature: float = 0.7, 150 max_tokens: Optional[int] = None, 151 **kwargs: Any 152 ) -> AsyncIterator[str]: 153 """ 154 Stream a completion using Ollama. 155 156 Args: 157 prompt: The prompt to complete 158 model: Model name (overrides instance default) 159 temperature: Sampling temperature (0-2) 160 max_tokens: Maximum tokens to generate 161 **kwargs: Additional Ollama-specific parameters 162 163 Yields: 164 Chunks of the completion 165 166 Raises: 167 httpx.ConnectError: If cannot connect to Ollama 168 httpx.HTTPStatusError: If Ollama returns an error 169 """ 170 model_name = model or self.model 171 172 # Build request payload 173 payload = { 174 "model": model_name, 175 "prompt": prompt, 176 "stream": True, 177 "options": { 178 "temperature": temperature, 179 } 180 } 181 182 # Add max_tokens if specified 183 if max_tokens is not None: 184 payload["options"]["num_predict"] = max_tokens 185 186 # Add any additional options 187 if kwargs: 188 payload["options"].update(kwargs) 189 190 try: 191 # Stream from Ollama API 192 async with self._client.stream( 193 "POST", 194 f"{self.base_url}/api/generate", 195 json=payload 196 ) as response: 197 response.raise_for_status() 198 199 async for line in response.aiter_lines(): 200 if line.strip(): 201 import json 202 chunk = json.loads(line) 203 if "response" in chunk: 204 yield chunk["response"] 205 206 # Check if done 207 if chunk.get("done", False): 208 break 209 210 except httpx.ConnectError as e: 211 raise ConnectionError( 212 f"Cannot connect to Ollama at {self.base_url}. " 213 f"Is Ollama running? Install and run: ollama serve" 214 ) from e 215 except httpx.HTTPStatusError as e: 216 if e.response.status_code == 404: 217 raise ValueError( 218 f"Model '{model_name}' not found. " 219 f"Pull it first with: ollama pull {model_name}" 220 ) from e 221 raise 222 223 async def validate_credentials(self) -> bool: 224 """ 225 Validate connection to Ollama server. 226 227 Returns: 228 True if can connect to Ollama, False otherwise 229 """ 230 try: 231 response = await self._client.get(f"{self.base_url}/api/tags") 232 return response.status_code == 200 233 except Exception: 234 return False 235 236 async def close(self): 237 """Close the HTTP client.""" 238 await self._client.aclose() 239 240 async def __aenter__(self): 241 """Async context manager entry.""" 242 return self 243 244 async def __aexit__(self, exc_type, exc_val, exc_tb): 245 """Async context manager exit.""" 246 await self.close()
Ollama provider for local LLM development and testing.
Ollama runs open-source models locally (llama3, mistral, phi3, etc.). Perfect for development, testing, and environments without cloud access.
Install Ollama: https://ollama.ai/download
Example:
provider = OllamaProvider( ... base_url="http://localhost:11434", ... model="llama3" ... ) response = await provider.complete("What is RAG?")
34 def __init__( 35 self, 36 base_url: str = "http://localhost:11434", 37 model: str = "llama3", 38 timeout: float = 30.0 39 ): 40 """ 41 Initialize Ollama provider. 42 43 Args: 44 base_url: Ollama server URL (default: http://localhost:11434) 45 model: Model name (e.g., "llama3", "mistral", "phi3") 46 timeout: Request timeout in seconds 47 """ 48 super().__init__(name="ollama") 49 self.base_url = base_url.rstrip("/") 50 self.model = model 51 self.timeout = timeout 52 self._client = httpx.AsyncClient(timeout=self.timeout)
Initialize Ollama provider.
Args: base_url: Ollama server URL (default: http://localhost:11434) model: Model name (e.g., "llama3", "mistral", "phi3") timeout: Request timeout in seconds
54 async def complete( 55 self, 56 prompt: str, 57 model: Optional[str] = None, 58 temperature: float = 0.7, 59 max_tokens: Optional[int] = None, 60 **kwargs: Any 61 ) -> CompletionResponse: 62 """ 63 Generate a completion using Ollama. 64 65 Args: 66 prompt: The prompt to complete 67 model: Model name (overrides instance default) 68 temperature: Sampling temperature (0-2) 69 max_tokens: Maximum tokens to generate 70 **kwargs: Additional Ollama-specific parameters 71 72 Returns: 73 CompletionResponse object 74 75 Raises: 76 httpx.ConnectError: If cannot connect to Ollama 77 httpx.HTTPStatusError: If Ollama returns an error 78 """ 79 model_name = model or self.model 80 81 # Build request payload 82 payload = { 83 "model": model_name, 84 "prompt": prompt, 85 "stream": False, 86 "options": { 87 "temperature": temperature, 88 } 89 } 90 91 # Add max_tokens if specified (Ollama calls it num_predict) 92 if max_tokens is not None: 93 payload["options"]["num_predict"] = max_tokens 94 95 # Add any additional options 96 if kwargs: 97 payload["options"].update(kwargs) 98 99 try: 100 # Call Ollama API 101 response = await self._client.post( 102 f"{self.base_url}/api/generate", 103 json=payload 104 ) 105 response.raise_for_status() 106 107 result = response.json() 108 109 # Extract response content 110 content = result.get("response", "") 111 112 # Calculate token usage (Ollama provides these) 113 prompt_tokens = result.get("prompt_eval_count", 0) 114 completion_tokens = result.get("eval_count", 0) 115 116 return CompletionResponse( 117 content=content, 118 model=model_name, 119 usage={ 120 "prompt_tokens": prompt_tokens, 121 "completion_tokens": completion_tokens, 122 "total_tokens": prompt_tokens + completion_tokens 123 }, 124 metadata={ 125 "provider": "ollama", 126 "total_duration": result.get("total_duration"), 127 "load_duration": result.get("load_duration"), 128 "eval_duration": result.get("eval_duration") 129 } 130 ) 131 132 except httpx.ConnectError as e: 133 raise ConnectionError( 134 f"Cannot connect to Ollama at {self.base_url}. " 135 f"Is Ollama running? Install and run: ollama serve" 136 ) from e 137 except httpx.HTTPStatusError as e: 138 if e.response.status_code == 404: 139 raise ValueError( 140 f"Model '{model_name}' not found. " 141 f"Pull it first with: ollama pull {model_name}" 142 ) from e 143 raise
Generate a completion using Ollama.
Args: prompt: The prompt to complete model: Model name (overrides instance default) temperature: Sampling temperature (0-2) max_tokens: Maximum tokens to generate **kwargs: Additional Ollama-specific parameters
Returns: CompletionResponse object
Raises: httpx.ConnectError: If cannot connect to Ollama httpx.HTTPStatusError: If Ollama returns an error
145 async def stream_complete( 146 self, 147 prompt: str, 148 model: Optional[str] = None, 149 temperature: float = 0.7, 150 max_tokens: Optional[int] = None, 151 **kwargs: Any 152 ) -> AsyncIterator[str]: 153 """ 154 Stream a completion using Ollama. 155 156 Args: 157 prompt: The prompt to complete 158 model: Model name (overrides instance default) 159 temperature: Sampling temperature (0-2) 160 max_tokens: Maximum tokens to generate 161 **kwargs: Additional Ollama-specific parameters 162 163 Yields: 164 Chunks of the completion 165 166 Raises: 167 httpx.ConnectError: If cannot connect to Ollama 168 httpx.HTTPStatusError: If Ollama returns an error 169 """ 170 model_name = model or self.model 171 172 # Build request payload 173 payload = { 174 "model": model_name, 175 "prompt": prompt, 176 "stream": True, 177 "options": { 178 "temperature": temperature, 179 } 180 } 181 182 # Add max_tokens if specified 183 if max_tokens is not None: 184 payload["options"]["num_predict"] = max_tokens 185 186 # Add any additional options 187 if kwargs: 188 payload["options"].update(kwargs) 189 190 try: 191 # Stream from Ollama API 192 async with self._client.stream( 193 "POST", 194 f"{self.base_url}/api/generate", 195 json=payload 196 ) as response: 197 response.raise_for_status() 198 199 async for line in response.aiter_lines(): 200 if line.strip(): 201 import json 202 chunk = json.loads(line) 203 if "response" in chunk: 204 yield chunk["response"] 205 206 # Check if done 207 if chunk.get("done", False): 208 break 209 210 except httpx.ConnectError as e: 211 raise ConnectionError( 212 f"Cannot connect to Ollama at {self.base_url}. " 213 f"Is Ollama running? Install and run: ollama serve" 214 ) from e 215 except httpx.HTTPStatusError as e: 216 if e.response.status_code == 404: 217 raise ValueError( 218 f"Model '{model_name}' not found. " 219 f"Pull it first with: ollama pull {model_name}" 220 ) from e 221 raise
Stream a completion using Ollama.
Args: prompt: The prompt to complete model: Model name (overrides instance default) temperature: Sampling temperature (0-2) max_tokens: Maximum tokens to generate **kwargs: Additional Ollama-specific parameters
Yields: Chunks of the completion
Raises: httpx.ConnectError: If cannot connect to Ollama httpx.HTTPStatusError: If Ollama returns an error
223 async def validate_credentials(self) -> bool: 224 """ 225 Validate connection to Ollama server. 226 227 Returns: 228 True if can connect to Ollama, False otherwise 229 """ 230 try: 231 response = await self._client.get(f"{self.base_url}/api/tags") 232 return response.status_code == 200 233 except Exception: 234 return False
Validate connection to Ollama server.
Returns: True if can connect to Ollama, False otherwise