gmf_forge_ai_orchestration
GMF Forge AI Orchestration — Agents, workflows, and multi-agent systems.
1"""GMF Forge AI Orchestration — Agents, workflows, and multi-agent systems.""" 2 3from gmf_forge_ai_orchestration.version import __version__ 4 5# State 6from gmf_forge_ai_orchestration.state import ( 7 BaseStateStore, 8 InMemoryStateStore, 9 RedisStateStore, 10 StateStoreFactory, 11 ConversationState, 12 ConversationMessage, 13 Checkpoint, 14 CheckpointManager, 15 Blackboard, 16 BlackboardEntry, 17) 18 19# Behaviors 20from gmf_forge_ai_orchestration.behaviors import ( 21 BaseBehavior, 22 BehaviorContext, 23 RetryBehavior, 24 GuardrailBehavior, 25 GuardrailRule, 26 GuardrailViolationError, 27 HumanInLoopBehavior, 28 HumanApprovalRequired, 29 PendingApproval, 30 CircuitBreakerBehavior, 31 CircuitState, 32 CircuitOpenError, 33 RateLimitBehavior, 34 RateLimitExceededError, 35 AuditBehavior, 36 AgentDiscoveryBehavior, 37) 38 39# Agents 40from gmf_forge_ai_orchestration.agents import ( 41 AgentStep, 42 AgentResult, 43 BaseAgent, 44 ReActAgent, 45 PlanExecuteAgent, 46 ReflexionAgent, 47 ChainOfThoughtAgent, 48) 49 50# Routing 51from gmf_forge_ai_orchestration.routing import ( 52 BaseRouter, 53 RoutingRequest, 54 RoutingDecision, 55 LLMRouter, 56 SemanticRouter, 57 RuleBasedRouter, 58 LoadBalancingRouter, 59) 60 61# Workflows 62from gmf_forge_ai_orchestration.workflows import ( 63 BaseWorkflow, 64 WorkflowNode, 65 WorkflowEdge, 66 WorkflowResult, 67 DAGWorkflow, 68 StateMachineWorkflow, 69 EventDrivenWorkflow, 70 WorkflowEvent, 71) 72 73# Multi-agent 74from gmf_forge_ai_orchestration.multi_agent import ( 75 BaseOrchestrator, 76 OrchestratorResult, 77 SupervisorOrchestrator, 78 PipelineOrchestrator, 79 DebateOrchestrator, 80 SwarmOrchestrator, 81) 82 83__all__ = [ 84 "__version__", 85 # State 86 "BaseStateStore", "InMemoryStateStore", "RedisStateStore", "StateStoreFactory", 87 "ConversationState", "ConversationMessage", "Checkpoint", "CheckpointManager", 88 "Blackboard", "BlackboardEntry", 89 # Behaviors 90 "BaseBehavior", "BehaviorContext", "RetryBehavior", 91 "GuardrailBehavior", "GuardrailRule", "GuardrailViolationError", 92 "HumanInLoopBehavior", "HumanApprovalRequired", "PendingApproval", 93 "CircuitBreakerBehavior", "CircuitState", "CircuitOpenError", 94 "RateLimitBehavior", "RateLimitExceededError", "AuditBehavior", 95 "AgentDiscoveryBehavior", 96 # Agents 97 "AgentStep", "AgentResult", "BaseAgent", 98 "ReActAgent", "PlanExecuteAgent", "ReflexionAgent", "ChainOfThoughtAgent", 99 # Routing 100 "BaseRouter", "RoutingRequest", "RoutingDecision", 101 "LLMRouter", "SemanticRouter", "RuleBasedRouter", "LoadBalancingRouter", 102 # Workflows 103 "BaseWorkflow", "WorkflowNode", "WorkflowEdge", "WorkflowResult", 104 "DAGWorkflow", "StateMachineWorkflow", "EventDrivenWorkflow", "WorkflowEvent", 105 # Multi-agent 106 "BaseOrchestrator", "OrchestratorResult", 107 "SupervisorOrchestrator", "PipelineOrchestrator", 108 "DebateOrchestrator", "SwarmOrchestrator", 109]
102class BaseStateStore(ABC): 103 """ 104 Abstract key-value state store. 105 106 Both InMemoryStateStore and RedisStateStore implement this interface, 107 making them interchangeable in all orchestration components. 108 """ 109 110 @abstractmethod 111 async def get(self, key: str) -> Optional[Any]: 112 """Retrieve a value by key. Returns None if not found or expired.""" 113 114 @abstractmethod 115 async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None: 116 """ 117 Store a value. 118 119 Args: 120 key: Storage key. 121 value: JSON-serialisable value. 122 ttl: Time-to-live in seconds. None means no expiry. 123 """ 124 125 @abstractmethod 126 async def delete(self, key: str) -> None: 127 """Remove a key. No-op if the key does not exist.""" 128 129 @abstractmethod 130 async def exists(self, key: str) -> bool: 131 """Return True if the key exists and has not expired.""" 132 133 @abstractmethod 134 async def clear(self) -> None: 135 """Delete all keys managed by this store instance."""
Abstract key-value state store.
Both InMemoryStateStore and RedisStateStore implement this interface, making them interchangeable in all orchestration components.
110 @abstractmethod 111 async def get(self, key: str) -> Optional[Any]: 112 """Retrieve a value by key. Returns None if not found or expired."""
Retrieve a value by key. Returns None if not found or expired.
114 @abstractmethod 115 async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None: 116 """ 117 Store a value. 118 119 Args: 120 key: Storage key. 121 value: JSON-serialisable value. 122 ttl: Time-to-live in seconds. None means no expiry. 123 """
Store a value.
Args: key: Storage key. value: JSON-serialisable value. ttl: Time-to-live in seconds. None means no expiry.
125 @abstractmethod 126 async def delete(self, key: str) -> None: 127 """Remove a key. No-op if the key does not exist."""
Remove a key. No-op if the key does not exist.
11class InMemoryStateStore(BaseStateStore): 12 """ 13 Thread-safe in-memory key-value store with optional TTL. 14 15 Suitable for single-process deployments and testing. For multi-process 16 or persistent state, use RedisStateStore instead. 17 """ 18 19 def __init__(self) -> None: 20 # values are (data, expires_at_monotonic | None) 21 self._store: Dict[str, Tuple[Any, Optional[float]]] = {} 22 self._lock = asyncio.Lock() 23 24 # ------------------------------------------------------------------ 25 # Internal helpers 26 # ------------------------------------------------------------------ 27 28 def _is_expired(self, expires_at: Optional[float]) -> bool: 29 if expires_at is None: 30 return False 31 return time.monotonic() > expires_at 32 33 def _clean_key(self, key: str) -> None: 34 """Remove a key if it has expired (lazy eviction).""" 35 entry = self._store.get(key) 36 if entry and self._is_expired(entry[1]): 37 del self._store[key] 38 39 # ------------------------------------------------------------------ 40 # BaseStateStore implementation 41 # ------------------------------------------------------------------ 42 43 async def get(self, key: str) -> Optional[Any]: 44 async with self._lock: 45 self._clean_key(key) 46 entry = self._store.get(key) 47 return entry[0] if entry else None 48 49 async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None: 50 expires_at = time.monotonic() + ttl if ttl is not None else None 51 async with self._lock: 52 self._store[key] = (value, expires_at) 53 54 async def delete(self, key: str) -> None: 55 async with self._lock: 56 self._store.pop(key, None) 57 58 async def exists(self, key: str) -> bool: 59 async with self._lock: 60 self._clean_key(key) 61 return key in self._store 62 63 async def clear(self) -> None: 64 async with self._lock: 65 self._store.clear()
Thread-safe in-memory key-value store with optional TTL.
Suitable for single-process deployments and testing. For multi-process or persistent state, use RedisStateStore instead.
43 async def get(self, key: str) -> Optional[Any]: 44 async with self._lock: 45 self._clean_key(key) 46 entry = self._store.get(key) 47 return entry[0] if entry else None
Retrieve a value by key. Returns None if not found or expired.
49 async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None: 50 expires_at = time.monotonic() + ttl if ttl is not None else None 51 async with self._lock: 52 self._store[key] = (value, expires_at)
Store a value.
Args: key: Storage key. value: JSON-serialisable value. ttl: Time-to-live in seconds. None means no expiry.
54 async def delete(self, key: str) -> None: 55 async with self._lock: 56 self._store.pop(key, None)
Remove a key. No-op if the key does not exist.
10class RedisStateStore(BaseStateStore): 11 """ 12 Redis-backed key-value store. 13 14 Uses redis.asyncio for non-blocking I/O. Values are JSON-serialised 15 before storage and deserialised on retrieval. 16 17 Args: 18 url: Redis connection URL (e.g. ``"redis://localhost:6379"``). 19 key_prefix: Optional prefix applied to every key to avoid collisions 20 when sharing a Redis instance across services. Defaults to 21 ``"gmf_forge_ai:"``. 22 decode_responses: Passed through to the Redis client. Defaults to True. 23 24 Example:: 25 26 store = RedisStateStore(url="redis://localhost:6379") 27 await store.set("session:abc", {"key": "value"}, ttl=3600) 28 data = await store.get("session:abc") 29 """ 30 31 def __init__( 32 self, 33 url: str = "redis://localhost:6379", 34 key_prefix: str = "gmf_forge_ai:", 35 **redis_kwargs: Any, 36 ) -> None: 37 try: 38 import redis.asyncio as aioredis # type: ignore[import] 39 except ImportError as exc: 40 raise ImportError( 41 "redis package is required for RedisStateStore. " 42 "Install it with: pip install redis>=5.0.0" 43 ) from exc 44 45 self._client = aioredis.from_url(url, decode_responses=True, **redis_kwargs) 46 self._prefix = key_prefix 47 48 # ------------------------------------------------------------------ 49 # Internal helpers 50 # ------------------------------------------------------------------ 51 52 def _k(self, key: str) -> str: 53 return f"{self._prefix}{key}" 54 55 @staticmethod 56 def _serialize(value: Any) -> str: 57 return json.dumps(value, default=str) 58 59 @staticmethod 60 def _deserialize(raw: str) -> Any: 61 return json.loads(raw) 62 63 # ------------------------------------------------------------------ 64 # BaseStateStore implementation 65 # ------------------------------------------------------------------ 66 67 async def get(self, key: str) -> Optional[Any]: 68 raw = await self._client.get(self._k(key)) 69 if raw is None: 70 return None 71 return self._deserialize(raw) 72 73 async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None: 74 serialized = self._serialize(value) 75 if ttl is not None: 76 await self._client.setex(self._k(key), ttl, serialized) 77 else: 78 await self._client.set(self._k(key), serialized) 79 80 async def delete(self, key: str) -> None: 81 await self._client.delete(self._k(key)) 82 83 async def exists(self, key: str) -> bool: 84 result = await self._client.exists(self._k(key)) 85 return bool(result) 86 87 async def clear(self) -> None: 88 """Delete all keys matching this store's prefix.""" 89 pattern = f"{self._prefix}*" 90 cursor = 0 91 while True: 92 cursor, keys = await self._client.scan(cursor, match=pattern, count=100) 93 if keys: 94 await self._client.delete(*keys) 95 if cursor == 0: 96 break 97 98 async def close(self) -> None: 99 """Close the underlying Redis connection pool.""" 100 await self._client.aclose()
Redis-backed key-value store.
Uses redis.asyncio for non-blocking I/O. Values are JSON-serialised before storage and deserialised on retrieval.
Args:
url: Redis connection URL (e.g. "redis://localhost:6379").
key_prefix: Optional prefix applied to every key to avoid collisions
when sharing a Redis instance across services. Defaults to
"gmf_forge_ai:".
decode_responses: Passed through to the Redis client. Defaults to True.
Example::
store = RedisStateStore(url="redis://localhost:6379")
await store.set("session:abc", {"key": "value"}, ttl=3600)
data = await store.get("session:abc")
31 def __init__( 32 self, 33 url: str = "redis://localhost:6379", 34 key_prefix: str = "gmf_forge_ai:", 35 **redis_kwargs: Any, 36 ) -> None: 37 try: 38 import redis.asyncio as aioredis # type: ignore[import] 39 except ImportError as exc: 40 raise ImportError( 41 "redis package is required for RedisStateStore. " 42 "Install it with: pip install redis>=5.0.0" 43 ) from exc 44 45 self._client = aioredis.from_url(url, decode_responses=True, **redis_kwargs) 46 self._prefix = key_prefix
67 async def get(self, key: str) -> Optional[Any]: 68 raw = await self._client.get(self._k(key)) 69 if raw is None: 70 return None 71 return self._deserialize(raw)
Retrieve a value by key. Returns None if not found or expired.
73 async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None: 74 serialized = self._serialize(value) 75 if ttl is not None: 76 await self._client.setex(self._k(key), ttl, serialized) 77 else: 78 await self._client.set(self._k(key), serialized)
Store a value.
Args: key: Storage key. value: JSON-serialisable value. ttl: Time-to-live in seconds. None means no expiry.
83 async def exists(self, key: str) -> bool: 84 result = await self._client.exists(self._k(key)) 85 return bool(result)
Return True if the key exists and has not expired.
87 async def clear(self) -> None: 88 """Delete all keys matching this store's prefix.""" 89 pattern = f"{self._prefix}*" 90 cursor = 0 91 while True: 92 cursor, keys = await self._client.scan(cursor, match=pattern, count=100) 93 if keys: 94 await self._client.delete(*keys) 95 if cursor == 0: 96 break
Delete all keys matching this store's prefix.
9class StateStoreFactory: 10 """ 11 Creates configured state store instances. 12 13 Example:: 14 15 # In-memory (default, no extra deps) 16 store = StateStoreFactory.create("memory") 17 18 # Redis 19 store = StateStoreFactory.create("redis", url="redis://localhost:6379") 20 """ 21 22 @staticmethod 23 def create( 24 backend: Literal["memory", "redis"] = "memory", 25 **kwargs: Any, 26 ) -> BaseStateStore: 27 """ 28 Instantiate a state store. 29 30 Args: 31 backend: ``"memory"`` or ``"redis"``. 32 **kwargs: Forwarded to the store constructor. 33 For Redis: ``url``, ``key_prefix``, and any 34 additional kwargs accepted by ``redis.asyncio.from_url``. 35 36 Returns: 37 A :class:`BaseStateStore` instance. 38 39 Raises: 40 ValueError: If an unknown backend name is provided. 41 """ 42 if backend == "memory": 43 from gmf_forge_ai_orchestration.state.memory_store import InMemoryStateStore 44 return InMemoryStateStore() 45 46 if backend == "redis": 47 from gmf_forge_ai_orchestration.state.redis_store import RedisStateStore 48 return RedisStateStore(**kwargs) 49 50 raise ValueError( 51 f"Unknown state store backend: {backend!r}. " 52 "Supported values: 'memory', 'redis'." 53 )
Creates configured state store instances.
Example::
# In-memory (default, no extra deps)
store = StateStoreFactory.create("memory")
# Redis
store = StateStoreFactory.create("redis", url="redis://localhost:6379")
22 @staticmethod 23 def create( 24 backend: Literal["memory", "redis"] = "memory", 25 **kwargs: Any, 26 ) -> BaseStateStore: 27 """ 28 Instantiate a state store. 29 30 Args: 31 backend: ``"memory"`` or ``"redis"``. 32 **kwargs: Forwarded to the store constructor. 33 For Redis: ``url``, ``key_prefix``, and any 34 additional kwargs accepted by ``redis.asyncio.from_url``. 35 36 Returns: 37 A :class:`BaseStateStore` instance. 38 39 Raises: 40 ValueError: If an unknown backend name is provided. 41 """ 42 if backend == "memory": 43 from gmf_forge_ai_orchestration.state.memory_store import InMemoryStateStore 44 return InMemoryStateStore() 45 46 if backend == "redis": 47 from gmf_forge_ai_orchestration.state.redis_store import RedisStateStore 48 return RedisStateStore(**kwargs) 49 50 raise ValueError( 51 f"Unknown state store backend: {backend!r}. " 52 "Supported values: 'memory', 'redis'." 53 )
Instantiate a state store.
Args:
backend: "memory" or "redis".
**kwargs: Forwarded to the store constructor.
For Redis: url, key_prefix, and any
additional kwargs accepted by redis.asyncio.from_url.
Returns:
A BaseStateStore instance.
Raises: ValueError: If an unknown backend name is provided.
20@dataclass 21class ConversationState: 22 """Full state of an agent conversation session.""" 23 24 session_id: str 25 messages: List[ConversationMessage] = field(default_factory=list) 26 agent_metadata: Dict[str, Any] = field(default_factory=dict) 27 created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) 28 updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) 29 30 def add_message(self, role: str, content: str, **metadata: Any) -> None: 31 """Append a message and update the timestamp.""" 32 self.messages.append(ConversationMessage(role=role, content=content, metadata=metadata)) 33 self.updated_at = datetime.now(timezone.utc) 34 35 def to_dict(self) -> Dict[str, Any]: 36 return { 37 "session_id": self.session_id, 38 "messages": [ 39 { 40 "role": m.role, 41 "content": m.content, 42 "timestamp": m.timestamp.isoformat(), 43 "metadata": m.metadata, 44 } 45 for m in self.messages 46 ], 47 "agent_metadata": self.agent_metadata, 48 "created_at": self.created_at.isoformat(), 49 "updated_at": self.updated_at.isoformat(), 50 } 51 52 @classmethod 53 def from_dict(cls, data: Dict[str, Any]) -> "ConversationState": 54 messages = [ 55 ConversationMessage( 56 role=m["role"], 57 content=m["content"], 58 timestamp=datetime.fromisoformat(m["timestamp"]), 59 metadata=m.get("metadata", {}), 60 ) 61 for m in data.get("messages", []) 62 ] 63 return cls( 64 session_id=data["session_id"], 65 messages=messages, 66 agent_metadata=data.get("agent_metadata", {}), 67 created_at=datetime.fromisoformat(data["created_at"]), 68 updated_at=datetime.fromisoformat(data["updated_at"]), 69 )
Full state of an agent conversation session.
30 def add_message(self, role: str, content: str, **metadata: Any) -> None: 31 """Append a message and update the timestamp.""" 32 self.messages.append(ConversationMessage(role=role, content=content, metadata=metadata)) 33 self.updated_at = datetime.now(timezone.utc)
Append a message and update the timestamp.
35 def to_dict(self) -> Dict[str, Any]: 36 return { 37 "session_id": self.session_id, 38 "messages": [ 39 { 40 "role": m.role, 41 "content": m.content, 42 "timestamp": m.timestamp.isoformat(), 43 "metadata": m.metadata, 44 } 45 for m in self.messages 46 ], 47 "agent_metadata": self.agent_metadata, 48 "created_at": self.created_at.isoformat(), 49 "updated_at": self.updated_at.isoformat(), 50 }
52 @classmethod 53 def from_dict(cls, data: Dict[str, Any]) -> "ConversationState": 54 messages = [ 55 ConversationMessage( 56 role=m["role"], 57 content=m["content"], 58 timestamp=datetime.fromisoformat(m["timestamp"]), 59 metadata=m.get("metadata", {}), 60 ) 61 for m in data.get("messages", []) 62 ] 63 return cls( 64 session_id=data["session_id"], 65 messages=messages, 66 agent_metadata=data.get("agent_metadata", {}), 67 created_at=datetime.fromisoformat(data["created_at"]), 68 updated_at=datetime.fromisoformat(data["updated_at"]), 69 )
10@dataclass 11class ConversationMessage: 12 """A single message in a conversation.""" 13 14 role: str 15 content: str 16 timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) 17 metadata: Dict[str, Any] = field(default_factory=dict)
A single message in a conversation.
72@dataclass 73class Checkpoint: 74 """A point-in-time snapshot of an agent's execution state.""" 75 76 checkpoint_id: str 77 agent_id: str 78 state: Dict[str, Any] 79 timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) 80 metadata: Dict[str, Any] = field(default_factory=dict) 81 82 def to_dict(self) -> Dict[str, Any]: 83 return { 84 "checkpoint_id": self.checkpoint_id, 85 "agent_id": self.agent_id, 86 "state": self.state, 87 "timestamp": self.timestamp.isoformat(), 88 "metadata": self.metadata, 89 } 90 91 @classmethod 92 def from_dict(cls, data: Dict[str, Any]) -> "Checkpoint": 93 return cls( 94 checkpoint_id=data["checkpoint_id"], 95 agent_id=data["agent_id"], 96 state=data["state"], 97 timestamp=datetime.fromisoformat(data["timestamp"]), 98 metadata=data.get("metadata", {}), 99 )
A point-in-time snapshot of an agent's execution state.
15class CheckpointManager: 16 """ 17 Saves and restores agent execution state checkpoints. 18 19 Works with any :class:`BaseStateStore` backend (in-memory or Redis). 20 21 Redis key structure 22 ------------------- 23 Three key types are written per agent/execution:: 24 25 Redis 26 │ 27 ├── __ckpt_index__{agent_id} (TTL: default_ttl, refreshed on each write) 28 │ └── [ "ckpt-uuid-A", "ckpt-uuid-B", "ckpt-uuid-C", ... ] 29 │ │ │ 30 │ │ execution-id-1 │ execution-id-2 31 │ ▼ ▼ 32 ├── __ckpt_exec__{execution_id_1} __ckpt_exec__{execution_id_2} 33 │ └── [ "ckpt-uuid-A", └── [ "ckpt-uuid-D", 34 │ "ckpt-uuid-B", "ckpt-uuid-E", 35 │ "ckpt-uuid-C" ] "ckpt-uuid-F" ] 36 │ │ │ 37 │ step 0, 1, 2 step 0, 1, 2 38 │ ▼ ▼ 39 ├── __ckpt_data__{ckpt-uuid-A} (TTL) __ckpt_data__{ckpt-uuid-D} (TTL) 40 │ step_number: 0 step_number: 0 41 │ steps: [ action_0 ] steps: [ action_0 ] 42 │ 43 ├── __ckpt_data__{ckpt-uuid-B} (TTL) __ckpt_data__{ckpt-uuid-E} (TTL) 44 │ step_number: 1 step_number: 1 45 │ steps: [ action_0, action_1 ] steps: [ action_0, action_1 ] 46 │ 47 └── __ckpt_data__{ckpt-uuid-C} (TTL) __ckpt_data__{ckpt-uuid-F} (TTL) 48 step_number: 2 step_number: 2 49 steps: [ action_0, action_1, action_2 ] steps: [ ... ] 50 51 Each ``__ckpt_data__`` key is a **cumulative snapshot** — step N contains all steps 52 0..N, so only the last entry in the exec index is needed to fully resume. 53 54 All three key types share the same TTL (``default_ttl``). Index and exec keys 55 therefore expire together with the data they reference, leaving no stale entries. 56 57 Resuming an execution 58 --------------------- 59 :: 60 61 # 1. Look up all checkpoint IDs for the execution 62 exec_key → __ckpt_exec__{execution_id} → [ id_0, id_1, id_2 ] 63 64 # 2. Load the last (highest step_number) checkpoint 65 last_id = checkpoint_ids[-1] 66 checkpoint = await manager.load(last_id) 67 68 # 3. The checkpoint.state contains the full steps list — resume from there 69 steps_so_far = checkpoint.state["steps"] 70 71 Listing all checkpoints for an agent 72 -------------------------------------- 73 :: 74 75 agent_key → __ckpt_index__{agent_id} → [ id_0, id_1, ... ] 76 checkpoints = await manager.list(agent_id) 77 78 API usage 79 --------- 80 ``save()`` is called **internally by the agent** after each step — developers 81 should not call it directly. The developer-facing methods are: 82 83 ``load(checkpoint_id)`` 84 Load a single checkpoint by ID. Returns ``None`` if expired or not found. 85 86 ``list(agent_id)`` 87 All checkpoints for an agent across all executions, oldest first. 88 Expired entries are silently skipped. 89 90 ``list_by_execution(execution_id)`` 91 All checkpoints for a specific execution, oldest first. 92 93 ``load_latest_for_execution(execution_id)`` 94 The most recent (highest step_number) checkpoint for an execution — 95 the primary entry point for resuming a task:: 96 97 checkpoint = await manager.load_latest_for_execution(execution_id) 98 if checkpoint: 99 steps_so_far = checkpoint.state["steps"] 100 # hand steps_so_far back to the agent to continue from 101 """ 102 103 def __init__(self, store: BaseStateStore, default_ttl: Optional[int] = None) -> None: 104 self._store = store 105 self._default_ttl = default_ttl 106 107 @property 108 def default_ttl(self) -> Optional[int]: 109 """TTL applied to checkpoint data keys when no explicit ttl is passed to save().""" 110 return self._default_ttl 111 112 # ------------------------------------------------------------------ 113 # Public API 114 # ------------------------------------------------------------------ 115 116 async def save( 117 self, 118 agent_id: str, 119 state: dict, 120 execution_id: Optional[str] = None, 121 metadata: Optional[dict] = None, 122 ttl: Optional[int] = None, 123 ) -> str: 124 """ 125 Persist an agent state snapshot. 126 127 Returns: 128 The generated checkpoint_id. 129 """ 130 checkpoint_id = str(uuid.uuid4()) 131 checkpoint = Checkpoint( 132 checkpoint_id=checkpoint_id, 133 agent_id=agent_id, 134 state=state, 135 metadata=metadata or {}, 136 ) 137 138 # Store the checkpoint data 139 effective_ttl = ttl if ttl is not None else self._default_ttl 140 data_key = f"{_CHECKPOINT_DATA_PREFIX}{checkpoint_id}" 141 await self._store.set(data_key, checkpoint.to_dict(), ttl=effective_ttl) 142 143 # Append to agent's checkpoint index 144 index_key = f"{_CHECKPOINT_INDEX_PREFIX}{agent_id}" 145 index: List[str] = await self._store.get(index_key) or [] 146 index.append(checkpoint_id) 147 await self._store.set(index_key, index, ttl=effective_ttl) 148 149 # Optionally index by execution_id for resume/query flows 150 if execution_id: 151 exec_key = f"{_CHECKPOINT_EXEC_PREFIX}{execution_id}" 152 exec_index: List[str] = await self._store.get(exec_key) or [] 153 exec_index.append(checkpoint_id) 154 await self._store.set(exec_key, exec_index, ttl=effective_ttl) 155 156 return checkpoint_id 157 158 async def load(self, checkpoint_id: str) -> Optional[Checkpoint]: 159 """ 160 Load a checkpoint by ID. 161 162 Returns: 163 The :class:`Checkpoint` or ``None`` if not found / expired. 164 """ 165 data_key = f"{_CHECKPOINT_DATA_PREFIX}{checkpoint_id}" 166 raw = await self._store.get(data_key) 167 if raw is None: 168 return None 169 return Checkpoint.from_dict(raw) 170 171 async def list(self, agent_id: str) -> List[Checkpoint]: 172 """ 173 Return all checkpoints for an agent, oldest first. 174 175 Checkpoints that have been evicted (TTL expired) are silently skipped. 176 """ 177 index_key = f"{_CHECKPOINT_INDEX_PREFIX}{agent_id}" 178 checkpoint_ids: List[str] = await self._store.get(index_key) or [] 179 180 checkpoints = [] 181 for ckpt_id in checkpoint_ids: 182 checkpoint = await self.load(ckpt_id) 183 if checkpoint is not None: 184 checkpoints.append(checkpoint) 185 return checkpoints 186 187 async def list_by_execution(self, execution_id: str) -> List[Checkpoint]: 188 """Return all checkpoints for an execution, oldest first.""" 189 exec_key = f"{_CHECKPOINT_EXEC_PREFIX}{execution_id}" 190 checkpoint_ids: List[str] = await self._store.get(exec_key) or [] 191 192 checkpoints = [] 193 for ckpt_id in checkpoint_ids: 194 checkpoint = await self.load(ckpt_id) 195 if checkpoint is not None: 196 checkpoints.append(checkpoint) 197 return checkpoints 198 199 async def load_latest_for_execution(self, execution_id: str) -> Optional[Checkpoint]: 200 """Return the latest checkpoint for an execution, if present.""" 201 checkpoints = await self.list_by_execution(execution_id) 202 if not checkpoints: 203 return None 204 return checkpoints[-1] 205 206 async def delete(self, checkpoint_id: str, agent_id: str) -> None: 207 """Remove a specific checkpoint and its index entry.""" 208 data_key = f"{_CHECKPOINT_DATA_PREFIX}{checkpoint_id}" 209 await self._store.delete(data_key) 210 211 index_key = f"{_CHECKPOINT_INDEX_PREFIX}{agent_id}" 212 index: List[str] = await self._store.get(index_key) or [] 213 updated = [cid for cid in index if cid != checkpoint_id] 214 await self._store.set(index_key, updated)
Saves and restores agent execution state checkpoints.
Works with any BaseStateStore backend (in-memory or Redis).
Redis key structure
Three key types are written per agent/execution::
Redis
│
├── __ckpt_index__{agent_id} (TTL: default_ttl, refreshed on each write)
│ └── [ "ckpt-uuid-A", "ckpt-uuid-B", "ckpt-uuid-C", ... ]
│ │ │
│ │ execution-id-1 │ execution-id-2
│ ▼ ▼
├── __ckpt_exec__{execution_id_1} __ckpt_exec__{execution_id_2}
│ └── [ "ckpt-uuid-A", └── [ "ckpt-uuid-D",
│ "ckpt-uuid-B", "ckpt-uuid-E",
│ "ckpt-uuid-C" ] "ckpt-uuid-F" ]
│ │ │
│ step 0, 1, 2 step 0, 1, 2
│ ▼ ▼
├── __ckpt_data__{ckpt-uuid-A} (TTL) __ckpt_data__{ckpt-uuid-D} (TTL)
│ step_number: 0 step_number: 0
│ steps: [ action_0 ] steps: [ action_0 ]
│
├── __ckpt_data__{ckpt-uuid-B} (TTL) __ckpt_data__{ckpt-uuid-E} (TTL)
│ step_number: 1 step_number: 1
│ steps: [ action_0, action_1 ] steps: [ action_0, action_1 ]
│
└── __ckpt_data__{ckpt-uuid-C} (TTL) __ckpt_data__{ckpt-uuid-F} (TTL)
step_number: 2 step_number: 2
steps: [ action_0, action_1, action_2 ] steps: [ ... ]
Each __ckpt_data__ key is a cumulative snapshot — step N contains all steps
0..N, so only the last entry in the exec index is needed to fully resume.
All three key types share the same TTL (default_ttl). Index and exec keys
therefore expire together with the data they reference, leaving no stale entries.
Resuming an execution
::
# 1. Look up all checkpoint IDs for the execution
exec_key → __ckpt_exec__{execution_id} → [ id_0, id_1, id_2 ]
# 2. Load the last (highest step_number) checkpoint
last_id = checkpoint_ids[-1]
checkpoint = await manager.load(last_id)
# 3. The checkpoint.state contains the full steps list — resume from there
steps_so_far = checkpoint.state["steps"]
Listing all checkpoints for an agent
::
agent_key → __ckpt_index__{agent_id} → [ id_0, id_1, ... ]
checkpoints = await manager.list(agent_id)
API usage
save() is called internally by the agent after each step — developers
should not call it directly. The developer-facing methods are:
load(checkpoint_id)
Load a single checkpoint by ID. Returns None if expired or not found.
list(agent_id)
All checkpoints for an agent across all executions, oldest first.
Expired entries are silently skipped.
list_by_execution(execution_id)
All checkpoints for a specific execution, oldest first.
load_latest_for_execution(execution_id)
The most recent (highest step_number) checkpoint for an execution —
the primary entry point for resuming a task::
checkpoint = await manager.load_latest_for_execution(execution_id)
if checkpoint:
steps_so_far = checkpoint.state["steps"]
# hand steps_so_far back to the agent to continue from
107 @property 108 def default_ttl(self) -> Optional[int]: 109 """TTL applied to checkpoint data keys when no explicit ttl is passed to save().""" 110 return self._default_ttl
TTL applied to checkpoint data keys when no explicit ttl is passed to save().
116 async def save( 117 self, 118 agent_id: str, 119 state: dict, 120 execution_id: Optional[str] = None, 121 metadata: Optional[dict] = None, 122 ttl: Optional[int] = None, 123 ) -> str: 124 """ 125 Persist an agent state snapshot. 126 127 Returns: 128 The generated checkpoint_id. 129 """ 130 checkpoint_id = str(uuid.uuid4()) 131 checkpoint = Checkpoint( 132 checkpoint_id=checkpoint_id, 133 agent_id=agent_id, 134 state=state, 135 metadata=metadata or {}, 136 ) 137 138 # Store the checkpoint data 139 effective_ttl = ttl if ttl is not None else self._default_ttl 140 data_key = f"{_CHECKPOINT_DATA_PREFIX}{checkpoint_id}" 141 await self._store.set(data_key, checkpoint.to_dict(), ttl=effective_ttl) 142 143 # Append to agent's checkpoint index 144 index_key = f"{_CHECKPOINT_INDEX_PREFIX}{agent_id}" 145 index: List[str] = await self._store.get(index_key) or [] 146 index.append(checkpoint_id) 147 await self._store.set(index_key, index, ttl=effective_ttl) 148 149 # Optionally index by execution_id for resume/query flows 150 if execution_id: 151 exec_key = f"{_CHECKPOINT_EXEC_PREFIX}{execution_id}" 152 exec_index: List[str] = await self._store.get(exec_key) or [] 153 exec_index.append(checkpoint_id) 154 await self._store.set(exec_key, exec_index, ttl=effective_ttl) 155 156 return checkpoint_id
Persist an agent state snapshot.
Returns: The generated checkpoint_id.
158 async def load(self, checkpoint_id: str) -> Optional[Checkpoint]: 159 """ 160 Load a checkpoint by ID. 161 162 Returns: 163 The :class:`Checkpoint` or ``None`` if not found / expired. 164 """ 165 data_key = f"{_CHECKPOINT_DATA_PREFIX}{checkpoint_id}" 166 raw = await self._store.get(data_key) 167 if raw is None: 168 return None 169 return Checkpoint.from_dict(raw)
Load a checkpoint by ID.
Returns:
The Checkpoint or None if not found / expired.
171 async def list(self, agent_id: str) -> List[Checkpoint]: 172 """ 173 Return all checkpoints for an agent, oldest first. 174 175 Checkpoints that have been evicted (TTL expired) are silently skipped. 176 """ 177 index_key = f"{_CHECKPOINT_INDEX_PREFIX}{agent_id}" 178 checkpoint_ids: List[str] = await self._store.get(index_key) or [] 179 180 checkpoints = [] 181 for ckpt_id in checkpoint_ids: 182 checkpoint = await self.load(ckpt_id) 183 if checkpoint is not None: 184 checkpoints.append(checkpoint) 185 return checkpoints
Return all checkpoints for an agent, oldest first.
Checkpoints that have been evicted (TTL expired) are silently skipped.
187 async def list_by_execution(self, execution_id: str) -> List[Checkpoint]: 188 """Return all checkpoints for an execution, oldest first.""" 189 exec_key = f"{_CHECKPOINT_EXEC_PREFIX}{execution_id}" 190 checkpoint_ids: List[str] = await self._store.get(exec_key) or [] 191 192 checkpoints = [] 193 for ckpt_id in checkpoint_ids: 194 checkpoint = await self.load(ckpt_id) 195 if checkpoint is not None: 196 checkpoints.append(checkpoint) 197 return checkpoints
Return all checkpoints for an execution, oldest first.
199 async def load_latest_for_execution(self, execution_id: str) -> Optional[Checkpoint]: 200 """Return the latest checkpoint for an execution, if present.""" 201 checkpoints = await self.list_by_execution(execution_id) 202 if not checkpoints: 203 return None 204 return checkpoints[-1]
Return the latest checkpoint for an execution, if present.
206 async def delete(self, checkpoint_id: str, agent_id: str) -> None: 207 """Remove a specific checkpoint and its index entry.""" 208 data_key = f"{_CHECKPOINT_DATA_PREFIX}{checkpoint_id}" 209 await self._store.delete(data_key) 210 211 index_key = f"{_CHECKPOINT_INDEX_PREFIX}{agent_id}" 212 index: List[str] = await self._store.get(index_key) or [] 213 updated = [cid for cid in index if cid != checkpoint_id] 214 await self._store.set(index_key, updated)
Remove a specific checkpoint and its index entry.
44class Blackboard: 45 """ 46 Shared read/write board for multi-agent communication. 47 48 Agents can post results, facts, and partial outputs here so other agents 49 in the same orchestration run can build on them. Backed by any 50 :class:`BaseStateStore`. 51 52 Example:: 53 54 board = Blackboard(store, namespace="run-abc") 55 await board.write("search_results", docs, author="search_agent") 56 results = await board.read("search_results") 57 all_entries = await board.list_entries() 58 """ 59 60 def __init__(self, store: BaseStateStore, namespace: str = "default") -> None: 61 self._store = store 62 self._namespace = namespace 63 64 # ------------------------------------------------------------------ 65 # Internal helpers 66 # ------------------------------------------------------------------ 67 68 def _board_key(self) -> str: 69 return f"{_BLACKBOARD_KEY}{self._namespace}" 70 71 async def _load(self) -> Dict[str, Dict[str, Any]]: 72 """Load the full board dict from the store.""" 73 return await self._store.get(self._board_key()) or {} 74 75 async def _save(self, board: Dict[str, Dict[str, Any]]) -> None: 76 await self._store.set(self._board_key(), board) 77 78 # ------------------------------------------------------------------ 79 # Public API 80 # ------------------------------------------------------------------ 81 82 async def write( 83 self, 84 key: str, 85 value: Any, 86 author: str, 87 **metadata: Any, 88 ) -> None: 89 """Write or overwrite a value on the blackboard.""" 90 board = await self._load() 91 entry = BlackboardEntry(key=key, value=value, author=author, metadata=dict(metadata)) 92 board[key] = entry.to_dict() 93 await self._save(board) 94 95 async def read(self, key: str) -> Optional[Any]: 96 """Read a value by key. Returns the raw value (not the full entry).""" 97 board = await self._load() 98 entry_dict = board.get(key) 99 if entry_dict is None: 100 return None 101 return BlackboardEntry.from_dict(entry_dict).value 102 103 async def read_entry(self, key: str) -> Optional[BlackboardEntry]: 104 """Read the full :class:`BlackboardEntry` including author and timestamp.""" 105 board = await self._load() 106 entry_dict = board.get(key) 107 return BlackboardEntry.from_dict(entry_dict) if entry_dict else None 108 109 async def list_entries(self) -> Dict[str, BlackboardEntry]: 110 """Return all entries on the board, keyed by their key name.""" 111 board = await self._load() 112 return {k: BlackboardEntry.from_dict(v) for k, v in board.items()} 113 114 async def delete(self, key: str) -> None: 115 """Remove one entry from the blackboard.""" 116 board = await self._load() 117 board.pop(key, None) 118 await self._save(board) 119 120 async def clear(self) -> None: 121 """Wipe the entire blackboard namespace.""" 122 await self._store.delete(self._board_key()) 123 124 async def keys(self) -> List[str]: 125 """Return the list of keys currently on the board.""" 126 board = await self._load() 127 return list(board.keys())
Shared read/write board for multi-agent communication.
Agents can post results, facts, and partial outputs here so other agents
in the same orchestration run can build on them. Backed by any
BaseStateStore.
Example::
board = Blackboard(store, namespace="run-abc")
await board.write("search_results", docs, author="search_agent")
results = await board.read("search_results")
all_entries = await board.list_entries()
82 async def write( 83 self, 84 key: str, 85 value: Any, 86 author: str, 87 **metadata: Any, 88 ) -> None: 89 """Write or overwrite a value on the blackboard.""" 90 board = await self._load() 91 entry = BlackboardEntry(key=key, value=value, author=author, metadata=dict(metadata)) 92 board[key] = entry.to_dict() 93 await self._save(board)
Write or overwrite a value on the blackboard.
95 async def read(self, key: str) -> Optional[Any]: 96 """Read a value by key. Returns the raw value (not the full entry).""" 97 board = await self._load() 98 entry_dict = board.get(key) 99 if entry_dict is None: 100 return None 101 return BlackboardEntry.from_dict(entry_dict).value
Read a value by key. Returns the raw value (not the full entry).
103 async def read_entry(self, key: str) -> Optional[BlackboardEntry]: 104 """Read the full :class:`BlackboardEntry` including author and timestamp.""" 105 board = await self._load() 106 entry_dict = board.get(key) 107 return BlackboardEntry.from_dict(entry_dict) if entry_dict else None
Read the full BlackboardEntry including author and timestamp.
109 async def list_entries(self) -> Dict[str, BlackboardEntry]: 110 """Return all entries on the board, keyed by their key name.""" 111 board = await self._load() 112 return {k: BlackboardEntry.from_dict(v) for k, v in board.items()}
Return all entries on the board, keyed by their key name.
114 async def delete(self, key: str) -> None: 115 """Remove one entry from the blackboard.""" 116 board = await self._load() 117 board.pop(key, None) 118 await self._save(board)
Remove one entry from the blackboard.
14@dataclass 15class BlackboardEntry: 16 """A single entry written to the blackboard.""" 17 18 key: str 19 value: Any 20 author: str 21 written_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) 22 metadata: Dict[str, Any] = field(default_factory=dict) 23 24 def to_dict(self) -> Dict[str, Any]: 25 return { 26 "key": self.key, 27 "value": self.value, 28 "author": self.author, 29 "written_at": self.written_at.isoformat(), 30 "metadata": self.metadata, 31 } 32 33 @classmethod 34 def from_dict(cls, data: Dict[str, Any]) -> "BlackboardEntry": 35 return cls( 36 key=data["key"], 37 value=data["value"], 38 author=data["author"], 39 written_at=datetime.fromisoformat(data["written_at"]), 40 metadata=data.get("metadata", {}), 41 )
A single entry written to the blackboard.
25class BaseBehavior(ABC): 26 """ 27 Composable agent behavior. 28 29 Behaviors are applied as a pipeline around every agent execution: 30 ``before_execute`` → agent runs → ``after_execute`` (or ``on_error``). 31 32 Subclasses override only the hooks they need. 33 """ 34 35 # ------------------------------------------------------------------ 36 # Hooks — override as needed 37 # ------------------------------------------------------------------ 38 39 async def before_execute(self, context: BehaviorContext) -> BehaviorContext: 40 """ 41 Called before the agent executes its task. 42 43 May mutate and return a modified context (e.g. to inject guardrails). 44 """ 45 return context 46 47 async def after_execute(self, context: BehaviorContext, result: Any) -> Any: 48 """ 49 Called after the agent successfully executes. 50 51 May mutate and return a modified result. 52 """ 53 return result 54 55 async def on_error( 56 self, context: BehaviorContext, error: Exception 57 ) -> Optional[Any]: 58 """ 59 Called when the agent raises an exception. 60 61 Return a fallback value to suppress the error, or re-raise / return 62 ``None`` to propagate it. 63 """ 64 return None
Composable agent behavior.
Behaviors are applied as a pipeline around every agent execution:
before_execute → agent runs → after_execute (or on_error).
Subclasses override only the hooks they need.
39 async def before_execute(self, context: BehaviorContext) -> BehaviorContext: 40 """ 41 Called before the agent executes its task. 42 43 May mutate and return a modified context (e.g. to inject guardrails). 44 """ 45 return context
Called before the agent executes its task.
May mutate and return a modified context (e.g. to inject guardrails).
47 async def after_execute(self, context: BehaviorContext, result: Any) -> Any: 48 """ 49 Called after the agent successfully executes. 50 51 May mutate and return a modified result. 52 """ 53 return result
Called after the agent successfully executes.
May mutate and return a modified result.
55 async def on_error( 56 self, context: BehaviorContext, error: Exception 57 ) -> Optional[Any]: 58 """ 59 Called when the agent raises an exception. 60 61 Return a fallback value to suppress the error, or re-raise / return 62 ``None`` to propagate it. 63 """ 64 return None
Called when the agent raises an exception.
Return a fallback value to suppress the error, or re-raise / return
None to propagate it.
13@dataclass 14class BehaviorContext: 15 """Execution context passed to every behavior hook.""" 16 17 agent_id: str 18 task: str 19 attempt: int = 1 20 execution_id: str = "" 21 last_completed_step: int = -1 22 metadata: Dict[str, Any] = field(default_factory=dict)
Execution context passed to every behavior hook.
12class RetryBehavior(BaseBehavior): 13 """ 14 Retries the agent on failure with exponential backoff. 15 16 Args: 17 max_retries: Maximum number of retry attempts (default: 3). 18 base_delay: Initial delay in seconds before the first retry (default: 1.0). 19 backoff_factor: Multiplier applied to delay after each failure (default: 2.0). 20 max_delay: Upper bound on delay between retries in seconds (default: 60.0). 21 exceptions_to_retry: Tuple of exception types that trigger a retry. 22 Defaults to ``(Exception,)`` — all exceptions. 23 24 Example:: 25 26 agent = ReActAgent( 27 llm_gateway=gateway, 28 behaviors=[RetryBehavior(max_retries=3, base_delay=0.5)], 29 ) 30 """ 31 32 def __init__( 33 self, 34 max_retries: int = 3, 35 base_delay: float = 1.0, 36 backoff_factor: float = 2.0, 37 max_delay: float = 60.0, 38 exceptions_to_retry: Tuple[Type[Exception], ...] = (Exception,), 39 logger: Optional[BasicLogger] = None, 40 ) -> None: 41 self.max_retries = max_retries 42 self.base_delay = base_delay 43 self.backoff_factor = backoff_factor 44 self.max_delay = max_delay 45 self.exceptions_to_retry = exceptions_to_retry 46 self._logger = logger 47 48 # Tracks remaining retries across on_error calls for a single execution 49 self._remaining: int = max_retries 50 51 async def before_execute(self, context: BehaviorContext) -> BehaviorContext: 52 # Reset retry counter at the start of each execution 53 self._remaining = self.max_retries 54 return context 55 56 async def on_error( 57 self, context: BehaviorContext, error: Exception 58 ) -> Optional[Any]: 59 if not isinstance(error, self.exceptions_to_retry): 60 raise error 61 62 if self._remaining <= 0: 63 if self._logger: 64 self._logger.info( 65 "RetryBehavior: all retries exhausted — propagating error", 66 agent_id=context.agent_id, 67 max_retries=self.max_retries, 68 error=str(error), 69 error_type=type(error).__name__, 70 ) 71 raise error 72 73 attempt = self.max_retries - self._remaining + 1 74 delay = min(self.base_delay * (self.backoff_factor ** (attempt - 1)), self.max_delay) 75 self._remaining -= 1 76 77 if self._logger: 78 self._logger.info( 79 "RetryBehavior: scheduling retry", 80 agent_id=context.agent_id, 81 attempt=attempt, 82 retries_remaining=self._remaining, 83 delay_s=round(delay, 3), 84 error=str(error), 85 error_type=type(error).__name__, 86 ) 87 88 await asyncio.sleep(delay) 89 90 # Prefer resume-from-checkpoint when execution metadata is available. 91 resume_agent = context.metadata.get("resume_agent") 92 if ( 93 context.execution_id 94 and context.last_completed_step >= 0 95 and resume_agent is not None 96 and hasattr(resume_agent, "resume_from") 97 ): 98 if self._logger: 99 self._logger.info( 100 "RetryBehavior: resuming from checkpoint", 101 agent_id=context.agent_id, 102 execution_id=context.execution_id, 103 last_completed_step=context.last_completed_step, 104 ) 105 return await resume_agent.resume_from(context.execution_id) 106 107 return RETRY_SENTINEL
Retries the agent on failure with exponential backoff.
Args:
max_retries: Maximum number of retry attempts (default: 3).
base_delay: Initial delay in seconds before the first retry (default: 1.0).
backoff_factor: Multiplier applied to delay after each failure (default: 2.0).
max_delay: Upper bound on delay between retries in seconds (default: 60.0).
exceptions_to_retry: Tuple of exception types that trigger a retry.
Defaults to (Exception,) — all exceptions.
Example::
agent = ReActAgent(
llm_gateway=gateway,
behaviors=[RetryBehavior(max_retries=3, base_delay=0.5)],
)
32 def __init__( 33 self, 34 max_retries: int = 3, 35 base_delay: float = 1.0, 36 backoff_factor: float = 2.0, 37 max_delay: float = 60.0, 38 exceptions_to_retry: Tuple[Type[Exception], ...] = (Exception,), 39 logger: Optional[BasicLogger] = None, 40 ) -> None: 41 self.max_retries = max_retries 42 self.base_delay = base_delay 43 self.backoff_factor = backoff_factor 44 self.max_delay = max_delay 45 self.exceptions_to_retry = exceptions_to_retry 46 self._logger = logger 47 48 # Tracks remaining retries across on_error calls for a single execution 49 self._remaining: int = max_retries
51 async def before_execute(self, context: BehaviorContext) -> BehaviorContext: 52 # Reset retry counter at the start of each execution 53 self._remaining = self.max_retries 54 return context
Called before the agent executes its task.
May mutate and return a modified context (e.g. to inject guardrails).
56 async def on_error( 57 self, context: BehaviorContext, error: Exception 58 ) -> Optional[Any]: 59 if not isinstance(error, self.exceptions_to_retry): 60 raise error 61 62 if self._remaining <= 0: 63 if self._logger: 64 self._logger.info( 65 "RetryBehavior: all retries exhausted — propagating error", 66 agent_id=context.agent_id, 67 max_retries=self.max_retries, 68 error=str(error), 69 error_type=type(error).__name__, 70 ) 71 raise error 72 73 attempt = self.max_retries - self._remaining + 1 74 delay = min(self.base_delay * (self.backoff_factor ** (attempt - 1)), self.max_delay) 75 self._remaining -= 1 76 77 if self._logger: 78 self._logger.info( 79 "RetryBehavior: scheduling retry", 80 agent_id=context.agent_id, 81 attempt=attempt, 82 retries_remaining=self._remaining, 83 delay_s=round(delay, 3), 84 error=str(error), 85 error_type=type(error).__name__, 86 ) 87 88 await asyncio.sleep(delay) 89 90 # Prefer resume-from-checkpoint when execution metadata is available. 91 resume_agent = context.metadata.get("resume_agent") 92 if ( 93 context.execution_id 94 and context.last_completed_step >= 0 95 and resume_agent is not None 96 and hasattr(resume_agent, "resume_from") 97 ): 98 if self._logger: 99 self._logger.info( 100 "RetryBehavior: resuming from checkpoint", 101 agent_id=context.agent_id, 102 execution_id=context.execution_id, 103 last_completed_step=context.last_completed_step, 104 ) 105 return await resume_agent.resume_from(context.execution_id) 106 107 return RETRY_SENTINEL
Called when the agent raises an exception.
Return a fallback value to suppress the error, or re-raise / return
None to propagate it.
30class GuardrailBehavior(BaseBehavior): 31 """ 32 Validates task inputs and agent outputs against configurable rules. 33 34 Raises :class:`GuardrailViolationError` if any rule is violated. 35 36 Args: 37 rules: List of :class:`GuardrailRule` to enforce. 38 39 Example:: 40 41 behavior = GuardrailBehavior(rules=[ 42 GuardrailRule( 43 name="no_pii", 44 blocked_words=["ssn", "social security"], 45 apply_to="both", 46 ), 47 GuardrailRule(name="max_input", max_length=4000, apply_to="input"), 48 ]) 49 """ 50 51 def __init__(self, rules: Optional[List[GuardrailRule]] = None) -> None: 52 self._rules = rules or [] 53 self._compiled: List[tuple] = [] 54 for rule in self._rules: 55 compiled_pattern: Optional[Pattern] = ( 56 re.compile(rule.pattern, re.IGNORECASE) if rule.pattern else None 57 ) 58 self._compiled.append((rule, compiled_pattern)) 59 60 def _validate(self, text: str, target: str) -> None: 61 """Run all rules applicable to ``target`` (``"input"`` or ``"output"``).""" 62 text_lower = text.lower() 63 for rule, compiled_pattern in self._compiled: 64 if rule.apply_to not in (target, "both"): 65 continue 66 67 if rule.max_length and len(text) > rule.max_length: 68 raise GuardrailViolationError( 69 rule.name, 70 f"Length {len(text)} exceeds max {rule.max_length}", 71 ) 72 73 for word in rule.blocked_words: 74 if word.lower() in text_lower: 75 raise GuardrailViolationError( 76 rule.name, f"Blocked term detected: '{word}'" 77 ) 78 79 if compiled_pattern and compiled_pattern.search(text): 80 raise GuardrailViolationError( 81 rule.name, f"Pattern '{rule.pattern}' matched" 82 ) 83 84 async def before_execute(self, context: BehaviorContext) -> BehaviorContext: 85 self._validate(context.task, "input") 86 return context 87 88 async def after_execute(self, context: BehaviorContext, result: Any) -> Any: 89 # Extract string output for validation 90 output_text = result.output if hasattr(result, "output") else str(result) 91 self._validate(output_text, "output") 92 return result
Validates task inputs and agent outputs against configurable rules.
Raises GuardrailViolationError if any rule is violated.
Args:
rules: List of GuardrailRule to enforce.
Example::
behavior = GuardrailBehavior(rules=[
GuardrailRule(
name="no_pii",
blocked_words=["ssn", "social security"],
apply_to="both",
),
GuardrailRule(name="max_input", max_length=4000, apply_to="input"),
])
51 def __init__(self, rules: Optional[List[GuardrailRule]] = None) -> None: 52 self._rules = rules or [] 53 self._compiled: List[tuple] = [] 54 for rule in self._rules: 55 compiled_pattern: Optional[Pattern] = ( 56 re.compile(rule.pattern, re.IGNORECASE) if rule.pattern else None 57 ) 58 self._compiled.append((rule, compiled_pattern))
84 async def before_execute(self, context: BehaviorContext) -> BehaviorContext: 85 self._validate(context.task, "input") 86 return context
Called before the agent executes its task.
May mutate and return a modified context (e.g. to inject guardrails).
88 async def after_execute(self, context: BehaviorContext, result: Any) -> Any: 89 # Extract string output for validation 90 output_text = result.output if hasattr(result, "output") else str(result) 91 self._validate(output_text, "output") 92 return result
Called after the agent successfully executes.
May mutate and return a modified result.
11@dataclass 12class GuardrailRule: 13 """A single validation rule.""" 14 15 name: str 16 pattern: Optional[str] = None # compiled as regex 17 blocked_words: List[str] = field(default_factory=list) 18 max_length: Optional[int] = None 19 apply_to: str = "both" # "input" | "output" | "both"
A single validation rule.
22class GuardrailViolationError(Exception): 23 """Raised when a guardrail rule is violated.""" 24 25 def __init__(self, rule_name: str, message: str) -> None: 26 self.rule_name = rule_name 27 super().__init__(f"Guardrail '{rule_name}' violated: {message}")
Raised when a guardrail rule is violated.
37class HumanInLoopBehavior(BaseBehavior): 38 """ 39 Pauses or suspends execution after the agent produces a result and waits 40 for human approval before the result is returned. 41 42 Two modes: 43 44 **Sync mode** — pass ``review_callback``. 45 The callback blocks until the human decides, returning ``True`` to 46 approve or ``False`` to reject. Suitable for interactive CLIs or 47 short-lived approval windows. 48 49 ``async (context, result) -> bool`` 50 51 **Async (offline) mode** — pass ``submit_callback``. 52 The callback fires the approval request (e.g. POST to a REST API or 53 queue) and returns an opaque ``approval_id`` string. The behavior 54 immediately raises :class:`PendingApproval` so the agent call 55 terminates without blocking. The caller is responsible for persisting 56 the pending state and resuming once the human's decision arrives via 57 webhook or callback. 58 59 ``async (context, result) -> str`` (returns ``approval_id``) 60 61 Args: 62 review_callback: Sync-mode callback. Mutually exclusive with 63 ``submit_callback``. 64 submit_callback: Async-mode callback. Mutually exclusive with 65 ``review_callback``. 66 67 Raises: 68 ValueError: If neither or both callbacks are supplied. 69 70 Example — sync mode:: 71 72 async def my_reviewer(ctx, result): 73 answer = input(f"Approve '{ctx.task}'? [y/n]: ") 74 return answer.strip().lower() == "y" 75 76 behavior = HumanInLoopBehavior(review_callback=my_reviewer) 77 78 Example — async (offline) mode:: 79 80 async def submit_for_approval(ctx, result): 81 resp = await http_client.post("/approvals", json={ 82 "task": ctx.task, 83 "output": result.output, 84 }) 85 return resp.json()["approval_id"] 86 87 behavior = HumanInLoopBehavior(submit_callback=submit_for_approval) 88 89 # Caller handles PendingApproval: 90 try: 91 result = await agent.execute(task) 92 except PendingApproval as pa: 93 store.save(pa.approval_id, pending_result=pa.result, task=pa.task) 94 # ... webhook delivers decision later, caller resumes from store 95 """ 96 97 def __init__( 98 self, 99 review_callback: Optional[Callable[[BehaviorContext, Any], Awaitable[bool]]] = None, 100 submit_callback: Optional[Callable[[BehaviorContext, Any], Awaitable[str]]] = None, 101 ) -> None: 102 if review_callback is None and submit_callback is None: 103 raise ValueError( 104 "HumanInLoopBehavior requires either 'review_callback' (sync) " 105 "or 'submit_callback' (async/offline)." 106 ) 107 if review_callback is not None and submit_callback is not None: 108 raise ValueError( 109 "HumanInLoopBehavior accepts only one of 'review_callback' or " 110 "'submit_callback', not both." 111 ) 112 self._review_callback = review_callback 113 self._submit_callback = submit_callback 114 115 async def after_execute(self, context: BehaviorContext, result: Any) -> Any: 116 if self._submit_callback is not None: 117 # Async / offline mode — fire the request, return approval_id, exit immediately. 118 approval_id = await self._submit_callback(context, result) 119 raise PendingApproval( 120 approval_id=approval_id, 121 task=context.task, 122 result=result, 123 ) 124 125 # Sync mode — block until the human responds. 126 approved = await self._review_callback(context, result) # type: ignore[misc] 127 if not approved: 128 raise HumanApprovalRequired( 129 f"Human reviewer rejected the result for task: '{context.task}'" 130 ) 131 return result
Pauses or suspends execution after the agent produces a result and waits for human approval before the result is returned.
Two modes:
Sync mode — pass review_callback.
The callback blocks until the human decides, returning True to
approve or False to reject. Suitable for interactive CLIs or
short-lived approval windows.
``async (context, result) -> bool``
Async (offline) mode — pass submit_callback.
The callback fires the approval request (e.g. POST to a REST API or
queue) and returns an opaque approval_id string. The behavior
immediately raises PendingApproval so the agent call
terminates without blocking. The caller is responsible for persisting
the pending state and resuming once the human's decision arrives via
webhook or callback.
``async (context, result) -> str`` (returns ``approval_id``)
Args:
review_callback: Sync-mode callback. Mutually exclusive with
submit_callback.
submit_callback: Async-mode callback. Mutually exclusive with
review_callback.
Raises: ValueError: If neither or both callbacks are supplied.
Example — sync mode::
async def my_reviewer(ctx, result):
answer = input(f"Approve '{ctx.task}'? [y/n]: ")
return answer.strip().lower() == "y"
behavior = HumanInLoopBehavior(review_callback=my_reviewer)
Example — async (offline) mode::
async def submit_for_approval(ctx, result):
resp = await http_client.post("/approvals", json={
"task": ctx.task,
"output": result.output,
})
return resp.json()["approval_id"]
behavior = HumanInLoopBehavior(submit_callback=submit_for_approval)
# Caller handles PendingApproval:
try:
result = await agent.execute(task)
except PendingApproval as pa:
store.save(pa.approval_id, pending_result=pa.result, task=pa.task)
# ... webhook delivers decision later, caller resumes from store
97 def __init__( 98 self, 99 review_callback: Optional[Callable[[BehaviorContext, Any], Awaitable[bool]]] = None, 100 submit_callback: Optional[Callable[[BehaviorContext, Any], Awaitable[str]]] = None, 101 ) -> None: 102 if review_callback is None and submit_callback is None: 103 raise ValueError( 104 "HumanInLoopBehavior requires either 'review_callback' (sync) " 105 "or 'submit_callback' (async/offline)." 106 ) 107 if review_callback is not None and submit_callback is not None: 108 raise ValueError( 109 "HumanInLoopBehavior accepts only one of 'review_callback' or " 110 "'submit_callback', not both." 111 ) 112 self._review_callback = review_callback 113 self._submit_callback = submit_callback
115 async def after_execute(self, context: BehaviorContext, result: Any) -> Any: 116 if self._submit_callback is not None: 117 # Async / offline mode — fire the request, return approval_id, exit immediately. 118 approval_id = await self._submit_callback(context, result) 119 raise PendingApproval( 120 approval_id=approval_id, 121 task=context.task, 122 result=result, 123 ) 124 125 # Sync mode — block until the human responds. 126 approved = await self._review_callback(context, result) # type: ignore[misc] 127 if not approved: 128 raise HumanApprovalRequired( 129 f"Human reviewer rejected the result for task: '{context.task}'" 130 ) 131 return result
Called after the agent successfully executes.
May mutate and return a modified result.
9class HumanApprovalRequired(Exception): 10 """Raised when a human reviewer rejects the agent result."""
Raised when a human reviewer rejects the agent result.
13class PendingApproval(Exception): 14 """ 15 Raised in async (offline) mode to signal that an approval request has been 16 submitted and the agent result is awaiting a human decision. 17 18 The caller should catch this, persist state keyed on ``approval_id``, and 19 resume processing when the approval webhook/callback delivers the decision. 20 21 Attributes: 22 approval_id: Opaque identifier returned by the ``submit_callback``. 23 task: The original task string passed to the agent. 24 result: The ``AgentResult`` produced by the agent, ready to deliver 25 once approved. 26 """ 27 28 def __init__(self, approval_id: str, task: str, result: Any) -> None: 29 self.approval_id = approval_id 30 self.task = task 31 self.result = result 32 super().__init__( 33 f"Approval pending for task: '{task}' (approval_id={approval_id})" 34 )
Raised in async (offline) mode to signal that an approval request has been submitted and the agent result is awaiting a human decision.
The caller should catch this, persist state keyed on approval_id, and
resume processing when the approval webhook/callback delivers the decision.
Attributes:
approval_id: Opaque identifier returned by the submit_callback.
task: The original task string passed to the agent.
result: The AgentResult produced by the agent, ready to deliver
once approved.
21class CircuitBreakerBehavior(BaseBehavior): 22 """ 23 Prevents repeated calls to a failing agent by tripping an open circuit. 24 25 State transitions:: 26 27 CLOSED ──(failures >= threshold)──► OPEN 28 OPEN ──(recovery_timeout elapsed)──► HALF_OPEN 29 HALF_OPEN ──(success)──► CLOSED 30 HALF_OPEN ──(failure)──► OPEN 31 32 Args: 33 failure_threshold: Consecutive failures before opening (default: 5). 34 recovery_timeout: Seconds to wait in OPEN state before probing (default: 60). 35 metrics: Optional :class:`BasicMetricsCollector` for state/failure tracking. 36 logger: Optional :class:`BasicLogger`. 37 38 Example:: 39 40 from gmf_forge_ai_shared_core.observability import BasicMetricsCollector, BasicLogger 41 42 behavior = CircuitBreakerBehavior( 43 failure_threshold=3, 44 recovery_timeout=30, 45 metrics=BasicMetricsCollector(), 46 logger=BasicLogger("circuit_breaker"), 47 ) 48 """ 49 50 def __init__( 51 self, 52 failure_threshold: int = 5, 53 recovery_timeout: float = 60.0, 54 metrics: Optional[Any] = None, 55 logger: Optional[Any] = None, 56 ) -> None: 57 self.failure_threshold = failure_threshold 58 self.recovery_timeout = recovery_timeout 59 self._metrics = metrics 60 self._logger = logger 61 62 self._state = CircuitState.CLOSED 63 self._failure_count = 0 64 self._opened_at: Optional[float] = None 65 66 # ------------------------------------------------------------------ 67 # Internal helpers 68 # ------------------------------------------------------------------ 69 70 def _transition(self, new_state: CircuitState) -> None: 71 if self._state == new_state: 72 return 73 old = self._state.value 74 self._state = new_state 75 if self._logger: 76 self._logger.info( 77 "Circuit state transition", 78 from_state=old, 79 to_state=new_state.value, 80 ) 81 if self._metrics: 82 self._metrics.gauge("circuit_breaker.state", 1.0, state=new_state.value) 83 84 # ------------------------------------------------------------------ 85 # BaseBehavior hooks 86 # ------------------------------------------------------------------ 87 88 async def before_execute(self, context: BehaviorContext) -> BehaviorContext: 89 if self._state == CircuitState.OPEN: 90 elapsed = time.monotonic() - (self._opened_at or 0) 91 if elapsed >= self.recovery_timeout: 92 self._transition(CircuitState.HALF_OPEN) 93 else: 94 raise CircuitOpenError( 95 f"Circuit is OPEN for agent '{context.agent_id}'. " 96 f"Retry in {self.recovery_timeout - elapsed:.1f}s." 97 ) 98 return context 99 100 async def after_execute(self, context: BehaviorContext, result: Any) -> Any: 101 # Success — reset failures, close circuit if probing 102 self._failure_count = 0 103 if self._state == CircuitState.HALF_OPEN: 104 self._transition(CircuitState.CLOSED) 105 return result 106 107 async def on_error( 108 self, context: BehaviorContext, error: Exception 109 ) -> Optional[Any]: 110 self._failure_count += 1 111 if self._metrics: 112 self._metrics.increment("circuit_breaker.failures", agent_id=context.agent_id) 113 114 if self._state == CircuitState.HALF_OPEN or self._failure_count >= self.failure_threshold: 115 self._opened_at = time.monotonic() 116 self._transition(CircuitState.OPEN) 117 118 raise error
Prevents repeated calls to a failing agent by tripping an open circuit.
State transitions::
CLOSED ──(failures >= threshold)──► OPEN
OPEN ──(recovery_timeout elapsed)──► HALF_OPEN
HALF_OPEN ──(success)──► CLOSED
HALF_OPEN ──(failure)──► OPEN
Args:
failure_threshold: Consecutive failures before opening (default: 5).
recovery_timeout: Seconds to wait in OPEN state before probing (default: 60).
metrics: Optional BasicMetricsCollector for state/failure tracking.
logger: Optional BasicLogger.
Example::
from gmf_forge_ai_shared_core.observability import BasicMetricsCollector, BasicLogger
behavior = CircuitBreakerBehavior(
failure_threshold=3,
recovery_timeout=30,
metrics=BasicMetricsCollector(),
logger=BasicLogger("circuit_breaker"),
)
50 def __init__( 51 self, 52 failure_threshold: int = 5, 53 recovery_timeout: float = 60.0, 54 metrics: Optional[Any] = None, 55 logger: Optional[Any] = None, 56 ) -> None: 57 self.failure_threshold = failure_threshold 58 self.recovery_timeout = recovery_timeout 59 self._metrics = metrics 60 self._logger = logger 61 62 self._state = CircuitState.CLOSED 63 self._failure_count = 0 64 self._opened_at: Optional[float] = None
88 async def before_execute(self, context: BehaviorContext) -> BehaviorContext: 89 if self._state == CircuitState.OPEN: 90 elapsed = time.monotonic() - (self._opened_at or 0) 91 if elapsed >= self.recovery_timeout: 92 self._transition(CircuitState.HALF_OPEN) 93 else: 94 raise CircuitOpenError( 95 f"Circuit is OPEN for agent '{context.agent_id}'. " 96 f"Retry in {self.recovery_timeout - elapsed:.1f}s." 97 ) 98 return context
Called before the agent executes its task.
May mutate and return a modified context (e.g. to inject guardrails).
100 async def after_execute(self, context: BehaviorContext, result: Any) -> Any: 101 # Success — reset failures, close circuit if probing 102 self._failure_count = 0 103 if self._state == CircuitState.HALF_OPEN: 104 self._transition(CircuitState.CLOSED) 105 return result
Called after the agent successfully executes.
May mutate and return a modified result.
107 async def on_error( 108 self, context: BehaviorContext, error: Exception 109 ) -> Optional[Any]: 110 self._failure_count += 1 111 if self._metrics: 112 self._metrics.increment("circuit_breaker.failures", agent_id=context.agent_id) 113 114 if self._state == CircuitState.HALF_OPEN or self._failure_count >= self.failure_threshold: 115 self._opened_at = time.monotonic() 116 self._transition(CircuitState.OPEN) 117 118 raise error
Called when the agent raises an exception.
Return a fallback value to suppress the error, or re-raise / return
None to propagate it.
11class CircuitState(Enum): 12 CLOSED = "closed" # Normal operation 13 OPEN = "open" # Blocking all calls 14 HALF_OPEN = "half_open" # Allowing a single probe call
17class CircuitOpenError(Exception): 18 """Raised when a call is attempted while the circuit is OPEN."""
Raised when a call is attempted while the circuit is OPEN.
15class RateLimitBehavior(BaseBehavior): 16 """ 17 Enforces a maximum call rate using a token bucket algorithm. 18 19 Args: 20 calls_per_second: Maximum calls allowed per second (default: 1.0). 21 burst: Maximum burst capacity — tokens that can accumulate while idle 22 (default: equals ``calls_per_second``). 23 wait_on_limit: If ``True`` (default), sleep until a token is available. 24 If ``False``, raise :class:`RateLimitExceededError` immediately. 25 26 Example:: 27 28 behavior = RateLimitBehavior(calls_per_second=2.0, burst=5) 29 """ 30 31 def __init__( 32 self, 33 calls_per_second: float = 1.0, 34 burst: Optional[float] = None, 35 wait_on_limit: bool = True, 36 ) -> None: 37 self.calls_per_second = calls_per_second 38 self.burst = burst if burst is not None else calls_per_second 39 self.wait_on_limit = wait_on_limit 40 41 self._tokens: float = self.burst 42 self._last_refill: float = time.monotonic() 43 self._lock = asyncio.Lock() 44 45 # ------------------------------------------------------------------ 46 # Token bucket 47 # ------------------------------------------------------------------ 48 49 def _refill(self) -> None: 50 now = time.monotonic() 51 elapsed = now - self._last_refill 52 self._tokens = min(self.burst, self._tokens + elapsed * self.calls_per_second) 53 self._last_refill = now 54 55 async def before_execute(self, context: BehaviorContext) -> BehaviorContext: 56 async with self._lock: 57 while True: 58 self._refill() 59 if self._tokens >= 1.0: 60 self._tokens -= 1.0 61 return context 62 63 if not self.wait_on_limit: 64 raise RateLimitExceededError( 65 f"Rate limit of {self.calls_per_second} call/s exceeded " 66 f"for agent '{context.agent_id}'." 67 ) 68 69 wait = (1.0 - self._tokens) / self.calls_per_second 70 await asyncio.sleep(wait)
Enforces a maximum call rate using a token bucket algorithm.
Args:
calls_per_second: Maximum calls allowed per second (default: 1.0).
burst: Maximum burst capacity — tokens that can accumulate while idle
(default: equals calls_per_second).
wait_on_limit: If True (default), sleep until a token is available.
If False, raise RateLimitExceededError immediately.
Example::
behavior = RateLimitBehavior(calls_per_second=2.0, burst=5)
31 def __init__( 32 self, 33 calls_per_second: float = 1.0, 34 burst: Optional[float] = None, 35 wait_on_limit: bool = True, 36 ) -> None: 37 self.calls_per_second = calls_per_second 38 self.burst = burst if burst is not None else calls_per_second 39 self.wait_on_limit = wait_on_limit 40 41 self._tokens: float = self.burst 42 self._last_refill: float = time.monotonic() 43 self._lock = asyncio.Lock()
55 async def before_execute(self, context: BehaviorContext) -> BehaviorContext: 56 async with self._lock: 57 while True: 58 self._refill() 59 if self._tokens >= 1.0: 60 self._tokens -= 1.0 61 return context 62 63 if not self.wait_on_limit: 64 raise RateLimitExceededError( 65 f"Rate limit of {self.calls_per_second} call/s exceeded " 66 f"for agent '{context.agent_id}'." 67 ) 68 69 wait = (1.0 - self._tokens) / self.calls_per_second 70 await asyncio.sleep(wait)
Called before the agent executes its task.
May mutate and return a modified context (e.g. to inject guardrails).
11class RateLimitExceededError(Exception): 12 """Raised when the rate limit is exceeded and wait_on_limit=False."""
Raised when the rate limit is exceeded and wait_on_limit=False.
12class AuditBehavior(BaseBehavior): 13 """ 14 Logs every agent input, output, and error using shared-core observability. 15 16 Uses :class:`BasicLogger` for structured log lines and 17 :class:`BasicMetricsCollector` to track invocation counts. 18 19 Args: 20 logger: :class:`BasicLogger` instance. A default named 21 ``"gmf_forge_ai.audit"`` is created if not provided. 22 metrics: Optional :class:`BasicMetricsCollector` for counting events. 23 24 Example:: 25 26 from gmf_forge_ai_shared_core.observability import BasicLogger, BasicMetricsCollector 27 28 behavior = AuditBehavior( 29 logger=BasicLogger("my_app.audit"), 30 metrics=BasicMetricsCollector(), 31 ) 32 """ 33 34 def __init__( 35 self, 36 logger: Optional[BasicLogger] = None, 37 metrics: Optional[BasicMetricsCollector] = None, 38 ) -> None: 39 self._logger = logger or BasicLogger("gmf_forge_ai.audit") 40 self._metrics = metrics 41 # Stores start timestamps keyed by agent_id+task to compute duration 42 self._start_times: dict = {} 43 44 async def before_execute(self, context: BehaviorContext) -> BehaviorContext: 45 key = f"{context.agent_id}:{context.task}" 46 self._start_times[key] = time.monotonic() 47 self._logger.info( 48 "Agent task started", 49 agent_id=context.agent_id, 50 task=context.task, 51 attempt=context.attempt, 52 ) 53 if self._metrics: 54 self._metrics.increment("behavior.audit.invocations", agent_id=context.agent_id) 55 return context 56 57 async def after_execute(self, context: BehaviorContext, result: Any) -> Any: 58 key = f"{context.agent_id}:{context.task}" 59 duration_ms = round((time.monotonic() - self._start_times.pop(key, time.monotonic())) * 1000) 60 output = result.output if hasattr(result, "output") else str(result) 61 self._logger.info( 62 "Agent task completed", 63 agent_id=context.agent_id, 64 task=context.task, 65 duration_ms=duration_ms, 66 success=True, 67 output_preview=output[:200], 68 ) 69 if self._metrics: 70 self._metrics.histogram("behavior.audit.duration_ms", duration_ms, agent_id=context.agent_id) 71 return result 72 73 async def on_error( 74 self, context: BehaviorContext, error: Exception 75 ) -> Optional[Any]: 76 key = f"{context.agent_id}:{context.task}" 77 duration_ms = round((time.monotonic() - self._start_times.pop(key, time.monotonic())) * 1000) 78 self._logger.error( 79 "Agent task failed", 80 agent_id=context.agent_id, 81 task=context.task, 82 duration_ms=duration_ms, 83 error=str(error), 84 error_type=type(error).__name__, 85 ) 86 if self._metrics: 87 self._metrics.increment("behavior.audit.errors", agent_id=context.agent_id) 88 return None
Logs every agent input, output, and error using shared-core observability.
Uses BasicLogger for structured log lines and
BasicMetricsCollector to track invocation counts.
Args:
logger: BasicLogger instance. A default named
"gmf_forge_ai.audit" is created if not provided.
metrics: Optional BasicMetricsCollector for counting events.
Example::
from gmf_forge_ai_shared_core.observability import BasicLogger, BasicMetricsCollector
behavior = AuditBehavior(
logger=BasicLogger("my_app.audit"),
metrics=BasicMetricsCollector(),
)
34 def __init__( 35 self, 36 logger: Optional[BasicLogger] = None, 37 metrics: Optional[BasicMetricsCollector] = None, 38 ) -> None: 39 self._logger = logger or BasicLogger("gmf_forge_ai.audit") 40 self._metrics = metrics 41 # Stores start timestamps keyed by agent_id+task to compute duration 42 self._start_times: dict = {}
44 async def before_execute(self, context: BehaviorContext) -> BehaviorContext: 45 key = f"{context.agent_id}:{context.task}" 46 self._start_times[key] = time.monotonic() 47 self._logger.info( 48 "Agent task started", 49 agent_id=context.agent_id, 50 task=context.task, 51 attempt=context.attempt, 52 ) 53 if self._metrics: 54 self._metrics.increment("behavior.audit.invocations", agent_id=context.agent_id) 55 return context
Called before the agent executes its task.
May mutate and return a modified context (e.g. to inject guardrails).
57 async def after_execute(self, context: BehaviorContext, result: Any) -> Any: 58 key = f"{context.agent_id}:{context.task}" 59 duration_ms = round((time.monotonic() - self._start_times.pop(key, time.monotonic())) * 1000) 60 output = result.output if hasattr(result, "output") else str(result) 61 self._logger.info( 62 "Agent task completed", 63 agent_id=context.agent_id, 64 task=context.task, 65 duration_ms=duration_ms, 66 success=True, 67 output_preview=output[:200], 68 ) 69 if self._metrics: 70 self._metrics.histogram("behavior.audit.duration_ms", duration_ms, agent_id=context.agent_id) 71 return result
Called after the agent successfully executes.
May mutate and return a modified result.
73 async def on_error( 74 self, context: BehaviorContext, error: Exception 75 ) -> Optional[Any]: 76 key = f"{context.agent_id}:{context.task}" 77 duration_ms = round((time.monotonic() - self._start_times.pop(key, time.monotonic())) * 1000) 78 self._logger.error( 79 "Agent task failed", 80 agent_id=context.agent_id, 81 task=context.task, 82 duration_ms=duration_ms, 83 error=str(error), 84 error_type=type(error).__name__, 85 ) 86 if self._metrics: 87 self._metrics.increment("behavior.audit.errors", agent_id=context.agent_id) 88 return None
Called when the agent raises an exception.
Return a fallback value to suppress the error, or re-raise / return
None to propagate it.
19class AgentDiscoveryBehavior: 20 """Refreshes supervisor agent registry on a fixed interval. 21 22 This behavior runs as an asyncio background task. Each refresh fetches the 23 latest ``(agents, descriptions)`` from ``discovery_fn`` and atomically 24 updates the provided ``SupervisorOrchestrator`` instance. 25 """ 26 27 def __init__( 28 self, 29 discovery_fn: DiscoveryFn, 30 interval_seconds: float = 30.0, 31 logger: Optional[BasicLogger] = None, 32 ) -> None: 33 if interval_seconds <= 0: 34 raise ValueError("interval_seconds must be greater than 0") 35 36 self.discovery_fn = discovery_fn 37 self.interval_seconds = interval_seconds 38 self._logger = logger or BasicLogger("gmf_forge_ai.behavior.AgentDiscoveryBehavior") 39 40 self._task: Optional[asyncio.Task] = None 41 self._supervisor: Optional[SupervisorOrchestrator] = None 42 43 async def refresh_once(self) -> None: 44 """Run one discovery cycle and update the supervisor in-place.""" 45 if self._supervisor is None: 46 raise RuntimeError("AgentDiscoveryBehavior.start() must be called first") 47 48 new_agents, new_descriptions = await self.discovery_fn() 49 50 current_agents = set(self._supervisor.agents.keys()) 51 discovered_agents = set(new_agents.keys()) 52 added = sorted(discovered_agents - current_agents) 53 removed = sorted(current_agents - discovered_agents) 54 55 if not added and not removed: 56 return 57 58 self._supervisor.agents = new_agents 59 self._supervisor.agent_descriptions = new_descriptions 60 61 router = self._supervisor.router 62 if router is not None and hasattr(router, "agent_descriptions"): 63 router.agent_descriptions = new_descriptions 64 if router is not None and hasattr(router, "fallback_target"): 65 if new_agents: 66 fallback = getattr(router, "fallback_target", None) 67 if fallback not in new_agents: 68 router.fallback_target = next(iter(new_agents)) 69 else: 70 router.fallback_target = None 71 72 self._logger.info( 73 "Agent registry updated via periodic discovery", 74 added=added, 75 removed=removed, 76 total_agents=len(new_agents), 77 ) 78 79 async def start(self, supervisor: SupervisorOrchestrator) -> None: 80 """Start periodic discovery for the given supervisor.""" 81 if self._task is not None and not self._task.done(): 82 return 83 84 self._supervisor = supervisor 85 86 async def _runner() -> None: 87 while True: 88 try: 89 await asyncio.sleep(self.interval_seconds) 90 await self.refresh_once() 91 except asyncio.CancelledError: 92 raise 93 except Exception as exc: 94 self._logger.warning( 95 "Error during periodic agent discovery", 96 error=str(exc), 97 ) 98 99 self._task = asyncio.create_task(_runner()) 100 101 async def stop(self) -> None: 102 """Stop periodic discovery and wait for task cancellation.""" 103 if self._task is None: 104 return 105 106 self._task.cancel() 107 try: 108 await self._task 109 except asyncio.CancelledError: 110 pass 111 finally: 112 self._task = None
Refreshes supervisor agent registry on a fixed interval.
This behavior runs as an asyncio background task. Each refresh fetches the
latest (agents, descriptions) from discovery_fn and atomically
updates the provided SupervisorOrchestrator instance.
27 def __init__( 28 self, 29 discovery_fn: DiscoveryFn, 30 interval_seconds: float = 30.0, 31 logger: Optional[BasicLogger] = None, 32 ) -> None: 33 if interval_seconds <= 0: 34 raise ValueError("interval_seconds must be greater than 0") 35 36 self.discovery_fn = discovery_fn 37 self.interval_seconds = interval_seconds 38 self._logger = logger or BasicLogger("gmf_forge_ai.behavior.AgentDiscoveryBehavior") 39 40 self._task: Optional[asyncio.Task] = None 41 self._supervisor: Optional[SupervisorOrchestrator] = None
43 async def refresh_once(self) -> None: 44 """Run one discovery cycle and update the supervisor in-place.""" 45 if self._supervisor is None: 46 raise RuntimeError("AgentDiscoveryBehavior.start() must be called first") 47 48 new_agents, new_descriptions = await self.discovery_fn() 49 50 current_agents = set(self._supervisor.agents.keys()) 51 discovered_agents = set(new_agents.keys()) 52 added = sorted(discovered_agents - current_agents) 53 removed = sorted(current_agents - discovered_agents) 54 55 if not added and not removed: 56 return 57 58 self._supervisor.agents = new_agents 59 self._supervisor.agent_descriptions = new_descriptions 60 61 router = self._supervisor.router 62 if router is not None and hasattr(router, "agent_descriptions"): 63 router.agent_descriptions = new_descriptions 64 if router is not None and hasattr(router, "fallback_target"): 65 if new_agents: 66 fallback = getattr(router, "fallback_target", None) 67 if fallback not in new_agents: 68 router.fallback_target = next(iter(new_agents)) 69 else: 70 router.fallback_target = None 71 72 self._logger.info( 73 "Agent registry updated via periodic discovery", 74 added=added, 75 removed=removed, 76 total_agents=len(new_agents), 77 )
Run one discovery cycle and update the supervisor in-place.
79 async def start(self, supervisor: SupervisorOrchestrator) -> None: 80 """Start periodic discovery for the given supervisor.""" 81 if self._task is not None and not self._task.done(): 82 return 83 84 self._supervisor = supervisor 85 86 async def _runner() -> None: 87 while True: 88 try: 89 await asyncio.sleep(self.interval_seconds) 90 await self.refresh_once() 91 except asyncio.CancelledError: 92 raise 93 except Exception as exc: 94 self._logger.warning( 95 "Error during periodic agent discovery", 96 error=str(exc), 97 ) 98 99 self._task = asyncio.create_task(_runner())
Start periodic discovery for the given supervisor.
101 async def stop(self) -> None: 102 """Stop periodic discovery and wait for task cancellation.""" 103 if self._task is None: 104 return 105 106 self._task.cancel() 107 try: 108 await self._task 109 except asyncio.CancelledError: 110 pass 111 finally: 112 self._task = None
Stop periodic discovery and wait for task cancellation.
23@dataclass 24class AgentStep: 25 """One thought/action/observation cycle within an agent execution.""" 26 27 thought: str 28 action: str 29 action_input: Dict[str, Any] = field(default_factory=dict) 30 observation: str = "" 31 metadata: Dict[str, Any] = field(default_factory=dict)
One thought/action/observation cycle within an agent execution.
34@dataclass 35class AgentResult: 36 """The final result of an agent execution.""" 37 38 output: str 39 steps: List[AgentStep] = field(default_factory=list) 40 metadata: Dict[str, Any] = field(default_factory=dict) 41 success: bool = True 42 error: Optional[str] = None
The final result of an agent execution.
45class BaseAgent(ABC): 46 """ 47 Abstract base class for all agents. 48 49 Wires together: 50 - :class:`UnifiedLLMGateway` for LLM calls 51 - :class:`ToolRegistry` for tool discovery and execution 52 - :class:`BaseBehavior` pipeline applied around every execution 53 - :class:`BaseStateStore` for persisting conversation/step state 54 - Full shared-core observability stack (Logger, Metrics, PerformanceMonitor, Tracing) 55 56 Args: 57 llm_gateway: Required. The LLM gateway to use for completions. 58 tool_registry: Optional tool registry. Provides available tools to the agent. 59 behaviors: Ordered list of behaviors applied around each execution. 60 state_store: Optional state store for persisting steps and conversation. 61 agent_id: Stable identifier used in logging and metrics. Defaults to class name. 62 logger: Optional :class:`BasicLogger`. Created automatically if omitted. 63 metrics: Optional :class:`BasicMetricsCollector`. 64 performance_monitor: Optional :class:`BasicPerformanceMonitor`. 65 tracer: Optional :class:`TracingProvider`. Falls back to ``get_tracer()``. 66 """ 67 68 def __init__( 69 self, 70 llm_gateway: Optional["UnifiedLLMGateway"] = None, 71 tool_registry: Optional["ToolRegistry"] = None, 72 behaviors: Optional[List["BaseBehavior"]] = None, 73 state_store: Optional["BaseStateStore"] = None, 74 checkpoint_manager: Optional["CheckpointManager"] = None, 75 agent_id: Optional[str] = None, 76 logger: Optional[BasicLogger] = None, 77 metrics: Optional[BasicMetricsCollector] = None, 78 performance_monitor: Optional[BasicPerformanceMonitor] = None, 79 tracer: Optional[TracingProvider] = None, 80 ) -> None: 81 self.llm_gateway = llm_gateway 82 self.tool_registry = tool_registry 83 self.behaviors: List["BaseBehavior"] = behaviors or [] 84 self.state_store = state_store 85 self.checkpoint_manager = checkpoint_manager 86 self.agent_id = agent_id or self.__class__.__name__ 87 self._logger = logger or BasicLogger(f"gmf_forge_ai.agent.{self.agent_id}") 88 self._metrics = metrics 89 self._performance_monitor = performance_monitor 90 self._tracer = tracer or get_tracer() 91 92 # ------------------------------------------------------------------ 93 # Abstract interface 94 # ------------------------------------------------------------------ 95 96 @abstractmethod 97 async def execute( 98 self, task: str, context: Optional[Dict[str, Any]] = None 99 ) -> AgentResult: 100 """Execute the task and return a result.""" 101 102 @abstractmethod 103 async def stream_execute( 104 self, task: str, context: Optional[Dict[str, Any]] = None 105 ) -> AsyncIterator[AgentStep]: 106 """Execute the task, yielding each step as it completes.""" 107 108 # ------------------------------------------------------------------ 109 # Behavior pipeline helpers (called by subclasses) 110 # ------------------------------------------------------------------ 111 112 async def _apply_behaviors_before( 113 self, context: "BehaviorContext" 114 ) -> "BehaviorContext": 115 for behavior in self.behaviors: 116 context = await behavior.before_execute(context) 117 return context 118 119 async def _apply_behaviors_after( 120 self, context: "BehaviorContext", result: AgentResult 121 ) -> AgentResult: 122 for behavior in self.behaviors: 123 result = await behavior.after_execute(context, result) 124 return result 125 126 async def _apply_behaviors_on_error( 127 self, context: "BehaviorContext", error: Exception 128 ) -> Optional[AgentResult]: 129 for behavior in self.behaviors: 130 fallback = await behavior.on_error(context, error) 131 if fallback is not None: 132 return fallback 133 return None 134 135 # ------------------------------------------------------------------ 136 # Shared observability helpers (called by subclasses) 137 # ------------------------------------------------------------------ 138 139 def _log_execution_start(self, task: str) -> None: 140 self._logger.info( 141 "Agent execution started", agent_id=self.agent_id, task=task 142 ) 143 if self._metrics: 144 self._metrics.increment("agent.executions", agent_id=self.agent_id) 145 146 def _log_execution_end(self, task: str, success: bool, steps: int) -> None: 147 self._logger.info( 148 "Agent execution finished", 149 agent_id=self.agent_id, 150 task=task, 151 success=success, 152 steps=steps, 153 ) 154 if self._metrics: 155 self._metrics.histogram("agent.steps", steps, agent_id=self.agent_id) 156 157 def _log_execution_error(self, task: str, error: Exception) -> None: 158 self._logger.error( 159 "Agent execution error", 160 agent_id=self.agent_id, 161 task=task, 162 error=str(error), 163 error_type=type(error).__name__, 164 ) 165 if self._metrics: 166 self._metrics.increment("agent.errors", agent_id=self.agent_id)
Abstract base class for all agents.
Wires together:
UnifiedLLMGatewayfor LLM callsToolRegistryfor tool discovery and executionBaseBehaviorpipeline applied around every executionBaseStateStorefor persisting conversation/step state- Full shared-core observability stack (Logger, Metrics, PerformanceMonitor, Tracing)
Args:
llm_gateway: Required. The LLM gateway to use for completions.
tool_registry: Optional tool registry. Provides available tools to the agent.
behaviors: Ordered list of behaviors applied around each execution.
state_store: Optional state store for persisting steps and conversation.
agent_id: Stable identifier used in logging and metrics. Defaults to class name.
logger: Optional BasicLogger. Created automatically if omitted.
metrics: Optional BasicMetricsCollector.
performance_monitor: Optional BasicPerformanceMonitor.
tracer: Optional TracingProvider. Falls back to get_tracer().
96 @abstractmethod 97 async def execute( 98 self, task: str, context: Optional[Dict[str, Any]] = None 99 ) -> AgentResult: 100 """Execute the task and return a result."""
Execute the task and return a result.
102 @abstractmethod 103 async def stream_execute( 104 self, task: str, context: Optional[Dict[str, Any]] = None 105 ) -> AsyncIterator[AgentStep]: 106 """Execute the task, yielding each step as it completes."""
Execute the task, yielding each step as it completes.
46class ReActAgent(BaseAgent): 47 """ 48 ReAct agent: interleaves Reasoning (Thought) and Acting (Action/Observation). 49 50 **When to use:** Any task that requires tool calls (search, API lookup, 51 database queries). The LLM reasons step by step and acts on each step. 52 53 **When NOT to use:** Pure reasoning/analysis tasks with no tools — use 54 :class:`ChainOfThoughtAgent` instead (cheaper: one LLM call vs. many). 55 56 On each step the LLM produces a Thought, an Action (tool name or 57 ``"Final Answer"``), and an Action Input. If the action is a tool call, 58 the tool's output is fed back as an Observation and the loop continues. 59 60 .. warning:: 61 The system prompt MUST instruct the LLM to respond using the 62 ``Thought:/Action:/Action Input:`` format. Without this contract the 63 regex parser will not match, the agent will treat the first response 64 as a Final Answer, and **no tool calls will ever be made**. 65 Always include ``{tool_descriptions}`` in a custom prompt so the LLM 66 knows what tools are available. 67 68 Args: 69 max_steps: Maximum thought/action cycles before stopping (default: 10). 70 model: LLM model name passed to the gateway (optional). 71 temperature: Sampling temperature (default: 0.0 for determinism). 72 system_prompt: Override the default ReAct system prompt. Use 73 ``{tool_descriptions}`` as a placeholder if you want the agent's 74 available tools listed in your prompt. The task is always appended 75 separately and does not need a placeholder here. 76 77 All other args inherited from :class:`BaseAgent`. 78 """ 79 80 def __init__(self, *args: Any, max_steps: int = 10, model: Optional[str] = None, 81 temperature: float = 0.0, system_prompt: Optional[str] = None, 82 **kwargs: Any) -> None: 83 super().__init__(*args, **kwargs) 84 self.max_steps = max_steps 85 self.model = model 86 self.temperature = temperature 87 self._system_prompt = system_prompt 88 89 #: Default ReAct system prompt used when no ``system_prompt`` is passed. 90 #: Inspect this to understand the expected format or use it as a base 91 #: for your own customisations. 92 DEFAULT_SYSTEM_PROMPT: str = _REACT_SYSTEM 93 94 # ------------------------------------------------------------------ 95 # Internal prompt helpers 96 # ------------------------------------------------------------------ 97 98 def _tool_descriptions(self) -> str: 99 if not self.tool_registry: 100 return "No tools available." 101 tools = self.tool_registry.list_tools() 102 if not tools: 103 return "No tools available." 104 lines = [] 105 for t in tools: 106 lines.append(f"- {t.name}: {t.description}") 107 return "\n".join(lines) 108 109 def _build_prompt(self, task: str, history: List[AgentStep]) -> str: 110 template = self._system_prompt if self._system_prompt is not None else _REACT_SYSTEM 111 if "{tool_descriptions}" in template: 112 system = template.format(tool_descriptions=self._tool_descriptions()) 113 else: 114 system = template 115 turns = [f"Task: {task}\n"] 116 for step in history: 117 turns.append(f"Thought: {step.thought}") 118 turns.append(f"Action: {step.action}") 119 turns.append(f"Action Input: {json.dumps(step.action_input)}") 120 if step.observation: 121 turns.append(f"Observation: {step.observation}") 122 return system + "\n" + "\n".join(turns) 123 124 def _parse_step(self, text: str) -> Optional[AgentStep]: 125 def _coerce_action_input(raw_input: str) -> Dict[str, Any]: 126 try: 127 action_input = json.loads(raw_input) 128 if isinstance(action_input, dict): 129 return action_input 130 except (json.JSONDecodeError, ValueError): 131 pass 132 return {"raw": raw_input} 133 134 match = _REACT_STEP_PATTERN.search(text) 135 if match: 136 thought = match.group("thought").strip() 137 action = match.group("action").strip() 138 raw_input = match.group("action_input").strip() 139 action_input = _coerce_action_input(raw_input) 140 return AgentStep(thought=thought, action=action, action_input=action_input) 141 142 # Some models omit "Thought:" and return only Action/Action Input. 143 final_only = _REACT_FINAL_ONLY_PATTERN.search(text) 144 if final_only: 145 raw_input = final_only.group("action_input").strip() 146 action_input = _coerce_action_input(raw_input) 147 return AgentStep(thought="", action="Final Answer", action_input=action_input) 148 149 return None 150 151 def _extract_final_answer(self, step: AgentStep) -> str: 152 """Return normalized final answer text from a Final Answer step.""" 153 answer_obj: Any 154 if isinstance(step.action_input, dict) and "answer" in step.action_input: 155 answer_obj = step.action_input["answer"] 156 else: 157 answer_obj = step.action_input 158 159 answer = str(answer_obj).strip() 160 161 # Unwrap nested ReAct control text accidentally returned as answer body. 162 for _ in range(3): 163 nested = _REACT_FINAL_ONLY_PATTERN.search(answer) 164 if nested: 165 raw_input = nested.group("action_input").strip() 166 elif answer.startswith("Final Answer:"): 167 raw_input = answer[len("Final Answer:"):].strip() 168 else: 169 break 170 try: 171 parsed = json.loads(raw_input) 172 if isinstance(parsed, dict) and "answer" in parsed: 173 answer = str(parsed["answer"]).strip() 174 else: 175 answer = raw_input 176 except (json.JSONDecodeError, ValueError): 177 answer = raw_input 178 179 return answer.strip() 180 181 async def _save_checkpoint( 182 self, 183 execution_id: str, 184 task: str, 185 context: Dict[str, Any], 186 steps: List[AgentStep], 187 step_number: int, 188 ) -> None: 189 if not self.checkpoint_manager: 190 return 191 await self.checkpoint_manager.save( 192 agent_id=self.agent_id, 193 execution_id=execution_id, 194 state={ 195 "execution_id": execution_id, 196 "task": task, 197 "context": context, 198 "step_number": step_number, 199 "steps": [dataclasses.asdict(s) for s in steps], 200 }, 201 metadata={"step_number": step_number}, 202 ) 203 204 async def resume_from( 205 self, execution_id: str, from_step: Optional[int] = None 206 ) -> AgentResult: 207 """Resume a ReAct execution from its latest (or selected) checkpoint.""" 208 if not self.checkpoint_manager: 209 raise ValueError("Checkpoint manager is not configured for this agent") 210 211 checkpoint = None 212 if from_step is not None: 213 checkpoints = await self.checkpoint_manager.list_by_execution(execution_id) 214 for ckpt in checkpoints: 215 if ckpt.state.get("step_number") == from_step: 216 checkpoint = ckpt 217 break 218 else: 219 checkpoint = await self.checkpoint_manager.load_latest_for_execution(execution_id) 220 221 if checkpoint is None: 222 raise ValueError(f"No checkpoint found for execution_id={execution_id}") 223 224 state = checkpoint.state 225 task = str(state.get("task", "")) 226 context = state.get("context") or {} 227 raw_steps = state.get("steps") or [] 228 steps = [AgentStep(**raw) for raw in raw_steps if isinstance(raw, dict)] 229 230 # If the checkpoint already contains a final answer, return idempotently. 231 if steps and steps[-1].action == "Final Answer": 232 answer = self._extract_final_answer(steps[-1]) 233 return AgentResult( 234 output=str(answer), 235 steps=steps, 236 success=True, 237 metadata={"execution_id": execution_id}, 238 ) 239 240 ctx = BehaviorContext( 241 agent_id=self.agent_id, 242 task=task, 243 execution_id=execution_id, 244 last_completed_step=int(state.get("step_number", len(steps) - 1)), 245 metadata={"resume_agent": self}, 246 ) 247 ctx = await self._apply_behaviors_before(ctx) 248 self._log_execution_start(task) 249 250 for step_num in range(len(steps), self.max_steps): 251 prompt = self._build_prompt(task, steps) 252 response = await self.llm_gateway.complete( 253 prompt, model=self.model, temperature=self.temperature 254 ) 255 step = self._parse_step(response.content) 256 if step is None: 257 step = AgentStep( 258 thought="Could not parse structured response.", 259 action="Final Answer", 260 action_input={"answer": self._extract_final_answer(AgentStep(thought="", action="Final Answer", action_input={"answer": response.content}))}, 261 ) 262 263 if step.action == "Final Answer": 264 answer = self._extract_final_answer(step) 265 step.observation = answer 266 steps.append(step) 267 await self._save_checkpoint(execution_id, task, context, steps, step_num) 268 ctx.last_completed_step = step_num 269 result = AgentResult( 270 output=answer, 271 steps=steps, 272 success=True, 273 metadata={"execution_id": execution_id}, 274 ) 275 result = await self._apply_behaviors_after(ctx, result) 276 self._log_execution_end(task, success=True, steps=len(steps)) 277 return result 278 279 observation = "" 280 if self.tool_registry: 281 try: 282 tool_result = await self.tool_registry.execute(step.action, **step.action_input) 283 observation = str(tool_result) 284 except Exception as tool_exc: 285 observation = f"Tool error: {tool_exc}" 286 else: 287 observation = f"No tool registry — cannot execute '{step.action}'" 288 289 step.observation = observation 290 steps.append(step) 291 await self._save_checkpoint(execution_id, task, context, steps, step_num) 292 ctx.last_completed_step = step_num 293 294 final_output = steps[-1].observation if steps else "No output" 295 result = AgentResult( 296 output=final_output, 297 steps=steps, 298 success=False, 299 error=f"Reached max_steps={self.max_steps} without Final Answer", 300 metadata={"execution_id": execution_id}, 301 ) 302 result = await self._apply_behaviors_after(ctx, result) 303 self._log_execution_end(task, success=False, steps=len(steps)) 304 return result 305 306 # ------------------------------------------------------------------ 307 # BaseAgent implementation 308 # ------------------------------------------------------------------ 309 310 async def execute( 311 self, task: str, context: Optional[Dict[str, Any]] = None 312 ) -> AgentResult: 313 execution_id = str(uuid4()) 314 run_context = context or {} 315 ctx = BehaviorContext( 316 agent_id=self.agent_id, 317 task=task, 318 execution_id=execution_id, 319 metadata={"resume_agent": self, **run_context}, 320 ) 321 ctx = await self._apply_behaviors_before(ctx) 322 self._log_execution_start(task) 323 324 while True: 325 _retry = False 326 _control_exit = None # HumanApprovalRequired or PendingApproval — not errors 327 steps: List[AgentStep] = [] 328 329 with self._tracer.trace( 330 "react_agent.execute", input=task, metadata={"agent_id": self.agent_id} 331 ) as trace: 332 try: 333 for _ in range(self.max_steps): 334 prompt = self._build_prompt(task, steps) 335 336 with trace.generation( 337 "llm_call", model=self.model or "default", input=prompt 338 ) as gen: 339 perf_id = None 340 if self._performance_monitor: 341 perf_id = self._performance_monitor.start_request( 342 provider="llm_gateway", model=self.model or "default" 343 ) 344 try: 345 response = await self.llm_gateway.complete( 346 prompt, model=self.model, temperature=self.temperature 347 ) 348 if self._performance_monitor and perf_id: 349 self._performance_monitor.end_request( 350 request_id=perf_id, 351 prompt_tokens=response.usage.get("prompt_tokens", 0), 352 completion_tokens=response.usage.get("completion_tokens", 0), 353 success=True, 354 ) 355 gen.set_output(response.content) 356 gen.set_token_usage(**response.usage) 357 except Exception as exc: 358 if self._performance_monitor and perf_id: 359 self._performance_monitor.end_request( 360 request_id=perf_id, 361 prompt_tokens=0, 362 completion_tokens=0, 363 success=False, 364 error=str(exc), 365 ) 366 raise 367 368 step = self._parse_step(response.content) 369 if step is None: 370 # Unparseable response — treat as final answer 371 step = AgentStep( 372 thought="Could not parse structured response.", 373 action="Final Answer", 374 action_input={"answer": self._extract_final_answer(AgentStep(thought="", action="Final Answer", action_input={"answer": response.content}))}, 375 ) 376 377 with trace.span("step", input=step.action) as span: 378 if step.action == "Final Answer": 379 answer = self._extract_final_answer(step) 380 step.observation = answer 381 steps.append(step) 382 span.set_output(answer) 383 384 # Persist to state store if available 385 if self.state_store: 386 await self.state_store.set( 387 f"agent:{self.agent_id}:last_steps", 388 [vars(s) for s in steps], 389 ttl=self.checkpoint_manager.default_ttl if self.checkpoint_manager else None, 390 ) 391 392 await self._save_checkpoint( 393 execution_id, 394 task, 395 run_context, 396 steps, 397 len(steps) - 1, 398 ) 399 ctx.last_completed_step = len(steps) - 1 400 401 result = AgentResult( 402 output=answer, 403 steps=steps, 404 success=True, 405 metadata={"execution_id": execution_id}, 406 ) 407 try: 408 result = await self._apply_behaviors_after(ctx, result) 409 except (HumanApprovalRequired, PendingApproval) as ctrl_exc: 410 _control_exit = ctrl_exc 411 412 if _control_exit is None: 413 self._log_execution_end(task, success=True, steps=len(steps)) 414 trace.set_output(answer) 415 return result 416 else: 417 # Control-flow exit (HumanApprovalRequired / PendingApproval) 418 # span output already set above; also set trace output so 419 # it logs the answer rather than finishing with no output. 420 trace.set_output(answer) 421 422 else: 423 # Tool call 424 observation = "" 425 if self.tool_registry: 426 try: 427 tool_result = await self.tool_registry.execute( 428 step.action, **step.action_input 429 ) 430 observation = str(tool_result) 431 except Exception as tool_exc: 432 observation = f"Tool error: {tool_exc}" 433 else: 434 observation = f"No tool registry — cannot execute '{step.action}'" 435 436 step.observation = observation 437 span.set_output(observation) 438 steps.append(step) 439 await self._save_checkpoint( 440 execution_id, 441 task, 442 run_context, 443 steps, 444 len(steps) - 1, 445 ) 446 ctx.last_completed_step = len(steps) - 1 447 448 if _control_exit is not None: 449 break # exit for loop; span already closed cleanly 450 451 # Max steps reached (or for loop broken by control-flow signal) 452 if _control_exit is None: 453 final_output = steps[-1].observation if steps else "No output" 454 result = AgentResult( 455 output=final_output, 456 steps=steps, 457 success=False, 458 error=f"Reached max_steps={self.max_steps} without Final Answer", 459 metadata={"execution_id": execution_id}, 460 ) 461 result = await self._apply_behaviors_after(ctx, result) 462 self._log_execution_end(task, success=False, steps=len(steps)) 463 trace.set_output(final_output) 464 return result 465 # else: control-flow exit — trace closes cleanly, re-raised below 466 467 except Exception as exc: 468 self._log_execution_error(task, exc) 469 fallback = await self._apply_behaviors_on_error(ctx, exc) 470 if fallback is RETRY_SENTINEL: 471 # RetryBehavior slept and wants us to retry from scratch. 472 # Mark the trace as failed so it logs "failed" not "finished". 473 trace.set_error(exc) 474 ctx.attempt += 1 475 _retry = True 476 elif fallback is not None: 477 return fallback 478 else: 479 raise 480 481 if _control_exit is not None: 482 raise _control_exit 483 484 if not _retry: 485 break 486 487 async def stream_execute( 488 self, task: str, context: Optional[Dict[str, Any]] = None 489 ) -> AsyncIterator[AgentStep]: 490 ctx = BehaviorContext(agent_id=self.agent_id, task=task) 491 ctx = await self._apply_behaviors_before(ctx) 492 self._log_execution_start(task) 493 494 steps: List[AgentStep] = [] 495 for _ in range(self.max_steps): 496 prompt = self._build_prompt(task, steps) 497 response = await self.llm_gateway.complete( 498 prompt, model=self.model, temperature=self.temperature 499 ) 500 step = self._parse_step(response.content) 501 if step is None: 502 step = AgentStep( 503 thought="Unparseable response.", 504 action="Final Answer", 505 action_input={"answer": response.content}, 506 ) 507 508 if step.action == "Final Answer": 509 step.observation = self._extract_final_answer(step) 510 steps.append(step) 511 yield step 512 return 513 514 if self.tool_registry: 515 try: 516 tool_result = await self.tool_registry.execute(step.action, **step.action_input) 517 step.observation = str(tool_result) 518 except Exception as tool_exc: 519 step.observation = f"Tool error: {tool_exc}" 520 else: 521 step.observation = f"No tool registry — cannot execute '{step.action}'" 522 523 steps.append(step) 524 yield step
ReAct agent: interleaves Reasoning (Thought) and Acting (Action/Observation).
When to use: Any task that requires tool calls (search, API lookup, database queries). The LLM reasons step by step and acts on each step.
When NOT to use: Pure reasoning/analysis tasks with no tools — use
ChainOfThoughtAgent instead (cheaper: one LLM call vs. many).
On each step the LLM produces a Thought, an Action (tool name or
"Final Answer"), and an Action Input. If the action is a tool call,
the tool's output is fed back as an Observation and the loop continues.
The system prompt MUST instruct the LLM to respond using the
Thought:/Action:/Action Input: format. Without this contract the
regex parser will not match, the agent will treat the first response
as a Final Answer, and no tool calls will ever be made.
Always include {tool_descriptions} in a custom prompt so the LLM
knows what tools are available.
Args:
max_steps: Maximum thought/action cycles before stopping (default: 10).
model: LLM model name passed to the gateway (optional).
temperature: Sampling temperature (default: 0.0 for determinism).
system_prompt: Override the default ReAct system prompt. Use
{tool_descriptions} as a placeholder if you want the agent's
available tools listed in your prompt. The task is always appended
separately and does not need a placeholder here.
All other args inherited from BaseAgent.
80 def __init__(self, *args: Any, max_steps: int = 10, model: Optional[str] = None, 81 temperature: float = 0.0, system_prompt: Optional[str] = None, 82 **kwargs: Any) -> None: 83 super().__init__(*args, **kwargs) 84 self.max_steps = max_steps 85 self.model = model 86 self.temperature = temperature 87 self._system_prompt = system_prompt
204 async def resume_from( 205 self, execution_id: str, from_step: Optional[int] = None 206 ) -> AgentResult: 207 """Resume a ReAct execution from its latest (or selected) checkpoint.""" 208 if not self.checkpoint_manager: 209 raise ValueError("Checkpoint manager is not configured for this agent") 210 211 checkpoint = None 212 if from_step is not None: 213 checkpoints = await self.checkpoint_manager.list_by_execution(execution_id) 214 for ckpt in checkpoints: 215 if ckpt.state.get("step_number") == from_step: 216 checkpoint = ckpt 217 break 218 else: 219 checkpoint = await self.checkpoint_manager.load_latest_for_execution(execution_id) 220 221 if checkpoint is None: 222 raise ValueError(f"No checkpoint found for execution_id={execution_id}") 223 224 state = checkpoint.state 225 task = str(state.get("task", "")) 226 context = state.get("context") or {} 227 raw_steps = state.get("steps") or [] 228 steps = [AgentStep(**raw) for raw in raw_steps if isinstance(raw, dict)] 229 230 # If the checkpoint already contains a final answer, return idempotently. 231 if steps and steps[-1].action == "Final Answer": 232 answer = self._extract_final_answer(steps[-1]) 233 return AgentResult( 234 output=str(answer), 235 steps=steps, 236 success=True, 237 metadata={"execution_id": execution_id}, 238 ) 239 240 ctx = BehaviorContext( 241 agent_id=self.agent_id, 242 task=task, 243 execution_id=execution_id, 244 last_completed_step=int(state.get("step_number", len(steps) - 1)), 245 metadata={"resume_agent": self}, 246 ) 247 ctx = await self._apply_behaviors_before(ctx) 248 self._log_execution_start(task) 249 250 for step_num in range(len(steps), self.max_steps): 251 prompt = self._build_prompt(task, steps) 252 response = await self.llm_gateway.complete( 253 prompt, model=self.model, temperature=self.temperature 254 ) 255 step = self._parse_step(response.content) 256 if step is None: 257 step = AgentStep( 258 thought="Could not parse structured response.", 259 action="Final Answer", 260 action_input={"answer": self._extract_final_answer(AgentStep(thought="", action="Final Answer", action_input={"answer": response.content}))}, 261 ) 262 263 if step.action == "Final Answer": 264 answer = self._extract_final_answer(step) 265 step.observation = answer 266 steps.append(step) 267 await self._save_checkpoint(execution_id, task, context, steps, step_num) 268 ctx.last_completed_step = step_num 269 result = AgentResult( 270 output=answer, 271 steps=steps, 272 success=True, 273 metadata={"execution_id": execution_id}, 274 ) 275 result = await self._apply_behaviors_after(ctx, result) 276 self._log_execution_end(task, success=True, steps=len(steps)) 277 return result 278 279 observation = "" 280 if self.tool_registry: 281 try: 282 tool_result = await self.tool_registry.execute(step.action, **step.action_input) 283 observation = str(tool_result) 284 except Exception as tool_exc: 285 observation = f"Tool error: {tool_exc}" 286 else: 287 observation = f"No tool registry — cannot execute '{step.action}'" 288 289 step.observation = observation 290 steps.append(step) 291 await self._save_checkpoint(execution_id, task, context, steps, step_num) 292 ctx.last_completed_step = step_num 293 294 final_output = steps[-1].observation if steps else "No output" 295 result = AgentResult( 296 output=final_output, 297 steps=steps, 298 success=False, 299 error=f"Reached max_steps={self.max_steps} without Final Answer", 300 metadata={"execution_id": execution_id}, 301 ) 302 result = await self._apply_behaviors_after(ctx, result) 303 self._log_execution_end(task, success=False, steps=len(steps)) 304 return result
Resume a ReAct execution from its latest (or selected) checkpoint.
310 async def execute( 311 self, task: str, context: Optional[Dict[str, Any]] = None 312 ) -> AgentResult: 313 execution_id = str(uuid4()) 314 run_context = context or {} 315 ctx = BehaviorContext( 316 agent_id=self.agent_id, 317 task=task, 318 execution_id=execution_id, 319 metadata={"resume_agent": self, **run_context}, 320 ) 321 ctx = await self._apply_behaviors_before(ctx) 322 self._log_execution_start(task) 323 324 while True: 325 _retry = False 326 _control_exit = None # HumanApprovalRequired or PendingApproval — not errors 327 steps: List[AgentStep] = [] 328 329 with self._tracer.trace( 330 "react_agent.execute", input=task, metadata={"agent_id": self.agent_id} 331 ) as trace: 332 try: 333 for _ in range(self.max_steps): 334 prompt = self._build_prompt(task, steps) 335 336 with trace.generation( 337 "llm_call", model=self.model or "default", input=prompt 338 ) as gen: 339 perf_id = None 340 if self._performance_monitor: 341 perf_id = self._performance_monitor.start_request( 342 provider="llm_gateway", model=self.model or "default" 343 ) 344 try: 345 response = await self.llm_gateway.complete( 346 prompt, model=self.model, temperature=self.temperature 347 ) 348 if self._performance_monitor and perf_id: 349 self._performance_monitor.end_request( 350 request_id=perf_id, 351 prompt_tokens=response.usage.get("prompt_tokens", 0), 352 completion_tokens=response.usage.get("completion_tokens", 0), 353 success=True, 354 ) 355 gen.set_output(response.content) 356 gen.set_token_usage(**response.usage) 357 except Exception as exc: 358 if self._performance_monitor and perf_id: 359 self._performance_monitor.end_request( 360 request_id=perf_id, 361 prompt_tokens=0, 362 completion_tokens=0, 363 success=False, 364 error=str(exc), 365 ) 366 raise 367 368 step = self._parse_step(response.content) 369 if step is None: 370 # Unparseable response — treat as final answer 371 step = AgentStep( 372 thought="Could not parse structured response.", 373 action="Final Answer", 374 action_input={"answer": self._extract_final_answer(AgentStep(thought="", action="Final Answer", action_input={"answer": response.content}))}, 375 ) 376 377 with trace.span("step", input=step.action) as span: 378 if step.action == "Final Answer": 379 answer = self._extract_final_answer(step) 380 step.observation = answer 381 steps.append(step) 382 span.set_output(answer) 383 384 # Persist to state store if available 385 if self.state_store: 386 await self.state_store.set( 387 f"agent:{self.agent_id}:last_steps", 388 [vars(s) for s in steps], 389 ttl=self.checkpoint_manager.default_ttl if self.checkpoint_manager else None, 390 ) 391 392 await self._save_checkpoint( 393 execution_id, 394 task, 395 run_context, 396 steps, 397 len(steps) - 1, 398 ) 399 ctx.last_completed_step = len(steps) - 1 400 401 result = AgentResult( 402 output=answer, 403 steps=steps, 404 success=True, 405 metadata={"execution_id": execution_id}, 406 ) 407 try: 408 result = await self._apply_behaviors_after(ctx, result) 409 except (HumanApprovalRequired, PendingApproval) as ctrl_exc: 410 _control_exit = ctrl_exc 411 412 if _control_exit is None: 413 self._log_execution_end(task, success=True, steps=len(steps)) 414 trace.set_output(answer) 415 return result 416 else: 417 # Control-flow exit (HumanApprovalRequired / PendingApproval) 418 # span output already set above; also set trace output so 419 # it logs the answer rather than finishing with no output. 420 trace.set_output(answer) 421 422 else: 423 # Tool call 424 observation = "" 425 if self.tool_registry: 426 try: 427 tool_result = await self.tool_registry.execute( 428 step.action, **step.action_input 429 ) 430 observation = str(tool_result) 431 except Exception as tool_exc: 432 observation = f"Tool error: {tool_exc}" 433 else: 434 observation = f"No tool registry — cannot execute '{step.action}'" 435 436 step.observation = observation 437 span.set_output(observation) 438 steps.append(step) 439 await self._save_checkpoint( 440 execution_id, 441 task, 442 run_context, 443 steps, 444 len(steps) - 1, 445 ) 446 ctx.last_completed_step = len(steps) - 1 447 448 if _control_exit is not None: 449 break # exit for loop; span already closed cleanly 450 451 # Max steps reached (or for loop broken by control-flow signal) 452 if _control_exit is None: 453 final_output = steps[-1].observation if steps else "No output" 454 result = AgentResult( 455 output=final_output, 456 steps=steps, 457 success=False, 458 error=f"Reached max_steps={self.max_steps} without Final Answer", 459 metadata={"execution_id": execution_id}, 460 ) 461 result = await self._apply_behaviors_after(ctx, result) 462 self._log_execution_end(task, success=False, steps=len(steps)) 463 trace.set_output(final_output) 464 return result 465 # else: control-flow exit — trace closes cleanly, re-raised below 466 467 except Exception as exc: 468 self._log_execution_error(task, exc) 469 fallback = await self._apply_behaviors_on_error(ctx, exc) 470 if fallback is RETRY_SENTINEL: 471 # RetryBehavior slept and wants us to retry from scratch. 472 # Mark the trace as failed so it logs "failed" not "finished". 473 trace.set_error(exc) 474 ctx.attempt += 1 475 _retry = True 476 elif fallback is not None: 477 return fallback 478 else: 479 raise 480 481 if _control_exit is not None: 482 raise _control_exit 483 484 if not _retry: 485 break
Execute the task and return a result.
487 async def stream_execute( 488 self, task: str, context: Optional[Dict[str, Any]] = None 489 ) -> AsyncIterator[AgentStep]: 490 ctx = BehaviorContext(agent_id=self.agent_id, task=task) 491 ctx = await self._apply_behaviors_before(ctx) 492 self._log_execution_start(task) 493 494 steps: List[AgentStep] = [] 495 for _ in range(self.max_steps): 496 prompt = self._build_prompt(task, steps) 497 response = await self.llm_gateway.complete( 498 prompt, model=self.model, temperature=self.temperature 499 ) 500 step = self._parse_step(response.content) 501 if step is None: 502 step = AgentStep( 503 thought="Unparseable response.", 504 action="Final Answer", 505 action_input={"answer": response.content}, 506 ) 507 508 if step.action == "Final Answer": 509 step.observation = self._extract_final_answer(step) 510 steps.append(step) 511 yield step 512 return 513 514 if self.tool_registry: 515 try: 516 tool_result = await self.tool_registry.execute(step.action, **step.action_input) 517 step.observation = str(tool_result) 518 except Exception as tool_exc: 519 step.observation = f"Tool error: {tool_exc}" 520 else: 521 step.observation = f"No tool registry — cannot execute '{step.action}'" 522 523 steps.append(step) 524 yield step
Execute the task, yielding each step as it completes.
54class PlanExecuteAgent(BaseAgent): 55 """ 56 Two-phase agent: LLM plans all steps first, then executes each step sequentially. 57 58 Phase 1 — Plan: Ask the LLM to decompose the task into a JSON list of steps. 59 Phase 2 — Execute: Feed each step back to the LLM (with accumulated context) to 60 produce a result. Each step runs an inner tool-calling loop — if the LLM emits a 61 ``Thought/Action/Action Input`` block referencing a registered tool, the tool is 62 invoked and its observation is fed back for the LLM to produce a final step result. 63 64 Args: 65 max_plan_steps: Maximum number of planned steps allowed (default: 10). 66 max_tool_calls_per_step: Maximum tool calls allowed within a single plan step 67 (default: 3). 68 model: LLM model name (optional). 69 temperature: Sampling temperature (default: 0.1). 70 plan_prompt: Override the planning prompt. Must contain ``{task}``. 71 execute_prompt: Override the step-execution prompt. Must contain 72 ``{task}``, ``{plan}``, ``{previous_results}``, ``{step_num}``, 73 ``{current_step}``, and ``{tool_section}``. 74 75 All other args inherited from :class:`BaseAgent`. 76 """ 77 78 def __init__( 79 self, 80 *args: Any, 81 max_plan_steps: int = 10, 82 max_tool_calls_per_step: int = 3, 83 model: Optional[str] = None, 84 temperature: float = 0.1, 85 plan_prompt: Optional[str] = None, 86 execute_prompt: Optional[str] = None, 87 **kwargs: Any, 88 ) -> None: 89 super().__init__(*args, **kwargs) 90 self.max_plan_steps = max_plan_steps 91 self.max_tool_calls_per_step = max_tool_calls_per_step 92 self.model = model 93 self.temperature = temperature 94 self._plan_prompt = plan_prompt 95 self._execute_prompt = execute_prompt 96 97 #: Default planning prompt. Must contain ``{task}`` if overriding. 98 DEFAULT_PLAN_PROMPT: str = _PLAN_PROMPT 99 #: Default step-execution prompt. Must contain ``{task}``, ``{plan}``, 100 #: ``{previous_results}``, ``{step_num}``, ``{current_step}``, and 101 #: ``{tool_section}`` if overriding (``{tool_section}`` is auto-populated 102 #: from the tool registry; pass an empty string if no tools are needed). 103 DEFAULT_EXECUTE_PROMPT: str = _EXECUTE_PROMPT 104 105 # ------------------------------------------------------------------ 106 # Phase 1: Plan 107 # ------------------------------------------------------------------ 108 109 async def _plan(self, task: str) -> List[str]: 110 template = self._plan_prompt if self._plan_prompt is not None else _PLAN_PROMPT 111 prompt = template.format(task=task) 112 response = await self.llm_gateway.complete( 113 prompt, model=self.model, temperature=0.0 114 ) 115 raw = response.content.strip() 116 # Extract JSON array even if wrapped in markdown fences 117 match = re.search(r"\[.*\]", raw, re.DOTALL) 118 if match: 119 try: 120 steps = json.loads(match.group()) 121 if isinstance(steps, list): 122 return [str(s) for s in steps[: self.max_plan_steps]] 123 except (json.JSONDecodeError, ValueError): 124 pass 125 # Fallback: treat each line as a step 126 lines = [ln.strip().lstrip("0123456789.-) ") for ln in raw.splitlines() if ln.strip()] 127 return lines[: self.max_plan_steps] 128 129 # ------------------------------------------------------------------ 130 # Phase 2: Execute a single step (with optional tool calls) 131 # ------------------------------------------------------------------ 132 133 def _tool_descriptions(self) -> str: 134 if not self.tool_registry: 135 return "" 136 tools = self.tool_registry.list_tools() 137 if not tools: 138 return "" 139 return "\n".join(f"- {t.name}: {t.description}" for t in tools) 140 141 async def _execute_step( 142 self, task: str, plan: List[str], step_num: int, previous: List[str] 143 ) -> str: 144 tool_descriptions = self._tool_descriptions() 145 tool_section = ( 146 _TOOL_SECTION.format(tool_descriptions=tool_descriptions) 147 if tool_descriptions 148 else "" 149 ) 150 template = self._execute_prompt if self._execute_prompt is not None else _EXECUTE_PROMPT 151 prompt = template.format( 152 task=task, 153 plan="\n".join(f"{i+1}. {s}" for i, s in enumerate(plan)), 154 previous_results="\n".join( 155 f"Step {i+1} result: {r}" for i, r in enumerate(previous) 156 ) or "None yet.", 157 step_num=step_num, 158 current_step=plan[step_num - 1], 159 tool_section=tool_section, 160 ) 161 162 # Inner tool-calling loop: allow up to max_tool_calls tool calls per step 163 conversation = prompt 164 for _ in range(self.max_tool_calls_per_step): 165 response = await self.llm_gateway.complete( 166 conversation, model=self.model, temperature=self.temperature 167 ) 168 raw = response.content.strip() 169 170 # Check if the LLM wants to call a tool 171 match = _STEP_PATTERN.search(raw) 172 if match and self.tool_registry: 173 action = match.group("action").strip() 174 raw_input = match.group("action_input").strip() 175 try: 176 action_input = json.loads(raw_input) 177 if not isinstance(action_input, dict): 178 action_input = {"raw": raw_input} 179 except (json.JSONDecodeError, ValueError): 180 action_input = {"raw": raw_input} 181 182 try: 183 tool_result = await self.tool_registry.execute(action, **action_input) 184 observation = str(tool_result) 185 except Exception as exc: 186 observation = f"Tool error: {exc}" 187 188 # Append the tool call + observation to the conversation and continue 189 conversation = ( 190 conversation 191 + f"\n{raw}\nObservation: {observation}\n" 192 + "Now provide the final result for this step based on the above observation." 193 ) 194 else: 195 # No tool call pattern — treat the response as the step result 196 return raw 197 198 # Exhausted tool call budget — return last response 199 return raw 200 201 async def _save_checkpoint( 202 self, 203 execution_id: str, 204 task: str, 205 context: Dict[str, Any], 206 plan: List[str], 207 steps: List[AgentStep], 208 step_results: List[str], 209 step_number: int, 210 ) -> None: 211 if not self.checkpoint_manager: 212 return 213 await self.checkpoint_manager.save( 214 agent_id=self.agent_id, 215 execution_id=execution_id, 216 state={ 217 "execution_id": execution_id, 218 "task": task, 219 "context": context, 220 "plan": plan, 221 "step_number": step_number, 222 "step_results": list(step_results), 223 "steps": [dataclasses.asdict(s) for s in steps], 224 }, 225 metadata={"step_number": step_number}, 226 ttl=86400, 227 ) 228 229 async def resume_from( 230 self, execution_id: str, from_step: Optional[int] = None 231 ) -> AgentResult: 232 """Resume a Plan-Execute run from latest (or selected) checkpoint.""" 233 if not self.checkpoint_manager: 234 raise ValueError("Checkpoint manager is not configured for this agent") 235 236 checkpoint = None 237 if from_step is not None: 238 checkpoints = await self.checkpoint_manager.list_by_execution(execution_id) 239 for ckpt in checkpoints: 240 if ckpt.state.get("step_number") == from_step: 241 checkpoint = ckpt 242 break 243 else: 244 checkpoint = await self.checkpoint_manager.load_latest_for_execution(execution_id) 245 246 if checkpoint is None: 247 raise ValueError(f"No checkpoint found for execution_id={execution_id}") 248 249 state = checkpoint.state 250 task = str(state.get("task", "")) 251 context = state.get("context") or {} 252 plan = [str(s) for s in state.get("plan") or []] 253 raw_steps = state.get("steps") or [] 254 steps = [AgentStep(**raw) for raw in raw_steps if isinstance(raw, dict)] 255 step_results = [str(r) for r in state.get("step_results") or []] 256 current_step = int(state.get("step_number", len(step_results) - 1)) + 1 257 258 ctx = BehaviorContext( 259 agent_id=self.agent_id, 260 task=task, 261 execution_id=execution_id, 262 last_completed_step=max(current_step - 1, -1), 263 metadata={"resume_agent": self}, 264 ) 265 ctx = await self._apply_behaviors_before(ctx) 266 267 for i in range(current_step, len(plan)): 268 step_num = i + 1 269 result_text = await self._execute_step(task, plan, step_num, step_results) 270 step_results.append(result_text) 271 agent_step = AgentStep( 272 thought=f"Executing plan step {step_num}: {plan[i]}", 273 action="llm_execution", 274 action_input={"step": plan[i]}, 275 observation=result_text, 276 ) 277 steps.append(agent_step) 278 await self._save_checkpoint( 279 execution_id, 280 task, 281 context, 282 plan, 283 steps, 284 step_results, 285 i, 286 ) 287 ctx.last_completed_step = i 288 289 final_output = step_results[-1] if step_results else "" 290 result = AgentResult( 291 output=final_output, 292 steps=steps, 293 success=True, 294 metadata={"plan": plan, "execution_id": execution_id}, 295 ) 296 return await self._apply_behaviors_after(ctx, result) 297 298 # ------------------------------------------------------------------ 299 # BaseAgent implementation 300 # ------------------------------------------------------------------ 301 302 async def execute( 303 self, task: str, context: Optional[Dict[str, Any]] = None 304 ) -> AgentResult: 305 execution_id = str(uuid4()) 306 run_context = context or {} 307 ctx = BehaviorContext( 308 agent_id=self.agent_id, 309 task=task, 310 execution_id=execution_id, 311 metadata={"resume_agent": self, **run_context}, 312 ) 313 ctx = await self._apply_behaviors_before(ctx) 314 self._log_execution_start(task) 315 316 with self._tracer.trace( 317 "plan_execute_agent.execute", input=task, metadata={"agent_id": self.agent_id} 318 ) as trace: 319 try: 320 # Phase 1 — Plan 321 with trace.span("plan", input=task) as plan_span: 322 plan = await self._plan(task) 323 plan_span.set_output({"steps": plan}) 324 self._logger.info( 325 "Plan created", 326 agent_id=self.agent_id, 327 step_count=len(plan), 328 ) 329 await self._save_checkpoint( 330 execution_id, 331 task, 332 run_context, 333 plan, 334 [], 335 [], 336 -1, 337 ) 338 339 # Phase 2 — Execute 340 steps: List[AgentStep] = [] 341 step_results: List[str] = [] 342 for i, plan_step in enumerate(plan): 343 step_num = i + 1 344 with trace.span(f"execute_step_{step_num}", input=plan_step) as step_span: 345 perf_id = None 346 if self._performance_monitor: 347 perf_id = self._performance_monitor.start_request( 348 provider="llm_gateway", 349 model=self.model or "default", 350 ) 351 try: 352 result_text = await self._execute_step( 353 task, plan, step_num, step_results 354 ) 355 if self._performance_monitor and perf_id: 356 self._performance_monitor.end_request( 357 request_id=perf_id, 358 prompt_tokens=0, 359 completion_tokens=0, 360 success=True, 361 ) 362 except Exception as exc: 363 if self._performance_monitor and perf_id: 364 self._performance_monitor.end_request( 365 request_id=perf_id, 366 prompt_tokens=0, 367 completion_tokens=0, 368 success=False, 369 error=str(exc), 370 ) 371 raise 372 373 step_results.append(result_text) 374 agent_step = AgentStep( 375 thought=f"Executing plan step {step_num}: {plan_step}", 376 action="llm_execution", 377 action_input={"step": plan_step}, 378 observation=result_text, 379 ) 380 steps.append(agent_step) 381 step_span.set_output(result_text) 382 await self._save_checkpoint( 383 execution_id, 384 task, 385 run_context, 386 plan, 387 steps, 388 step_results, 389 i, 390 ) 391 ctx.last_completed_step = i 392 393 final_output = step_results[-1] if step_results else "" 394 if self.state_store: 395 await self.state_store.set( 396 f"agent:{self.agent_id}:last_steps", 397 [vars(s) for s in steps], 398 ttl=3600, 399 ) 400 401 result = AgentResult( 402 output=final_output, 403 steps=steps, 404 success=True, 405 metadata={"plan": plan, "execution_id": execution_id}, 406 ) 407 result = await self._apply_behaviors_after(ctx, result) 408 self._log_execution_end(task, success=True, steps=len(steps)) 409 trace.set_output(final_output) 410 return result 411 412 except Exception as exc: 413 self._log_execution_error(task, exc) 414 fallback = await self._apply_behaviors_on_error(ctx, exc) 415 if fallback is not None: 416 return fallback 417 raise 418 419 async def stream_execute( 420 self, task: str, context: Optional[Dict[str, Any]] = None 421 ) -> AsyncIterator[AgentStep]: 422 ctx = BehaviorContext(agent_id=self.agent_id, task=task) 423 ctx = await self._apply_behaviors_before(ctx) 424 self._log_execution_start(task) 425 426 plan = await self._plan(task) 427 step_results: List[str] = [] 428 for i, plan_step in enumerate(plan): 429 result_text = await self._execute_step(task, plan, i + 1, step_results) 430 step_results.append(result_text) 431 step = AgentStep( 432 thought=f"Executing plan step {i+1}: {plan_step}", 433 action="llm_execution", 434 action_input={"step": plan_step}, 435 observation=result_text, 436 ) 437 yield step
Two-phase agent: LLM plans all steps first, then executes each step sequentially.
Phase 1 — Plan: Ask the LLM to decompose the task into a JSON list of steps.
Phase 2 — Execute: Feed each step back to the LLM (with accumulated context) to
produce a result. Each step runs an inner tool-calling loop — if the LLM emits a
Thought/Action/Action Input block referencing a registered tool, the tool is
invoked and its observation is fed back for the LLM to produce a final step result.
Args:
max_plan_steps: Maximum number of planned steps allowed (default: 10).
max_tool_calls_per_step: Maximum tool calls allowed within a single plan step
(default: 3).
model: LLM model name (optional).
temperature: Sampling temperature (default: 0.1).
plan_prompt: Override the planning prompt. Must contain {task}.
execute_prompt: Override the step-execution prompt. Must contain
{task}, {plan}, {previous_results}, {step_num},
{current_step}, and {tool_section}.
All other args inherited from BaseAgent.
78 def __init__( 79 self, 80 *args: Any, 81 max_plan_steps: int = 10, 82 max_tool_calls_per_step: int = 3, 83 model: Optional[str] = None, 84 temperature: float = 0.1, 85 plan_prompt: Optional[str] = None, 86 execute_prompt: Optional[str] = None, 87 **kwargs: Any, 88 ) -> None: 89 super().__init__(*args, **kwargs) 90 self.max_plan_steps = max_plan_steps 91 self.max_tool_calls_per_step = max_tool_calls_per_step 92 self.model = model 93 self.temperature = temperature 94 self._plan_prompt = plan_prompt 95 self._execute_prompt = execute_prompt
229 async def resume_from( 230 self, execution_id: str, from_step: Optional[int] = None 231 ) -> AgentResult: 232 """Resume a Plan-Execute run from latest (or selected) checkpoint.""" 233 if not self.checkpoint_manager: 234 raise ValueError("Checkpoint manager is not configured for this agent") 235 236 checkpoint = None 237 if from_step is not None: 238 checkpoints = await self.checkpoint_manager.list_by_execution(execution_id) 239 for ckpt in checkpoints: 240 if ckpt.state.get("step_number") == from_step: 241 checkpoint = ckpt 242 break 243 else: 244 checkpoint = await self.checkpoint_manager.load_latest_for_execution(execution_id) 245 246 if checkpoint is None: 247 raise ValueError(f"No checkpoint found for execution_id={execution_id}") 248 249 state = checkpoint.state 250 task = str(state.get("task", "")) 251 context = state.get("context") or {} 252 plan = [str(s) for s in state.get("plan") or []] 253 raw_steps = state.get("steps") or [] 254 steps = [AgentStep(**raw) for raw in raw_steps if isinstance(raw, dict)] 255 step_results = [str(r) for r in state.get("step_results") or []] 256 current_step = int(state.get("step_number", len(step_results) - 1)) + 1 257 258 ctx = BehaviorContext( 259 agent_id=self.agent_id, 260 task=task, 261 execution_id=execution_id, 262 last_completed_step=max(current_step - 1, -1), 263 metadata={"resume_agent": self}, 264 ) 265 ctx = await self._apply_behaviors_before(ctx) 266 267 for i in range(current_step, len(plan)): 268 step_num = i + 1 269 result_text = await self._execute_step(task, plan, step_num, step_results) 270 step_results.append(result_text) 271 agent_step = AgentStep( 272 thought=f"Executing plan step {step_num}: {plan[i]}", 273 action="llm_execution", 274 action_input={"step": plan[i]}, 275 observation=result_text, 276 ) 277 steps.append(agent_step) 278 await self._save_checkpoint( 279 execution_id, 280 task, 281 context, 282 plan, 283 steps, 284 step_results, 285 i, 286 ) 287 ctx.last_completed_step = i 288 289 final_output = step_results[-1] if step_results else "" 290 result = AgentResult( 291 output=final_output, 292 steps=steps, 293 success=True, 294 metadata={"plan": plan, "execution_id": execution_id}, 295 ) 296 return await self._apply_behaviors_after(ctx, result)
Resume a Plan-Execute run from latest (or selected) checkpoint.
302 async def execute( 303 self, task: str, context: Optional[Dict[str, Any]] = None 304 ) -> AgentResult: 305 execution_id = str(uuid4()) 306 run_context = context or {} 307 ctx = BehaviorContext( 308 agent_id=self.agent_id, 309 task=task, 310 execution_id=execution_id, 311 metadata={"resume_agent": self, **run_context}, 312 ) 313 ctx = await self._apply_behaviors_before(ctx) 314 self._log_execution_start(task) 315 316 with self._tracer.trace( 317 "plan_execute_agent.execute", input=task, metadata={"agent_id": self.agent_id} 318 ) as trace: 319 try: 320 # Phase 1 — Plan 321 with trace.span("plan", input=task) as plan_span: 322 plan = await self._plan(task) 323 plan_span.set_output({"steps": plan}) 324 self._logger.info( 325 "Plan created", 326 agent_id=self.agent_id, 327 step_count=len(plan), 328 ) 329 await self._save_checkpoint( 330 execution_id, 331 task, 332 run_context, 333 plan, 334 [], 335 [], 336 -1, 337 ) 338 339 # Phase 2 — Execute 340 steps: List[AgentStep] = [] 341 step_results: List[str] = [] 342 for i, plan_step in enumerate(plan): 343 step_num = i + 1 344 with trace.span(f"execute_step_{step_num}", input=plan_step) as step_span: 345 perf_id = None 346 if self._performance_monitor: 347 perf_id = self._performance_monitor.start_request( 348 provider="llm_gateway", 349 model=self.model or "default", 350 ) 351 try: 352 result_text = await self._execute_step( 353 task, plan, step_num, step_results 354 ) 355 if self._performance_monitor and perf_id: 356 self._performance_monitor.end_request( 357 request_id=perf_id, 358 prompt_tokens=0, 359 completion_tokens=0, 360 success=True, 361 ) 362 except Exception as exc: 363 if self._performance_monitor and perf_id: 364 self._performance_monitor.end_request( 365 request_id=perf_id, 366 prompt_tokens=0, 367 completion_tokens=0, 368 success=False, 369 error=str(exc), 370 ) 371 raise 372 373 step_results.append(result_text) 374 agent_step = AgentStep( 375 thought=f"Executing plan step {step_num}: {plan_step}", 376 action="llm_execution", 377 action_input={"step": plan_step}, 378 observation=result_text, 379 ) 380 steps.append(agent_step) 381 step_span.set_output(result_text) 382 await self._save_checkpoint( 383 execution_id, 384 task, 385 run_context, 386 plan, 387 steps, 388 step_results, 389 i, 390 ) 391 ctx.last_completed_step = i 392 393 final_output = step_results[-1] if step_results else "" 394 if self.state_store: 395 await self.state_store.set( 396 f"agent:{self.agent_id}:last_steps", 397 [vars(s) for s in steps], 398 ttl=3600, 399 ) 400 401 result = AgentResult( 402 output=final_output, 403 steps=steps, 404 success=True, 405 metadata={"plan": plan, "execution_id": execution_id}, 406 ) 407 result = await self._apply_behaviors_after(ctx, result) 408 self._log_execution_end(task, success=True, steps=len(steps)) 409 trace.set_output(final_output) 410 return result 411 412 except Exception as exc: 413 self._log_execution_error(task, exc) 414 fallback = await self._apply_behaviors_on_error(ctx, exc) 415 if fallback is not None: 416 return fallback 417 raise
Execute the task and return a result.
419 async def stream_execute( 420 self, task: str, context: Optional[Dict[str, Any]] = None 421 ) -> AsyncIterator[AgentStep]: 422 ctx = BehaviorContext(agent_id=self.agent_id, task=task) 423 ctx = await self._apply_behaviors_before(ctx) 424 self._log_execution_start(task) 425 426 plan = await self._plan(task) 427 step_results: List[str] = [] 428 for i, plan_step in enumerate(plan): 429 result_text = await self._execute_step(task, plan, i + 1, step_results) 430 step_results.append(result_text) 431 step = AgentStep( 432 thought=f"Executing plan step {i+1}: {plan_step}", 433 action="llm_execution", 434 action_input={"step": plan_step}, 435 observation=result_text, 436 ) 437 yield step
Execute the task, yielding each step as it completes.
31class ReflexionAgent(BaseAgent): 32 """ 33 Wraps another agent and applies self-reflection on failure. 34 35 After each attempt, a separate LLM call critiques the output. If the 36 critique says the result is unsatisfactory, the agent reflects and retries 37 with an improved prompt up to ``max_reflections`` times. 38 39 Args: 40 inner_agent: The underlying agent to run and reflect on. 41 max_reflections: Maximum reflection/retry cycles (default: 2). 42 model: Model for critique and reflection calls (defaults to inner agent 43 model or gateway default). 44 critique_prompt: Override the critique prompt. Must contain ``{task}`` 45 and ``{response}``. 46 reflect_prompt: Override the reflection/retry prompt. Must contain 47 ``{task}``, ``{previous_output}``, and ``{critique}``. 48 49 All other args inherited from :class:`BaseAgent`. 50 """ 51 52 def __init__( 53 self, 54 inner_agent: BaseAgent, 55 *args: Any, 56 max_reflections: int = 2, 57 model: Optional[str] = None, 58 critique_prompt: Optional[str] = None, 59 reflect_prompt: Optional[str] = None, 60 **kwargs: Any, 61 ) -> None: 62 super().__init__(inner_agent.llm_gateway, *args, **kwargs) 63 self.inner_agent = inner_agent 64 self.max_reflections = max_reflections 65 self.model = model 66 self._critique_prompt = critique_prompt 67 self._reflect_prompt = reflect_prompt 68 69 #: Default critique prompt. Must contain ``{task}`` and ``{response}`` if overriding. 70 DEFAULT_CRITIQUE_PROMPT: str = _CRITIQUE_PROMPT 71 #: Default reflection prompt. Must contain ``{task}``, ``{previous_output}``, 72 #: and ``{critique}`` if overriding. 73 DEFAULT_REFLECT_PROMPT: str = _REFLECT_PROMPT 74 75 async def _critique(self, task: str, output: str) -> Optional[str]: 76 """Returns the critique reason string, or None if the output is satisfactory.""" 77 template = self._critique_prompt if self._critique_prompt is not None else _CRITIQUE_PROMPT 78 prompt = template.format(task=task, response=output) 79 response = await self.llm_gateway.complete(prompt, model=self.model, temperature=0.0) 80 text = response.content.strip() 81 if text.upper().startswith("YES"): 82 return None 83 # Extract reason after "NO:" 84 if ":" in text: 85 return text.split(":", 1)[1].strip() 86 return text 87 88 async def _reflect_and_retry( 89 self, task: str, previous_output: str, critique: str 90 ) -> AgentResult: 91 template = self._reflect_prompt if self._reflect_prompt is not None else _REFLECT_PROMPT 92 improved_task = template.format( 93 task=task, previous_output=previous_output, critique=critique 94 ) 95 return await self.inner_agent.execute(improved_task) 96 97 async def _save_checkpoint( 98 self, 99 execution_id: str, 100 task: str, 101 context: Dict[str, Any], 102 reflection_round: int, 103 result: AgentResult, 104 critique: Optional[str] = None, 105 ) -> None: 106 if not self.checkpoint_manager: 107 return 108 await self.checkpoint_manager.save( 109 agent_id=self.agent_id, 110 execution_id=execution_id, 111 state={ 112 "execution_id": execution_id, 113 "task": task, 114 "context": context, 115 "reflection_round": reflection_round, 116 "critique": critique, 117 "result": { 118 "output": result.output, 119 "success": result.success, 120 "error": result.error, 121 "metadata": result.metadata, 122 "steps": [dataclasses.asdict(s) for s in result.steps], 123 }, 124 }, 125 metadata={"reflection_round": reflection_round}, 126 ttl=86400, 127 ) 128 129 async def resume_from( 130 self, execution_id: str, from_reflection: Optional[int] = None 131 ) -> AgentResult: 132 """Resume a reflexion execution from latest (or selected) reflection checkpoint.""" 133 if not self.checkpoint_manager: 134 raise ValueError("Checkpoint manager is not configured for this agent") 135 136 checkpoint = None 137 if from_reflection is not None: 138 checkpoints = await self.checkpoint_manager.list_by_execution(execution_id) 139 for ckpt in checkpoints: 140 if ckpt.state.get("reflection_round") == from_reflection: 141 checkpoint = ckpt 142 break 143 else: 144 checkpoint = await self.checkpoint_manager.load_latest_for_execution(execution_id) 145 146 if checkpoint is None: 147 raise ValueError(f"No checkpoint found for execution_id={execution_id}") 148 149 state = checkpoint.state 150 task = str(state.get("task", "")) 151 context = state.get("context") or {} 152 reflection_round = int(state.get("reflection_round", 0)) 153 raw = state.get("result") or {} 154 155 steps = [ 156 AgentStep(**s) for s in raw.get("steps", []) if isinstance(s, dict) 157 ] 158 result = AgentResult( 159 output=str(raw.get("output", "")), 160 steps=steps, 161 success=bool(raw.get("success", True)), 162 error=raw.get("error"), 163 metadata=raw.get("metadata", {}), 164 ) 165 166 ctx = BehaviorContext( 167 agent_id=self.agent_id, 168 task=task, 169 execution_id=execution_id, 170 last_completed_step=reflection_round, 171 metadata={"resume_agent": self}, 172 ) 173 ctx = await self._apply_behaviors_before(ctx) 174 175 all_steps: List[AgentStep] = list(result.steps) 176 for idx in range(reflection_round, self.max_reflections): 177 critique = await self._critique(task, result.output) 178 if critique is None: 179 break 180 result = await self._reflect_and_retry(task, result.output, critique) 181 all_steps.extend(result.steps) 182 await self._save_checkpoint( 183 execution_id, 184 task, 185 context, 186 idx + 1, 187 result, 188 critique, 189 ) 190 ctx.last_completed_step = idx + 1 191 192 final = AgentResult( 193 output=result.output, 194 steps=all_steps, 195 success=result.success, 196 error=result.error, 197 metadata={**result.metadata, "execution_id": execution_id}, 198 ) 199 return await self._apply_behaviors_after(ctx, final) 200 201 # ------------------------------------------------------------------ 202 # BaseAgent implementation 203 # ------------------------------------------------------------------ 204 205 async def execute( 206 self, task: str, context: Optional[Dict[str, Any]] = None 207 ) -> AgentResult: 208 execution_id = str(uuid4()) 209 run_context = context or {} 210 ctx = BehaviorContext( 211 agent_id=self.agent_id, 212 task=task, 213 execution_id=execution_id, 214 metadata={"resume_agent": self, **run_context}, 215 ) 216 ctx = await self._apply_behaviors_before(ctx) 217 self._log_execution_start(task) 218 219 with self._tracer.trace( 220 "reflexion_agent.execute", input=task, metadata={"agent_id": self.agent_id} 221 ) as trace: 222 try: 223 result = await self.inner_agent.execute(task, context) 224 all_steps: List[AgentStep] = list(result.steps) 225 await self._save_checkpoint(execution_id, task, run_context, 0, result) 226 ctx.last_completed_step = 0 227 228 for reflection_num in range(self.max_reflections): 229 with trace.span( 230 f"critique_{reflection_num + 1}", input=result.output 231 ) as cspan: 232 critique = await self._critique(task, result.output) 233 cspan.set_output({"satisfactory": critique is None, "critique": critique}) 234 235 if critique is None: 236 self._logger.info( 237 "Reflexion: output accepted", 238 agent_id=self.agent_id, 239 reflection=reflection_num + 1, 240 ) 241 break 242 243 self._logger.info( 244 "Reflexion: retrying", 245 agent_id=self.agent_id, 246 reflection=reflection_num + 1, 247 critique=critique, 248 ) 249 if self._metrics: 250 self._metrics.increment( 251 "agent.reflexions", agent_id=self.agent_id 252 ) 253 254 with trace.span(f"reflect_{reflection_num + 1}", input=critique) as rspan: 255 result = await self._reflect_and_retry(task, result.output, critique) 256 all_steps.extend(result.steps) 257 await self._save_checkpoint( 258 execution_id, 259 task, 260 run_context, 261 reflection_num + 1, 262 result, 263 critique, 264 ) 265 ctx.last_completed_step = reflection_num + 1 266 rspan.set_output(result.output) 267 268 final = AgentResult( 269 output=result.output, 270 steps=all_steps, 271 success=result.success, 272 error=result.error, 273 metadata={**result.metadata, "execution_id": execution_id}, 274 ) 275 final = await self._apply_behaviors_after(ctx, final) 276 self._log_execution_end(task, success=final.success, steps=len(all_steps)) 277 trace.set_output(final.output) 278 return final 279 280 except Exception as exc: 281 self._log_execution_error(task, exc) 282 fallback = await self._apply_behaviors_on_error(ctx, exc) 283 if fallback is not None: 284 return fallback 285 raise 286 287 async def stream_execute( 288 self, task: str, context: Optional[Dict[str, Any]] = None 289 ) -> AsyncIterator[AgentStep]: 290 # Reflexion does full-cycle reflection, so streaming yields inner steps 291 async for step in self.inner_agent.stream_execute(task, context): 292 yield step
Wraps another agent and applies self-reflection on failure.
After each attempt, a separate LLM call critiques the output. If the
critique says the result is unsatisfactory, the agent reflects and retries
with an improved prompt up to max_reflections times.
Args:
inner_agent: The underlying agent to run and reflect on.
max_reflections: Maximum reflection/retry cycles (default: 2).
model: Model for critique and reflection calls (defaults to inner agent
model or gateway default).
critique_prompt: Override the critique prompt. Must contain {task}
and {response}.
reflect_prompt: Override the reflection/retry prompt. Must contain
{task}, {previous_output}, and {critique}.
All other args inherited from BaseAgent.
52 def __init__( 53 self, 54 inner_agent: BaseAgent, 55 *args: Any, 56 max_reflections: int = 2, 57 model: Optional[str] = None, 58 critique_prompt: Optional[str] = None, 59 reflect_prompt: Optional[str] = None, 60 **kwargs: Any, 61 ) -> None: 62 super().__init__(inner_agent.llm_gateway, *args, **kwargs) 63 self.inner_agent = inner_agent 64 self.max_reflections = max_reflections 65 self.model = model 66 self._critique_prompt = critique_prompt 67 self._reflect_prompt = reflect_prompt
129 async def resume_from( 130 self, execution_id: str, from_reflection: Optional[int] = None 131 ) -> AgentResult: 132 """Resume a reflexion execution from latest (or selected) reflection checkpoint.""" 133 if not self.checkpoint_manager: 134 raise ValueError("Checkpoint manager is not configured for this agent") 135 136 checkpoint = None 137 if from_reflection is not None: 138 checkpoints = await self.checkpoint_manager.list_by_execution(execution_id) 139 for ckpt in checkpoints: 140 if ckpt.state.get("reflection_round") == from_reflection: 141 checkpoint = ckpt 142 break 143 else: 144 checkpoint = await self.checkpoint_manager.load_latest_for_execution(execution_id) 145 146 if checkpoint is None: 147 raise ValueError(f"No checkpoint found for execution_id={execution_id}") 148 149 state = checkpoint.state 150 task = str(state.get("task", "")) 151 context = state.get("context") or {} 152 reflection_round = int(state.get("reflection_round", 0)) 153 raw = state.get("result") or {} 154 155 steps = [ 156 AgentStep(**s) for s in raw.get("steps", []) if isinstance(s, dict) 157 ] 158 result = AgentResult( 159 output=str(raw.get("output", "")), 160 steps=steps, 161 success=bool(raw.get("success", True)), 162 error=raw.get("error"), 163 metadata=raw.get("metadata", {}), 164 ) 165 166 ctx = BehaviorContext( 167 agent_id=self.agent_id, 168 task=task, 169 execution_id=execution_id, 170 last_completed_step=reflection_round, 171 metadata={"resume_agent": self}, 172 ) 173 ctx = await self._apply_behaviors_before(ctx) 174 175 all_steps: List[AgentStep] = list(result.steps) 176 for idx in range(reflection_round, self.max_reflections): 177 critique = await self._critique(task, result.output) 178 if critique is None: 179 break 180 result = await self._reflect_and_retry(task, result.output, critique) 181 all_steps.extend(result.steps) 182 await self._save_checkpoint( 183 execution_id, 184 task, 185 context, 186 idx + 1, 187 result, 188 critique, 189 ) 190 ctx.last_completed_step = idx + 1 191 192 final = AgentResult( 193 output=result.output, 194 steps=all_steps, 195 success=result.success, 196 error=result.error, 197 metadata={**result.metadata, "execution_id": execution_id}, 198 ) 199 return await self._apply_behaviors_after(ctx, final)
Resume a reflexion execution from latest (or selected) reflection checkpoint.
205 async def execute( 206 self, task: str, context: Optional[Dict[str, Any]] = None 207 ) -> AgentResult: 208 execution_id = str(uuid4()) 209 run_context = context or {} 210 ctx = BehaviorContext( 211 agent_id=self.agent_id, 212 task=task, 213 execution_id=execution_id, 214 metadata={"resume_agent": self, **run_context}, 215 ) 216 ctx = await self._apply_behaviors_before(ctx) 217 self._log_execution_start(task) 218 219 with self._tracer.trace( 220 "reflexion_agent.execute", input=task, metadata={"agent_id": self.agent_id} 221 ) as trace: 222 try: 223 result = await self.inner_agent.execute(task, context) 224 all_steps: List[AgentStep] = list(result.steps) 225 await self._save_checkpoint(execution_id, task, run_context, 0, result) 226 ctx.last_completed_step = 0 227 228 for reflection_num in range(self.max_reflections): 229 with trace.span( 230 f"critique_{reflection_num + 1}", input=result.output 231 ) as cspan: 232 critique = await self._critique(task, result.output) 233 cspan.set_output({"satisfactory": critique is None, "critique": critique}) 234 235 if critique is None: 236 self._logger.info( 237 "Reflexion: output accepted", 238 agent_id=self.agent_id, 239 reflection=reflection_num + 1, 240 ) 241 break 242 243 self._logger.info( 244 "Reflexion: retrying", 245 agent_id=self.agent_id, 246 reflection=reflection_num + 1, 247 critique=critique, 248 ) 249 if self._metrics: 250 self._metrics.increment( 251 "agent.reflexions", agent_id=self.agent_id 252 ) 253 254 with trace.span(f"reflect_{reflection_num + 1}", input=critique) as rspan: 255 result = await self._reflect_and_retry(task, result.output, critique) 256 all_steps.extend(result.steps) 257 await self._save_checkpoint( 258 execution_id, 259 task, 260 run_context, 261 reflection_num + 1, 262 result, 263 critique, 264 ) 265 ctx.last_completed_step = reflection_num + 1 266 rspan.set_output(result.output) 267 268 final = AgentResult( 269 output=result.output, 270 steps=all_steps, 271 success=result.success, 272 error=result.error, 273 metadata={**result.metadata, "execution_id": execution_id}, 274 ) 275 final = await self._apply_behaviors_after(ctx, final) 276 self._log_execution_end(task, success=final.success, steps=len(all_steps)) 277 trace.set_output(final.output) 278 return final 279 280 except Exception as exc: 281 self._log_execution_error(task, exc) 282 fallback = await self._apply_behaviors_on_error(ctx, exc) 283 if fallback is not None: 284 return fallback 285 raise
Execute the task and return a result.
287 async def stream_execute( 288 self, task: str, context: Optional[Dict[str, Any]] = None 289 ) -> AsyncIterator[AgentStep]: 290 # Reflexion does full-cycle reflection, so streaming yields inner steps 291 async for step in self.inner_agent.stream_execute(task, context): 292 yield step
Execute the task, yielding each step as it completes.
28class ChainOfThoughtAgent(BaseAgent): 29 """ 30 Prompts the LLM to reason via a structured ``<thinking>`` scratchpad before 31 producing a final answer. 32 33 **When to use:** Tasks requiring structured multi-step reasoning with no 34 external tool calls — classification, analysis, summarisation, math, or 35 any question answerable from the LLM's own knowledge. Single LLM call: 36 lower latency and cost than :class:`ReActAgent`. 37 38 **When NOT to use:** Tasks that require searching a database, calling an 39 API, or any live data lookup — use :class:`ReActAgent` instead. 40 41 .. note:: 42 Passing a ``tool_registry`` to this agent has no effect — tools are 43 never invoked. If you need tool calls, use :class:`ReActAgent`. 44 45 The scratchpad is extracted and stored in the returned :class:`AgentStep` 46 as the ``thought`` field, while the final answer becomes ``observation``. 47 48 Args: 49 model: LLM model name (optional). 50 temperature: Sampling temperature (default: 0.1). 51 system_prompt: Override the default Chain-of-Thought prompt template. 52 Must contain ``{task}`` as a placeholder where the task will be 53 inserted. 54 55 All other args inherited from :class:`BaseAgent`. 56 """ 57 58 def __init__( 59 self, 60 *args: Any, 61 model: Optional[str] = None, 62 temperature: float = 0.1, 63 system_prompt: Optional[str] = None, 64 **kwargs: Any, 65 ) -> None: 66 super().__init__(*args, **kwargs) 67 self.model = model 68 self.temperature = temperature 69 self._system_prompt = system_prompt 70 71 #: Default CoT prompt template used when no ``system_prompt`` is passed. 72 #: Must contain ``{task}`` if overriding. 73 DEFAULT_SYSTEM_PROMPT: str = _COT_PROMPT 74 75 def _build_prompt(self, task: str) -> str: 76 template = self._system_prompt if self._system_prompt is not None else _COT_PROMPT 77 return template.format(task=task) 78 79 def _parse_response(self, text: str) -> tuple[str, str]: 80 """Returns (thinking_scratchpad, final_answer).""" 81 match = _THINKING_PATTERN.search(text) 82 if match: 83 thinking = match.group(1).strip() 84 answer = _THINKING_PATTERN.sub("", text).strip() 85 else: 86 # No <thinking> block — treat everything as the answer 87 thinking = "" 88 answer = text.strip() 89 return thinking, answer 90 91 async def _save_checkpoint( 92 self, 93 execution_id: str, 94 task: str, 95 context: Dict[str, Any], 96 thinking: str, 97 answer: str, 98 ) -> None: 99 if not self.checkpoint_manager: 100 return 101 await self.checkpoint_manager.save( 102 agent_id=self.agent_id, 103 execution_id=execution_id, 104 state={ 105 "execution_id": execution_id, 106 "task": task, 107 "context": context, 108 "step_number": 0, 109 "steps": [ 110 { 111 "thought": thinking, 112 "action": "chain_of_thought", 113 "action_input": {"task": task}, 114 "observation": answer, 115 "metadata": {}, 116 } 117 ], 118 "thinking": thinking, 119 "answer": answer, 120 }, 121 metadata={"step_number": 0}, 122 ttl=86400, 123 ) 124 125 async def resume_from(self, execution_id: str) -> AgentResult: 126 """Resume (idempotently) from a completed Chain-of-Thought checkpoint.""" 127 if not self.checkpoint_manager: 128 raise ValueError("Checkpoint manager is not configured for this agent") 129 130 checkpoint = await self.checkpoint_manager.load_latest_for_execution(execution_id) 131 if checkpoint is None: 132 raise ValueError(f"No checkpoint found for execution_id={execution_id}") 133 134 state = checkpoint.state 135 thinking = str(state.get("thinking", "")) 136 answer = str(state.get("answer", "")) 137 task = str(state.get("task", "")) 138 139 ctx = BehaviorContext( 140 agent_id=self.agent_id, 141 task=task, 142 execution_id=execution_id, 143 last_completed_step=0, 144 metadata={"resume_agent": self}, 145 ) 146 ctx = await self._apply_behaviors_before(ctx) 147 148 step = AgentStep( 149 thought=thinking, 150 action="chain_of_thought", 151 action_input={"task": task}, 152 observation=answer, 153 ) 154 result = AgentResult( 155 output=answer, 156 steps=[step], 157 success=True, 158 metadata={"thinking": thinking, "execution_id": execution_id}, 159 ) 160 return await self._apply_behaviors_after(ctx, result) 161 162 # ------------------------------------------------------------------ 163 # BaseAgent implementation 164 # ------------------------------------------------------------------ 165 166 async def execute( 167 self, task: str, context: Optional[Dict[str, Any]] = None 168 ) -> AgentResult: 169 execution_id = str(uuid4()) 170 run_context = context or {} 171 ctx = BehaviorContext( 172 agent_id=self.agent_id, 173 task=task, 174 execution_id=execution_id, 175 metadata={"resume_agent": self, **run_context}, 176 ) 177 ctx = await self._apply_behaviors_before(ctx) 178 self._log_execution_start(task) 179 180 with self._tracer.trace( 181 "cot_agent.execute", input=task, metadata={"agent_id": self.agent_id} 182 ) as trace: 183 try: 184 prompt = self._build_prompt(task) 185 186 perf_id = None 187 if self._performance_monitor: 188 perf_id = self._performance_monitor.start_request( 189 provider="llm_gateway", model=self.model or "default" 190 ) 191 192 with trace.generation("llm_call", model=self.model or "default", input=prompt) as gen: 193 try: 194 response = await self.llm_gateway.complete( 195 prompt, model=self.model, temperature=self.temperature 196 ) 197 if self._performance_monitor and perf_id: 198 self._performance_monitor.end_request( 199 request_id=perf_id, 200 prompt_tokens=response.usage.get("prompt_tokens", 0), 201 completion_tokens=response.usage.get("completion_tokens", 0), 202 success=True, 203 ) 204 gen.set_output(response.content) 205 gen.set_token_usage(**response.usage) 206 except Exception as exc: 207 if self._performance_monitor and perf_id: 208 self._performance_monitor.end_request( 209 request_id=perf_id, 210 prompt_tokens=0, 211 completion_tokens=0, 212 success=False, 213 error=str(exc), 214 ) 215 raise 216 217 thinking, answer = self._parse_response(response.content) 218 step = AgentStep( 219 thought=thinking, 220 action="chain_of_thought", 221 action_input={"task": task}, 222 observation=answer, 223 ) 224 225 if self.state_store: 226 await self.state_store.set( 227 f"agent:{self.agent_id}:last_steps", 228 [vars(step)], 229 ttl=3600, 230 ) 231 232 await self._save_checkpoint( 233 execution_id=execution_id, 234 task=task, 235 context=run_context, 236 thinking=thinking, 237 answer=answer, 238 ) 239 ctx.last_completed_step = 0 240 241 result = AgentResult( 242 output=answer, 243 steps=[step], 244 success=True, 245 metadata={"thinking": thinking, "execution_id": execution_id}, 246 ) 247 result = await self._apply_behaviors_after(ctx, result) 248 self._log_execution_end(task, success=True, steps=1) 249 trace.set_output(answer) 250 return result 251 252 except Exception as exc: 253 self._log_execution_error(task, exc) 254 fallback = await self._apply_behaviors_on_error(ctx, exc) 255 if fallback is not None: 256 return fallback 257 raise 258 259 async def stream_execute( 260 self, task: str, context: Optional[Dict[str, Any]] = None 261 ) -> AsyncIterator[AgentStep]: 262 result = await self.execute(task, context) 263 for step in result.steps: 264 yield step
Prompts the LLM to reason via a structured <thinking> scratchpad before
producing a final answer.
When to use: Tasks requiring structured multi-step reasoning with no
external tool calls — classification, analysis, summarisation, math, or
any question answerable from the LLM's own knowledge. Single LLM call:
lower latency and cost than ReActAgent.
When NOT to use: Tasks that require searching a database, calling an
API, or any live data lookup — use ReActAgent instead.
Passing a tool_registry to this agent has no effect — tools are
never invoked. If you need tool calls, use ReActAgent.
The scratchpad is extracted and stored in the returned AgentStep
as the thought field, while the final answer becomes observation.
Args:
model: LLM model name (optional).
temperature: Sampling temperature (default: 0.1).
system_prompt: Override the default Chain-of-Thought prompt template.
Must contain {task} as a placeholder where the task will be
inserted.
All other args inherited from BaseAgent.
58 def __init__( 59 self, 60 *args: Any, 61 model: Optional[str] = None, 62 temperature: float = 0.1, 63 system_prompt: Optional[str] = None, 64 **kwargs: Any, 65 ) -> None: 66 super().__init__(*args, **kwargs) 67 self.model = model 68 self.temperature = temperature 69 self._system_prompt = system_prompt
125 async def resume_from(self, execution_id: str) -> AgentResult: 126 """Resume (idempotently) from a completed Chain-of-Thought checkpoint.""" 127 if not self.checkpoint_manager: 128 raise ValueError("Checkpoint manager is not configured for this agent") 129 130 checkpoint = await self.checkpoint_manager.load_latest_for_execution(execution_id) 131 if checkpoint is None: 132 raise ValueError(f"No checkpoint found for execution_id={execution_id}") 133 134 state = checkpoint.state 135 thinking = str(state.get("thinking", "")) 136 answer = str(state.get("answer", "")) 137 task = str(state.get("task", "")) 138 139 ctx = BehaviorContext( 140 agent_id=self.agent_id, 141 task=task, 142 execution_id=execution_id, 143 last_completed_step=0, 144 metadata={"resume_agent": self}, 145 ) 146 ctx = await self._apply_behaviors_before(ctx) 147 148 step = AgentStep( 149 thought=thinking, 150 action="chain_of_thought", 151 action_input={"task": task}, 152 observation=answer, 153 ) 154 result = AgentResult( 155 output=answer, 156 steps=[step], 157 success=True, 158 metadata={"thinking": thinking, "execution_id": execution_id}, 159 ) 160 return await self._apply_behaviors_after(ctx, result)
Resume (idempotently) from a completed Chain-of-Thought checkpoint.
166 async def execute( 167 self, task: str, context: Optional[Dict[str, Any]] = None 168 ) -> AgentResult: 169 execution_id = str(uuid4()) 170 run_context = context or {} 171 ctx = BehaviorContext( 172 agent_id=self.agent_id, 173 task=task, 174 execution_id=execution_id, 175 metadata={"resume_agent": self, **run_context}, 176 ) 177 ctx = await self._apply_behaviors_before(ctx) 178 self._log_execution_start(task) 179 180 with self._tracer.trace( 181 "cot_agent.execute", input=task, metadata={"agent_id": self.agent_id} 182 ) as trace: 183 try: 184 prompt = self._build_prompt(task) 185 186 perf_id = None 187 if self._performance_monitor: 188 perf_id = self._performance_monitor.start_request( 189 provider="llm_gateway", model=self.model or "default" 190 ) 191 192 with trace.generation("llm_call", model=self.model or "default", input=prompt) as gen: 193 try: 194 response = await self.llm_gateway.complete( 195 prompt, model=self.model, temperature=self.temperature 196 ) 197 if self._performance_monitor and perf_id: 198 self._performance_monitor.end_request( 199 request_id=perf_id, 200 prompt_tokens=response.usage.get("prompt_tokens", 0), 201 completion_tokens=response.usage.get("completion_tokens", 0), 202 success=True, 203 ) 204 gen.set_output(response.content) 205 gen.set_token_usage(**response.usage) 206 except Exception as exc: 207 if self._performance_monitor and perf_id: 208 self._performance_monitor.end_request( 209 request_id=perf_id, 210 prompt_tokens=0, 211 completion_tokens=0, 212 success=False, 213 error=str(exc), 214 ) 215 raise 216 217 thinking, answer = self._parse_response(response.content) 218 step = AgentStep( 219 thought=thinking, 220 action="chain_of_thought", 221 action_input={"task": task}, 222 observation=answer, 223 ) 224 225 if self.state_store: 226 await self.state_store.set( 227 f"agent:{self.agent_id}:last_steps", 228 [vars(step)], 229 ttl=3600, 230 ) 231 232 await self._save_checkpoint( 233 execution_id=execution_id, 234 task=task, 235 context=run_context, 236 thinking=thinking, 237 answer=answer, 238 ) 239 ctx.last_completed_step = 0 240 241 result = AgentResult( 242 output=answer, 243 steps=[step], 244 success=True, 245 metadata={"thinking": thinking, "execution_id": execution_id}, 246 ) 247 result = await self._apply_behaviors_after(ctx, result) 248 self._log_execution_end(task, success=True, steps=1) 249 trace.set_output(answer) 250 return result 251 252 except Exception as exc: 253 self._log_execution_error(task, exc) 254 fallback = await self._apply_behaviors_on_error(ctx, exc) 255 if fallback is not None: 256 return fallback 257 raise
Execute the task and return a result.
259 async def stream_execute( 260 self, task: str, context: Optional[Dict[str, Any]] = None 261 ) -> AsyncIterator[AgentStep]: 262 result = await self.execute(task, context) 263 for step in result.steps: 264 yield step
Execute the task, yielding each step as it completes.
35class BaseRouter(ABC): 36 """ 37 Abstract router — selects the appropriate agent for a given request. 38 39 Args: 40 logger: Optional :class:`BasicLogger`. 41 metrics: Optional :class:`BasicMetricsCollector`. 42 tracer: Optional :class:`TracingProvider`. Falls back to ``get_tracer()``. 43 """ 44 45 def __init__( 46 self, 47 logger: Optional[BasicLogger] = None, 48 metrics: Optional[BasicMetricsCollector] = None, 49 tracer: Optional[TracingProvider] = None, 50 ) -> None: 51 self._logger = logger or BasicLogger(f"gmf_forge_ai.router.{self.__class__.__name__}") 52 self._metrics = metrics 53 self._tracer = tracer or get_tracer() 54 55 @abstractmethod 56 async def route(self, request: RoutingRequest) -> RoutingDecision: 57 """Select a target agent for the given request.""" 58 59 def _record_decision(self, decision: RoutingDecision) -> None: 60 self._logger.info( 61 "Routing decision", 62 target=decision.target, 63 confidence=decision.confidence, 64 reasoning=decision.reasoning, 65 ) 66 if self._metrics: 67 self._metrics.increment( 68 "router.decisions", 69 router=self.__class__.__name__, 70 target=decision.target, 71 )
Abstract router — selects the appropriate agent for a given request.
Args:
logger: Optional BasicLogger.
metrics: Optional BasicMetricsCollector.
tracer: Optional TracingProvider. Falls back to get_tracer().
15@dataclass 16class RoutingRequest: 17 """Encapsulates a routing decision request.""" 18 19 input: str 20 available_agents: List[str] 21 context: Dict[str, Any] = field(default_factory=dict) 22 metadata: Dict[str, Any] = field(default_factory=dict)
Encapsulates a routing decision request.
25@dataclass 26class RoutingDecision: 27 """The result of a routing decision.""" 28 29 target: str 30 confidence: float 31 reasoning: str 32 metadata: Dict[str, Any] = field(default_factory=dict)
The result of a routing decision.
26class LLMRouter(BaseRouter): 27 """ 28 Uses the :class:`UnifiedLLMGateway` to select an agent based on descriptions. 29 30 Args: 31 llm_gateway: The LLM gateway to use for routing. 32 agent_descriptions: Mapping of agent name → description shown to the LLM. 33 model: LLM model name (optional). 34 fallback_target: Agent to use if the LLM response cannot be parsed. 35 logger, metrics, tracer: Observability (optional). 36 37 Example:: 38 39 router = LLMRouter( 40 llm_gateway=gateway, 41 agent_descriptions={ 42 "search_agent": "Searches the web for information.", 43 "code_agent": "Writes and explains code.", 44 }, 45 fallback_target="search_agent", 46 ) 47 """ 48 49 def __init__( 50 self, 51 llm_gateway: Any, 52 agent_descriptions: Optional[Dict[str, str]] = None, 53 model: Optional[str] = None, 54 fallback_target: Optional[str] = None, 55 logger: Optional[BasicLogger] = None, 56 metrics: Optional[BasicMetricsCollector] = None, 57 tracer: Optional[TracingProvider] = None, 58 ) -> None: 59 super().__init__(logger=logger, metrics=metrics, tracer=tracer) 60 self.llm_gateway = llm_gateway 61 self.agent_descriptions: Dict[str, str] = agent_descriptions or {} 62 self.model = model 63 self.fallback_target = fallback_target 64 65 async def route(self, request: RoutingRequest) -> RoutingDecision: 66 # Build description block from provided map + request.available_agents 67 descriptions = {} 68 for name in request.available_agents: 69 descriptions[name] = self.agent_descriptions.get(name, "No description provided.") 70 71 desc_block = "\n".join( 72 f"- {name}: {desc}" for name, desc in descriptions.items() 73 ) 74 prompt = _ROUTER_PROMPT.format(agent_descriptions=desc_block, input=request.input) 75 76 with self._tracer.trace("llm_router.route", input=request.input) as trace: 77 response = await self.llm_gateway.complete(prompt, model=self.model, temperature=0.0) 78 raw = response.content.strip() 79 80 # Extract JSON even if wrapped in markdown 81 match = re.search(r"\{.*\}", raw, re.DOTALL) 82 decision: RoutingDecision 83 if match: 84 try: 85 data = json.loads(match.group()) 86 target = data.get("target", "") 87 if target not in request.available_agents: 88 target = self._fallback(request) 89 decision = RoutingDecision( 90 target=target, 91 confidence=float(data.get("confidence", 0.5)), 92 reasoning=str(data.get("reasoning", "")), 93 ) 94 except (json.JSONDecodeError, ValueError): 95 decision = RoutingDecision( 96 target=self._fallback(request), 97 confidence=0.0, 98 reasoning="LLM response could not be parsed.", 99 ) 100 else: 101 decision = RoutingDecision( 102 target=self._fallback(request), 103 confidence=0.0, 104 reasoning="No JSON found in LLM response.", 105 ) 106 107 self._record_decision(decision) 108 trace.set_output({"target": decision.target, "confidence": decision.confidence}) 109 return decision 110 111 def _fallback(self, request: RoutingRequest) -> str: 112 if self.fallback_target and self.fallback_target in request.available_agents: 113 return self.fallback_target 114 return request.available_agents[0] if request.available_agents else ""
Uses the UnifiedLLMGateway to select an agent based on descriptions.
Args: llm_gateway: The LLM gateway to use for routing. agent_descriptions: Mapping of agent name → description shown to the LLM. model: LLM model name (optional). fallback_target: Agent to use if the LLM response cannot be parsed. logger, metrics, tracer: Observability (optional).
Example::
router = LLMRouter(
llm_gateway=gateway,
agent_descriptions={
"search_agent": "Searches the web for information.",
"code_agent": "Writes and explains code.",
},
fallback_target="search_agent",
)
49 def __init__( 50 self, 51 llm_gateway: Any, 52 agent_descriptions: Optional[Dict[str, str]] = None, 53 model: Optional[str] = None, 54 fallback_target: Optional[str] = None, 55 logger: Optional[BasicLogger] = None, 56 metrics: Optional[BasicMetricsCollector] = None, 57 tracer: Optional[TracingProvider] = None, 58 ) -> None: 59 super().__init__(logger=logger, metrics=metrics, tracer=tracer) 60 self.llm_gateway = llm_gateway 61 self.agent_descriptions: Dict[str, str] = agent_descriptions or {} 62 self.model = model 63 self.fallback_target = fallback_target
65 async def route(self, request: RoutingRequest) -> RoutingDecision: 66 # Build description block from provided map + request.available_agents 67 descriptions = {} 68 for name in request.available_agents: 69 descriptions[name] = self.agent_descriptions.get(name, "No description provided.") 70 71 desc_block = "\n".join( 72 f"- {name}: {desc}" for name, desc in descriptions.items() 73 ) 74 prompt = _ROUTER_PROMPT.format(agent_descriptions=desc_block, input=request.input) 75 76 with self._tracer.trace("llm_router.route", input=request.input) as trace: 77 response = await self.llm_gateway.complete(prompt, model=self.model, temperature=0.0) 78 raw = response.content.strip() 79 80 # Extract JSON even if wrapped in markdown 81 match = re.search(r"\{.*\}", raw, re.DOTALL) 82 decision: RoutingDecision 83 if match: 84 try: 85 data = json.loads(match.group()) 86 target = data.get("target", "") 87 if target not in request.available_agents: 88 target = self._fallback(request) 89 decision = RoutingDecision( 90 target=target, 91 confidence=float(data.get("confidence", 0.5)), 92 reasoning=str(data.get("reasoning", "")), 93 ) 94 except (json.JSONDecodeError, ValueError): 95 decision = RoutingDecision( 96 target=self._fallback(request), 97 confidence=0.0, 98 reasoning="LLM response could not be parsed.", 99 ) 100 else: 101 decision = RoutingDecision( 102 target=self._fallback(request), 103 confidence=0.0, 104 reasoning="No JSON found in LLM response.", 105 ) 106 107 self._record_decision(decision) 108 trace.set_output({"target": decision.target, "confidence": decision.confidence}) 109 return decision
Select a target agent for the given request.
28class SemanticRouter(BaseRouter): 29 """ 30 Routes by measuring cosine similarity between the embedded input and 31 pre-computed route descriptor embeddings. 32 33 The caller provides an ``embed_fn`` — any async function that accepts a 34 string and returns ``List[float]``. This keeps the router independent of 35 any specific embedding provider. 36 37 Args: 38 embed_fn: ``async (text: str) -> List[float]``. 39 route_embeddings: Mapping of agent name → pre-computed embedding vector. 40 If not provided upfront, call :meth:`add_route` to register routes. 41 route_descriptions: Optional text descriptions (stored for reference). 42 fallback_target: Agent to use if similarity is below threshold. 43 similarity_threshold: Minimum cosine similarity to accept a route (default 0.0). 44 logger, metrics, tracer: Observability (optional). 45 46 Example:: 47 48 router = SemanticRouter(embed_fn=my_embed) 49 await router.add_route("search_agent", "web search and information retrieval") 50 await router.add_route("code_agent", "code generation debugging and review") 51 decision = await router.route(request) 52 """ 53 54 def __init__( 55 self, 56 embed_fn: Any, 57 route_embeddings: Optional[Dict[str, List[float]]] = None, 58 route_descriptions: Optional[Dict[str, str]] = None, 59 fallback_target: Optional[str] = None, 60 similarity_threshold: float = 0.0, 61 logger: Optional[BasicLogger] = None, 62 metrics: Optional[BasicMetricsCollector] = None, 63 tracer: Optional[TracingProvider] = None, 64 ) -> None: 65 super().__init__(logger=logger, metrics=metrics, tracer=tracer) 66 self._embed_fn = embed_fn 67 self._route_embeddings: Dict[str, List[float]] = route_embeddings or {} 68 self._route_descriptions: Dict[str, str] = route_descriptions or {} 69 self.fallback_target = fallback_target 70 self.similarity_threshold = similarity_threshold 71 72 async def add_route(self, agent_name: str, description: str) -> None: 73 """Embed ``description`` and store it as the routing vector for ``agent_name``.""" 74 embedding = await self._embed_fn(description) 75 self._route_embeddings[agent_name] = embedding 76 self._route_descriptions[agent_name] = description 77 78 async def route(self, request: RoutingRequest) -> RoutingDecision: 79 candidates = [a for a in request.available_agents if a in self._route_embeddings] 80 81 with self._tracer.trace("semantic_router.route", input=request.input) as trace: 82 if not candidates: 83 fallback = self.fallback_target or ( 84 request.available_agents[0] if request.available_agents else "" 85 ) 86 decision = RoutingDecision( 87 target=fallback, 88 confidence=0.0, 89 reasoning="No route embeddings available for candidates.", 90 ) 91 self._record_decision(decision) 92 trace.set_output({"target": decision.target}) 93 return decision 94 95 input_embedding: List[float] = await self._embed_fn(request.input) 96 97 best_target = "" 98 best_score = -1.0 99 for agent_name in candidates: 100 route_emb = self._route_embeddings[agent_name] 101 score = _cosine(input_embedding, route_emb) 102 if score > best_score: 103 best_score = score 104 best_target = agent_name 105 106 if best_score < self.similarity_threshold: 107 best_target = self.fallback_target or candidates[0] 108 109 desc = self._route_descriptions.get(best_target, "") 110 decision = RoutingDecision( 111 target=best_target, 112 confidence=round(best_score, 4), 113 reasoning=f"Highest cosine similarity ({best_score:.4f}) to '{desc}'", 114 ) 115 self._record_decision(decision) 116 trace.set_output({"target": decision.target, "confidence": decision.confidence}) 117 return decision
Routes by measuring cosine similarity between the embedded input and pre-computed route descriptor embeddings.
The caller provides an embed_fn — any async function that accepts a
string and returns List[float]. This keeps the router independent of
any specific embedding provider.
Args:
embed_fn: async (text: str) -> List[float].
route_embeddings: Mapping of agent name → pre-computed embedding vector.
If not provided upfront, call add_route() to register routes.
route_descriptions: Optional text descriptions (stored for reference).
fallback_target: Agent to use if similarity is below threshold.
similarity_threshold: Minimum cosine similarity to accept a route (default 0.0).
logger, metrics, tracer: Observability (optional).
Example::
router = SemanticRouter(embed_fn=my_embed)
await router.add_route("search_agent", "web search and information retrieval")
await router.add_route("code_agent", "code generation debugging and review")
decision = await router.route(request)
54 def __init__( 55 self, 56 embed_fn: Any, 57 route_embeddings: Optional[Dict[str, List[float]]] = None, 58 route_descriptions: Optional[Dict[str, str]] = None, 59 fallback_target: Optional[str] = None, 60 similarity_threshold: float = 0.0, 61 logger: Optional[BasicLogger] = None, 62 metrics: Optional[BasicMetricsCollector] = None, 63 tracer: Optional[TracingProvider] = None, 64 ) -> None: 65 super().__init__(logger=logger, metrics=metrics, tracer=tracer) 66 self._embed_fn = embed_fn 67 self._route_embeddings: Dict[str, List[float]] = route_embeddings or {} 68 self._route_descriptions: Dict[str, str] = route_descriptions or {} 69 self.fallback_target = fallback_target 70 self.similarity_threshold = similarity_threshold
72 async def add_route(self, agent_name: str, description: str) -> None: 73 """Embed ``description`` and store it as the routing vector for ``agent_name``.""" 74 embedding = await self._embed_fn(description) 75 self._route_embeddings[agent_name] = embedding 76 self._route_descriptions[agent_name] = description
Embed description and store it as the routing vector for agent_name.
78 async def route(self, request: RoutingRequest) -> RoutingDecision: 79 candidates = [a for a in request.available_agents if a in self._route_embeddings] 80 81 with self._tracer.trace("semantic_router.route", input=request.input) as trace: 82 if not candidates: 83 fallback = self.fallback_target or ( 84 request.available_agents[0] if request.available_agents else "" 85 ) 86 decision = RoutingDecision( 87 target=fallback, 88 confidence=0.0, 89 reasoning="No route embeddings available for candidates.", 90 ) 91 self._record_decision(decision) 92 trace.set_output({"target": decision.target}) 93 return decision 94 95 input_embedding: List[float] = await self._embed_fn(request.input) 96 97 best_target = "" 98 best_score = -1.0 99 for agent_name in candidates: 100 route_emb = self._route_embeddings[agent_name] 101 score = _cosine(input_embedding, route_emb) 102 if score > best_score: 103 best_score = score 104 best_target = agent_name 105 106 if best_score < self.similarity_threshold: 107 best_target = self.fallback_target or candidates[0] 108 109 desc = self._route_descriptions.get(best_target, "") 110 decision = RoutingDecision( 111 target=best_target, 112 confidence=round(best_score, 4), 113 reasoning=f"Highest cosine similarity ({best_score:.4f}) to '{desc}'", 114 ) 115 self._record_decision(decision) 116 trace.set_output({"target": decision.target, "confidence": decision.confidence}) 117 return decision
Select a target agent for the given request.
19class RuleBasedRouter(BaseRouter): 20 """ 21 Routes based on an ordered list of ``(condition_fn, target_agent)`` rules. 22 23 Rules are evaluated in order; the first rule whose condition returns ``True`` 24 wins. Falls back to ``fallback_target`` if no rule matches. 25 26 Args: 27 rules: Ordered list of ``(condition_fn, agent_name)`` tuples. 28 Conditions may be sync or async callables accepting a 29 :class:`RoutingRequest` and returning ``bool``. 30 fallback_target: Agent to use if no rule matches. 31 logger, metrics, tracer: Observability (optional). 32 33 Example:: 34 35 router = RuleBasedRouter( 36 rules=[ 37 (lambda r: "code" in r.input.lower(), "code_agent"), 38 (lambda r: "search" in r.input.lower(), "search_agent"), 39 ], 40 fallback_target="general_agent", 41 ) 42 """ 43 44 def __init__( 45 self, 46 rules: Optional[List[Tuple[ConditionFn, str]]] = None, 47 fallback_target: Optional[str] = None, 48 logger: Optional[BasicLogger] = None, 49 metrics: Optional[BasicMetricsCollector] = None, 50 tracer: Optional[TracingProvider] = None, 51 ) -> None: 52 super().__init__(logger=logger, metrics=metrics, tracer=tracer) 53 self._rules: List[Tuple[ConditionFn, str]] = rules or [] 54 self.fallback_target = fallback_target 55 56 def add_rule(self, condition: ConditionFn, target: str) -> None: 57 """Append a rule to the end of the list.""" 58 self._rules.append((condition, target)) 59 60 async def route(self, request: RoutingRequest) -> RoutingDecision: 61 import asyncio, inspect 62 63 with self._tracer.trace("rule_router.route", input=request.input) as trace: 64 for i, (condition, target) in enumerate(self._rules): 65 if target not in request.available_agents: 66 continue 67 result = condition(request) 68 if inspect.isawaitable(result): 69 matched: bool = await result 70 else: 71 matched = bool(result) 72 73 if matched: 74 decision = RoutingDecision( 75 target=target, 76 confidence=1.0, 77 reasoning=f"Rule {i} matched.", 78 ) 79 self._record_decision(decision) 80 trace.set_output({"target": target, "rule_index": i}) 81 return decision 82 83 # No rule matched — use fallback 84 fallback = self.fallback_target or ( 85 request.available_agents[0] if request.available_agents else "" 86 ) 87 decision = RoutingDecision( 88 target=fallback, 89 confidence=0.0, 90 reasoning="No rule matched; using fallback target.", 91 ) 92 self._record_decision(decision) 93 trace.set_output({"target": fallback, "fallback": True}) 94 return decision
Routes based on an ordered list of (condition_fn, target_agent) rules.
Rules are evaluated in order; the first rule whose condition returns True
wins. Falls back to fallback_target if no rule matches.
Args:
rules: Ordered list of (condition_fn, agent_name) tuples.
Conditions may be sync or async callables accepting a
RoutingRequest and returning bool.
fallback_target: Agent to use if no rule matches.
logger, metrics, tracer: Observability (optional).
Example::
router = RuleBasedRouter(
rules=[
(lambda r: "code" in r.input.lower(), "code_agent"),
(lambda r: "search" in r.input.lower(), "search_agent"),
],
fallback_target="general_agent",
)
44 def __init__( 45 self, 46 rules: Optional[List[Tuple[ConditionFn, str]]] = None, 47 fallback_target: Optional[str] = None, 48 logger: Optional[BasicLogger] = None, 49 metrics: Optional[BasicMetricsCollector] = None, 50 tracer: Optional[TracingProvider] = None, 51 ) -> None: 52 super().__init__(logger=logger, metrics=metrics, tracer=tracer) 53 self._rules: List[Tuple[ConditionFn, str]] = rules or [] 54 self.fallback_target = fallback_target
56 def add_rule(self, condition: ConditionFn, target: str) -> None: 57 """Append a rule to the end of the list.""" 58 self._rules.append((condition, target))
Append a rule to the end of the list.
60 async def route(self, request: RoutingRequest) -> RoutingDecision: 61 import asyncio, inspect 62 63 with self._tracer.trace("rule_router.route", input=request.input) as trace: 64 for i, (condition, target) in enumerate(self._rules): 65 if target not in request.available_agents: 66 continue 67 result = condition(request) 68 if inspect.isawaitable(result): 69 matched: bool = await result 70 else: 71 matched = bool(result) 72 73 if matched: 74 decision = RoutingDecision( 75 target=target, 76 confidence=1.0, 77 reasoning=f"Rule {i} matched.", 78 ) 79 self._record_decision(decision) 80 trace.set_output({"target": target, "rule_index": i}) 81 return decision 82 83 # No rule matched — use fallback 84 fallback = self.fallback_target or ( 85 request.available_agents[0] if request.available_agents else "" 86 ) 87 decision = RoutingDecision( 88 target=fallback, 89 confidence=0.0, 90 reasoning="No rule matched; using fallback target.", 91 ) 92 self._record_decision(decision) 93 trace.set_output({"target": fallback, "fallback": True}) 94 return decision
Select a target agent for the given request.
15class LoadBalancingRouter(BaseRouter): 16 """ 17 Routes in round-robin order across the available agent list. 18 19 Call counts are tracked in an :class:`InMemoryStateStore` so the counter 20 survives across multiple ``route()`` calls within the same process. 21 22 Args: 23 agent_weights: Optional mapping of agent name → relative weight for 24 weighted round-robin. All weights default to 1 if not specified. 25 logger, metrics, tracer: Observability (optional). 26 27 Example:: 28 29 router = LoadBalancingRouter() 30 # or with weights: 31 router = LoadBalancingRouter(agent_weights={"heavy_agent": 2, "light_agent": 1}) 32 """ 33 34 def __init__( 35 self, 36 agent_weights: Optional[Dict[str, int]] = None, 37 logger: Optional[BasicLogger] = None, 38 metrics: Optional[BasicMetricsCollector] = None, 39 tracer: Optional[TracingProvider] = None, 40 ) -> None: 41 super().__init__(logger=logger, metrics=metrics, tracer=tracer) 42 self._weights = agent_weights or {} 43 self._store = InMemoryStateStore() 44 self._call_total: int = 0 45 46 def _expand_agents(self, agents: List[str]) -> List[str]: 47 """Expand agents by weight into a weighted list for round-robin.""" 48 expanded: List[str] = [] 49 for agent in agents: 50 weight = self._weights.get(agent, 1) 51 expanded.extend([agent] * weight) 52 return expanded 53 54 async def route(self, request: RoutingRequest) -> RoutingDecision: 55 agents = request.available_agents 56 if not agents: 57 return RoutingDecision( 58 target="", confidence=0.0, reasoning="No agents available." 59 ) 60 61 with self._tracer.trace("lb_router.route", input=request.input) as trace: 62 # Load counter from store (survives if store is shared) 63 counter: int = await self._store.get(_COUNTER_KEY) or 0 64 expanded = self._expand_agents(agents) 65 target = expanded[counter % len(expanded)] 66 counter += 1 67 await self._store.set(_COUNTER_KEY, counter) 68 69 if self._metrics: 70 self._metrics.increment("router.lb.calls", target=target) 71 72 decision = RoutingDecision( 73 target=target, 74 confidence=1.0, 75 reasoning=f"Round-robin selection (call #{counter}).", 76 metadata={"call_count": counter}, 77 ) 78 self._record_decision(decision) 79 trace.set_output({"target": target, "call_count": counter}) 80 return decision
Routes in round-robin order across the available agent list.
Call counts are tracked in an InMemoryStateStore so the counter
survives across multiple route() calls within the same process.
Args: agent_weights: Optional mapping of agent name → relative weight for weighted round-robin. All weights default to 1 if not specified. logger, metrics, tracer: Observability (optional).
Example::
router = LoadBalancingRouter()
# or with weights:
router = LoadBalancingRouter(agent_weights={"heavy_agent": 2, "light_agent": 1})
34 def __init__( 35 self, 36 agent_weights: Optional[Dict[str, int]] = None, 37 logger: Optional[BasicLogger] = None, 38 metrics: Optional[BasicMetricsCollector] = None, 39 tracer: Optional[TracingProvider] = None, 40 ) -> None: 41 super().__init__(logger=logger, metrics=metrics, tracer=tracer) 42 self._weights = agent_weights or {} 43 self._store = InMemoryStateStore() 44 self._call_total: int = 0
54 async def route(self, request: RoutingRequest) -> RoutingDecision: 55 agents = request.available_agents 56 if not agents: 57 return RoutingDecision( 58 target="", confidence=0.0, reasoning="No agents available." 59 ) 60 61 with self._tracer.trace("lb_router.route", input=request.input) as trace: 62 # Load counter from store (survives if store is shared) 63 counter: int = await self._store.get(_COUNTER_KEY) or 0 64 expanded = self._expand_agents(agents) 65 target = expanded[counter % len(expanded)] 66 counter += 1 67 await self._store.set(_COUNTER_KEY, counter) 68 69 if self._metrics: 70 self._metrics.increment("router.lb.calls", target=target) 71 72 decision = RoutingDecision( 73 target=target, 74 confidence=1.0, 75 reasoning=f"Round-robin selection (call #{counter}).", 76 metadata={"call_count": counter}, 77 ) 78 self._record_decision(decision) 79 trace.set_output({"target": target, "call_count": counter}) 80 return decision
Select a target agent for the given request.
52class BaseWorkflow(ABC): 53 """ 54 Abstract base class for all workflow engines. 55 56 Args: 57 logger: Optional :class:`BasicLogger`. 58 metrics: Optional :class:`BasicMetricsCollector`. 59 tracer: Optional :class:`TracingProvider`. Falls back to ``get_tracer()``. 60 """ 61 62 def __init__( 63 self, 64 logger: Optional[BasicLogger] = None, 65 metrics: Optional[BasicMetricsCollector] = None, 66 tracer: Optional[TracingProvider] = None, 67 ) -> None: 68 self._logger = logger or BasicLogger(f"gmf_forge_ai.workflow.{self.__class__.__name__}") 69 self._metrics = metrics 70 self._tracer = tracer or get_tracer() 71 72 @abstractmethod 73 async def run(self, initial_input: Dict[str, Any]) -> WorkflowResult: 74 """Execute the workflow and return the aggregated result.""" 75 76 def _log_node_start(self, node_id: str) -> None: 77 self._logger.info("Workflow node started", node_id=node_id) 78 79 def _log_node_end(self, node_id: str, success: bool) -> None: 80 self._logger.info("Workflow node finished", node_id=node_id, success=success) 81 if self._metrics: 82 self._metrics.increment("workflow.nodes_executed", node_id=node_id, success=str(success))
Abstract base class for all workflow engines.
Args:
logger: Optional BasicLogger.
metrics: Optional BasicMetricsCollector.
tracer: Optional TracingProvider. Falls back to get_tracer().
15@dataclass 16class WorkflowNode: 17 """A single node in a workflow graph.""" 18 19 node_id: str 20 agent: "BaseAgent" 21 inputs_map: Dict[str, str] = field(default_factory=dict) 22 """Maps node input keys to keys from the initial_input dict or prior node outputs.""" 23 outputs_map: Dict[str, str] = field(default_factory=dict) 24 """Renames node output keys before storing in the accumulated output dict.""" 25 metadata: Dict[str, Any] = field(default_factory=dict)
A single node in a workflow graph.
28@dataclass 29class WorkflowEdge: 30 """A directed edge between two workflow nodes.""" 31 32 source: str 33 """node_id of the source node.""" 34 target: str 35 """node_id of the target node.""" 36 condition: Optional[Callable[["AgentResult"], bool]] = None 37 """Optional guard — edge is only traversed if this returns True."""
A directed edge between two workflow nodes.
40@dataclass 41class WorkflowResult: 42 """The aggregated result of a workflow run.""" 43 44 outputs: Dict[str, "AgentResult"] = field(default_factory=dict) 45 """Keyed by node_id.""" 46 final_output: str = "" 47 success: bool = True 48 error: Optional[str] = None 49 metadata: Dict[str, Any] = field(default_factory=dict)
The aggregated result of a workflow run.
18class DAGWorkflow(BaseWorkflow): 19 """ 20 Executes workflow nodes in topological order using ``networkx``. 21 22 Nodes with no mutual dependency are executed concurrently with 23 ``asyncio.gather``. 24 25 Args: 26 nodes: List of :class:`WorkflowNode` objects. 27 edges: List of :class:`WorkflowEdge` objects. 28 logger, metrics, tracer: Observability (optional). 29 30 Example:: 31 32 wf = DAGWorkflow( 33 nodes=[search_node, summarise_node], 34 edges=[WorkflowEdge(source="search", target="summarise")], 35 ) 36 result = await wf.run({"query": "AI trends 2026"}) 37 """ 38 39 def __init__( 40 self, 41 nodes: Optional[List[WorkflowNode]] = None, 42 edges: Optional[List[WorkflowEdge]] = None, 43 logger: Optional[BasicLogger] = None, 44 metrics: Optional[BasicMetricsCollector] = None, 45 tracer: Optional[TracingProvider] = None, 46 ) -> None: 47 super().__init__(logger=logger, metrics=metrics, tracer=tracer) 48 self._nodes: Dict[str, WorkflowNode] = {n.node_id: n for n in (nodes or [])} 49 self._edges: List[WorkflowEdge] = edges or [] 50 51 def add_node(self, node: WorkflowNode) -> None: 52 self._nodes[node.node_id] = node 53 54 def add_edge(self, edge: WorkflowEdge) -> None: 55 self._edges.append(edge) 56 57 # ------------------------------------------------------------------ 58 # Topological sort (no networkx dependency — implemented manually) 59 # ------------------------------------------------------------------ 60 61 def _build_adjacency(self) -> tuple[Dict[str, List[str]], Dict[str, int]]: 62 """Returns (successors dict, in-degree dict).""" 63 successors: Dict[str, List[str]] = {nid: [] for nid in self._nodes} 64 in_degree: Dict[str, int] = {nid: 0 for nid in self._nodes} 65 for edge in self._edges: 66 if edge.source in successors and edge.target in self._nodes: 67 successors[edge.source].append(edge.target) 68 in_degree[edge.target] += 1 69 return successors, in_degree 70 71 def _topological_levels(self) -> List[List[str]]: 72 """ 73 Returns node IDs grouped into levels. All nodes in a level can run in 74 parallel because none depend on each other. 75 """ 76 successors, in_degree = self._build_adjacency() 77 levels: List[List[str]] = [] 78 ready: List[str] = [nid for nid, deg in in_degree.items() if deg == 0] 79 80 while ready: 81 levels.append(list(ready)) 82 next_ready: List[str] = [] 83 for nid in ready: 84 for successor in successors[nid]: 85 in_degree[successor] -= 1 86 if in_degree[successor] == 0: 87 next_ready.append(successor) 88 ready = next_ready 89 90 if sum(len(lvl) for lvl in levels) < len(self._nodes): 91 raise ValueError("DAGWorkflow: cycle detected in the workflow graph.") 92 93 return levels 94 95 # ------------------------------------------------------------------ 96 # Input resolution 97 # ------------------------------------------------------------------ 98 99 def _resolve_inputs( 100 self, 101 node: WorkflowNode, 102 initial_input: Dict[str, Any], 103 all_outputs: Dict[str, Any], 104 ) -> Dict[str, Any]: 105 """Build the input dict for a node from initial_input and prior outputs.""" 106 if not node.inputs_map: 107 return {**initial_input, **all_outputs} 108 109 resolved: Dict[str, Any] = {} 110 merged = {**initial_input, **all_outputs} 111 for local_key, source_key in node.inputs_map.items(): 112 resolved[local_key] = merged.get(source_key) 113 return resolved 114 115 # ------------------------------------------------------------------ 116 # BaseWorkflow implementation 117 # ------------------------------------------------------------------ 118 119 async def run(self, initial_input: Dict[str, Any]) -> WorkflowResult: 120 self._logger.info("DAGWorkflow started", node_count=len(self._nodes)) 121 122 if self._metrics: 123 self._metrics.increment("workflow.runs", workflow="dag") 124 125 with self._tracer.trace( 126 "dag_workflow.run", 127 input=str(initial_input), 128 metadata={"node_count": len(self._nodes)}, 129 ) as trace: 130 try: 131 levels = self._topological_levels() 132 all_outputs: Dict[str, Any] = {} 133 node_results: Dict[str, Any] = {} 134 135 for level in levels: 136 # Run all nodes in this level concurrently 137 async def _run_node(node_id: str) -> tuple[str, Any]: 138 node = self._nodes[node_id] 139 task_input = self._resolve_inputs(node, initial_input, all_outputs) 140 task_text = task_input.get("task") or task_input.get("query", str(task_input)) 141 self._log_node_start(node_id) 142 143 with trace.span(f"node.{node_id}", input=task_text) as span: 144 try: 145 result = await node.agent.execute(task_text, context=task_input) 146 self._log_node_end(node_id, success=result.success) 147 span.set_output(result.output) 148 return node_id, result 149 except Exception as exc: 150 self._log_node_end(node_id, success=False) 151 span.set_error(exc) 152 raise 153 154 results = await asyncio.gather( 155 *[_run_node(nid) for nid in level], return_exceptions=False 156 ) 157 for nid, result in results: 158 node_results[nid] = result 159 # Apply outputs_map renaming 160 node = self._nodes[nid] 161 if node.outputs_map: 162 for out_key, mapped_key in node.outputs_map.items(): 163 all_outputs[mapped_key] = getattr(result, out_key, None) 164 else: 165 all_outputs[nid] = result.output 166 167 # Final output = last level's last node output 168 last_level = levels[-1] if levels else [] 169 final_output = ( 170 node_results[last_level[-1]].output if last_level else "" 171 ) 172 wf_result = WorkflowResult( 173 outputs=node_results, 174 final_output=final_output, 175 success=True, 176 ) 177 self._logger.info("DAGWorkflow completed", nodes_run=len(node_results)) 178 trace.set_output(final_output) 179 return wf_result 180 181 except Exception as exc: 182 self._logger.error("DAGWorkflow failed", error=str(exc)) 183 trace.set_error(exc) 184 return WorkflowResult(success=False, error=str(exc))
Executes workflow nodes in topological order using networkx.
Nodes with no mutual dependency are executed concurrently with
asyncio.gather.
Args:
nodes: List of WorkflowNode objects.
edges: List of WorkflowEdge objects.
logger, metrics, tracer: Observability (optional).
Example::
wf = DAGWorkflow(
nodes=[search_node, summarise_node],
edges=[WorkflowEdge(source="search", target="summarise")],
)
result = await wf.run({"query": "AI trends 2026"})
39 def __init__( 40 self, 41 nodes: Optional[List[WorkflowNode]] = None, 42 edges: Optional[List[WorkflowEdge]] = None, 43 logger: Optional[BasicLogger] = None, 44 metrics: Optional[BasicMetricsCollector] = None, 45 tracer: Optional[TracingProvider] = None, 46 ) -> None: 47 super().__init__(logger=logger, metrics=metrics, tracer=tracer) 48 self._nodes: Dict[str, WorkflowNode] = {n.node_id: n for n in (nodes or [])} 49 self._edges: List[WorkflowEdge] = edges or []
119 async def run(self, initial_input: Dict[str, Any]) -> WorkflowResult: 120 self._logger.info("DAGWorkflow started", node_count=len(self._nodes)) 121 122 if self._metrics: 123 self._metrics.increment("workflow.runs", workflow="dag") 124 125 with self._tracer.trace( 126 "dag_workflow.run", 127 input=str(initial_input), 128 metadata={"node_count": len(self._nodes)}, 129 ) as trace: 130 try: 131 levels = self._topological_levels() 132 all_outputs: Dict[str, Any] = {} 133 node_results: Dict[str, Any] = {} 134 135 for level in levels: 136 # Run all nodes in this level concurrently 137 async def _run_node(node_id: str) -> tuple[str, Any]: 138 node = self._nodes[node_id] 139 task_input = self._resolve_inputs(node, initial_input, all_outputs) 140 task_text = task_input.get("task") or task_input.get("query", str(task_input)) 141 self._log_node_start(node_id) 142 143 with trace.span(f"node.{node_id}", input=task_text) as span: 144 try: 145 result = await node.agent.execute(task_text, context=task_input) 146 self._log_node_end(node_id, success=result.success) 147 span.set_output(result.output) 148 return node_id, result 149 except Exception as exc: 150 self._log_node_end(node_id, success=False) 151 span.set_error(exc) 152 raise 153 154 results = await asyncio.gather( 155 *[_run_node(nid) for nid in level], return_exceptions=False 156 ) 157 for nid, result in results: 158 node_results[nid] = result 159 # Apply outputs_map renaming 160 node = self._nodes[nid] 161 if node.outputs_map: 162 for out_key, mapped_key in node.outputs_map.items(): 163 all_outputs[mapped_key] = getattr(result, out_key, None) 164 else: 165 all_outputs[nid] = result.output 166 167 # Final output = last level's last node output 168 last_level = levels[-1] if levels else [] 169 final_output = ( 170 node_results[last_level[-1]].output if last_level else "" 171 ) 172 wf_result = WorkflowResult( 173 outputs=node_results, 174 final_output=final_output, 175 success=True, 176 ) 177 self._logger.info("DAGWorkflow completed", nodes_run=len(node_results)) 178 trace.set_output(final_output) 179 return wf_result 180 181 except Exception as exc: 182 self._logger.error("DAGWorkflow failed", error=str(exc)) 183 trace.set_error(exc) 184 return WorkflowResult(success=False, error=str(exc))
Execute the workflow and return the aggregated result.
18class StateMachineWorkflow(BaseWorkflow): 19 """ 20 Drives execution through an explicit state/transition table. 21 22 Each state maps to an :class:`BaseAgent`. On completion, transition 23 rules are evaluated in order; the first matching rule determines the 24 next state. ``None`` as a guard means "always transition here". 25 26 Args: 27 states: Mapping of state name → agent to run in that state. 28 transitions: Mapping of state name → list of ``(guard_fn | None, next_state)`` 29 tuples evaluated in order after the state's agent finishes. 30 initial_state: Name of the starting state. 31 terminal_states: Set of state names that end the workflow. 32 state_store: Optional store to persist current state across restarts. 33 max_transitions: Safety cap on total state transitions (default: 50). 34 logger, metrics, tracer: Observability (optional). 35 36 Example:: 37 38 wf = StateMachineWorkflow( 39 states={"gather": search_agent, "summarise": summarise_agent}, 40 transitions={ 41 "gather": [(None, "summarise")], 42 "summarise": [(None, "END")], 43 }, 44 initial_state="gather", 45 terminal_states={"END"}, 46 ) 47 result = await wf.run({"query": "AI news"}) 48 """ 49 50 _TERMINAL = "__terminal__" 51 52 def __init__( 53 self, 54 states: Optional[Dict[str, BaseAgent]] = None, 55 transitions: Optional[Dict[str, List[TransitionRule]]] = None, 56 initial_state: str = "", 57 terminal_states: Optional[set] = None, 58 state_store: Optional[BaseStateStore] = None, 59 max_transitions: int = 50, 60 logger: Optional[BasicLogger] = None, 61 metrics: Optional[BasicMetricsCollector] = None, 62 tracer: Optional[TracingProvider] = None, 63 ) -> None: 64 super().__init__(logger=logger, metrics=metrics, tracer=tracer) 65 self._states: Dict[str, BaseAgent] = states or {} 66 self._transitions: Dict[str, List[TransitionRule]] = transitions or {} 67 self.initial_state = initial_state 68 self.terminal_states: set = terminal_states or set() 69 self._state_store = state_store 70 self.max_transitions = max_transitions 71 72 def _next_state(self, current: str, result: AgentResult) -> Optional[str]: 73 rules = self._transitions.get(current, []) 74 for guard, next_state in rules: 75 if guard is None or guard(result): 76 return next_state 77 return None 78 79 async def run(self, initial_input: Dict[str, Any]) -> WorkflowResult: 80 self._logger.info( 81 "StateMachineWorkflow started", initial_state=self.initial_state 82 ) 83 84 current_state = self.initial_state 85 # Resume from persisted state if available 86 if self._state_store: 87 persisted = await self._state_store.get("__sm_current_state__") 88 if persisted: 89 current_state = persisted 90 91 node_results: Dict[str, AgentResult] = {} 92 context = dict(initial_input) 93 final_output = "" 94 95 with self._tracer.trace( 96 "state_machine.run", input=str(initial_input), metadata={"initial": current_state} 97 ) as trace: 98 for _ in range(self.max_transitions): 99 if current_state in self.terminal_states or current_state is None: 100 break 101 102 agent = self._states.get(current_state) 103 if agent is None: 104 self._logger.error( 105 "No agent for state", state=current_state 106 ) 107 break 108 109 self._log_node_start(current_state) 110 with trace.span(f"state.{current_state}", input=str(context)) as span: 111 try: 112 task = context.get("task") or context.get("query", str(context)) 113 result = await agent.execute(task, context=context) 114 self._log_node_end(current_state, success=result.success) 115 node_results[current_state] = result 116 final_output = result.output 117 context["last_output"] = result.output 118 span.set_output(result.output) 119 120 if self._state_store: 121 next_s = self._next_state(current_state, result) 122 await self._state_store.set("__sm_current_state__", next_s or "__terminal__") 123 124 if self._metrics: 125 self._metrics.increment( 126 "workflow.state_transitions", 127 from_state=current_state, 128 ) 129 130 current_state = self._next_state(current_state, result) or self._TERMINAL 131 132 except Exception as exc: 133 self._log_node_end(current_state, success=False) 134 span.set_error(exc) 135 self._logger.error( 136 "StateMachineWorkflow state error", 137 state=current_state, 138 error=str(exc), 139 ) 140 trace.set_error(exc) 141 return WorkflowResult( 142 outputs=node_results, success=False, error=str(exc) 143 ) 144 145 trace.set_output(final_output) 146 return WorkflowResult(outputs=node_results, final_output=final_output, success=True)
Drives execution through an explicit state/transition table.
Each state maps to an BaseAgent. On completion, transition
rules are evaluated in order; the first matching rule determines the
next state. None as a guard means "always transition here".
Args:
states: Mapping of state name → agent to run in that state.
transitions: Mapping of state name → list of (guard_fn | None, next_state)
tuples evaluated in order after the state's agent finishes.
initial_state: Name of the starting state.
terminal_states: Set of state names that end the workflow.
state_store: Optional store to persist current state across restarts.
max_transitions: Safety cap on total state transitions (default: 50).
logger, metrics, tracer: Observability (optional).
Example::
wf = StateMachineWorkflow(
states={"gather": search_agent, "summarise": summarise_agent},
transitions={
"gather": [(None, "summarise")],
"summarise": [(None, "END")],
},
initial_state="gather",
terminal_states={"END"},
)
result = await wf.run({"query": "AI news"})
52 def __init__( 53 self, 54 states: Optional[Dict[str, BaseAgent]] = None, 55 transitions: Optional[Dict[str, List[TransitionRule]]] = None, 56 initial_state: str = "", 57 terminal_states: Optional[set] = None, 58 state_store: Optional[BaseStateStore] = None, 59 max_transitions: int = 50, 60 logger: Optional[BasicLogger] = None, 61 metrics: Optional[BasicMetricsCollector] = None, 62 tracer: Optional[TracingProvider] = None, 63 ) -> None: 64 super().__init__(logger=logger, metrics=metrics, tracer=tracer) 65 self._states: Dict[str, BaseAgent] = states or {} 66 self._transitions: Dict[str, List[TransitionRule]] = transitions or {} 67 self.initial_state = initial_state 68 self.terminal_states: set = terminal_states or set() 69 self._state_store = state_store 70 self.max_transitions = max_transitions
79 async def run(self, initial_input: Dict[str, Any]) -> WorkflowResult: 80 self._logger.info( 81 "StateMachineWorkflow started", initial_state=self.initial_state 82 ) 83 84 current_state = self.initial_state 85 # Resume from persisted state if available 86 if self._state_store: 87 persisted = await self._state_store.get("__sm_current_state__") 88 if persisted: 89 current_state = persisted 90 91 node_results: Dict[str, AgentResult] = {} 92 context = dict(initial_input) 93 final_output = "" 94 95 with self._tracer.trace( 96 "state_machine.run", input=str(initial_input), metadata={"initial": current_state} 97 ) as trace: 98 for _ in range(self.max_transitions): 99 if current_state in self.terminal_states or current_state is None: 100 break 101 102 agent = self._states.get(current_state) 103 if agent is None: 104 self._logger.error( 105 "No agent for state", state=current_state 106 ) 107 break 108 109 self._log_node_start(current_state) 110 with trace.span(f"state.{current_state}", input=str(context)) as span: 111 try: 112 task = context.get("task") or context.get("query", str(context)) 113 result = await agent.execute(task, context=context) 114 self._log_node_end(current_state, success=result.success) 115 node_results[current_state] = result 116 final_output = result.output 117 context["last_output"] = result.output 118 span.set_output(result.output) 119 120 if self._state_store: 121 next_s = self._next_state(current_state, result) 122 await self._state_store.set("__sm_current_state__", next_s or "__terminal__") 123 124 if self._metrics: 125 self._metrics.increment( 126 "workflow.state_transitions", 127 from_state=current_state, 128 ) 129 130 current_state = self._next_state(current_state, result) or self._TERMINAL 131 132 except Exception as exc: 133 self._log_node_end(current_state, success=False) 134 span.set_error(exc) 135 self._logger.error( 136 "StateMachineWorkflow state error", 137 state=current_state, 138 error=str(exc), 139 ) 140 trace.set_error(exc) 141 return WorkflowResult( 142 outputs=node_results, success=False, error=str(exc) 143 ) 144 145 trace.set_output(final_output) 146 return WorkflowResult(outputs=node_results, final_output=final_output, success=True)
Execute the workflow and return the aggregated result.
28class EventDrivenWorkflow(BaseWorkflow): 29 """ 30 Processes events through registered async handler functions. 31 32 Handlers are registered per event name. When an event is emitted, all 33 registered handlers for that name run concurrently. If a handler returns 34 a new :class:`WorkflowEvent`, it is automatically enqueued for further 35 processing (event chaining). 36 37 The workflow terminates when: 38 - The event queue is empty and no handlers are running, OR 39 - ``max_events`` total events have been processed, OR 40 - A ``terminal_event`` is emitted. 41 42 Args: 43 max_events: Safety cap on total events processed (default: 100). 44 terminal_event: Event name that signals completion (optional). 45 logger, metrics, tracer: Observability (optional). 46 47 Example:: 48 49 wf = EventDrivenWorkflow(terminal_event="done") 50 wf.on("query_received", handle_query) 51 wf.on("results_ready", handle_results) 52 await wf.emit(WorkflowEvent(name="query_received", payload={"q": "AI"})) 53 result = await wf.run({}) 54 """ 55 56 def __init__( 57 self, 58 max_events: int = 100, 59 terminal_event: Optional[str] = None, 60 logger: Optional[BasicLogger] = None, 61 metrics: Optional[BasicMetricsCollector] = None, 62 tracer: Optional[TracingProvider] = None, 63 ) -> None: 64 super().__init__(logger=logger, metrics=metrics, tracer=tracer) 65 self.max_events = max_events 66 self.terminal_event = terminal_event 67 self._handlers: Dict[str, List[EventHandler]] = {} 68 self._queue: asyncio.Queue = asyncio.Queue() 69 self._processed: List[WorkflowEvent] = [] 70 71 def on(self, event_name: str, handler: EventHandler) -> None: 72 """Register an async handler for a given event name.""" 73 self._handlers.setdefault(event_name, []).append(handler) 74 75 async def emit(self, event: WorkflowEvent) -> None: 76 """Enqueue an event for processing.""" 77 await self._queue.put(event) 78 79 async def run(self, initial_input: Dict[str, Any]) -> WorkflowResult: 80 self._logger.info("EventDrivenWorkflow started") 81 82 # Optionally seed with an initial event from input 83 if "event" in initial_input: 84 event_data = initial_input["event"] 85 if isinstance(event_data, dict): 86 await self.emit(WorkflowEvent(**event_data)) 87 elif isinstance(event_data, WorkflowEvent): 88 await self.emit(event_data) 89 90 total_processed = 0 91 final_output = "" 92 93 with self._tracer.trace("event_workflow.run", input=str(initial_input)) as trace: 94 while total_processed < self.max_events: 95 try: 96 event: WorkflowEvent = self._queue.get_nowait() 97 except asyncio.QueueEmpty: 98 break 99 100 self._logger.info( 101 "Processing event", 102 event_name=event.name, 103 source=event.source, 104 ) 105 self._processed.append(event) 106 total_processed += 1 107 108 if self._metrics: 109 self._metrics.increment("workflow.events_processed", event=event.name) 110 111 with trace.span(f"event.{event.name}", input=str(event.payload)) as span: 112 handlers = self._handlers.get(event.name, []) 113 if not handlers: 114 self._logger.warning("No handler for event", event_name=event.name) 115 span.set_output("no_handler") 116 self._queue.task_done() 117 continue 118 119 async def _run_handler(h: EventHandler, ev: WorkflowEvent) -> Optional[WorkflowEvent]: 120 return await h(ev) 121 122 follow_ups = await asyncio.gather( 123 *[_run_handler(h, event) for h in handlers], 124 return_exceptions=False, 125 ) 126 127 for follow_up in follow_ups: 128 if isinstance(follow_up, WorkflowEvent): 129 follow_up.source = event.name 130 await self.emit(follow_up) 131 if self.terminal_event and follow_up.name == self.terminal_event: 132 final_output = str(follow_up.payload.get("output", "")) 133 span.set_output(final_output) 134 self._queue.task_done() 135 trace.set_output(final_output) 136 return WorkflowResult( 137 final_output=final_output, 138 success=True, 139 metadata={"events_processed": total_processed}, 140 ) 141 142 span.set_output(f"processed {len(handlers)} handler(s)") 143 self._queue.task_done() 144 145 self._logger.info( 146 "EventDrivenWorkflow completed", events_processed=total_processed 147 ) 148 trace.set_output(final_output) 149 return WorkflowResult( 150 final_output=final_output, 151 success=True, 152 metadata={"events_processed": total_processed}, 153 )
Processes events through registered async handler functions.
Handlers are registered per event name. When an event is emitted, all
registered handlers for that name run concurrently. If a handler returns
a new WorkflowEvent, it is automatically enqueued for further
processing (event chaining).
The workflow terminates when:
- The event queue is empty and no handlers are running, OR
max_eventstotal events have been processed, OR- A
terminal_eventis emitted.
Args: max_events: Safety cap on total events processed (default: 100). terminal_event: Event name that signals completion (optional). logger, metrics, tracer: Observability (optional).
Example::
wf = EventDrivenWorkflow(terminal_event="done")
wf.on("query_received", handle_query)
wf.on("results_ready", handle_results)
await wf.emit(WorkflowEvent(name="query_received", payload={"q": "AI"}))
result = await wf.run({})
56 def __init__( 57 self, 58 max_events: int = 100, 59 terminal_event: Optional[str] = None, 60 logger: Optional[BasicLogger] = None, 61 metrics: Optional[BasicMetricsCollector] = None, 62 tracer: Optional[TracingProvider] = None, 63 ) -> None: 64 super().__init__(logger=logger, metrics=metrics, tracer=tracer) 65 self.max_events = max_events 66 self.terminal_event = terminal_event 67 self._handlers: Dict[str, List[EventHandler]] = {} 68 self._queue: asyncio.Queue = asyncio.Queue() 69 self._processed: List[WorkflowEvent] = []
71 def on(self, event_name: str, handler: EventHandler) -> None: 72 """Register an async handler for a given event name.""" 73 self._handlers.setdefault(event_name, []).append(handler)
Register an async handler for a given event name.
75 async def emit(self, event: WorkflowEvent) -> None: 76 """Enqueue an event for processing.""" 77 await self._queue.put(event)
Enqueue an event for processing.
79 async def run(self, initial_input: Dict[str, Any]) -> WorkflowResult: 80 self._logger.info("EventDrivenWorkflow started") 81 82 # Optionally seed with an initial event from input 83 if "event" in initial_input: 84 event_data = initial_input["event"] 85 if isinstance(event_data, dict): 86 await self.emit(WorkflowEvent(**event_data)) 87 elif isinstance(event_data, WorkflowEvent): 88 await self.emit(event_data) 89 90 total_processed = 0 91 final_output = "" 92 93 with self._tracer.trace("event_workflow.run", input=str(initial_input)) as trace: 94 while total_processed < self.max_events: 95 try: 96 event: WorkflowEvent = self._queue.get_nowait() 97 except asyncio.QueueEmpty: 98 break 99 100 self._logger.info( 101 "Processing event", 102 event_name=event.name, 103 source=event.source, 104 ) 105 self._processed.append(event) 106 total_processed += 1 107 108 if self._metrics: 109 self._metrics.increment("workflow.events_processed", event=event.name) 110 111 with trace.span(f"event.{event.name}", input=str(event.payload)) as span: 112 handlers = self._handlers.get(event.name, []) 113 if not handlers: 114 self._logger.warning("No handler for event", event_name=event.name) 115 span.set_output("no_handler") 116 self._queue.task_done() 117 continue 118 119 async def _run_handler(h: EventHandler, ev: WorkflowEvent) -> Optional[WorkflowEvent]: 120 return await h(ev) 121 122 follow_ups = await asyncio.gather( 123 *[_run_handler(h, event) for h in handlers], 124 return_exceptions=False, 125 ) 126 127 for follow_up in follow_ups: 128 if isinstance(follow_up, WorkflowEvent): 129 follow_up.source = event.name 130 await self.emit(follow_up) 131 if self.terminal_event and follow_up.name == self.terminal_event: 132 final_output = str(follow_up.payload.get("output", "")) 133 span.set_output(final_output) 134 self._queue.task_done() 135 trace.set_output(final_output) 136 return WorkflowResult( 137 final_output=final_output, 138 success=True, 139 metadata={"events_processed": total_processed}, 140 ) 141 142 span.set_output(f"processed {len(handlers)} handler(s)") 143 self._queue.task_done() 144 145 self._logger.info( 146 "EventDrivenWorkflow completed", events_processed=total_processed 147 ) 148 trace.set_output(final_output) 149 return WorkflowResult( 150 final_output=final_output, 151 success=True, 152 metadata={"events_processed": total_processed}, 153 )
Execute the workflow and return the aggregated result.
14@dataclass 15class WorkflowEvent: 16 """An event emitted into the workflow event bus.""" 17 18 name: str 19 payload: Dict[str, Any] = field(default_factory=dict) 20 source: str = "external"
An event emitted into the workflow event bus.
30class BaseOrchestrator(ABC): 31 """ 32 Abstract base class for multi-agent orchestrators. 33 34 Args: 35 logger: Optional :class:`BasicLogger`. 36 metrics: Optional :class:`BasicMetricsCollector`. 37 tracer: Optional :class:`TracingProvider`. Falls back to ``get_tracer()``. 38 """ 39 40 def __init__( 41 self, 42 router: Optional["BaseRouter"] = None, 43 logger: Optional[BasicLogger] = None, 44 metrics: Optional[BasicMetricsCollector] = None, 45 tracer: Optional[TracingProvider] = None, 46 ) -> None: 47 self._logger = logger or BasicLogger( 48 f"gmf_forge_ai.orchestrator.{self.__class__.__name__}" 49 ) 50 self._metrics = metrics 51 self._tracer = tracer or get_tracer() 52 self.router: Optional["BaseRouter"] = router 53 54 @abstractmethod 55 async def run(self, task: str, context: Optional[Dict[str, Any]] = None) -> OrchestratorResult: 56 """Execute the multi-agent orchestration and return the aggregated result.""" 57 58 def _log_start(self, task: str) -> None: 59 self._logger.info("Orchestrator started", orchestrator=self.__class__.__name__, task=task) 60 if self._metrics: 61 self._metrics.increment("orchestrator.runs", orchestrator=self.__class__.__name__) 62 63 def _log_agent_dispatch(self, agent_id: str, task: str) -> None: 64 self._logger.info("Dispatching to agent", agent_id=agent_id, task=task[:100]) 65 66 def _log_finished(self, rounds: int, success: bool) -> None: 67 self._logger.info( 68 "Orchestrator finished", 69 orchestrator=self.__class__.__name__, 70 rounds=rounds, 71 success=success, 72 ) 73 if self._metrics: 74 self._metrics.increment("orchestrator.rounds", count=rounds)
Abstract base class for multi-agent orchestrators.
Args:
logger: Optional BasicLogger.
metrics: Optional BasicMetricsCollector.
tracer: Optional TracingProvider. Falls back to get_tracer().
54 @abstractmethod 55 async def run(self, task: str, context: Optional[Dict[str, Any]] = None) -> OrchestratorResult: 56 """Execute the multi-agent orchestration and return the aggregated result."""
Execute the multi-agent orchestration and return the aggregated result.
17@dataclass 18class OrchestratorResult: 19 """The aggregated result of a multi-agent orchestration run.""" 20 21 final_output: str 22 agent_outputs: Dict[str, AgentResult] = field(default_factory=dict) 23 subtask_outputs: List[Tuple[str, str, AgentResult]] = field(default_factory=list) 24 rounds: int = 0 25 success: bool = True 26 error: Optional[str] = None 27 metadata: Dict[str, Any] = field(default_factory=dict)
The aggregated result of a multi-agent orchestration run.
55class SupervisorOrchestrator(BaseOrchestrator): 56 """ 57 LLM-based supervisor that plans subtask assignment and synthesizes results. 58 59 Phase 1 — Plan: The LLM supervisor decomposes the task and assigns each 60 subtask to the most appropriate worker agent. 61 Phase 2 — Execute: All worker agents run (concurrently if not interdependent). 62 Phase 3 — Synthesize: The supervisor LLM merges all results into a final answer. 63 64 Args: 65 supervisor_gateway: The LLM gateway used by the supervisor. 66 agents: Mapping of agent name → :class:`BaseAgent`. 67 agent_descriptions: Optional human-readable descriptions for the supervisor prompt. 68 supervisor_model: LLM model for supervisor calls (optional). 69 logger, metrics, tracer: Observability (optional). 70 """ 71 72 def __init__( 73 self, 74 supervisor_gateway: Any, 75 agents: Optional[Dict[str, BaseAgent]] = None, 76 agent_descriptions: Optional[Dict[str, str]] = None, 77 supervisor_model: Optional[str] = None, 78 router: Optional[BaseRouter] = None, 79 logger: Optional[BasicLogger] = None, 80 metrics: Optional[BasicMetricsCollector] = None, 81 tracer: Optional[TracingProvider] = None, 82 ) -> None: 83 super().__init__(router=router, logger=logger, metrics=metrics, tracer=tracer) 84 self.supervisor_gateway = supervisor_gateway 85 self.agents: Dict[str, BaseAgent] = agents or {} 86 self.agent_descriptions: Dict[str, str] = agent_descriptions or {} 87 self.supervisor_model = supervisor_model 88 89 async def _decompose(self, task: str) -> List[str]: 90 """Use LLM to decompose a task into subtask strings (no agent assignment).""" 91 prompt = _SUPERVISOR_DECOMPOSE_PROMPT.format(task=task) 92 response = await self.supervisor_gateway.complete( 93 prompt, model=self.supervisor_model, temperature=0.0 94 ) 95 raw = response.content.strip() 96 match = re.search(r"\[.*\]", raw, re.DOTALL) 97 if match: 98 try: 99 subtasks = json.loads(match.group()) 100 if isinstance(subtasks, list) and all(isinstance(s, str) for s in subtasks): 101 return subtasks 102 except (json.JSONDecodeError, ValueError): 103 pass 104 return [task] # fallback: treat whole task as one subtask 105 106 async def _plan(self, task: str) -> List[Dict[str, str]]: 107 """Decompose task and assign agents. 108 109 When a ``router`` is injected: the LLM only decomposes the task into 110 subtasks; the router selects the best agent for each subtask. 111 112 When no router is provided: the LLM both decomposes and assigns agents 113 (original behaviour). 114 """ 115 if self.router: 116 subtasks = await self._decompose(task) 117 available = list(self.agents.keys()) 118 plan: List[Dict[str, str]] = [] 119 for subtask in subtasks: 120 decision = await self.router.route( 121 RoutingRequest(input=subtask, available_agents=available) 122 ) 123 agent_name = decision.target if decision.target in self.agents else available[0] 124 plan.append({"agent": agent_name, "subtask": subtask}) 125 self._logger.info( 126 "Router assigned subtask", 127 agent=agent_name, 128 confidence=decision.confidence, 129 subtask=subtask[:80], 130 ) 131 return plan 132 133 # --- Original LLM-based plan (no router) --- 134 desc_block = "\n".join( 135 f"- {name}: {self.agent_descriptions.get(name, 'General agent.')}" 136 for name in self.agents 137 ) 138 prompt = _SUPERVISOR_PLAN_PROMPT.format( 139 agent_descriptions=desc_block, task=task 140 ) 141 response = await self.supervisor_gateway.complete( 142 prompt, model=self.supervisor_model, temperature=0.0 143 ) 144 raw = response.content.strip() 145 match = re.search(r"\[.*\]", raw, re.DOTALL) 146 if match: 147 try: 148 llm_plan = json.loads(match.group()) 149 if isinstance(llm_plan, list): 150 return [p for p in llm_plan if "agent" in p and "subtask" in p] 151 except (json.JSONDecodeError, ValueError): 152 pass 153 # Fallback: assign whole task to first agent 154 first_agent = next(iter(self.agents), "") 155 return [{"agent": first_agent, "subtask": task}] 156 157 async def _synthesize(self, task: str, subtask_results: List[Tuple[str, str, AgentResult]]) -> str: 158 results_block = "\n\n".join( 159 f"[{agent}]: {result.output}" for agent, _subtask, result in subtask_results 160 ) 161 prompt = _SUPERVISOR_SYNTHESIS_PROMPT.format( 162 task=task, agent_results=results_block 163 ) 164 response = await self.supervisor_gateway.complete( 165 prompt, model=self.supervisor_model, temperature=0.1 166 ) 167 return response.content.strip() 168 169 async def run(self, task: str, context: Optional[Dict[str, Any]] = None) -> OrchestratorResult: 170 self._log_start(task) 171 172 with self._tracer.trace( 173 "supervisor.run", input=task, metadata={"agent_count": len(self.agents)} 174 ) as trace: 175 try: 176 # Phase 1: Plan 177 with trace.span("supervisor.plan", input=task) as plan_span: 178 plan = await self._plan(task) 179 plan_span.set_output({"plan": plan}) 180 181 # Phase 2: Execute subtasks 182 agent_outputs: Dict[str, AgentResult] = {} 183 subtask_list: List[Tuple[str, str, AgentResult]] = [] 184 for assignment in plan: 185 agent_name = assignment["agent"] 186 subtask = assignment["subtask"] 187 agent = self.agents.get(agent_name) 188 if agent is None: 189 self._logger.warning("Unknown agent in plan", agent=agent_name) 190 continue 191 192 self._log_agent_dispatch(agent_name, subtask) 193 with trace.span(f"agent.{agent_name}", input=subtask) as aspan: 194 result = await agent.execute(subtask, context=context) 195 agent_outputs[agent_name] = result 196 subtask_list.append((agent_name, subtask, result)) 197 aspan.set_output(result.output) 198 199 # Phase 3: Synthesize 200 with trace.span("supervisor.synthesize") as syn_span: 201 final_output = await self._synthesize(task, subtask_list) 202 syn_span.set_output(final_output) 203 204 self._log_finished(rounds=1, success=True) 205 trace.set_output(final_output) 206 return OrchestratorResult( 207 final_output=final_output, 208 agent_outputs=agent_outputs, 209 subtask_outputs=subtask_list, 210 rounds=1, 211 success=True, 212 ) 213 214 except Exception as exc: 215 self._logger.error("SupervisorOrchestrator failed", error=str(exc)) 216 trace.set_error(exc) 217 return OrchestratorResult( 218 final_output="", 219 rounds=1, 220 success=False, 221 error=str(exc), 222 )
LLM-based supervisor that plans subtask assignment and synthesizes results.
Phase 1 — Plan: The LLM supervisor decomposes the task and assigns each subtask to the most appropriate worker agent. Phase 2 — Execute: All worker agents run (concurrently if not interdependent). Phase 3 — Synthesize: The supervisor LLM merges all results into a final answer.
Args:
supervisor_gateway: The LLM gateway used by the supervisor.
agents: Mapping of agent name → BaseAgent.
agent_descriptions: Optional human-readable descriptions for the supervisor prompt.
supervisor_model: LLM model for supervisor calls (optional).
logger, metrics, tracer: Observability (optional).
72 def __init__( 73 self, 74 supervisor_gateway: Any, 75 agents: Optional[Dict[str, BaseAgent]] = None, 76 agent_descriptions: Optional[Dict[str, str]] = None, 77 supervisor_model: Optional[str] = None, 78 router: Optional[BaseRouter] = None, 79 logger: Optional[BasicLogger] = None, 80 metrics: Optional[BasicMetricsCollector] = None, 81 tracer: Optional[TracingProvider] = None, 82 ) -> None: 83 super().__init__(router=router, logger=logger, metrics=metrics, tracer=tracer) 84 self.supervisor_gateway = supervisor_gateway 85 self.agents: Dict[str, BaseAgent] = agents or {} 86 self.agent_descriptions: Dict[str, str] = agent_descriptions or {} 87 self.supervisor_model = supervisor_model
169 async def run(self, task: str, context: Optional[Dict[str, Any]] = None) -> OrchestratorResult: 170 self._log_start(task) 171 172 with self._tracer.trace( 173 "supervisor.run", input=task, metadata={"agent_count": len(self.agents)} 174 ) as trace: 175 try: 176 # Phase 1: Plan 177 with trace.span("supervisor.plan", input=task) as plan_span: 178 plan = await self._plan(task) 179 plan_span.set_output({"plan": plan}) 180 181 # Phase 2: Execute subtasks 182 agent_outputs: Dict[str, AgentResult] = {} 183 subtask_list: List[Tuple[str, str, AgentResult]] = [] 184 for assignment in plan: 185 agent_name = assignment["agent"] 186 subtask = assignment["subtask"] 187 agent = self.agents.get(agent_name) 188 if agent is None: 189 self._logger.warning("Unknown agent in plan", agent=agent_name) 190 continue 191 192 self._log_agent_dispatch(agent_name, subtask) 193 with trace.span(f"agent.{agent_name}", input=subtask) as aspan: 194 result = await agent.execute(subtask, context=context) 195 agent_outputs[agent_name] = result 196 subtask_list.append((agent_name, subtask, result)) 197 aspan.set_output(result.output) 198 199 # Phase 3: Synthesize 200 with trace.span("supervisor.synthesize") as syn_span: 201 final_output = await self._synthesize(task, subtask_list) 202 syn_span.set_output(final_output) 203 204 self._log_finished(rounds=1, success=True) 205 trace.set_output(final_output) 206 return OrchestratorResult( 207 final_output=final_output, 208 agent_outputs=agent_outputs, 209 subtask_outputs=subtask_list, 210 rounds=1, 211 success=True, 212 ) 213 214 except Exception as exc: 215 self._logger.error("SupervisorOrchestrator failed", error=str(exc)) 216 trace.set_error(exc) 217 return OrchestratorResult( 218 final_output="", 219 rounds=1, 220 success=False, 221 error=str(exc), 222 )
Execute the multi-agent orchestration and return the aggregated result.
14class PipelineOrchestrator(BaseOrchestrator): 15 """ 16 Runs a sequence of agents where the output of each becomes the input to the next. 17 18 The first agent in the pipeline receives the original ``task``. Each 19 subsequent agent receives the previous agent's ``output`` as its task. 20 21 Args: 22 agents: Ordered list of :class:`BaseAgent` instances. 23 pass_context: If True, the full context dict (including all prior outputs) 24 is also forwarded to each agent (default: True). 25 logger, metrics, tracer: Observability (optional). 26 27 Example:: 28 29 pipeline = PipelineOrchestrator( 30 agents=[search_agent, summarise_agent, translate_agent] 31 ) 32 result = await pipeline.run("Summarise the latest AI news in Spanish") 33 """ 34 35 def __init__( 36 self, 37 agents: Optional[List[BaseAgent]] = None, 38 pass_context: bool = True, 39 router: Optional[BaseRouter] = None, 40 logger: Optional[BasicLogger] = None, 41 metrics: Optional[BasicMetricsCollector] = None, 42 tracer: Optional[TracingProvider] = None, 43 ) -> None: 44 super().__init__(router=router, logger=logger, metrics=metrics, tracer=tracer) 45 self.agents: List[BaseAgent] = agents or [] 46 self.pass_context = pass_context 47 48 async def run(self, task: str, context: Optional[Dict[str, Any]] = None) -> OrchestratorResult: 49 self._log_start(task) 50 51 agent_outputs: Dict[str, AgentResult] = {} 52 current_task = task 53 # Merge caller-supplied context (user_assertion, etc.) so it flows to agents; 54 # step outputs are accumulated on top under step_N_output keys. 55 pipeline_context: Dict[str, Any] = {"original_task": task, **(context or {})} 56 57 with self._tracer.trace( 58 "pipeline.run", 59 input=task, 60 metadata={"pipeline_length": len(self.agents)}, 61 ) as trace: 62 try: 63 for i, agent in enumerate(self.agents): 64 # If a router is provided, let it pick the best agent from 65 # the pool for the current task rather than using fixed order. 66 if self.router: 67 available = [a.agent_id for a in self.agents] 68 decision = await self.router.route( 69 RoutingRequest(input=current_task, available_agents=available) 70 ) 71 agent = next( 72 (a for a in self.agents if a.agent_id == decision.target), 73 self.agents[i], # fallback to positional agent 74 ) 75 self._logger.info( 76 "Router selected pipeline agent", 77 step=i, 78 agent=agent.agent_id, 79 confidence=decision.confidence, 80 ) 81 agent_id = agent.agent_id or f"agent_{i}" 82 self._log_agent_dispatch(agent_id, current_task) 83 84 with trace.span(f"pipeline.step_{i}.{agent_id}", input=current_task) as span: 85 # pass_context=True: full accumulated pipeline context (step outputs included) 86 # pass_context=False: only caller's original context (no step outputs) 87 exec_context = pipeline_context if self.pass_context else (context or None) 88 result = await agent.execute(current_task, context=exec_context) 89 agent_outputs[f"step_{i}_{agent_id}"] = result 90 pipeline_context[f"step_{i}_output"] = result.output 91 span.set_output(result.output) 92 93 if not result.success: 94 self._logger.warning( 95 "Pipeline step failed", 96 step=i, 97 agent=agent_id, 98 error=result.error, 99 ) 100 101 # Output of this step becomes the task for the next 102 current_task = result.output 103 104 self._log_finished(rounds=len(self.agents), success=True) 105 trace.set_output(current_task) 106 return OrchestratorResult( 107 final_output=current_task, 108 agent_outputs=agent_outputs, 109 rounds=len(self.agents), 110 success=True, 111 ) 112 113 except Exception as exc: 114 self._logger.error("PipelineOrchestrator failed", error=str(exc)) 115 trace.set_error(exc) 116 return OrchestratorResult( 117 final_output=current_task, 118 agent_outputs=agent_outputs, 119 rounds=len(agent_outputs), 120 success=False, 121 error=str(exc), 122 )
Runs a sequence of agents where the output of each becomes the input to the next.
The first agent in the pipeline receives the original task. Each
subsequent agent receives the previous agent's output as its task.
Args:
agents: Ordered list of BaseAgent instances.
pass_context: If True, the full context dict (including all prior outputs)
is also forwarded to each agent (default: True).
logger, metrics, tracer: Observability (optional).
Example::
pipeline = PipelineOrchestrator(
agents=[search_agent, summarise_agent, translate_agent]
)
result = await pipeline.run("Summarise the latest AI news in Spanish")
35 def __init__( 36 self, 37 agents: Optional[List[BaseAgent]] = None, 38 pass_context: bool = True, 39 router: Optional[BaseRouter] = None, 40 logger: Optional[BasicLogger] = None, 41 metrics: Optional[BasicMetricsCollector] = None, 42 tracer: Optional[TracingProvider] = None, 43 ) -> None: 44 super().__init__(router=router, logger=logger, metrics=metrics, tracer=tracer) 45 self.agents: List[BaseAgent] = agents or [] 46 self.pass_context = pass_context
48 async def run(self, task: str, context: Optional[Dict[str, Any]] = None) -> OrchestratorResult: 49 self._log_start(task) 50 51 agent_outputs: Dict[str, AgentResult] = {} 52 current_task = task 53 # Merge caller-supplied context (user_assertion, etc.) so it flows to agents; 54 # step outputs are accumulated on top under step_N_output keys. 55 pipeline_context: Dict[str, Any] = {"original_task": task, **(context or {})} 56 57 with self._tracer.trace( 58 "pipeline.run", 59 input=task, 60 metadata={"pipeline_length": len(self.agents)}, 61 ) as trace: 62 try: 63 for i, agent in enumerate(self.agents): 64 # If a router is provided, let it pick the best agent from 65 # the pool for the current task rather than using fixed order. 66 if self.router: 67 available = [a.agent_id for a in self.agents] 68 decision = await self.router.route( 69 RoutingRequest(input=current_task, available_agents=available) 70 ) 71 agent = next( 72 (a for a in self.agents if a.agent_id == decision.target), 73 self.agents[i], # fallback to positional agent 74 ) 75 self._logger.info( 76 "Router selected pipeline agent", 77 step=i, 78 agent=agent.agent_id, 79 confidence=decision.confidence, 80 ) 81 agent_id = agent.agent_id or f"agent_{i}" 82 self._log_agent_dispatch(agent_id, current_task) 83 84 with trace.span(f"pipeline.step_{i}.{agent_id}", input=current_task) as span: 85 # pass_context=True: full accumulated pipeline context (step outputs included) 86 # pass_context=False: only caller's original context (no step outputs) 87 exec_context = pipeline_context if self.pass_context else (context or None) 88 result = await agent.execute(current_task, context=exec_context) 89 agent_outputs[f"step_{i}_{agent_id}"] = result 90 pipeline_context[f"step_{i}_output"] = result.output 91 span.set_output(result.output) 92 93 if not result.success: 94 self._logger.warning( 95 "Pipeline step failed", 96 step=i, 97 agent=agent_id, 98 error=result.error, 99 ) 100 101 # Output of this step becomes the task for the next 102 current_task = result.output 103 104 self._log_finished(rounds=len(self.agents), success=True) 105 trace.set_output(current_task) 106 return OrchestratorResult( 107 final_output=current_task, 108 agent_outputs=agent_outputs, 109 rounds=len(self.agents), 110 success=True, 111 ) 112 113 except Exception as exc: 114 self._logger.error("PipelineOrchestrator failed", error=str(exc)) 115 trace.set_error(exc) 116 return OrchestratorResult( 117 final_output=current_task, 118 agent_outputs=agent_outputs, 119 rounds=len(agent_outputs), 120 success=False, 121 error=str(exc), 122 )
Execute the multi-agent orchestration and return the aggregated result.
42class DebateOrchestrator(BaseOrchestrator): 43 """ 44 Runs structured multi-agent debate to improve answer quality. 45 46 Each agent generates an initial position. Then for ``debate_rounds`` 47 rounds, each agent critiques other agents' positions and refines its own. 48 A synthesis LLM call (or the first agent's gateway) produces the final answer. 49 50 Args: 51 agents: List of debating :class:`BaseAgent` instances. 52 debate_rounds: Number of critique/refine cycles (default: 1). 53 synthesis_gateway: Optional separate LLM gateway for the final synthesis. 54 If None, uses the first agent's gateway. 55 synthesis_model: Model for synthesis (optional). 56 logger, metrics, tracer: Observability (optional). 57 58 Example:: 59 60 debate = DebateOrchestrator( 61 agents=[agent_a, agent_b, agent_c], 62 debate_rounds=2, 63 ) 64 result = await debate.run("What is the best AI architecture for RAG?") 65 """ 66 67 def __init__( 68 self, 69 agents: Optional[List[BaseAgent]] = None, 70 debate_rounds: int = 1, 71 synthesis_gateway: Optional[Any] = None, 72 synthesis_model: Optional[str] = None, 73 router: Optional[BaseRouter] = None, 74 logger: Optional[BasicLogger] = None, 75 metrics: Optional[BasicMetricsCollector] = None, 76 tracer: Optional[TracingProvider] = None, 77 ) -> None: 78 super().__init__(router=router, logger=logger, metrics=metrics, tracer=tracer) 79 self.agents: List[BaseAgent] = agents or [] 80 self.debate_rounds = debate_rounds 81 self._synthesis_gateway = synthesis_gateway 82 self.synthesis_model = synthesis_model 83 84 def _synthesis_gw(self) -> Any: 85 if self._synthesis_gateway: 86 return self._synthesis_gateway 87 if self.agents: 88 return self.agents[0].llm_gateway 89 raise ValueError("DebateOrchestrator: no agents or synthesis_gateway provided.") 90 91 async def run(self, task: str, context: Optional[Dict[str, Any]] = None) -> OrchestratorResult: 92 self._log_start(task) 93 94 with self._tracer.trace( 95 "debate.run", 96 input=task, 97 metadata={"agents": len(self.agents), "rounds": self.debate_rounds}, 98 ) as trace: 99 try: 100 # Phase 1: Initial positions 101 positions: Dict[str, str] = {} 102 agent_outputs: Dict[str, AgentResult] = {} 103 104 with trace.span("debate.initial_positions") as init_span: 105 for agent in self.agents: 106 prompt = _DEBATE_INITIAL_PROMPT.format(task=task) 107 result = await agent.execute(prompt, context=context) 108 positions[agent.agent_id] = result.output 109 agent_outputs[agent.agent_id] = result 110 111 init_span.set_output({"agents": list(positions.keys())}) 112 113 # Phase 2: Critique rounds 114 for round_num in range(self.debate_rounds): 115 self._logger.info("Debate round", round=round_num + 1) 116 if self._metrics: 117 self._metrics.increment("orchestrator.debate_rounds") 118 119 with trace.span(f"debate.round_{round_num + 1}") as round_span: 120 new_positions: Dict[str, str] = {} 121 for agent in self.agents: 122 others = "\n\n".join( 123 f"[{aid}]: {pos}" 124 for aid, pos in positions.items() 125 if aid != agent.agent_id 126 ) 127 critique_prompt = _DEBATE_CRITIQUE_PROMPT.format( 128 task=task, 129 positions=others, 130 own_position=positions.get(agent.agent_id, ""), 131 ) 132 result = await agent.execute(critique_prompt, context=context) 133 new_positions[agent.agent_id] = result.output 134 agent_outputs[agent.agent_id] = result 135 136 positions = new_positions 137 round_span.set_output({"positions_updated": len(positions)}) 138 139 # Phase 3: Synthesis 140 with trace.span("debate.synthesis") as syn_span: 141 all_positions = "\n\n".join( 142 f"[{aid}] (round {self.debate_rounds}):\n{pos}" 143 for aid, pos in positions.items() 144 ) 145 synthesis_prompt = _DEBATE_SYNTHESIS_PROMPT.format( 146 task=task, 147 rounds=self.debate_rounds, 148 all_positions=all_positions, 149 ) 150 synthesis_response = await self._synthesis_gw().complete( 151 synthesis_prompt, model=self.synthesis_model, temperature=0.1 152 ) 153 final_output = synthesis_response.content.strip() 154 syn_span.set_output(final_output) 155 156 self._log_finished(rounds=self.debate_rounds, success=True) 157 trace.set_output(final_output) 158 return OrchestratorResult( 159 final_output=final_output, 160 agent_outputs=agent_outputs, 161 rounds=self.debate_rounds, 162 success=True, 163 ) 164 165 except Exception as exc: 166 self._logger.error("DebateOrchestrator failed", error=str(exc)) 167 trace.set_error(exc) 168 return OrchestratorResult( 169 final_output="", 170 rounds=self.debate_rounds, 171 success=False, 172 error=str(exc), 173 )
Runs structured multi-agent debate to improve answer quality.
Each agent generates an initial position. Then for debate_rounds
rounds, each agent critiques other agents' positions and refines its own.
A synthesis LLM call (or the first agent's gateway) produces the final answer.
Args:
agents: List of debating BaseAgent instances.
debate_rounds: Number of critique/refine cycles (default: 1).
synthesis_gateway: Optional separate LLM gateway for the final synthesis.
If None, uses the first agent's gateway.
synthesis_model: Model for synthesis (optional).
logger, metrics, tracer: Observability (optional).
Example::
debate = DebateOrchestrator(
agents=[agent_a, agent_b, agent_c],
debate_rounds=2,
)
result = await debate.run("What is the best AI architecture for RAG?")
67 def __init__( 68 self, 69 agents: Optional[List[BaseAgent]] = None, 70 debate_rounds: int = 1, 71 synthesis_gateway: Optional[Any] = None, 72 synthesis_model: Optional[str] = None, 73 router: Optional[BaseRouter] = None, 74 logger: Optional[BasicLogger] = None, 75 metrics: Optional[BasicMetricsCollector] = None, 76 tracer: Optional[TracingProvider] = None, 77 ) -> None: 78 super().__init__(router=router, logger=logger, metrics=metrics, tracer=tracer) 79 self.agents: List[BaseAgent] = agents or [] 80 self.debate_rounds = debate_rounds 81 self._synthesis_gateway = synthesis_gateway 82 self.synthesis_model = synthesis_model
91 async def run(self, task: str, context: Optional[Dict[str, Any]] = None) -> OrchestratorResult: 92 self._log_start(task) 93 94 with self._tracer.trace( 95 "debate.run", 96 input=task, 97 metadata={"agents": len(self.agents), "rounds": self.debate_rounds}, 98 ) as trace: 99 try: 100 # Phase 1: Initial positions 101 positions: Dict[str, str] = {} 102 agent_outputs: Dict[str, AgentResult] = {} 103 104 with trace.span("debate.initial_positions") as init_span: 105 for agent in self.agents: 106 prompt = _DEBATE_INITIAL_PROMPT.format(task=task) 107 result = await agent.execute(prompt, context=context) 108 positions[agent.agent_id] = result.output 109 agent_outputs[agent.agent_id] = result 110 111 init_span.set_output({"agents": list(positions.keys())}) 112 113 # Phase 2: Critique rounds 114 for round_num in range(self.debate_rounds): 115 self._logger.info("Debate round", round=round_num + 1) 116 if self._metrics: 117 self._metrics.increment("orchestrator.debate_rounds") 118 119 with trace.span(f"debate.round_{round_num + 1}") as round_span: 120 new_positions: Dict[str, str] = {} 121 for agent in self.agents: 122 others = "\n\n".join( 123 f"[{aid}]: {pos}" 124 for aid, pos in positions.items() 125 if aid != agent.agent_id 126 ) 127 critique_prompt = _DEBATE_CRITIQUE_PROMPT.format( 128 task=task, 129 positions=others, 130 own_position=positions.get(agent.agent_id, ""), 131 ) 132 result = await agent.execute(critique_prompt, context=context) 133 new_positions[agent.agent_id] = result.output 134 agent_outputs[agent.agent_id] = result 135 136 positions = new_positions 137 round_span.set_output({"positions_updated": len(positions)}) 138 139 # Phase 3: Synthesis 140 with trace.span("debate.synthesis") as syn_span: 141 all_positions = "\n\n".join( 142 f"[{aid}] (round {self.debate_rounds}):\n{pos}" 143 for aid, pos in positions.items() 144 ) 145 synthesis_prompt = _DEBATE_SYNTHESIS_PROMPT.format( 146 task=task, 147 rounds=self.debate_rounds, 148 all_positions=all_positions, 149 ) 150 synthesis_response = await self._synthesis_gw().complete( 151 synthesis_prompt, model=self.synthesis_model, temperature=0.1 152 ) 153 final_output = synthesis_response.content.strip() 154 syn_span.set_output(final_output) 155 156 self._log_finished(rounds=self.debate_rounds, success=True) 157 trace.set_output(final_output) 158 return OrchestratorResult( 159 final_output=final_output, 160 agent_outputs=agent_outputs, 161 rounds=self.debate_rounds, 162 success=True, 163 ) 164 165 except Exception as exc: 166 self._logger.error("DebateOrchestrator failed", error=str(exc)) 167 trace.set_error(exc) 168 return OrchestratorResult( 169 final_output="", 170 rounds=self.debate_rounds, 171 success=False, 172 error=str(exc), 173 )
Execute the multi-agent orchestration and return the aggregated result.
29class SwarmOrchestrator(BaseOrchestrator): 30 """ 31 Dynamically routes work to the best available agent each round. 32 33 Each round: 34 1. A coordinator LLM call decides whether the task is done or describes the 35 next sub-task. 36 2. The :class:`BaseRouter` selects the agent best suited for that sub-task. 37 3. The selected agent executes the sub-task. 38 4. Results accumulate until ``DONE`` is signalled or ``max_rounds`` is reached. 39 40 Args: 41 coordinator_gateway: LLM gateway used by the coordinator. 42 agents: Mapping of agent name → :class:`BaseAgent`. 43 router: :class:`BaseRouter` that selects agents each round. 44 max_rounds: Maximum dispatch rounds (default: 10). 45 coordinator_model: Model for coordinator calls (optional). 46 logger, metrics, tracer: Observability (optional). 47 """ 48 49 def __init__( 50 self, 51 coordinator_gateway: Any, 52 agents: Optional[Dict[str, BaseAgent]] = None, 53 router: Optional[BaseRouter] = None, 54 max_rounds: int = 10, 55 coordinator_model: Optional[str] = None, 56 logger: Optional[BasicLogger] = None, 57 metrics: Optional[BasicMetricsCollector] = None, 58 tracer: Optional[TracingProvider] = None, 59 ) -> None: 60 super().__init__(router=router, logger=logger, metrics=metrics, tracer=tracer) 61 self.coordinator_gateway = coordinator_gateway 62 self.agents: Dict[str, BaseAgent] = agents or {} 63 self.max_rounds = max_rounds 64 self.coordinator_model = coordinator_model 65 66 async def _next_step(self, task: str, history: List[str]) -> Optional[str]: 67 """Returns None to signal DONE, otherwise the next sub-task description. 68 69 When a router is present the coordinator is asked to name an explicit agent 70 using the structured ``AGENT: / TASK:`` format. The router then matches on 71 the agent name directly rather than performing keyword matching on free text. 72 """ 73 history_text = "\n".join( 74 f"Round {i+1}: {h}" for i, h in enumerate(history) 75 ) or "Nothing yet." 76 agents_list = ", ".join(self.agents.keys()) if self.agents else "(none)" 77 prompt = _SWARM_CONTINUE_PROMPT.format( 78 task=task, history=history_text, agents=agents_list 79 ) 80 response = await self.coordinator_gateway.complete( 81 prompt, model=self.coordinator_model, temperature=0.0 82 ) 83 text = response.content.strip() 84 if text.upper().startswith("DONE"): 85 return None # Task complete 86 return text 87 88 async def run(self, task: str, context: Optional[Dict[str, Any]] = None) -> OrchestratorResult: 89 self._log_start(task) 90 91 agent_outputs: Dict[str, AgentResult] = {} 92 history: List[str] = [] 93 final_output = "" 94 available = list(self.agents.keys()) 95 96 with self._tracer.trace( 97 "swarm.run", input=task, metadata={"max_rounds": self.max_rounds} 98 ) as trace: 99 try: 100 for round_num in range(self.max_rounds): 101 with trace.span(f"swarm.round_{round_num + 1}") as round_span: 102 # Coordinator decides next step 103 next_task = await self._next_step(task, history) 104 if next_task is None: 105 # Extract final answer from DONE line 106 if history: 107 done_response = await self.coordinator_gateway.complete( 108 f"Summarise the work done:\n" + "\n".join(history), 109 model=self.coordinator_model, 110 temperature=0.1, 111 ) 112 final_output = done_response.content.strip() 113 round_span.set_output("DONE") 114 break 115 116 # Route to best agent 117 routing_req = RoutingRequest( 118 input=next_task, available_agents=available 119 ) 120 if self.router: 121 decision = await self.router.route(routing_req) 122 agent_name = decision.target 123 else: 124 agent_name = available[round_num % len(available)] 125 126 agent = self.agents.get(agent_name) 127 if agent is None: 128 self._logger.warning("Swarm: agent not found", agent=agent_name) 129 continue 130 131 # Extract just the task description from structured AGENT:/TASK: format 132 agent_task = next_task 133 for line in next_task.splitlines(): 134 stripped = line.strip() 135 if stripped.upper().startswith("TASK:"): 136 agent_task = stripped[5:].strip() 137 break 138 139 self._log_agent_dispatch(agent_name, agent_task) 140 result = await agent.execute(agent_task, context=context) 141 history.append(f"[{agent_name}]: {result.output[:300]}") 142 agent_outputs[f"round_{round_num + 1}_{agent_name}"] = result 143 final_output = result.output 144 round_span.set_output(result.output[:200]) 145 146 if self._metrics: 147 self._metrics.increment( 148 "orchestrator.swarm_dispatches", agent=agent_name 149 ) 150 151 self._log_finished(rounds=len(history), success=True) 152 trace.set_output(final_output) 153 return OrchestratorResult( 154 final_output=final_output, 155 agent_outputs=agent_outputs, 156 rounds=len(history), 157 success=True, 158 ) 159 160 except Exception as exc: 161 self._logger.error("SwarmOrchestrator failed", error=str(exc)) 162 trace.set_error(exc) 163 return OrchestratorResult( 164 final_output=final_output, 165 agent_outputs=agent_outputs, 166 rounds=len(history), 167 success=False, 168 error=str(exc), 169 )
Dynamically routes work to the best available agent each round.
Each round:
- A coordinator LLM call decides whether the task is done or describes the next sub-task.
- The
BaseRouterselects the agent best suited for that sub-task. - The selected agent executes the sub-task.
- Results accumulate until
DONEis signalled ormax_roundsis reached.
Args:
coordinator_gateway: LLM gateway used by the coordinator.
agents: Mapping of agent name → BaseAgent.
router: BaseRouter that selects agents each round.
max_rounds: Maximum dispatch rounds (default: 10).
coordinator_model: Model for coordinator calls (optional).
logger, metrics, tracer: Observability (optional).
49 def __init__( 50 self, 51 coordinator_gateway: Any, 52 agents: Optional[Dict[str, BaseAgent]] = None, 53 router: Optional[BaseRouter] = None, 54 max_rounds: int = 10, 55 coordinator_model: Optional[str] = None, 56 logger: Optional[BasicLogger] = None, 57 metrics: Optional[BasicMetricsCollector] = None, 58 tracer: Optional[TracingProvider] = None, 59 ) -> None: 60 super().__init__(router=router, logger=logger, metrics=metrics, tracer=tracer) 61 self.coordinator_gateway = coordinator_gateway 62 self.agents: Dict[str, BaseAgent] = agents or {} 63 self.max_rounds = max_rounds 64 self.coordinator_model = coordinator_model
88 async def run(self, task: str, context: Optional[Dict[str, Any]] = None) -> OrchestratorResult: 89 self._log_start(task) 90 91 agent_outputs: Dict[str, AgentResult] = {} 92 history: List[str] = [] 93 final_output = "" 94 available = list(self.agents.keys()) 95 96 with self._tracer.trace( 97 "swarm.run", input=task, metadata={"max_rounds": self.max_rounds} 98 ) as trace: 99 try: 100 for round_num in range(self.max_rounds): 101 with trace.span(f"swarm.round_{round_num + 1}") as round_span: 102 # Coordinator decides next step 103 next_task = await self._next_step(task, history) 104 if next_task is None: 105 # Extract final answer from DONE line 106 if history: 107 done_response = await self.coordinator_gateway.complete( 108 f"Summarise the work done:\n" + "\n".join(history), 109 model=self.coordinator_model, 110 temperature=0.1, 111 ) 112 final_output = done_response.content.strip() 113 round_span.set_output("DONE") 114 break 115 116 # Route to best agent 117 routing_req = RoutingRequest( 118 input=next_task, available_agents=available 119 ) 120 if self.router: 121 decision = await self.router.route(routing_req) 122 agent_name = decision.target 123 else: 124 agent_name = available[round_num % len(available)] 125 126 agent = self.agents.get(agent_name) 127 if agent is None: 128 self._logger.warning("Swarm: agent not found", agent=agent_name) 129 continue 130 131 # Extract just the task description from structured AGENT:/TASK: format 132 agent_task = next_task 133 for line in next_task.splitlines(): 134 stripped = line.strip() 135 if stripped.upper().startswith("TASK:"): 136 agent_task = stripped[5:].strip() 137 break 138 139 self._log_agent_dispatch(agent_name, agent_task) 140 result = await agent.execute(agent_task, context=context) 141 history.append(f"[{agent_name}]: {result.output[:300]}") 142 agent_outputs[f"round_{round_num + 1}_{agent_name}"] = result 143 final_output = result.output 144 round_span.set_output(result.output[:200]) 145 146 if self._metrics: 147 self._metrics.increment( 148 "orchestrator.swarm_dispatches", agent=agent_name 149 ) 150 151 self._log_finished(rounds=len(history), success=True) 152 trace.set_output(final_output) 153 return OrchestratorResult( 154 final_output=final_output, 155 agent_outputs=agent_outputs, 156 rounds=len(history), 157 success=True, 158 ) 159 160 except Exception as exc: 161 self._logger.error("SwarmOrchestrator failed", error=str(exc)) 162 trace.set_error(exc) 163 return OrchestratorResult( 164 final_output=final_output, 165 agent_outputs=agent_outputs, 166 rounds=len(history), 167 success=False, 168 error=str(exc), 169 )
Execute the multi-agent orchestration and return the aggregated result.