gmf_forge_ai_orchestration.workflows
Workflow engines — DAG, state machine, and event-driven.
1"""Workflow engines — DAG, state machine, and event-driven.""" 2 3from gmf_forge_ai_orchestration.workflows.base import ( 4 BaseWorkflow, 5 WorkflowEdge, 6 WorkflowNode, 7 WorkflowResult, 8) 9from gmf_forge_ai_orchestration.workflows.dag_workflow import DAGWorkflow 10from gmf_forge_ai_orchestration.workflows.state_machine_workflow import StateMachineWorkflow 11from gmf_forge_ai_orchestration.workflows.event_driven_workflow import ( 12 EventDrivenWorkflow, 13 WorkflowEvent, 14) 15 16__all__ = [ 17 "BaseWorkflow", 18 "WorkflowEdge", 19 "WorkflowNode", 20 "WorkflowResult", 21 "DAGWorkflow", 22 "StateMachineWorkflow", 23 "EventDrivenWorkflow", 24 "WorkflowEvent", 25]
52class BaseWorkflow(ABC): 53 """ 54 Abstract base class for all workflow engines. 55 56 Args: 57 logger: Optional :class:`BasicLogger`. 58 metrics: Optional :class:`BasicMetricsCollector`. 59 tracer: Optional :class:`TracingProvider`. Falls back to ``get_tracer()``. 60 """ 61 62 def __init__( 63 self, 64 logger: Optional[BasicLogger] = None, 65 metrics: Optional[BasicMetricsCollector] = None, 66 tracer: Optional[TracingProvider] = None, 67 ) -> None: 68 self._logger = logger or BasicLogger(f"gmf_forge_ai.workflow.{self.__class__.__name__}") 69 self._metrics = metrics 70 self._tracer = tracer or get_tracer() 71 72 @abstractmethod 73 async def run(self, initial_input: Dict[str, Any]) -> WorkflowResult: 74 """Execute the workflow and return the aggregated result.""" 75 76 def _log_node_start(self, node_id: str) -> None: 77 self._logger.info("Workflow node started", node_id=node_id) 78 79 def _log_node_end(self, node_id: str, success: bool) -> None: 80 self._logger.info("Workflow node finished", node_id=node_id, success=success) 81 if self._metrics: 82 self._metrics.increment("workflow.nodes_executed", node_id=node_id, success=str(success))
Abstract base class for all workflow engines.
Args:
logger: Optional BasicLogger.
metrics: Optional BasicMetricsCollector.
tracer: Optional TracingProvider. Falls back to get_tracer().
28@dataclass 29class WorkflowEdge: 30 """A directed edge between two workflow nodes.""" 31 32 source: str 33 """node_id of the source node.""" 34 target: str 35 """node_id of the target node.""" 36 condition: Optional[Callable[["AgentResult"], bool]] = None 37 """Optional guard — edge is only traversed if this returns True."""
A directed edge between two workflow nodes.
15@dataclass 16class WorkflowNode: 17 """A single node in a workflow graph.""" 18 19 node_id: str 20 agent: "BaseAgent" 21 inputs_map: Dict[str, str] = field(default_factory=dict) 22 """Maps node input keys to keys from the initial_input dict or prior node outputs.""" 23 outputs_map: Dict[str, str] = field(default_factory=dict) 24 """Renames node output keys before storing in the accumulated output dict.""" 25 metadata: Dict[str, Any] = field(default_factory=dict)
A single node in a workflow graph.
40@dataclass 41class WorkflowResult: 42 """The aggregated result of a workflow run.""" 43 44 outputs: Dict[str, "AgentResult"] = field(default_factory=dict) 45 """Keyed by node_id.""" 46 final_output: str = "" 47 success: bool = True 48 error: Optional[str] = None 49 metadata: Dict[str, Any] = field(default_factory=dict)
The aggregated result of a workflow run.
18class DAGWorkflow(BaseWorkflow): 19 """ 20 Executes workflow nodes in topological order using ``networkx``. 21 22 Nodes with no mutual dependency are executed concurrently with 23 ``asyncio.gather``. 24 25 Args: 26 nodes: List of :class:`WorkflowNode` objects. 27 edges: List of :class:`WorkflowEdge` objects. 28 logger, metrics, tracer: Observability (optional). 29 30 Example:: 31 32 wf = DAGWorkflow( 33 nodes=[search_node, summarise_node], 34 edges=[WorkflowEdge(source="search", target="summarise")], 35 ) 36 result = await wf.run({"query": "AI trends 2026"}) 37 """ 38 39 def __init__( 40 self, 41 nodes: Optional[List[WorkflowNode]] = None, 42 edges: Optional[List[WorkflowEdge]] = None, 43 logger: Optional[BasicLogger] = None, 44 metrics: Optional[BasicMetricsCollector] = None, 45 tracer: Optional[TracingProvider] = None, 46 ) -> None: 47 super().__init__(logger=logger, metrics=metrics, tracer=tracer) 48 self._nodes: Dict[str, WorkflowNode] = {n.node_id: n for n in (nodes or [])} 49 self._edges: List[WorkflowEdge] = edges or [] 50 51 def add_node(self, node: WorkflowNode) -> None: 52 self._nodes[node.node_id] = node 53 54 def add_edge(self, edge: WorkflowEdge) -> None: 55 self._edges.append(edge) 56 57 # ------------------------------------------------------------------ 58 # Topological sort (no networkx dependency — implemented manually) 59 # ------------------------------------------------------------------ 60 61 def _build_adjacency(self) -> tuple[Dict[str, List[str]], Dict[str, int]]: 62 """Returns (successors dict, in-degree dict).""" 63 successors: Dict[str, List[str]] = {nid: [] for nid in self._nodes} 64 in_degree: Dict[str, int] = {nid: 0 for nid in self._nodes} 65 for edge in self._edges: 66 if edge.source in successors and edge.target in self._nodes: 67 successors[edge.source].append(edge.target) 68 in_degree[edge.target] += 1 69 return successors, in_degree 70 71 def _topological_levels(self) -> List[List[str]]: 72 """ 73 Returns node IDs grouped into levels. All nodes in a level can run in 74 parallel because none depend on each other. 75 """ 76 successors, in_degree = self._build_adjacency() 77 levels: List[List[str]] = [] 78 ready: List[str] = [nid for nid, deg in in_degree.items() if deg == 0] 79 80 while ready: 81 levels.append(list(ready)) 82 next_ready: List[str] = [] 83 for nid in ready: 84 for successor in successors[nid]: 85 in_degree[successor] -= 1 86 if in_degree[successor] == 0: 87 next_ready.append(successor) 88 ready = next_ready 89 90 if sum(len(lvl) for lvl in levels) < len(self._nodes): 91 raise ValueError("DAGWorkflow: cycle detected in the workflow graph.") 92 93 return levels 94 95 # ------------------------------------------------------------------ 96 # Input resolution 97 # ------------------------------------------------------------------ 98 99 def _resolve_inputs( 100 self, 101 node: WorkflowNode, 102 initial_input: Dict[str, Any], 103 all_outputs: Dict[str, Any], 104 ) -> Dict[str, Any]: 105 """Build the input dict for a node from initial_input and prior outputs.""" 106 if not node.inputs_map: 107 return {**initial_input, **all_outputs} 108 109 resolved: Dict[str, Any] = {} 110 merged = {**initial_input, **all_outputs} 111 for local_key, source_key in node.inputs_map.items(): 112 resolved[local_key] = merged.get(source_key) 113 return resolved 114 115 # ------------------------------------------------------------------ 116 # BaseWorkflow implementation 117 # ------------------------------------------------------------------ 118 119 async def run(self, initial_input: Dict[str, Any]) -> WorkflowResult: 120 self._logger.info("DAGWorkflow started", node_count=len(self._nodes)) 121 122 if self._metrics: 123 self._metrics.increment("workflow.runs", workflow="dag") 124 125 with self._tracer.trace( 126 "dag_workflow.run", 127 input=str(initial_input), 128 metadata={"node_count": len(self._nodes)}, 129 ) as trace: 130 try: 131 levels = self._topological_levels() 132 all_outputs: Dict[str, Any] = {} 133 node_results: Dict[str, Any] = {} 134 135 for level in levels: 136 # Run all nodes in this level concurrently 137 async def _run_node(node_id: str) -> tuple[str, Any]: 138 node = self._nodes[node_id] 139 task_input = self._resolve_inputs(node, initial_input, all_outputs) 140 task_text = task_input.get("task") or task_input.get("query", str(task_input)) 141 self._log_node_start(node_id) 142 143 with trace.span(f"node.{node_id}", input=task_text) as span: 144 try: 145 result = await node.agent.execute(task_text, context=task_input) 146 self._log_node_end(node_id, success=result.success) 147 span.set_output(result.output) 148 return node_id, result 149 except Exception as exc: 150 self._log_node_end(node_id, success=False) 151 span.set_error(exc) 152 raise 153 154 results = await asyncio.gather( 155 *[_run_node(nid) for nid in level], return_exceptions=False 156 ) 157 for nid, result in results: 158 node_results[nid] = result 159 # Apply outputs_map renaming 160 node = self._nodes[nid] 161 if node.outputs_map: 162 for out_key, mapped_key in node.outputs_map.items(): 163 all_outputs[mapped_key] = getattr(result, out_key, None) 164 else: 165 all_outputs[nid] = result.output 166 167 # Final output = last level's last node output 168 last_level = levels[-1] if levels else [] 169 final_output = ( 170 node_results[last_level[-1]].output if last_level else "" 171 ) 172 wf_result = WorkflowResult( 173 outputs=node_results, 174 final_output=final_output, 175 success=True, 176 ) 177 self._logger.info("DAGWorkflow completed", nodes_run=len(node_results)) 178 trace.set_output(final_output) 179 return wf_result 180 181 except Exception as exc: 182 self._logger.error("DAGWorkflow failed", error=str(exc)) 183 trace.set_error(exc) 184 return WorkflowResult(success=False, error=str(exc))
Executes workflow nodes in topological order using networkx.
Nodes with no mutual dependency are executed concurrently with
asyncio.gather.
Args:
nodes: List of WorkflowNode objects.
edges: List of WorkflowEdge objects.
logger, metrics, tracer: Observability (optional).
Example::
wf = DAGWorkflow(
nodes=[search_node, summarise_node],
edges=[WorkflowEdge(source="search", target="summarise")],
)
result = await wf.run({"query": "AI trends 2026"})
39 def __init__( 40 self, 41 nodes: Optional[List[WorkflowNode]] = None, 42 edges: Optional[List[WorkflowEdge]] = None, 43 logger: Optional[BasicLogger] = None, 44 metrics: Optional[BasicMetricsCollector] = None, 45 tracer: Optional[TracingProvider] = None, 46 ) -> None: 47 super().__init__(logger=logger, metrics=metrics, tracer=tracer) 48 self._nodes: Dict[str, WorkflowNode] = {n.node_id: n for n in (nodes or [])} 49 self._edges: List[WorkflowEdge] = edges or []
119 async def run(self, initial_input: Dict[str, Any]) -> WorkflowResult: 120 self._logger.info("DAGWorkflow started", node_count=len(self._nodes)) 121 122 if self._metrics: 123 self._metrics.increment("workflow.runs", workflow="dag") 124 125 with self._tracer.trace( 126 "dag_workflow.run", 127 input=str(initial_input), 128 metadata={"node_count": len(self._nodes)}, 129 ) as trace: 130 try: 131 levels = self._topological_levels() 132 all_outputs: Dict[str, Any] = {} 133 node_results: Dict[str, Any] = {} 134 135 for level in levels: 136 # Run all nodes in this level concurrently 137 async def _run_node(node_id: str) -> tuple[str, Any]: 138 node = self._nodes[node_id] 139 task_input = self._resolve_inputs(node, initial_input, all_outputs) 140 task_text = task_input.get("task") or task_input.get("query", str(task_input)) 141 self._log_node_start(node_id) 142 143 with trace.span(f"node.{node_id}", input=task_text) as span: 144 try: 145 result = await node.agent.execute(task_text, context=task_input) 146 self._log_node_end(node_id, success=result.success) 147 span.set_output(result.output) 148 return node_id, result 149 except Exception as exc: 150 self._log_node_end(node_id, success=False) 151 span.set_error(exc) 152 raise 153 154 results = await asyncio.gather( 155 *[_run_node(nid) for nid in level], return_exceptions=False 156 ) 157 for nid, result in results: 158 node_results[nid] = result 159 # Apply outputs_map renaming 160 node = self._nodes[nid] 161 if node.outputs_map: 162 for out_key, mapped_key in node.outputs_map.items(): 163 all_outputs[mapped_key] = getattr(result, out_key, None) 164 else: 165 all_outputs[nid] = result.output 166 167 # Final output = last level's last node output 168 last_level = levels[-1] if levels else [] 169 final_output = ( 170 node_results[last_level[-1]].output if last_level else "" 171 ) 172 wf_result = WorkflowResult( 173 outputs=node_results, 174 final_output=final_output, 175 success=True, 176 ) 177 self._logger.info("DAGWorkflow completed", nodes_run=len(node_results)) 178 trace.set_output(final_output) 179 return wf_result 180 181 except Exception as exc: 182 self._logger.error("DAGWorkflow failed", error=str(exc)) 183 trace.set_error(exc) 184 return WorkflowResult(success=False, error=str(exc))
Execute the workflow and return the aggregated result.
18class StateMachineWorkflow(BaseWorkflow): 19 """ 20 Drives execution through an explicit state/transition table. 21 22 Each state maps to an :class:`BaseAgent`. On completion, transition 23 rules are evaluated in order; the first matching rule determines the 24 next state. ``None`` as a guard means "always transition here". 25 26 Args: 27 states: Mapping of state name → agent to run in that state. 28 transitions: Mapping of state name → list of ``(guard_fn | None, next_state)`` 29 tuples evaluated in order after the state's agent finishes. 30 initial_state: Name of the starting state. 31 terminal_states: Set of state names that end the workflow. 32 state_store: Optional store to persist current state across restarts. 33 max_transitions: Safety cap on total state transitions (default: 50). 34 logger, metrics, tracer: Observability (optional). 35 36 Example:: 37 38 wf = StateMachineWorkflow( 39 states={"gather": search_agent, "summarise": summarise_agent}, 40 transitions={ 41 "gather": [(None, "summarise")], 42 "summarise": [(None, "END")], 43 }, 44 initial_state="gather", 45 terminal_states={"END"}, 46 ) 47 result = await wf.run({"query": "AI news"}) 48 """ 49 50 _TERMINAL = "__terminal__" 51 52 def __init__( 53 self, 54 states: Optional[Dict[str, BaseAgent]] = None, 55 transitions: Optional[Dict[str, List[TransitionRule]]] = None, 56 initial_state: str = "", 57 terminal_states: Optional[set] = None, 58 state_store: Optional[BaseStateStore] = None, 59 max_transitions: int = 50, 60 logger: Optional[BasicLogger] = None, 61 metrics: Optional[BasicMetricsCollector] = None, 62 tracer: Optional[TracingProvider] = None, 63 ) -> None: 64 super().__init__(logger=logger, metrics=metrics, tracer=tracer) 65 self._states: Dict[str, BaseAgent] = states or {} 66 self._transitions: Dict[str, List[TransitionRule]] = transitions or {} 67 self.initial_state = initial_state 68 self.terminal_states: set = terminal_states or set() 69 self._state_store = state_store 70 self.max_transitions = max_transitions 71 72 def _next_state(self, current: str, result: AgentResult) -> Optional[str]: 73 rules = self._transitions.get(current, []) 74 for guard, next_state in rules: 75 if guard is None or guard(result): 76 return next_state 77 return None 78 79 async def run(self, initial_input: Dict[str, Any]) -> WorkflowResult: 80 self._logger.info( 81 "StateMachineWorkflow started", initial_state=self.initial_state 82 ) 83 84 current_state = self.initial_state 85 # Resume from persisted state if available 86 if self._state_store: 87 persisted = await self._state_store.get("__sm_current_state__") 88 if persisted: 89 current_state = persisted 90 91 node_results: Dict[str, AgentResult] = {} 92 context = dict(initial_input) 93 final_output = "" 94 95 with self._tracer.trace( 96 "state_machine.run", input=str(initial_input), metadata={"initial": current_state} 97 ) as trace: 98 for _ in range(self.max_transitions): 99 if current_state in self.terminal_states or current_state is None: 100 break 101 102 agent = self._states.get(current_state) 103 if agent is None: 104 self._logger.error( 105 "No agent for state", state=current_state 106 ) 107 break 108 109 self._log_node_start(current_state) 110 with trace.span(f"state.{current_state}", input=str(context)) as span: 111 try: 112 task = context.get("task") or context.get("query", str(context)) 113 result = await agent.execute(task, context=context) 114 self._log_node_end(current_state, success=result.success) 115 node_results[current_state] = result 116 final_output = result.output 117 context["last_output"] = result.output 118 span.set_output(result.output) 119 120 if self._state_store: 121 next_s = self._next_state(current_state, result) 122 await self._state_store.set("__sm_current_state__", next_s or "__terminal__") 123 124 if self._metrics: 125 self._metrics.increment( 126 "workflow.state_transitions", 127 from_state=current_state, 128 ) 129 130 current_state = self._next_state(current_state, result) or self._TERMINAL 131 132 except Exception as exc: 133 self._log_node_end(current_state, success=False) 134 span.set_error(exc) 135 self._logger.error( 136 "StateMachineWorkflow state error", 137 state=current_state, 138 error=str(exc), 139 ) 140 trace.set_error(exc) 141 return WorkflowResult( 142 outputs=node_results, success=False, error=str(exc) 143 ) 144 145 trace.set_output(final_output) 146 return WorkflowResult(outputs=node_results, final_output=final_output, success=True)
Drives execution through an explicit state/transition table.
Each state maps to an BaseAgent. On completion, transition
rules are evaluated in order; the first matching rule determines the
next state. None as a guard means "always transition here".
Args:
states: Mapping of state name → agent to run in that state.
transitions: Mapping of state name → list of (guard_fn | None, next_state)
tuples evaluated in order after the state's agent finishes.
initial_state: Name of the starting state.
terminal_states: Set of state names that end the workflow.
state_store: Optional store to persist current state across restarts.
max_transitions: Safety cap on total state transitions (default: 50).
logger, metrics, tracer: Observability (optional).
Example::
wf = StateMachineWorkflow(
states={"gather": search_agent, "summarise": summarise_agent},
transitions={
"gather": [(None, "summarise")],
"summarise": [(None, "END")],
},
initial_state="gather",
terminal_states={"END"},
)
result = await wf.run({"query": "AI news"})
52 def __init__( 53 self, 54 states: Optional[Dict[str, BaseAgent]] = None, 55 transitions: Optional[Dict[str, List[TransitionRule]]] = None, 56 initial_state: str = "", 57 terminal_states: Optional[set] = None, 58 state_store: Optional[BaseStateStore] = None, 59 max_transitions: int = 50, 60 logger: Optional[BasicLogger] = None, 61 metrics: Optional[BasicMetricsCollector] = None, 62 tracer: Optional[TracingProvider] = None, 63 ) -> None: 64 super().__init__(logger=logger, metrics=metrics, tracer=tracer) 65 self._states: Dict[str, BaseAgent] = states or {} 66 self._transitions: Dict[str, List[TransitionRule]] = transitions or {} 67 self.initial_state = initial_state 68 self.terminal_states: set = terminal_states or set() 69 self._state_store = state_store 70 self.max_transitions = max_transitions
79 async def run(self, initial_input: Dict[str, Any]) -> WorkflowResult: 80 self._logger.info( 81 "StateMachineWorkflow started", initial_state=self.initial_state 82 ) 83 84 current_state = self.initial_state 85 # Resume from persisted state if available 86 if self._state_store: 87 persisted = await self._state_store.get("__sm_current_state__") 88 if persisted: 89 current_state = persisted 90 91 node_results: Dict[str, AgentResult] = {} 92 context = dict(initial_input) 93 final_output = "" 94 95 with self._tracer.trace( 96 "state_machine.run", input=str(initial_input), metadata={"initial": current_state} 97 ) as trace: 98 for _ in range(self.max_transitions): 99 if current_state in self.terminal_states or current_state is None: 100 break 101 102 agent = self._states.get(current_state) 103 if agent is None: 104 self._logger.error( 105 "No agent for state", state=current_state 106 ) 107 break 108 109 self._log_node_start(current_state) 110 with trace.span(f"state.{current_state}", input=str(context)) as span: 111 try: 112 task = context.get("task") or context.get("query", str(context)) 113 result = await agent.execute(task, context=context) 114 self._log_node_end(current_state, success=result.success) 115 node_results[current_state] = result 116 final_output = result.output 117 context["last_output"] = result.output 118 span.set_output(result.output) 119 120 if self._state_store: 121 next_s = self._next_state(current_state, result) 122 await self._state_store.set("__sm_current_state__", next_s or "__terminal__") 123 124 if self._metrics: 125 self._metrics.increment( 126 "workflow.state_transitions", 127 from_state=current_state, 128 ) 129 130 current_state = self._next_state(current_state, result) or self._TERMINAL 131 132 except Exception as exc: 133 self._log_node_end(current_state, success=False) 134 span.set_error(exc) 135 self._logger.error( 136 "StateMachineWorkflow state error", 137 state=current_state, 138 error=str(exc), 139 ) 140 trace.set_error(exc) 141 return WorkflowResult( 142 outputs=node_results, success=False, error=str(exc) 143 ) 144 145 trace.set_output(final_output) 146 return WorkflowResult(outputs=node_results, final_output=final_output, success=True)
Execute the workflow and return the aggregated result.
28class EventDrivenWorkflow(BaseWorkflow): 29 """ 30 Processes events through registered async handler functions. 31 32 Handlers are registered per event name. When an event is emitted, all 33 registered handlers for that name run concurrently. If a handler returns 34 a new :class:`WorkflowEvent`, it is automatically enqueued for further 35 processing (event chaining). 36 37 The workflow terminates when: 38 - The event queue is empty and no handlers are running, OR 39 - ``max_events`` total events have been processed, OR 40 - A ``terminal_event`` is emitted. 41 42 Args: 43 max_events: Safety cap on total events processed (default: 100). 44 terminal_event: Event name that signals completion (optional). 45 logger, metrics, tracer: Observability (optional). 46 47 Example:: 48 49 wf = EventDrivenWorkflow(terminal_event="done") 50 wf.on("query_received", handle_query) 51 wf.on("results_ready", handle_results) 52 await wf.emit(WorkflowEvent(name="query_received", payload={"q": "AI"})) 53 result = await wf.run({}) 54 """ 55 56 def __init__( 57 self, 58 max_events: int = 100, 59 terminal_event: Optional[str] = None, 60 logger: Optional[BasicLogger] = None, 61 metrics: Optional[BasicMetricsCollector] = None, 62 tracer: Optional[TracingProvider] = None, 63 ) -> None: 64 super().__init__(logger=logger, metrics=metrics, tracer=tracer) 65 self.max_events = max_events 66 self.terminal_event = terminal_event 67 self._handlers: Dict[str, List[EventHandler]] = {} 68 self._queue: asyncio.Queue = asyncio.Queue() 69 self._processed: List[WorkflowEvent] = [] 70 71 def on(self, event_name: str, handler: EventHandler) -> None: 72 """Register an async handler for a given event name.""" 73 self._handlers.setdefault(event_name, []).append(handler) 74 75 async def emit(self, event: WorkflowEvent) -> None: 76 """Enqueue an event for processing.""" 77 await self._queue.put(event) 78 79 async def run(self, initial_input: Dict[str, Any]) -> WorkflowResult: 80 self._logger.info("EventDrivenWorkflow started") 81 82 # Optionally seed with an initial event from input 83 if "event" in initial_input: 84 event_data = initial_input["event"] 85 if isinstance(event_data, dict): 86 await self.emit(WorkflowEvent(**event_data)) 87 elif isinstance(event_data, WorkflowEvent): 88 await self.emit(event_data) 89 90 total_processed = 0 91 final_output = "" 92 93 with self._tracer.trace("event_workflow.run", input=str(initial_input)) as trace: 94 while total_processed < self.max_events: 95 try: 96 event: WorkflowEvent = self._queue.get_nowait() 97 except asyncio.QueueEmpty: 98 break 99 100 self._logger.info( 101 "Processing event", 102 event_name=event.name, 103 source=event.source, 104 ) 105 self._processed.append(event) 106 total_processed += 1 107 108 if self._metrics: 109 self._metrics.increment("workflow.events_processed", event=event.name) 110 111 with trace.span(f"event.{event.name}", input=str(event.payload)) as span: 112 handlers = self._handlers.get(event.name, []) 113 if not handlers: 114 self._logger.warning("No handler for event", event_name=event.name) 115 span.set_output("no_handler") 116 self._queue.task_done() 117 continue 118 119 async def _run_handler(h: EventHandler, ev: WorkflowEvent) -> Optional[WorkflowEvent]: 120 return await h(ev) 121 122 follow_ups = await asyncio.gather( 123 *[_run_handler(h, event) for h in handlers], 124 return_exceptions=False, 125 ) 126 127 for follow_up in follow_ups: 128 if isinstance(follow_up, WorkflowEvent): 129 follow_up.source = event.name 130 await self.emit(follow_up) 131 if self.terminal_event and follow_up.name == self.terminal_event: 132 final_output = str(follow_up.payload.get("output", "")) 133 span.set_output(final_output) 134 self._queue.task_done() 135 trace.set_output(final_output) 136 return WorkflowResult( 137 final_output=final_output, 138 success=True, 139 metadata={"events_processed": total_processed}, 140 ) 141 142 span.set_output(f"processed {len(handlers)} handler(s)") 143 self._queue.task_done() 144 145 self._logger.info( 146 "EventDrivenWorkflow completed", events_processed=total_processed 147 ) 148 trace.set_output(final_output) 149 return WorkflowResult( 150 final_output=final_output, 151 success=True, 152 metadata={"events_processed": total_processed}, 153 )
Processes events through registered async handler functions.
Handlers are registered per event name. When an event is emitted, all
registered handlers for that name run concurrently. If a handler returns
a new WorkflowEvent, it is automatically enqueued for further
processing (event chaining).
The workflow terminates when:
- The event queue is empty and no handlers are running, OR
max_eventstotal events have been processed, OR- A
terminal_eventis emitted.
Args: max_events: Safety cap on total events processed (default: 100). terminal_event: Event name that signals completion (optional). logger, metrics, tracer: Observability (optional).
Example::
wf = EventDrivenWorkflow(terminal_event="done")
wf.on("query_received", handle_query)
wf.on("results_ready", handle_results)
await wf.emit(WorkflowEvent(name="query_received", payload={"q": "AI"}))
result = await wf.run({})
56 def __init__( 57 self, 58 max_events: int = 100, 59 terminal_event: Optional[str] = None, 60 logger: Optional[BasicLogger] = None, 61 metrics: Optional[BasicMetricsCollector] = None, 62 tracer: Optional[TracingProvider] = None, 63 ) -> None: 64 super().__init__(logger=logger, metrics=metrics, tracer=tracer) 65 self.max_events = max_events 66 self.terminal_event = terminal_event 67 self._handlers: Dict[str, List[EventHandler]] = {} 68 self._queue: asyncio.Queue = asyncio.Queue() 69 self._processed: List[WorkflowEvent] = []
71 def on(self, event_name: str, handler: EventHandler) -> None: 72 """Register an async handler for a given event name.""" 73 self._handlers.setdefault(event_name, []).append(handler)
Register an async handler for a given event name.
75 async def emit(self, event: WorkflowEvent) -> None: 76 """Enqueue an event for processing.""" 77 await self._queue.put(event)
Enqueue an event for processing.
79 async def run(self, initial_input: Dict[str, Any]) -> WorkflowResult: 80 self._logger.info("EventDrivenWorkflow started") 81 82 # Optionally seed with an initial event from input 83 if "event" in initial_input: 84 event_data = initial_input["event"] 85 if isinstance(event_data, dict): 86 await self.emit(WorkflowEvent(**event_data)) 87 elif isinstance(event_data, WorkflowEvent): 88 await self.emit(event_data) 89 90 total_processed = 0 91 final_output = "" 92 93 with self._tracer.trace("event_workflow.run", input=str(initial_input)) as trace: 94 while total_processed < self.max_events: 95 try: 96 event: WorkflowEvent = self._queue.get_nowait() 97 except asyncio.QueueEmpty: 98 break 99 100 self._logger.info( 101 "Processing event", 102 event_name=event.name, 103 source=event.source, 104 ) 105 self._processed.append(event) 106 total_processed += 1 107 108 if self._metrics: 109 self._metrics.increment("workflow.events_processed", event=event.name) 110 111 with trace.span(f"event.{event.name}", input=str(event.payload)) as span: 112 handlers = self._handlers.get(event.name, []) 113 if not handlers: 114 self._logger.warning("No handler for event", event_name=event.name) 115 span.set_output("no_handler") 116 self._queue.task_done() 117 continue 118 119 async def _run_handler(h: EventHandler, ev: WorkflowEvent) -> Optional[WorkflowEvent]: 120 return await h(ev) 121 122 follow_ups = await asyncio.gather( 123 *[_run_handler(h, event) for h in handlers], 124 return_exceptions=False, 125 ) 126 127 for follow_up in follow_ups: 128 if isinstance(follow_up, WorkflowEvent): 129 follow_up.source = event.name 130 await self.emit(follow_up) 131 if self.terminal_event and follow_up.name == self.terminal_event: 132 final_output = str(follow_up.payload.get("output", "")) 133 span.set_output(final_output) 134 self._queue.task_done() 135 trace.set_output(final_output) 136 return WorkflowResult( 137 final_output=final_output, 138 success=True, 139 metadata={"events_processed": total_processed}, 140 ) 141 142 span.set_output(f"processed {len(handlers)} handler(s)") 143 self._queue.task_done() 144 145 self._logger.info( 146 "EventDrivenWorkflow completed", events_processed=total_processed 147 ) 148 trace.set_output(final_output) 149 return WorkflowResult( 150 final_output=final_output, 151 success=True, 152 metadata={"events_processed": total_processed}, 153 )
Execute the workflow and return the aggregated result.
14@dataclass 15class WorkflowEvent: 16 """An event emitted into the workflow event bus.""" 17 18 name: str 19 payload: Dict[str, Any] = field(default_factory=dict) 20 source: str = "external"
An event emitted into the workflow event bus.