gmf_forge_ai_orchestration.behaviors
Agent behaviors — composable hooks applied around every agent execution.
1"""Agent behaviors — composable hooks applied around every agent execution.""" 2 3from gmf_forge_ai_orchestration.behaviors.base import BaseBehavior, BehaviorContext 4from gmf_forge_ai_orchestration.behaviors.retry import RetryBehavior 5from gmf_forge_ai_orchestration.behaviors.guardrail import ( 6 GuardrailBehavior, 7 GuardrailRule, 8 GuardrailViolationError, 9) 10from gmf_forge_ai_orchestration.behaviors.human_in_loop import ( 11 HumanInLoopBehavior, 12 HumanApprovalRequired, 13 PendingApproval, 14) 15from gmf_forge_ai_orchestration.behaviors.circuit_breaker import ( 16 CircuitBreakerBehavior, 17 CircuitState, 18 CircuitOpenError, 19) 20from gmf_forge_ai_orchestration.behaviors.rate_limit import ( 21 RateLimitBehavior, 22 RateLimitExceededError, 23) 24from gmf_forge_ai_orchestration.behaviors.audit import AuditBehavior 25from gmf_forge_ai_orchestration.behaviors.agent_discovery import AgentDiscoveryBehavior 26from gmf_forge_ai_orchestration.behaviors.obo_token import ( 27 OBOTokenBehavior, 28 OBOTokenProvider, 29 EntraOBOProvider, 30 OktaOBOProvider, 31 OBOTokenError, 32) 33 34__all__ = [ 35 "BaseBehavior", 36 "BehaviorContext", 37 "RetryBehavior", 38 "GuardrailBehavior", 39 "GuardrailRule", 40 "GuardrailViolationError", 41 "HumanInLoopBehavior", 42 "HumanApprovalRequired", 43 "PendingApproval", 44 "CircuitBreakerBehavior", 45 "CircuitState", 46 "CircuitOpenError", 47 "RateLimitBehavior", 48 "RateLimitExceededError", 49 "AuditBehavior", 50 "AgentDiscoveryBehavior", 51 "OBOTokenBehavior", 52 "OBOTokenProvider", 53 "EntraOBOProvider", 54 "OktaOBOProvider", 55 "OBOTokenError", 56]
25class BaseBehavior(ABC): 26 """ 27 Composable agent behavior. 28 29 Behaviors are applied as a pipeline around every agent execution: 30 ``before_execute`` → agent runs → ``after_execute`` (or ``on_error``). 31 32 Subclasses override only the hooks they need. 33 """ 34 35 # ------------------------------------------------------------------ 36 # Hooks — override as needed 37 # ------------------------------------------------------------------ 38 39 async def before_execute(self, context: BehaviorContext) -> BehaviorContext: 40 """ 41 Called before the agent executes its task. 42 43 May mutate and return a modified context (e.g. to inject guardrails). 44 """ 45 return context 46 47 async def after_execute(self, context: BehaviorContext, result: Any) -> Any: 48 """ 49 Called after the agent successfully executes. 50 51 May mutate and return a modified result. 52 """ 53 return result 54 55 async def on_error( 56 self, context: BehaviorContext, error: Exception 57 ) -> Optional[Any]: 58 """ 59 Called when the agent raises an exception. 60 61 Return a fallback value to suppress the error, or re-raise / return 62 ``None`` to propagate it. 63 """ 64 return None
Composable agent behavior.
Behaviors are applied as a pipeline around every agent execution:
before_execute → agent runs → after_execute (or on_error).
Subclasses override only the hooks they need.
39 async def before_execute(self, context: BehaviorContext) -> BehaviorContext: 40 """ 41 Called before the agent executes its task. 42 43 May mutate and return a modified context (e.g. to inject guardrails). 44 """ 45 return context
Called before the agent executes its task.
May mutate and return a modified context (e.g. to inject guardrails).
47 async def after_execute(self, context: BehaviorContext, result: Any) -> Any: 48 """ 49 Called after the agent successfully executes. 50 51 May mutate and return a modified result. 52 """ 53 return result
Called after the agent successfully executes.
May mutate and return a modified result.
55 async def on_error( 56 self, context: BehaviorContext, error: Exception 57 ) -> Optional[Any]: 58 """ 59 Called when the agent raises an exception. 60 61 Return a fallback value to suppress the error, or re-raise / return 62 ``None`` to propagate it. 63 """ 64 return None
Called when the agent raises an exception.
Return a fallback value to suppress the error, or re-raise / return
None to propagate it.
13@dataclass 14class BehaviorContext: 15 """Execution context passed to every behavior hook.""" 16 17 agent_id: str 18 task: str 19 attempt: int = 1 20 execution_id: str = "" 21 last_completed_step: int = -1 22 metadata: Dict[str, Any] = field(default_factory=dict)
Execution context passed to every behavior hook.
12class RetryBehavior(BaseBehavior): 13 """ 14 Retries the agent on failure with exponential backoff. 15 16 Args: 17 max_retries: Maximum number of retry attempts (default: 3). 18 base_delay: Initial delay in seconds before the first retry (default: 1.0). 19 backoff_factor: Multiplier applied to delay after each failure (default: 2.0). 20 max_delay: Upper bound on delay between retries in seconds (default: 60.0). 21 exceptions_to_retry: Tuple of exception types that trigger a retry. 22 Defaults to ``(Exception,)`` — all exceptions. 23 24 Example:: 25 26 agent = ReActAgent( 27 llm_gateway=gateway, 28 behaviors=[RetryBehavior(max_retries=3, base_delay=0.5)], 29 ) 30 """ 31 32 def __init__( 33 self, 34 max_retries: int = 3, 35 base_delay: float = 1.0, 36 backoff_factor: float = 2.0, 37 max_delay: float = 60.0, 38 exceptions_to_retry: Tuple[Type[Exception], ...] = (Exception,), 39 logger: Optional[BasicLogger] = None, 40 ) -> None: 41 self.max_retries = max_retries 42 self.base_delay = base_delay 43 self.backoff_factor = backoff_factor 44 self.max_delay = max_delay 45 self.exceptions_to_retry = exceptions_to_retry 46 self._logger = logger 47 48 # Tracks remaining retries across on_error calls for a single execution 49 self._remaining: int = max_retries 50 51 async def before_execute(self, context: BehaviorContext) -> BehaviorContext: 52 # Reset retry counter at the start of each execution 53 self._remaining = self.max_retries 54 return context 55 56 async def on_error( 57 self, context: BehaviorContext, error: Exception 58 ) -> Optional[Any]: 59 if not isinstance(error, self.exceptions_to_retry): 60 raise error 61 62 if self._remaining <= 0: 63 if self._logger: 64 self._logger.info( 65 "RetryBehavior: all retries exhausted — propagating error", 66 agent_id=context.agent_id, 67 max_retries=self.max_retries, 68 error=str(error), 69 error_type=type(error).__name__, 70 ) 71 raise error 72 73 attempt = self.max_retries - self._remaining + 1 74 delay = min(self.base_delay * (self.backoff_factor ** (attempt - 1)), self.max_delay) 75 self._remaining -= 1 76 77 if self._logger: 78 self._logger.info( 79 "RetryBehavior: scheduling retry", 80 agent_id=context.agent_id, 81 attempt=attempt, 82 retries_remaining=self._remaining, 83 delay_s=round(delay, 3), 84 error=str(error), 85 error_type=type(error).__name__, 86 ) 87 88 await asyncio.sleep(delay) 89 90 # Prefer resume-from-checkpoint when execution metadata is available. 91 resume_agent = context.metadata.get("resume_agent") 92 if ( 93 context.execution_id 94 and context.last_completed_step >= 0 95 and resume_agent is not None 96 and hasattr(resume_agent, "resume_from") 97 ): 98 if self._logger: 99 self._logger.info( 100 "RetryBehavior: resuming from checkpoint", 101 agent_id=context.agent_id, 102 execution_id=context.execution_id, 103 last_completed_step=context.last_completed_step, 104 ) 105 return await resume_agent.resume_from(context.execution_id) 106 107 return RETRY_SENTINEL
Retries the agent on failure with exponential backoff.
Args:
max_retries: Maximum number of retry attempts (default: 3).
base_delay: Initial delay in seconds before the first retry (default: 1.0).
backoff_factor: Multiplier applied to delay after each failure (default: 2.0).
max_delay: Upper bound on delay between retries in seconds (default: 60.0).
exceptions_to_retry: Tuple of exception types that trigger a retry.
Defaults to (Exception,) — all exceptions.
Example::
agent = ReActAgent(
llm_gateway=gateway,
behaviors=[RetryBehavior(max_retries=3, base_delay=0.5)],
)
32 def __init__( 33 self, 34 max_retries: int = 3, 35 base_delay: float = 1.0, 36 backoff_factor: float = 2.0, 37 max_delay: float = 60.0, 38 exceptions_to_retry: Tuple[Type[Exception], ...] = (Exception,), 39 logger: Optional[BasicLogger] = None, 40 ) -> None: 41 self.max_retries = max_retries 42 self.base_delay = base_delay 43 self.backoff_factor = backoff_factor 44 self.max_delay = max_delay 45 self.exceptions_to_retry = exceptions_to_retry 46 self._logger = logger 47 48 # Tracks remaining retries across on_error calls for a single execution 49 self._remaining: int = max_retries
51 async def before_execute(self, context: BehaviorContext) -> BehaviorContext: 52 # Reset retry counter at the start of each execution 53 self._remaining = self.max_retries 54 return context
Called before the agent executes its task.
May mutate and return a modified context (e.g. to inject guardrails).
56 async def on_error( 57 self, context: BehaviorContext, error: Exception 58 ) -> Optional[Any]: 59 if not isinstance(error, self.exceptions_to_retry): 60 raise error 61 62 if self._remaining <= 0: 63 if self._logger: 64 self._logger.info( 65 "RetryBehavior: all retries exhausted — propagating error", 66 agent_id=context.agent_id, 67 max_retries=self.max_retries, 68 error=str(error), 69 error_type=type(error).__name__, 70 ) 71 raise error 72 73 attempt = self.max_retries - self._remaining + 1 74 delay = min(self.base_delay * (self.backoff_factor ** (attempt - 1)), self.max_delay) 75 self._remaining -= 1 76 77 if self._logger: 78 self._logger.info( 79 "RetryBehavior: scheduling retry", 80 agent_id=context.agent_id, 81 attempt=attempt, 82 retries_remaining=self._remaining, 83 delay_s=round(delay, 3), 84 error=str(error), 85 error_type=type(error).__name__, 86 ) 87 88 await asyncio.sleep(delay) 89 90 # Prefer resume-from-checkpoint when execution metadata is available. 91 resume_agent = context.metadata.get("resume_agent") 92 if ( 93 context.execution_id 94 and context.last_completed_step >= 0 95 and resume_agent is not None 96 and hasattr(resume_agent, "resume_from") 97 ): 98 if self._logger: 99 self._logger.info( 100 "RetryBehavior: resuming from checkpoint", 101 agent_id=context.agent_id, 102 execution_id=context.execution_id, 103 last_completed_step=context.last_completed_step, 104 ) 105 return await resume_agent.resume_from(context.execution_id) 106 107 return RETRY_SENTINEL
Called when the agent raises an exception.
Return a fallback value to suppress the error, or re-raise / return
None to propagate it.
30class GuardrailBehavior(BaseBehavior): 31 """ 32 Validates task inputs and agent outputs against configurable rules. 33 34 Raises :class:`GuardrailViolationError` if any rule is violated. 35 36 Args: 37 rules: List of :class:`GuardrailRule` to enforce. 38 39 Example:: 40 41 behavior = GuardrailBehavior(rules=[ 42 GuardrailRule( 43 name="no_pii", 44 blocked_words=["ssn", "social security"], 45 apply_to="both", 46 ), 47 GuardrailRule(name="max_input", max_length=4000, apply_to="input"), 48 ]) 49 """ 50 51 def __init__(self, rules: Optional[List[GuardrailRule]] = None) -> None: 52 self._rules = rules or [] 53 self._compiled: List[tuple] = [] 54 for rule in self._rules: 55 compiled_pattern: Optional[Pattern] = ( 56 re.compile(rule.pattern, re.IGNORECASE) if rule.pattern else None 57 ) 58 self._compiled.append((rule, compiled_pattern)) 59 60 def _validate(self, text: str, target: str) -> None: 61 """Run all rules applicable to ``target`` (``"input"`` or ``"output"``).""" 62 text_lower = text.lower() 63 for rule, compiled_pattern in self._compiled: 64 if rule.apply_to not in (target, "both"): 65 continue 66 67 if rule.max_length and len(text) > rule.max_length: 68 raise GuardrailViolationError( 69 rule.name, 70 f"Length {len(text)} exceeds max {rule.max_length}", 71 ) 72 73 for word in rule.blocked_words: 74 if word.lower() in text_lower: 75 raise GuardrailViolationError( 76 rule.name, f"Blocked term detected: '{word}'" 77 ) 78 79 if compiled_pattern and compiled_pattern.search(text): 80 raise GuardrailViolationError( 81 rule.name, f"Pattern '{rule.pattern}' matched" 82 ) 83 84 async def before_execute(self, context: BehaviorContext) -> BehaviorContext: 85 self._validate(context.task, "input") 86 return context 87 88 async def after_execute(self, context: BehaviorContext, result: Any) -> Any: 89 # Extract string output for validation 90 output_text = result.output if hasattr(result, "output") else str(result) 91 self._validate(output_text, "output") 92 return result
Validates task inputs and agent outputs against configurable rules.
Raises GuardrailViolationError if any rule is violated.
Args:
rules: List of GuardrailRule to enforce.
Example::
behavior = GuardrailBehavior(rules=[
GuardrailRule(
name="no_pii",
blocked_words=["ssn", "social security"],
apply_to="both",
),
GuardrailRule(name="max_input", max_length=4000, apply_to="input"),
])
51 def __init__(self, rules: Optional[List[GuardrailRule]] = None) -> None: 52 self._rules = rules or [] 53 self._compiled: List[tuple] = [] 54 for rule in self._rules: 55 compiled_pattern: Optional[Pattern] = ( 56 re.compile(rule.pattern, re.IGNORECASE) if rule.pattern else None 57 ) 58 self._compiled.append((rule, compiled_pattern))
84 async def before_execute(self, context: BehaviorContext) -> BehaviorContext: 85 self._validate(context.task, "input") 86 return context
Called before the agent executes its task.
May mutate and return a modified context (e.g. to inject guardrails).
88 async def after_execute(self, context: BehaviorContext, result: Any) -> Any: 89 # Extract string output for validation 90 output_text = result.output if hasattr(result, "output") else str(result) 91 self._validate(output_text, "output") 92 return result
Called after the agent successfully executes.
May mutate and return a modified result.
11@dataclass 12class GuardrailRule: 13 """A single validation rule.""" 14 15 name: str 16 pattern: Optional[str] = None # compiled as regex 17 blocked_words: List[str] = field(default_factory=list) 18 max_length: Optional[int] = None 19 apply_to: str = "both" # "input" | "output" | "both"
A single validation rule.
22class GuardrailViolationError(Exception): 23 """Raised when a guardrail rule is violated.""" 24 25 def __init__(self, rule_name: str, message: str) -> None: 26 self.rule_name = rule_name 27 super().__init__(f"Guardrail '{rule_name}' violated: {message}")
Raised when a guardrail rule is violated.
37class HumanInLoopBehavior(BaseBehavior): 38 """ 39 Pauses or suspends execution after the agent produces a result and waits 40 for human approval before the result is returned. 41 42 Two modes: 43 44 **Sync mode** — pass ``review_callback``. 45 The callback blocks until the human decides, returning ``True`` to 46 approve or ``False`` to reject. Suitable for interactive CLIs or 47 short-lived approval windows. 48 49 ``async (context, result) -> bool`` 50 51 **Async (offline) mode** — pass ``submit_callback``. 52 The callback fires the approval request (e.g. POST to a REST API or 53 queue) and returns an opaque ``approval_id`` string. The behavior 54 immediately raises :class:`PendingApproval` so the agent call 55 terminates without blocking. The caller is responsible for persisting 56 the pending state and resuming once the human's decision arrives via 57 webhook or callback. 58 59 ``async (context, result) -> str`` (returns ``approval_id``) 60 61 Args: 62 review_callback: Sync-mode callback. Mutually exclusive with 63 ``submit_callback``. 64 submit_callback: Async-mode callback. Mutually exclusive with 65 ``review_callback``. 66 67 Raises: 68 ValueError: If neither or both callbacks are supplied. 69 70 Example — sync mode:: 71 72 async def my_reviewer(ctx, result): 73 answer = input(f"Approve '{ctx.task}'? [y/n]: ") 74 return answer.strip().lower() == "y" 75 76 behavior = HumanInLoopBehavior(review_callback=my_reviewer) 77 78 Example — async (offline) mode:: 79 80 async def submit_for_approval(ctx, result): 81 resp = await http_client.post("/approvals", json={ 82 "task": ctx.task, 83 "output": result.output, 84 }) 85 return resp.json()["approval_id"] 86 87 behavior = HumanInLoopBehavior(submit_callback=submit_for_approval) 88 89 # Caller handles PendingApproval: 90 try: 91 result = await agent.execute(task) 92 except PendingApproval as pa: 93 store.save(pa.approval_id, pending_result=pa.result, task=pa.task) 94 # ... webhook delivers decision later, caller resumes from store 95 """ 96 97 def __init__( 98 self, 99 review_callback: Optional[Callable[[BehaviorContext, Any], Awaitable[bool]]] = None, 100 submit_callback: Optional[Callable[[BehaviorContext, Any], Awaitable[str]]] = None, 101 ) -> None: 102 if review_callback is None and submit_callback is None: 103 raise ValueError( 104 "HumanInLoopBehavior requires either 'review_callback' (sync) " 105 "or 'submit_callback' (async/offline)." 106 ) 107 if review_callback is not None and submit_callback is not None: 108 raise ValueError( 109 "HumanInLoopBehavior accepts only one of 'review_callback' or " 110 "'submit_callback', not both." 111 ) 112 self._review_callback = review_callback 113 self._submit_callback = submit_callback 114 115 async def after_execute(self, context: BehaviorContext, result: Any) -> Any: 116 if self._submit_callback is not None: 117 # Async / offline mode — fire the request, return approval_id, exit immediately. 118 approval_id = await self._submit_callback(context, result) 119 raise PendingApproval( 120 approval_id=approval_id, 121 task=context.task, 122 result=result, 123 ) 124 125 # Sync mode — block until the human responds. 126 approved = await self._review_callback(context, result) # type: ignore[misc] 127 if not approved: 128 raise HumanApprovalRequired( 129 f"Human reviewer rejected the result for task: '{context.task}'" 130 ) 131 return result
Pauses or suspends execution after the agent produces a result and waits for human approval before the result is returned.
Two modes:
Sync mode — pass review_callback.
The callback blocks until the human decides, returning True to
approve or False to reject. Suitable for interactive CLIs or
short-lived approval windows.
``async (context, result) -> bool``
Async (offline) mode — pass submit_callback.
The callback fires the approval request (e.g. POST to a REST API or
queue) and returns an opaque approval_id string. The behavior
immediately raises PendingApproval so the agent call
terminates without blocking. The caller is responsible for persisting
the pending state and resuming once the human's decision arrives via
webhook or callback.
``async (context, result) -> str`` (returns ``approval_id``)
Args:
review_callback: Sync-mode callback. Mutually exclusive with
submit_callback.
submit_callback: Async-mode callback. Mutually exclusive with
review_callback.
Raises: ValueError: If neither or both callbacks are supplied.
Example — sync mode::
async def my_reviewer(ctx, result):
answer = input(f"Approve '{ctx.task}'? [y/n]: ")
return answer.strip().lower() == "y"
behavior = HumanInLoopBehavior(review_callback=my_reviewer)
Example — async (offline) mode::
async def submit_for_approval(ctx, result):
resp = await http_client.post("/approvals", json={
"task": ctx.task,
"output": result.output,
})
return resp.json()["approval_id"]
behavior = HumanInLoopBehavior(submit_callback=submit_for_approval)
# Caller handles PendingApproval:
try:
result = await agent.execute(task)
except PendingApproval as pa:
store.save(pa.approval_id, pending_result=pa.result, task=pa.task)
# ... webhook delivers decision later, caller resumes from store
97 def __init__( 98 self, 99 review_callback: Optional[Callable[[BehaviorContext, Any], Awaitable[bool]]] = None, 100 submit_callback: Optional[Callable[[BehaviorContext, Any], Awaitable[str]]] = None, 101 ) -> None: 102 if review_callback is None and submit_callback is None: 103 raise ValueError( 104 "HumanInLoopBehavior requires either 'review_callback' (sync) " 105 "or 'submit_callback' (async/offline)." 106 ) 107 if review_callback is not None and submit_callback is not None: 108 raise ValueError( 109 "HumanInLoopBehavior accepts only one of 'review_callback' or " 110 "'submit_callback', not both." 111 ) 112 self._review_callback = review_callback 113 self._submit_callback = submit_callback
115 async def after_execute(self, context: BehaviorContext, result: Any) -> Any: 116 if self._submit_callback is not None: 117 # Async / offline mode — fire the request, return approval_id, exit immediately. 118 approval_id = await self._submit_callback(context, result) 119 raise PendingApproval( 120 approval_id=approval_id, 121 task=context.task, 122 result=result, 123 ) 124 125 # Sync mode — block until the human responds. 126 approved = await self._review_callback(context, result) # type: ignore[misc] 127 if not approved: 128 raise HumanApprovalRequired( 129 f"Human reviewer rejected the result for task: '{context.task}'" 130 ) 131 return result
Called after the agent successfully executes.
May mutate and return a modified result.
9class HumanApprovalRequired(Exception): 10 """Raised when a human reviewer rejects the agent result."""
Raised when a human reviewer rejects the agent result.
13class PendingApproval(Exception): 14 """ 15 Raised in async (offline) mode to signal that an approval request has been 16 submitted and the agent result is awaiting a human decision. 17 18 The caller should catch this, persist state keyed on ``approval_id``, and 19 resume processing when the approval webhook/callback delivers the decision. 20 21 Attributes: 22 approval_id: Opaque identifier returned by the ``submit_callback``. 23 task: The original task string passed to the agent. 24 result: The ``AgentResult`` produced by the agent, ready to deliver 25 once approved. 26 """ 27 28 def __init__(self, approval_id: str, task: str, result: Any) -> None: 29 self.approval_id = approval_id 30 self.task = task 31 self.result = result 32 super().__init__( 33 f"Approval pending for task: '{task}' (approval_id={approval_id})" 34 )
Raised in async (offline) mode to signal that an approval request has been submitted and the agent result is awaiting a human decision.
The caller should catch this, persist state keyed on approval_id, and
resume processing when the approval webhook/callback delivers the decision.
Attributes:
approval_id: Opaque identifier returned by the submit_callback.
task: The original task string passed to the agent.
result: The AgentResult produced by the agent, ready to deliver
once approved.
21class CircuitBreakerBehavior(BaseBehavior): 22 """ 23 Prevents repeated calls to a failing agent by tripping an open circuit. 24 25 State transitions:: 26 27 CLOSED ──(failures >= threshold)──► OPEN 28 OPEN ──(recovery_timeout elapsed)──► HALF_OPEN 29 HALF_OPEN ──(success)──► CLOSED 30 HALF_OPEN ──(failure)──► OPEN 31 32 Args: 33 failure_threshold: Consecutive failures before opening (default: 5). 34 recovery_timeout: Seconds to wait in OPEN state before probing (default: 60). 35 metrics: Optional :class:`BasicMetricsCollector` for state/failure tracking. 36 logger: Optional :class:`BasicLogger`. 37 38 Example:: 39 40 from gmf_forge_ai_shared_core.observability import BasicMetricsCollector, BasicLogger 41 42 behavior = CircuitBreakerBehavior( 43 failure_threshold=3, 44 recovery_timeout=30, 45 metrics=BasicMetricsCollector(), 46 logger=BasicLogger("circuit_breaker"), 47 ) 48 """ 49 50 def __init__( 51 self, 52 failure_threshold: int = 5, 53 recovery_timeout: float = 60.0, 54 metrics: Optional[Any] = None, 55 logger: Optional[Any] = None, 56 ) -> None: 57 self.failure_threshold = failure_threshold 58 self.recovery_timeout = recovery_timeout 59 self._metrics = metrics 60 self._logger = logger 61 62 self._state = CircuitState.CLOSED 63 self._failure_count = 0 64 self._opened_at: Optional[float] = None 65 66 # ------------------------------------------------------------------ 67 # Internal helpers 68 # ------------------------------------------------------------------ 69 70 def _transition(self, new_state: CircuitState) -> None: 71 if self._state == new_state: 72 return 73 old = self._state.value 74 self._state = new_state 75 if self._logger: 76 self._logger.info( 77 "Circuit state transition", 78 from_state=old, 79 to_state=new_state.value, 80 ) 81 if self._metrics: 82 self._metrics.gauge("circuit_breaker.state", 1.0, state=new_state.value) 83 84 # ------------------------------------------------------------------ 85 # BaseBehavior hooks 86 # ------------------------------------------------------------------ 87 88 async def before_execute(self, context: BehaviorContext) -> BehaviorContext: 89 if self._state == CircuitState.OPEN: 90 elapsed = time.monotonic() - (self._opened_at or 0) 91 if elapsed >= self.recovery_timeout: 92 self._transition(CircuitState.HALF_OPEN) 93 else: 94 raise CircuitOpenError( 95 f"Circuit is OPEN for agent '{context.agent_id}'. " 96 f"Retry in {self.recovery_timeout - elapsed:.1f}s." 97 ) 98 return context 99 100 async def after_execute(self, context: BehaviorContext, result: Any) -> Any: 101 # Success — reset failures, close circuit if probing 102 self._failure_count = 0 103 if self._state == CircuitState.HALF_OPEN: 104 self._transition(CircuitState.CLOSED) 105 return result 106 107 async def on_error( 108 self, context: BehaviorContext, error: Exception 109 ) -> Optional[Any]: 110 self._failure_count += 1 111 if self._metrics: 112 self._metrics.increment("circuit_breaker.failures", agent_id=context.agent_id) 113 114 if self._state == CircuitState.HALF_OPEN or self._failure_count >= self.failure_threshold: 115 self._opened_at = time.monotonic() 116 self._transition(CircuitState.OPEN) 117 118 raise error
Prevents repeated calls to a failing agent by tripping an open circuit.
State transitions::
CLOSED ──(failures >= threshold)──► OPEN
OPEN ──(recovery_timeout elapsed)──► HALF_OPEN
HALF_OPEN ──(success)──► CLOSED
HALF_OPEN ──(failure)──► OPEN
Args:
failure_threshold: Consecutive failures before opening (default: 5).
recovery_timeout: Seconds to wait in OPEN state before probing (default: 60).
metrics: Optional BasicMetricsCollector for state/failure tracking.
logger: Optional BasicLogger.
Example::
from gmf_forge_ai_shared_core.observability import BasicMetricsCollector, BasicLogger
behavior = CircuitBreakerBehavior(
failure_threshold=3,
recovery_timeout=30,
metrics=BasicMetricsCollector(),
logger=BasicLogger("circuit_breaker"),
)
50 def __init__( 51 self, 52 failure_threshold: int = 5, 53 recovery_timeout: float = 60.0, 54 metrics: Optional[Any] = None, 55 logger: Optional[Any] = None, 56 ) -> None: 57 self.failure_threshold = failure_threshold 58 self.recovery_timeout = recovery_timeout 59 self._metrics = metrics 60 self._logger = logger 61 62 self._state = CircuitState.CLOSED 63 self._failure_count = 0 64 self._opened_at: Optional[float] = None
88 async def before_execute(self, context: BehaviorContext) -> BehaviorContext: 89 if self._state == CircuitState.OPEN: 90 elapsed = time.monotonic() - (self._opened_at or 0) 91 if elapsed >= self.recovery_timeout: 92 self._transition(CircuitState.HALF_OPEN) 93 else: 94 raise CircuitOpenError( 95 f"Circuit is OPEN for agent '{context.agent_id}'. " 96 f"Retry in {self.recovery_timeout - elapsed:.1f}s." 97 ) 98 return context
Called before the agent executes its task.
May mutate and return a modified context (e.g. to inject guardrails).
100 async def after_execute(self, context: BehaviorContext, result: Any) -> Any: 101 # Success — reset failures, close circuit if probing 102 self._failure_count = 0 103 if self._state == CircuitState.HALF_OPEN: 104 self._transition(CircuitState.CLOSED) 105 return result
Called after the agent successfully executes.
May mutate and return a modified result.
107 async def on_error( 108 self, context: BehaviorContext, error: Exception 109 ) -> Optional[Any]: 110 self._failure_count += 1 111 if self._metrics: 112 self._metrics.increment("circuit_breaker.failures", agent_id=context.agent_id) 113 114 if self._state == CircuitState.HALF_OPEN or self._failure_count >= self.failure_threshold: 115 self._opened_at = time.monotonic() 116 self._transition(CircuitState.OPEN) 117 118 raise error
Called when the agent raises an exception.
Return a fallback value to suppress the error, or re-raise / return
None to propagate it.
11class CircuitState(Enum): 12 CLOSED = "closed" # Normal operation 13 OPEN = "open" # Blocking all calls 14 HALF_OPEN = "half_open" # Allowing a single probe call
17class CircuitOpenError(Exception): 18 """Raised when a call is attempted while the circuit is OPEN."""
Raised when a call is attempted while the circuit is OPEN.
15class RateLimitBehavior(BaseBehavior): 16 """ 17 Enforces a maximum call rate using a token bucket algorithm. 18 19 Args: 20 calls_per_second: Maximum calls allowed per second (default: 1.0). 21 burst: Maximum burst capacity — tokens that can accumulate while idle 22 (default: equals ``calls_per_second``). 23 wait_on_limit: If ``True`` (default), sleep until a token is available. 24 If ``False``, raise :class:`RateLimitExceededError` immediately. 25 26 Example:: 27 28 behavior = RateLimitBehavior(calls_per_second=2.0, burst=5) 29 """ 30 31 def __init__( 32 self, 33 calls_per_second: float = 1.0, 34 burst: Optional[float] = None, 35 wait_on_limit: bool = True, 36 ) -> None: 37 self.calls_per_second = calls_per_second 38 self.burst = burst if burst is not None else calls_per_second 39 self.wait_on_limit = wait_on_limit 40 41 self._tokens: float = self.burst 42 self._last_refill: float = time.monotonic() 43 self._lock = asyncio.Lock() 44 45 # ------------------------------------------------------------------ 46 # Token bucket 47 # ------------------------------------------------------------------ 48 49 def _refill(self) -> None: 50 now = time.monotonic() 51 elapsed = now - self._last_refill 52 self._tokens = min(self.burst, self._tokens + elapsed * self.calls_per_second) 53 self._last_refill = now 54 55 async def before_execute(self, context: BehaviorContext) -> BehaviorContext: 56 async with self._lock: 57 while True: 58 self._refill() 59 if self._tokens >= 1.0: 60 self._tokens -= 1.0 61 return context 62 63 if not self.wait_on_limit: 64 raise RateLimitExceededError( 65 f"Rate limit of {self.calls_per_second} call/s exceeded " 66 f"for agent '{context.agent_id}'." 67 ) 68 69 wait = (1.0 - self._tokens) / self.calls_per_second 70 await asyncio.sleep(wait)
Enforces a maximum call rate using a token bucket algorithm.
Args:
calls_per_second: Maximum calls allowed per second (default: 1.0).
burst: Maximum burst capacity — tokens that can accumulate while idle
(default: equals calls_per_second).
wait_on_limit: If True (default), sleep until a token is available.
If False, raise RateLimitExceededError immediately.
Example::
behavior = RateLimitBehavior(calls_per_second=2.0, burst=5)
31 def __init__( 32 self, 33 calls_per_second: float = 1.0, 34 burst: Optional[float] = None, 35 wait_on_limit: bool = True, 36 ) -> None: 37 self.calls_per_second = calls_per_second 38 self.burst = burst if burst is not None else calls_per_second 39 self.wait_on_limit = wait_on_limit 40 41 self._tokens: float = self.burst 42 self._last_refill: float = time.monotonic() 43 self._lock = asyncio.Lock()
55 async def before_execute(self, context: BehaviorContext) -> BehaviorContext: 56 async with self._lock: 57 while True: 58 self._refill() 59 if self._tokens >= 1.0: 60 self._tokens -= 1.0 61 return context 62 63 if not self.wait_on_limit: 64 raise RateLimitExceededError( 65 f"Rate limit of {self.calls_per_second} call/s exceeded " 66 f"for agent '{context.agent_id}'." 67 ) 68 69 wait = (1.0 - self._tokens) / self.calls_per_second 70 await asyncio.sleep(wait)
Called before the agent executes its task.
May mutate and return a modified context (e.g. to inject guardrails).
11class RateLimitExceededError(Exception): 12 """Raised when the rate limit is exceeded and wait_on_limit=False."""
Raised when the rate limit is exceeded and wait_on_limit=False.
12class AuditBehavior(BaseBehavior): 13 """ 14 Logs every agent input, output, and error using shared-core observability. 15 16 Uses :class:`BasicLogger` for structured log lines and 17 :class:`BasicMetricsCollector` to track invocation counts. 18 19 Args: 20 logger: :class:`BasicLogger` instance. A default named 21 ``"gmf_forge_ai.audit"`` is created if not provided. 22 metrics: Optional :class:`BasicMetricsCollector` for counting events. 23 24 Example:: 25 26 from gmf_forge_ai_shared_core.observability import BasicLogger, BasicMetricsCollector 27 28 behavior = AuditBehavior( 29 logger=BasicLogger("my_app.audit"), 30 metrics=BasicMetricsCollector(), 31 ) 32 """ 33 34 def __init__( 35 self, 36 logger: Optional[BasicLogger] = None, 37 metrics: Optional[BasicMetricsCollector] = None, 38 ) -> None: 39 self._logger = logger or BasicLogger("gmf_forge_ai.audit") 40 self._metrics = metrics 41 # Stores start timestamps keyed by agent_id+task to compute duration 42 self._start_times: dict = {} 43 44 async def before_execute(self, context: BehaviorContext) -> BehaviorContext: 45 key = f"{context.agent_id}:{context.task}" 46 self._start_times[key] = time.monotonic() 47 self._logger.info( 48 "Agent task started", 49 agent_id=context.agent_id, 50 task=context.task, 51 attempt=context.attempt, 52 ) 53 if self._metrics: 54 self._metrics.increment("behavior.audit.invocations", agent_id=context.agent_id) 55 return context 56 57 async def after_execute(self, context: BehaviorContext, result: Any) -> Any: 58 key = f"{context.agent_id}:{context.task}" 59 duration_ms = round((time.monotonic() - self._start_times.pop(key, time.monotonic())) * 1000) 60 output = result.output if hasattr(result, "output") else str(result) 61 self._logger.info( 62 "Agent task completed", 63 agent_id=context.agent_id, 64 task=context.task, 65 duration_ms=duration_ms, 66 success=True, 67 output_preview=output[:200], 68 ) 69 if self._metrics: 70 self._metrics.histogram("behavior.audit.duration_ms", duration_ms, agent_id=context.agent_id) 71 return result 72 73 async def on_error( 74 self, context: BehaviorContext, error: Exception 75 ) -> Optional[Any]: 76 key = f"{context.agent_id}:{context.task}" 77 duration_ms = round((time.monotonic() - self._start_times.pop(key, time.monotonic())) * 1000) 78 self._logger.error( 79 "Agent task failed", 80 agent_id=context.agent_id, 81 task=context.task, 82 duration_ms=duration_ms, 83 error=str(error), 84 error_type=type(error).__name__, 85 ) 86 if self._metrics: 87 self._metrics.increment("behavior.audit.errors", agent_id=context.agent_id) 88 return None
Logs every agent input, output, and error using shared-core observability.
Uses BasicLogger for structured log lines and
BasicMetricsCollector to track invocation counts.
Args:
logger: BasicLogger instance. A default named
"gmf_forge_ai.audit" is created if not provided.
metrics: Optional BasicMetricsCollector for counting events.
Example::
from gmf_forge_ai_shared_core.observability import BasicLogger, BasicMetricsCollector
behavior = AuditBehavior(
logger=BasicLogger("my_app.audit"),
metrics=BasicMetricsCollector(),
)
34 def __init__( 35 self, 36 logger: Optional[BasicLogger] = None, 37 metrics: Optional[BasicMetricsCollector] = None, 38 ) -> None: 39 self._logger = logger or BasicLogger("gmf_forge_ai.audit") 40 self._metrics = metrics 41 # Stores start timestamps keyed by agent_id+task to compute duration 42 self._start_times: dict = {}
44 async def before_execute(self, context: BehaviorContext) -> BehaviorContext: 45 key = f"{context.agent_id}:{context.task}" 46 self._start_times[key] = time.monotonic() 47 self._logger.info( 48 "Agent task started", 49 agent_id=context.agent_id, 50 task=context.task, 51 attempt=context.attempt, 52 ) 53 if self._metrics: 54 self._metrics.increment("behavior.audit.invocations", agent_id=context.agent_id) 55 return context
Called before the agent executes its task.
May mutate and return a modified context (e.g. to inject guardrails).
57 async def after_execute(self, context: BehaviorContext, result: Any) -> Any: 58 key = f"{context.agent_id}:{context.task}" 59 duration_ms = round((time.monotonic() - self._start_times.pop(key, time.monotonic())) * 1000) 60 output = result.output if hasattr(result, "output") else str(result) 61 self._logger.info( 62 "Agent task completed", 63 agent_id=context.agent_id, 64 task=context.task, 65 duration_ms=duration_ms, 66 success=True, 67 output_preview=output[:200], 68 ) 69 if self._metrics: 70 self._metrics.histogram("behavior.audit.duration_ms", duration_ms, agent_id=context.agent_id) 71 return result
Called after the agent successfully executes.
May mutate and return a modified result.
73 async def on_error( 74 self, context: BehaviorContext, error: Exception 75 ) -> Optional[Any]: 76 key = f"{context.agent_id}:{context.task}" 77 duration_ms = round((time.monotonic() - self._start_times.pop(key, time.monotonic())) * 1000) 78 self._logger.error( 79 "Agent task failed", 80 agent_id=context.agent_id, 81 task=context.task, 82 duration_ms=duration_ms, 83 error=str(error), 84 error_type=type(error).__name__, 85 ) 86 if self._metrics: 87 self._metrics.increment("behavior.audit.errors", agent_id=context.agent_id) 88 return None
Called when the agent raises an exception.
Return a fallback value to suppress the error, or re-raise / return
None to propagate it.
19class AgentDiscoveryBehavior: 20 """Refreshes supervisor agent registry on a fixed interval. 21 22 This behavior runs as an asyncio background task. Each refresh fetches the 23 latest ``(agents, descriptions)`` from ``discovery_fn`` and atomically 24 updates the provided ``SupervisorOrchestrator`` instance. 25 """ 26 27 def __init__( 28 self, 29 discovery_fn: DiscoveryFn, 30 interval_seconds: float = 30.0, 31 logger: Optional[BasicLogger] = None, 32 ) -> None: 33 if interval_seconds <= 0: 34 raise ValueError("interval_seconds must be greater than 0") 35 36 self.discovery_fn = discovery_fn 37 self.interval_seconds = interval_seconds 38 self._logger = logger or BasicLogger("gmf_forge_ai.behavior.AgentDiscoveryBehavior") 39 40 self._task: Optional[asyncio.Task] = None 41 self._supervisor: Optional[SupervisorOrchestrator] = None 42 43 async def refresh_once(self) -> None: 44 """Run one discovery cycle and update the supervisor in-place.""" 45 if self._supervisor is None: 46 raise RuntimeError("AgentDiscoveryBehavior.start() must be called first") 47 48 new_agents, new_descriptions = await self.discovery_fn() 49 50 current_agents = set(self._supervisor.agents.keys()) 51 discovered_agents = set(new_agents.keys()) 52 added = sorted(discovered_agents - current_agents) 53 removed = sorted(current_agents - discovered_agents) 54 55 if not added and not removed: 56 return 57 58 self._supervisor.agents = new_agents 59 self._supervisor.agent_descriptions = new_descriptions 60 61 router = self._supervisor.router 62 if router is not None and hasattr(router, "agent_descriptions"): 63 router.agent_descriptions = new_descriptions 64 if router is not None and hasattr(router, "fallback_target"): 65 if new_agents: 66 fallback = getattr(router, "fallback_target", None) 67 if fallback not in new_agents: 68 router.fallback_target = next(iter(new_agents)) 69 else: 70 router.fallback_target = None 71 72 self._logger.info( 73 "Agent registry updated via periodic discovery", 74 added=added, 75 removed=removed, 76 total_agents=len(new_agents), 77 ) 78 79 async def start(self, supervisor: SupervisorOrchestrator) -> None: 80 """Start periodic discovery for the given supervisor.""" 81 if self._task is not None and not self._task.done(): 82 return 83 84 self._supervisor = supervisor 85 86 async def _runner() -> None: 87 while True: 88 try: 89 await asyncio.sleep(self.interval_seconds) 90 await self.refresh_once() 91 except asyncio.CancelledError: 92 raise 93 except Exception as exc: 94 self._logger.warning( 95 "Error during periodic agent discovery", 96 error=str(exc), 97 ) 98 99 self._task = asyncio.create_task(_runner()) 100 101 async def stop(self) -> None: 102 """Stop periodic discovery and wait for task cancellation.""" 103 if self._task is None: 104 return 105 106 self._task.cancel() 107 try: 108 await self._task 109 except asyncio.CancelledError: 110 pass 111 finally: 112 self._task = None
Refreshes supervisor agent registry on a fixed interval.
This behavior runs as an asyncio background task. Each refresh fetches the
latest (agents, descriptions) from discovery_fn and atomically
updates the provided SupervisorOrchestrator instance.
27 def __init__( 28 self, 29 discovery_fn: DiscoveryFn, 30 interval_seconds: float = 30.0, 31 logger: Optional[BasicLogger] = None, 32 ) -> None: 33 if interval_seconds <= 0: 34 raise ValueError("interval_seconds must be greater than 0") 35 36 self.discovery_fn = discovery_fn 37 self.interval_seconds = interval_seconds 38 self._logger = logger or BasicLogger("gmf_forge_ai.behavior.AgentDiscoveryBehavior") 39 40 self._task: Optional[asyncio.Task] = None 41 self._supervisor: Optional[SupervisorOrchestrator] = None
43 async def refresh_once(self) -> None: 44 """Run one discovery cycle and update the supervisor in-place.""" 45 if self._supervisor is None: 46 raise RuntimeError("AgentDiscoveryBehavior.start() must be called first") 47 48 new_agents, new_descriptions = await self.discovery_fn() 49 50 current_agents = set(self._supervisor.agents.keys()) 51 discovered_agents = set(new_agents.keys()) 52 added = sorted(discovered_agents - current_agents) 53 removed = sorted(current_agents - discovered_agents) 54 55 if not added and not removed: 56 return 57 58 self._supervisor.agents = new_agents 59 self._supervisor.agent_descriptions = new_descriptions 60 61 router = self._supervisor.router 62 if router is not None and hasattr(router, "agent_descriptions"): 63 router.agent_descriptions = new_descriptions 64 if router is not None and hasattr(router, "fallback_target"): 65 if new_agents: 66 fallback = getattr(router, "fallback_target", None) 67 if fallback not in new_agents: 68 router.fallback_target = next(iter(new_agents)) 69 else: 70 router.fallback_target = None 71 72 self._logger.info( 73 "Agent registry updated via periodic discovery", 74 added=added, 75 removed=removed, 76 total_agents=len(new_agents), 77 )
Run one discovery cycle and update the supervisor in-place.
79 async def start(self, supervisor: SupervisorOrchestrator) -> None: 80 """Start periodic discovery for the given supervisor.""" 81 if self._task is not None and not self._task.done(): 82 return 83 84 self._supervisor = supervisor 85 86 async def _runner() -> None: 87 while True: 88 try: 89 await asyncio.sleep(self.interval_seconds) 90 await self.refresh_once() 91 except asyncio.CancelledError: 92 raise 93 except Exception as exc: 94 self._logger.warning( 95 "Error during periodic agent discovery", 96 error=str(exc), 97 ) 98 99 self._task = asyncio.create_task(_runner())
Start periodic discovery for the given supervisor.
101 async def stop(self) -> None: 102 """Stop periodic discovery and wait for task cancellation.""" 103 if self._task is None: 104 return 105 106 self._task.cancel() 107 try: 108 await self._task 109 except asyncio.CancelledError: 110 pass 111 finally: 112 self._task = None
Stop periodic discovery and wait for task cancellation.
301class OBOTokenBehavior(BaseBehavior): 302 """Exchange a user's access token for a downstream OBO token before execution. 303 304 Reads ``context.metadata["user_assertion"]`` (the user's raw access token, 305 injected by the FastAPI layer), calls :meth:`OBOTokenProvider.exchange`, 306 and stores the result in ``context.metadata[token_metadata_key]``. 307 308 The token is also written back into the ``BehaviorContext`` so the agent 309 and its tools can retrieve it via ``context.metadata["obo_token"]``. 310 311 Args: 312 provider: An :class:`OBOTokenProvider` instance (Entra or Okta). 313 token_metadata_key: Key under which the exchanged token is stored in 314 ``BehaviorContext.metadata``. Defaults to ``"obo_token"``. 315 316 Raises: 317 ValueError: If ``user_assertion`` is absent from ``context.metadata``. 318 OBOTokenError: If the token exchange fails. 319 320 Example:: 321 322 behavior = OBOTokenBehavior( 323 provider=EntraOBOProvider( 324 tenant_id="...", 325 client_id="...", 326 client_secret="...", 327 scopes=["api://servicenow/.default"], 328 ) 329 ) 330 agent = ReActAgent(llm_gateway=gateway, behaviors=[behavior]) 331 332 # FastAPI layer passes the user's token in context 333 result = await agent.execute(task, context={"user_assertion": bearer_token}) 334 """ 335 336 def __init__( 337 self, 338 provider: OBOTokenProvider, 339 token_metadata_key: str = "obo_token", 340 ) -> None: 341 self._provider = provider 342 self._token_metadata_key = token_metadata_key 343 self._last_token: Optional[str] = None 344 345 async def before_execute(self, context: BehaviorContext) -> BehaviorContext: 346 user_assertion: Optional[str] = context.metadata.get("user_assertion") 347 if not user_assertion: 348 raise ValueError( 349 "OBOTokenBehavior requires 'user_assertion' in context.metadata. " 350 "Inject the user's access token before calling agent.execute()." 351 ) 352 353 try: 354 obo_token = await self._provider.exchange(user_assertion) 355 except OBOTokenError: 356 raise 357 except Exception as exc: 358 raise OBOTokenError( 359 f"Unexpected error during token exchange: {exc}", 360 self._provider.provider_name, 361 ) from exc 362 363 self._last_token = obo_token 364 current_obo_token.set(obo_token) 365 logger.debug( 366 "OBOTokenBehavior: token exchanged and stored in current_obo_token context var", 367 provider=self._provider.provider_name, 368 key=self._token_metadata_key, 369 ) 370 context.metadata[self._token_metadata_key] = obo_token 371 return context 372 373 def debug_token(self, mask: bool = False) -> str: 374 """Return the OBO token for debugging purposes. 375 376 Intended for local development only. Do NOT call this in production 377 code or log the return value to any persistent log sink — bearer 378 tokens grant full user-level access to downstream APIs. 379 380 Args: 381 mask: If ``True``, returns a masked version showing only the first 382 8 and last 4 characters (e.g. ``"eyJhbGci...a1b2"``). 383 If ``False`` (default), returns the full token. 384 385 Usage:: 386 387 behavior = OBOTokenBehavior(provider=provider) 388 # after agent.execute() has run: 389 print(behavior.debug_token()) # full token 390 print(behavior.debug_token(mask=True)) # eyJhbGci...a1b2 391 392 Returns: 393 The OBO token string (full or masked), or ``"<not set>"`` if the 394 token has not been exchanged yet. 395 """ 396 if not self._last_token: 397 return "<not set>" 398 if mask: 399 if len(self._last_token) <= 12: 400 return "***" 401 return f"{self._last_token[:8]}...{self._last_token[-4:]}" 402 return self._last_token 403 404 async def on_error( 405 self, context: BehaviorContext, error: Exception 406 ) -> Optional[Any]: 407 # OBOTokenError and ValueError (missing assertion) are not retryable — 408 # re-raise immediately so RetryBehavior does not attempt to retry them. 409 if isinstance(error, (OBOTokenError, ValueError)): 410 raise error 411 return None
Exchange a user's access token for a downstream OBO token before execution.
Reads context.metadata["user_assertion"] (the user's raw access token,
injected by the FastAPI layer), calls OBOTokenProvider.exchange(),
and stores the result in context.metadata[token_metadata_key].
The token is also written back into the BehaviorContext so the agent
and its tools can retrieve it via context.metadata["obo_token"].
Args:
provider: An OBOTokenProvider instance (Entra or Okta).
token_metadata_key: Key under which the exchanged token is stored in
BehaviorContext.metadata. Defaults to "obo_token".
Raises:
ValueError: If user_assertion is absent from context.metadata.
OBOTokenError: If the token exchange fails.
Example::
behavior = OBOTokenBehavior(
provider=EntraOBOProvider(
tenant_id="...",
client_id="...",
client_secret="...",
scopes=["api://servicenow/.default"],
)
)
agent = ReActAgent(llm_gateway=gateway, behaviors=[behavior])
# FastAPI layer passes the user's token in context
result = await agent.execute(task, context={"user_assertion": bearer_token})
345 async def before_execute(self, context: BehaviorContext) -> BehaviorContext: 346 user_assertion: Optional[str] = context.metadata.get("user_assertion") 347 if not user_assertion: 348 raise ValueError( 349 "OBOTokenBehavior requires 'user_assertion' in context.metadata. " 350 "Inject the user's access token before calling agent.execute()." 351 ) 352 353 try: 354 obo_token = await self._provider.exchange(user_assertion) 355 except OBOTokenError: 356 raise 357 except Exception as exc: 358 raise OBOTokenError( 359 f"Unexpected error during token exchange: {exc}", 360 self._provider.provider_name, 361 ) from exc 362 363 self._last_token = obo_token 364 current_obo_token.set(obo_token) 365 logger.debug( 366 "OBOTokenBehavior: token exchanged and stored in current_obo_token context var", 367 provider=self._provider.provider_name, 368 key=self._token_metadata_key, 369 ) 370 context.metadata[self._token_metadata_key] = obo_token 371 return context
Called before the agent executes its task.
May mutate and return a modified context (e.g. to inject guardrails).
373 def debug_token(self, mask: bool = False) -> str: 374 """Return the OBO token for debugging purposes. 375 376 Intended for local development only. Do NOT call this in production 377 code or log the return value to any persistent log sink — bearer 378 tokens grant full user-level access to downstream APIs. 379 380 Args: 381 mask: If ``True``, returns a masked version showing only the first 382 8 and last 4 characters (e.g. ``"eyJhbGci...a1b2"``). 383 If ``False`` (default), returns the full token. 384 385 Usage:: 386 387 behavior = OBOTokenBehavior(provider=provider) 388 # after agent.execute() has run: 389 print(behavior.debug_token()) # full token 390 print(behavior.debug_token(mask=True)) # eyJhbGci...a1b2 391 392 Returns: 393 The OBO token string (full or masked), or ``"<not set>"`` if the 394 token has not been exchanged yet. 395 """ 396 if not self._last_token: 397 return "<not set>" 398 if mask: 399 if len(self._last_token) <= 12: 400 return "***" 401 return f"{self._last_token[:8]}...{self._last_token[-4:]}" 402 return self._last_token
Return the OBO token for debugging purposes.
Intended for local development only. Do NOT call this in production code or log the return value to any persistent log sink — bearer tokens grant full user-level access to downstream APIs.
Args:
mask: If True, returns a masked version showing only the first
8 and last 4 characters (e.g. "eyJhbGci...a1b2").
If False (default), returns the full token.
Usage::
behavior = OBOTokenBehavior(provider=provider)
# after agent.execute() has run:
print(behavior.debug_token()) # full token
print(behavior.debug_token(mask=True)) # eyJhbGci...a1b2
Returns:
The OBO token string (full or masked), or "<not set>" if the
token has not been exchanged yet.
404 async def on_error( 405 self, context: BehaviorContext, error: Exception 406 ) -> Optional[Any]: 407 # OBOTokenError and ValueError (missing assertion) are not retryable — 408 # re-raise immediately so RetryBehavior does not attempt to retry them. 409 if isinstance(error, (OBOTokenError, ValueError)): 410 raise error 411 return None
Called when the agent raises an exception.
Return a fallback value to suppress the error, or re-raise / return
None to propagate it.
104class OBOTokenProvider(ABC): 105 """Abstract base for OBO token exchange providers. 106 107 Subclass this to support additional identity providers (PingFederate, 108 Auth0, etc.) without modifying :class:`OBOTokenBehavior`. 109 """ 110 111 @abstractmethod 112 async def exchange(self, user_assertion: str) -> str: 113 """Exchange a user access token for a downstream OBO token. 114 115 Args: 116 user_assertion: The user's existing access token. 117 118 Returns: 119 A downstream access token scoped to the configured resource. 120 121 Raises: 122 OBOTokenError: If the exchange fails for any reason. 123 """ 124 125 @property 126 def provider_name(self) -> str: 127 return self.__class__.__name__
Abstract base for OBO token exchange providers.
Subclass this to support additional identity providers (PingFederate,
Auth0, etc.) without modifying OBOTokenBehavior.
111 @abstractmethod 112 async def exchange(self, user_assertion: str) -> str: 113 """Exchange a user access token for a downstream OBO token. 114 115 Args: 116 user_assertion: The user's existing access token. 117 118 Returns: 119 A downstream access token scoped to the configured resource. 120 121 Raises: 122 OBOTokenError: If the exchange fails for any reason. 123 """
Exchange a user access token for a downstream OBO token.
Args: user_assertion: The user's existing access token.
Returns: A downstream access token scoped to the configured resource.
Raises: OBOTokenError: If the exchange fails for any reason.
135class EntraOBOProvider(OBOTokenProvider): 136 """Microsoft Entra ID On-Behalf-Of token exchange using MSAL. 137 138 Requires the ``msal`` package (``pip install msal``). 139 140 The agent's service principal (``client_id`` / ``client_secret``) must 141 have been granted permission to perform OBO exchanges for the downstream 142 resource scopes. 143 144 Args: 145 tenant_id: Azure AD tenant ID. 146 client_id: Application (client) ID of the agent's service principal. 147 client_secret: Client secret for the agent's service principal. 148 scopes: List of downstream resource scopes. 149 Example: ``["api://servicenow/.default"]`` 150 """ 151 152 def __init__( 153 self, 154 tenant_id: str, 155 client_id: str, 156 client_secret: str, 157 scopes: List[str], 158 ) -> None: 159 self._tenant_id = tenant_id 160 self._client_id = client_id 161 self._client_secret = client_secret 162 self._scopes = scopes 163 self._cache: Dict[str, str] = {} # assertion_hash -> access_token 164 165 async def exchange(self, user_assertion: str) -> str: 166 cache_key = hashlib.sha256(user_assertion.encode()).hexdigest() 167 if cache_key in self._cache: 168 return self._cache[cache_key] 169 170 try: 171 import msal # lazy import — not required unless Entra is used 172 except ImportError as exc: 173 raise OBOTokenError( 174 "msal package is required for EntraOBOProvider. " 175 "Install it with: pip install msal", 176 self.provider_name, 177 ) from exc 178 179 authority = f"https://login.microsoftonline.com/{self._tenant_id}" 180 app = msal.ConfidentialClientApplication( 181 client_id=self._client_id, 182 client_credential=self._client_secret, 183 authority=authority, 184 ) 185 186 result = app.acquire_token_on_behalf_of( 187 user_assertion=user_assertion, 188 scopes=self._scopes, 189 ) 190 191 if "access_token" not in result: 192 error_desc = result.get("error_description") or result.get("error", "unknown") 193 raise OBOTokenError( 194 f"Token exchange failed: {error_desc}", 195 self.provider_name, 196 ) 197 198 token: str = result["access_token"] 199 self._cache[cache_key] = token 200 return token
Microsoft Entra ID On-Behalf-Of token exchange using MSAL.
Requires the msal package (pip install msal).
The agent's service principal (client_id / client_secret) must
have been granted permission to perform OBO exchanges for the downstream
resource scopes.
Args:
tenant_id: Azure AD tenant ID.
client_id: Application (client) ID of the agent's service principal.
client_secret: Client secret for the agent's service principal.
scopes: List of downstream resource scopes.
Example: ["api://servicenow/.default"]
152 def __init__( 153 self, 154 tenant_id: str, 155 client_id: str, 156 client_secret: str, 157 scopes: List[str], 158 ) -> None: 159 self._tenant_id = tenant_id 160 self._client_id = client_id 161 self._client_secret = client_secret 162 self._scopes = scopes 163 self._cache: Dict[str, str] = {} # assertion_hash -> access_token
165 async def exchange(self, user_assertion: str) -> str: 166 cache_key = hashlib.sha256(user_assertion.encode()).hexdigest() 167 if cache_key in self._cache: 168 return self._cache[cache_key] 169 170 try: 171 import msal # lazy import — not required unless Entra is used 172 except ImportError as exc: 173 raise OBOTokenError( 174 "msal package is required for EntraOBOProvider. " 175 "Install it with: pip install msal", 176 self.provider_name, 177 ) from exc 178 179 authority = f"https://login.microsoftonline.com/{self._tenant_id}" 180 app = msal.ConfidentialClientApplication( 181 client_id=self._client_id, 182 client_credential=self._client_secret, 183 authority=authority, 184 ) 185 186 result = app.acquire_token_on_behalf_of( 187 user_assertion=user_assertion, 188 scopes=self._scopes, 189 ) 190 191 if "access_token" not in result: 192 error_desc = result.get("error_description") or result.get("error", "unknown") 193 raise OBOTokenError( 194 f"Token exchange failed: {error_desc}", 195 self.provider_name, 196 ) 197 198 token: str = result["access_token"] 199 self._cache[cache_key] = token 200 return token
Exchange a user access token for a downstream OBO token.
Args: user_assertion: The user's existing access token.
Returns: A downstream access token scoped to the configured resource.
Raises: OBOTokenError: If the exchange fails for any reason.
208class OktaOBOProvider(OBOTokenProvider): 209 """Okta token exchange using RFC 8693 (token-exchange grant type). 210 211 Requires the ``httpx`` package (already a dependency of this library). 212 213 The Okta authorization server must be configured to allow the 214 ``urn:ietf:params:oauth:grant-type:token-exchange`` grant type for the 215 agent's client application. 216 217 Args: 218 domain: Okta domain, e.g. ``"company.okta.com"``. 219 client_id: Okta application client ID for the agent's service principal. 220 client_secret: Okta application client secret. 221 scopes: List of downstream scopes, e.g. ``["servicenow.write"]``. 222 authorization_server_id: Okta authorization server ID. 223 Defaults to ``"default"`` (the org authorization server). 224 """ 225 226 def __init__( 227 self, 228 domain: str, 229 client_id: str, 230 client_secret: str, 231 scopes: List[str], 232 authorization_server_id: str = "default", 233 ) -> None: 234 self._domain = domain.rstrip("/") 235 self._client_id = client_id 236 self._client_secret = client_secret 237 self._scopes = scopes 238 self._authorization_server_id = authorization_server_id 239 self._cache: Dict[str, str] = {} # assertion_hash -> access_token 240 241 async def exchange(self, user_assertion: str) -> str: 242 cache_key = hashlib.sha256(user_assertion.encode()).hexdigest() 243 if cache_key in self._cache: 244 return self._cache[cache_key] 245 246 import httpx # already a library dependency 247 248 token_url = ( 249 f"https://{self._domain}/oauth2/{self._authorization_server_id}/v1/token" 250 ) 251 252 payload = { 253 "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", 254 "subject_token": user_assertion, 255 "subject_token_type": "urn:ietf:params:oauth:token-type:access_token", 256 "scope": " ".join(self._scopes), 257 } 258 259 try: 260 async with httpx.AsyncClient() as client: 261 response = await client.post( 262 token_url, 263 data=payload, 264 auth=(self._client_id, self._client_secret), 265 headers={"Accept": "application/json"}, 266 ) 267 except httpx.HTTPError as exc: 268 raise OBOTokenError( 269 f"HTTP error during token exchange: {exc}", 270 self.provider_name, 271 ) from exc 272 273 if response.status_code != 200: 274 try: 275 body: Any = response.json() 276 error_desc = body.get("error_description") or body.get("error", response.text) 277 except Exception: 278 error_desc = response.text 279 raise OBOTokenError( 280 f"Token exchange returned HTTP {response.status_code}: {error_desc}", 281 self.provider_name, 282 ) 283 284 data = response.json() 285 if "access_token" not in data: 286 raise OBOTokenError( 287 f"Token exchange response missing access_token: {data}", 288 self.provider_name, 289 ) 290 291 token: str = data["access_token"] 292 self._cache[cache_key] = token 293 return token
Okta token exchange using RFC 8693 (token-exchange grant type).
Requires the httpx package (already a dependency of this library).
The Okta authorization server must be configured to allow the
urn:ietf:params:oauth:grant-type:token-exchange grant type for the
agent's client application.
Args:
domain: Okta domain, e.g. "company.okta.com".
client_id: Okta application client ID for the agent's service principal.
client_secret: Okta application client secret.
scopes: List of downstream scopes, e.g. ["servicenow.write"].
authorization_server_id: Okta authorization server ID.
Defaults to "default" (the org authorization server).
226 def __init__( 227 self, 228 domain: str, 229 client_id: str, 230 client_secret: str, 231 scopes: List[str], 232 authorization_server_id: str = "default", 233 ) -> None: 234 self._domain = domain.rstrip("/") 235 self._client_id = client_id 236 self._client_secret = client_secret 237 self._scopes = scopes 238 self._authorization_server_id = authorization_server_id 239 self._cache: Dict[str, str] = {} # assertion_hash -> access_token
241 async def exchange(self, user_assertion: str) -> str: 242 cache_key = hashlib.sha256(user_assertion.encode()).hexdigest() 243 if cache_key in self._cache: 244 return self._cache[cache_key] 245 246 import httpx # already a library dependency 247 248 token_url = ( 249 f"https://{self._domain}/oauth2/{self._authorization_server_id}/v1/token" 250 ) 251 252 payload = { 253 "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", 254 "subject_token": user_assertion, 255 "subject_token_type": "urn:ietf:params:oauth:token-type:access_token", 256 "scope": " ".join(self._scopes), 257 } 258 259 try: 260 async with httpx.AsyncClient() as client: 261 response = await client.post( 262 token_url, 263 data=payload, 264 auth=(self._client_id, self._client_secret), 265 headers={"Accept": "application/json"}, 266 ) 267 except httpx.HTTPError as exc: 268 raise OBOTokenError( 269 f"HTTP error during token exchange: {exc}", 270 self.provider_name, 271 ) from exc 272 273 if response.status_code != 200: 274 try: 275 body: Any = response.json() 276 error_desc = body.get("error_description") or body.get("error", response.text) 277 except Exception: 278 error_desc = response.text 279 raise OBOTokenError( 280 f"Token exchange returned HTTP {response.status_code}: {error_desc}", 281 self.provider_name, 282 ) 283 284 data = response.json() 285 if "access_token" not in data: 286 raise OBOTokenError( 287 f"Token exchange response missing access_token: {data}", 288 self.provider_name, 289 ) 290 291 token: str = data["access_token"] 292 self._cache[cache_key] = token 293 return token
Exchange a user access token for a downstream OBO token.
Args: user_assertion: The user's existing access token.
Returns: A downstream access token scoped to the configured resource.
Raises: OBOTokenError: If the exchange fails for any reason.
86class OBOTokenError(Exception): 87 """Raised when an OBO token exchange fails. 88 89 Args: 90 message: Human-readable description of the failure. 91 provider_name: Name of the provider that failed (e.g. ``"EntraOBOProvider"``). 92 """ 93 94 def __init__(self, message: str, provider_name: str) -> None: 95 super().__init__(f"[{provider_name}] {message}") 96 self.provider_name = provider_name
Raised when an OBO token exchange fails.
Args:
message: Human-readable description of the failure.
provider_name: Name of the provider that failed (e.g. "EntraOBOProvider").