gmf_forge_ai_orchestration.workflows

Workflow engines — DAG, state machine, and event-driven.

 1"""Workflow engines — DAG, state machine, and event-driven."""
 2
 3from gmf_forge_ai_orchestration.workflows.base import (
 4    BaseWorkflow,
 5    WorkflowEdge,
 6    WorkflowNode,
 7    WorkflowResult,
 8)
 9from gmf_forge_ai_orchestration.workflows.dag_workflow import DAGWorkflow
10from gmf_forge_ai_orchestration.workflows.state_machine_workflow import StateMachineWorkflow
11from gmf_forge_ai_orchestration.workflows.event_driven_workflow import (
12    EventDrivenWorkflow,
13    WorkflowEvent,
14)
15
16__all__ = [
17    "BaseWorkflow",
18    "WorkflowEdge",
19    "WorkflowNode",
20    "WorkflowResult",
21    "DAGWorkflow",
22    "StateMachineWorkflow",
23    "EventDrivenWorkflow",
24    "WorkflowEvent",
25]
class BaseWorkflow(abc.ABC):
52class BaseWorkflow(ABC):
53    """
54    Abstract base class for all workflow engines.
55
56    Args:
57        logger: Optional :class:`BasicLogger`.
58        metrics: Optional :class:`BasicMetricsCollector`.
59        tracer: Optional :class:`TracingProvider`. Falls back to ``get_tracer()``.
60    """
61
62    def __init__(
63        self,
64        logger: Optional[BasicLogger] = None,
65        metrics: Optional[BasicMetricsCollector] = None,
66        tracer: Optional[TracingProvider] = None,
67    ) -> None:
68        self._logger = logger or BasicLogger(f"gmf_forge_ai.workflow.{self.__class__.__name__}")
69        self._metrics = metrics
70        self._tracer = tracer or get_tracer()
71
72    @abstractmethod
73    async def run(self, initial_input: Dict[str, Any]) -> WorkflowResult:
74        """Execute the workflow and return the aggregated result."""
75
76    def _log_node_start(self, node_id: str) -> None:
77        self._logger.info("Workflow node started", node_id=node_id)
78
79    def _log_node_end(self, node_id: str, success: bool) -> None:
80        self._logger.info("Workflow node finished", node_id=node_id, success=success)
81        if self._metrics:
82            self._metrics.increment("workflow.nodes_executed", node_id=node_id, success=str(success))

Abstract base class for all workflow engines.

Args: logger: Optional BasicLogger. metrics: Optional BasicMetricsCollector. tracer: Optional TracingProvider. Falls back to get_tracer().

@abstractmethod
async def run( self, initial_input: Dict[str, Any]) -> WorkflowResult:
72    @abstractmethod
73    async def run(self, initial_input: Dict[str, Any]) -> WorkflowResult:
74        """Execute the workflow and return the aggregated result."""

Execute the workflow and return the aggregated result.

@dataclass
class WorkflowEdge:
28@dataclass
29class WorkflowEdge:
30    """A directed edge between two workflow nodes."""
31
32    source: str
33    """node_id of the source node."""
34    target: str
35    """node_id of the target node."""
36    condition: Optional[Callable[["AgentResult"], bool]] = None
37    """Optional guard — edge is only traversed if this returns True."""

A directed edge between two workflow nodes.

WorkflowEdge( source: str, target: str, condition: Optional[Callable[[gmf_forge_ai_orchestration.AgentResult], bool]] = None)
source: str

node_id of the source node.

target: str

node_id of the target node.

condition: Optional[Callable[[gmf_forge_ai_orchestration.AgentResult], bool]] = None

Optional guard — edge is only traversed if this returns True.

@dataclass
class WorkflowNode:
15@dataclass
16class WorkflowNode:
17    """A single node in a workflow graph."""
18
19    node_id: str
20    agent: "BaseAgent"
21    inputs_map: Dict[str, str] = field(default_factory=dict)
22    """Maps node input keys to keys from the initial_input dict or prior node outputs."""
23    outputs_map: Dict[str, str] = field(default_factory=dict)
24    """Renames node output keys before storing in the accumulated output dict."""
25    metadata: Dict[str, Any] = field(default_factory=dict)

A single node in a workflow graph.

WorkflowNode( node_id: str, agent: gmf_forge_ai_orchestration.BaseAgent, inputs_map: Dict[str, str] = <factory>, outputs_map: Dict[str, str] = <factory>, metadata: Dict[str, Any] = <factory>)
node_id: str
inputs_map: Dict[str, str]

Maps node input keys to keys from the initial_input dict or prior node outputs.

outputs_map: Dict[str, str]

Renames node output keys before storing in the accumulated output dict.

metadata: Dict[str, Any]
@dataclass
class WorkflowResult:
40@dataclass
41class WorkflowResult:
42    """The aggregated result of a workflow run."""
43
44    outputs: Dict[str, "AgentResult"] = field(default_factory=dict)
45    """Keyed by node_id."""
46    final_output: str = ""
47    success: bool = True
48    error: Optional[str] = None
49    metadata: Dict[str, Any] = field(default_factory=dict)

The aggregated result of a workflow run.

WorkflowResult( outputs: Dict[str, gmf_forge_ai_orchestration.AgentResult] = <factory>, final_output: str = '', success: bool = True, error: Optional[str] = None, metadata: Dict[str, Any] = <factory>)

Keyed by node_id.

final_output: str = ''
success: bool = True
error: Optional[str] = None
metadata: Dict[str, Any]
class DAGWorkflow(gmf_forge_ai_orchestration.workflows.BaseWorkflow):
 18class DAGWorkflow(BaseWorkflow):
 19    """
 20    Executes workflow nodes in topological order using ``networkx``.
 21
 22    Nodes with no mutual dependency are executed concurrently with
 23    ``asyncio.gather``.
 24
 25    Args:
 26        nodes: List of :class:`WorkflowNode` objects.
 27        edges: List of :class:`WorkflowEdge` objects.
 28        logger, metrics, tracer: Observability (optional).
 29
 30    Example::
 31
 32        wf = DAGWorkflow(
 33            nodes=[search_node, summarise_node],
 34            edges=[WorkflowEdge(source="search", target="summarise")],
 35        )
 36        result = await wf.run({"query": "AI trends 2026"})
 37    """
 38
 39    def __init__(
 40        self,
 41        nodes: Optional[List[WorkflowNode]] = None,
 42        edges: Optional[List[WorkflowEdge]] = None,
 43        logger: Optional[BasicLogger] = None,
 44        metrics: Optional[BasicMetricsCollector] = None,
 45        tracer: Optional[TracingProvider] = None,
 46    ) -> None:
 47        super().__init__(logger=logger, metrics=metrics, tracer=tracer)
 48        self._nodes: Dict[str, WorkflowNode] = {n.node_id: n for n in (nodes or [])}
 49        self._edges: List[WorkflowEdge] = edges or []
 50
 51    def add_node(self, node: WorkflowNode) -> None:
 52        self._nodes[node.node_id] = node
 53
 54    def add_edge(self, edge: WorkflowEdge) -> None:
 55        self._edges.append(edge)
 56
 57    # ------------------------------------------------------------------
 58    # Topological sort (no networkx dependency — implemented manually)
 59    # ------------------------------------------------------------------
 60
 61    def _build_adjacency(self) -> tuple[Dict[str, List[str]], Dict[str, int]]:
 62        """Returns (successors dict, in-degree dict)."""
 63        successors: Dict[str, List[str]] = {nid: [] for nid in self._nodes}
 64        in_degree: Dict[str, int] = {nid: 0 for nid in self._nodes}
 65        for edge in self._edges:
 66            if edge.source in successors and edge.target in self._nodes:
 67                successors[edge.source].append(edge.target)
 68                in_degree[edge.target] += 1
 69        return successors, in_degree
 70
 71    def _topological_levels(self) -> List[List[str]]:
 72        """
 73        Returns node IDs grouped into levels. All nodes in a level can run in
 74        parallel because none depend on each other.
 75        """
 76        successors, in_degree = self._build_adjacency()
 77        levels: List[List[str]] = []
 78        ready: List[str] = [nid for nid, deg in in_degree.items() if deg == 0]
 79
 80        while ready:
 81            levels.append(list(ready))
 82            next_ready: List[str] = []
 83            for nid in ready:
 84                for successor in successors[nid]:
 85                    in_degree[successor] -= 1
 86                    if in_degree[successor] == 0:
 87                        next_ready.append(successor)
 88            ready = next_ready
 89
 90        if sum(len(lvl) for lvl in levels) < len(self._nodes):
 91            raise ValueError("DAGWorkflow: cycle detected in the workflow graph.")
 92
 93        return levels
 94
 95    # ------------------------------------------------------------------
 96    # Input resolution
 97    # ------------------------------------------------------------------
 98
 99    def _resolve_inputs(
100        self,
101        node: WorkflowNode,
102        initial_input: Dict[str, Any],
103        all_outputs: Dict[str, Any],
104    ) -> Dict[str, Any]:
105        """Build the input dict for a node from initial_input and prior outputs."""
106        if not node.inputs_map:
107            return {**initial_input, **all_outputs}
108
109        resolved: Dict[str, Any] = {}
110        merged = {**initial_input, **all_outputs}
111        for local_key, source_key in node.inputs_map.items():
112            resolved[local_key] = merged.get(source_key)
113        return resolved
114
115    # ------------------------------------------------------------------
116    # BaseWorkflow implementation
117    # ------------------------------------------------------------------
118
119    async def run(self, initial_input: Dict[str, Any]) -> WorkflowResult:
120        self._logger.info("DAGWorkflow started", node_count=len(self._nodes))
121
122        if self._metrics:
123            self._metrics.increment("workflow.runs", workflow="dag")
124
125        with self._tracer.trace(
126            "dag_workflow.run",
127            input=str(initial_input),
128            metadata={"node_count": len(self._nodes)},
129        ) as trace:
130            try:
131                levels = self._topological_levels()
132                all_outputs: Dict[str, Any] = {}
133                node_results: Dict[str, Any] = {}
134
135                for level in levels:
136                    # Run all nodes in this level concurrently
137                    async def _run_node(node_id: str) -> tuple[str, Any]:
138                        node = self._nodes[node_id]
139                        task_input = self._resolve_inputs(node, initial_input, all_outputs)
140                        task_text = task_input.get("task") or task_input.get("query", str(task_input))
141                        self._log_node_start(node_id)
142
143                        with trace.span(f"node.{node_id}", input=task_text) as span:
144                            try:
145                                result = await node.agent.execute(task_text, context=task_input)
146                                self._log_node_end(node_id, success=result.success)
147                                span.set_output(result.output)
148                                return node_id, result
149                            except Exception as exc:
150                                self._log_node_end(node_id, success=False)
151                                span.set_error(exc)
152                                raise
153
154                    results = await asyncio.gather(
155                        *[_run_node(nid) for nid in level], return_exceptions=False
156                    )
157                    for nid, result in results:
158                        node_results[nid] = result
159                        # Apply outputs_map renaming
160                        node = self._nodes[nid]
161                        if node.outputs_map:
162                            for out_key, mapped_key in node.outputs_map.items():
163                                all_outputs[mapped_key] = getattr(result, out_key, None)
164                        else:
165                            all_outputs[nid] = result.output
166
167                # Final output = last level's last node output
168                last_level = levels[-1] if levels else []
169                final_output = (
170                    node_results[last_level[-1]].output if last_level else ""
171                )
172                wf_result = WorkflowResult(
173                    outputs=node_results,
174                    final_output=final_output,
175                    success=True,
176                )
177                self._logger.info("DAGWorkflow completed", nodes_run=len(node_results))
178                trace.set_output(final_output)
179                return wf_result
180
181            except Exception as exc:
182                self._logger.error("DAGWorkflow failed", error=str(exc))
183                trace.set_error(exc)
184                return WorkflowResult(success=False, error=str(exc))

Executes workflow nodes in topological order using networkx.

Nodes with no mutual dependency are executed concurrently with asyncio.gather.

Args: nodes: List of WorkflowNode objects. edges: List of WorkflowEdge objects. logger, metrics, tracer: Observability (optional).

Example::

wf = DAGWorkflow(
    nodes=[search_node, summarise_node],
    edges=[WorkflowEdge(source="search", target="summarise")],
)
result = await wf.run({"query": "AI trends 2026"})
DAGWorkflow( nodes: Optional[List[WorkflowNode]] = None, edges: Optional[List[WorkflowEdge]] = None, logger: Optional[gmf_forge_ai_shared_core.observability.BasicLogger] = None, metrics: Optional[gmf_forge_ai_shared_core.observability.BasicMetricsCollector] = None, tracer: Optional[gmf_forge_ai_shared_core.observability.TracingProvider] = None)
39    def __init__(
40        self,
41        nodes: Optional[List[WorkflowNode]] = None,
42        edges: Optional[List[WorkflowEdge]] = None,
43        logger: Optional[BasicLogger] = None,
44        metrics: Optional[BasicMetricsCollector] = None,
45        tracer: Optional[TracingProvider] = None,
46    ) -> None:
47        super().__init__(logger=logger, metrics=metrics, tracer=tracer)
48        self._nodes: Dict[str, WorkflowNode] = {n.node_id: n for n in (nodes or [])}
49        self._edges: List[WorkflowEdge] = edges or []
def add_node( self, node: WorkflowNode) -> None:
51    def add_node(self, node: WorkflowNode) -> None:
52        self._nodes[node.node_id] = node
def add_edge( self, edge: WorkflowEdge) -> None:
54    def add_edge(self, edge: WorkflowEdge) -> None:
55        self._edges.append(edge)
async def run( self, initial_input: Dict[str, Any]) -> WorkflowResult:
119    async def run(self, initial_input: Dict[str, Any]) -> WorkflowResult:
120        self._logger.info("DAGWorkflow started", node_count=len(self._nodes))
121
122        if self._metrics:
123            self._metrics.increment("workflow.runs", workflow="dag")
124
125        with self._tracer.trace(
126            "dag_workflow.run",
127            input=str(initial_input),
128            metadata={"node_count": len(self._nodes)},
129        ) as trace:
130            try:
131                levels = self._topological_levels()
132                all_outputs: Dict[str, Any] = {}
133                node_results: Dict[str, Any] = {}
134
135                for level in levels:
136                    # Run all nodes in this level concurrently
137                    async def _run_node(node_id: str) -> tuple[str, Any]:
138                        node = self._nodes[node_id]
139                        task_input = self._resolve_inputs(node, initial_input, all_outputs)
140                        task_text = task_input.get("task") or task_input.get("query", str(task_input))
141                        self._log_node_start(node_id)
142
143                        with trace.span(f"node.{node_id}", input=task_text) as span:
144                            try:
145                                result = await node.agent.execute(task_text, context=task_input)
146                                self._log_node_end(node_id, success=result.success)
147                                span.set_output(result.output)
148                                return node_id, result
149                            except Exception as exc:
150                                self._log_node_end(node_id, success=False)
151                                span.set_error(exc)
152                                raise
153
154                    results = await asyncio.gather(
155                        *[_run_node(nid) for nid in level], return_exceptions=False
156                    )
157                    for nid, result in results:
158                        node_results[nid] = result
159                        # Apply outputs_map renaming
160                        node = self._nodes[nid]
161                        if node.outputs_map:
162                            for out_key, mapped_key in node.outputs_map.items():
163                                all_outputs[mapped_key] = getattr(result, out_key, None)
164                        else:
165                            all_outputs[nid] = result.output
166
167                # Final output = last level's last node output
168                last_level = levels[-1] if levels else []
169                final_output = (
170                    node_results[last_level[-1]].output if last_level else ""
171                )
172                wf_result = WorkflowResult(
173                    outputs=node_results,
174                    final_output=final_output,
175                    success=True,
176                )
177                self._logger.info("DAGWorkflow completed", nodes_run=len(node_results))
178                trace.set_output(final_output)
179                return wf_result
180
181            except Exception as exc:
182                self._logger.error("DAGWorkflow failed", error=str(exc))
183                trace.set_error(exc)
184                return WorkflowResult(success=False, error=str(exc))

Execute the workflow and return the aggregated result.

class StateMachineWorkflow(gmf_forge_ai_orchestration.workflows.BaseWorkflow):
 18class StateMachineWorkflow(BaseWorkflow):
 19    """
 20    Drives execution through an explicit state/transition table.
 21
 22    Each state maps to an :class:`BaseAgent`. On completion, transition
 23    rules are evaluated in order; the first matching rule determines the
 24    next state. ``None`` as a guard means "always transition here".
 25
 26    Args:
 27        states: Mapping of state name → agent to run in that state.
 28        transitions: Mapping of state name → list of ``(guard_fn | None, next_state)``
 29            tuples evaluated in order after the state's agent finishes.
 30        initial_state: Name of the starting state.
 31        terminal_states: Set of state names that end the workflow.
 32        state_store: Optional store to persist current state across restarts.
 33        max_transitions: Safety cap on total state transitions (default: 50).
 34        logger, metrics, tracer: Observability (optional).
 35
 36    Example::
 37
 38        wf = StateMachineWorkflow(
 39            states={"gather": search_agent, "summarise": summarise_agent},
 40            transitions={
 41                "gather": [(None, "summarise")],
 42                "summarise": [(None, "END")],
 43            },
 44            initial_state="gather",
 45            terminal_states={"END"},
 46        )
 47        result = await wf.run({"query": "AI news"})
 48    """
 49
 50    _TERMINAL = "__terminal__"
 51
 52    def __init__(
 53        self,
 54        states: Optional[Dict[str, BaseAgent]] = None,
 55        transitions: Optional[Dict[str, List[TransitionRule]]] = None,
 56        initial_state: str = "",
 57        terminal_states: Optional[set] = None,
 58        state_store: Optional[BaseStateStore] = None,
 59        max_transitions: int = 50,
 60        logger: Optional[BasicLogger] = None,
 61        metrics: Optional[BasicMetricsCollector] = None,
 62        tracer: Optional[TracingProvider] = None,
 63    ) -> None:
 64        super().__init__(logger=logger, metrics=metrics, tracer=tracer)
 65        self._states: Dict[str, BaseAgent] = states or {}
 66        self._transitions: Dict[str, List[TransitionRule]] = transitions or {}
 67        self.initial_state = initial_state
 68        self.terminal_states: set = terminal_states or set()
 69        self._state_store = state_store
 70        self.max_transitions = max_transitions
 71
 72    def _next_state(self, current: str, result: AgentResult) -> Optional[str]:
 73        rules = self._transitions.get(current, [])
 74        for guard, next_state in rules:
 75            if guard is None or guard(result):
 76                return next_state
 77        return None
 78
 79    async def run(self, initial_input: Dict[str, Any]) -> WorkflowResult:
 80        self._logger.info(
 81            "StateMachineWorkflow started", initial_state=self.initial_state
 82        )
 83
 84        current_state = self.initial_state
 85        # Resume from persisted state if available
 86        if self._state_store:
 87            persisted = await self._state_store.get("__sm_current_state__")
 88            if persisted:
 89                current_state = persisted
 90
 91        node_results: Dict[str, AgentResult] = {}
 92        context = dict(initial_input)
 93        final_output = ""
 94
 95        with self._tracer.trace(
 96            "state_machine.run", input=str(initial_input), metadata={"initial": current_state}
 97        ) as trace:
 98            for _ in range(self.max_transitions):
 99                if current_state in self.terminal_states or current_state is None:
100                    break
101
102                agent = self._states.get(current_state)
103                if agent is None:
104                    self._logger.error(
105                        "No agent for state", state=current_state
106                    )
107                    break
108
109                self._log_node_start(current_state)
110                with trace.span(f"state.{current_state}", input=str(context)) as span:
111                    try:
112                        task = context.get("task") or context.get("query", str(context))
113                        result = await agent.execute(task, context=context)
114                        self._log_node_end(current_state, success=result.success)
115                        node_results[current_state] = result
116                        final_output = result.output
117                        context["last_output"] = result.output
118                        span.set_output(result.output)
119
120                        if self._state_store:
121                            next_s = self._next_state(current_state, result)
122                            await self._state_store.set("__sm_current_state__", next_s or "__terminal__")
123
124                        if self._metrics:
125                            self._metrics.increment(
126                                "workflow.state_transitions",
127                                from_state=current_state,
128                            )
129
130                        current_state = self._next_state(current_state, result) or self._TERMINAL
131
132                    except Exception as exc:
133                        self._log_node_end(current_state, success=False)
134                        span.set_error(exc)
135                        self._logger.error(
136                            "StateMachineWorkflow state error",
137                            state=current_state,
138                            error=str(exc),
139                        )
140                        trace.set_error(exc)
141                        return WorkflowResult(
142                            outputs=node_results, success=False, error=str(exc)
143                        )
144
145            trace.set_output(final_output)
146        return WorkflowResult(outputs=node_results, final_output=final_output, success=True)

Drives execution through an explicit state/transition table.

Each state maps to an BaseAgent. On completion, transition rules are evaluated in order; the first matching rule determines the next state. None as a guard means "always transition here".

Args: states: Mapping of state name → agent to run in that state. transitions: Mapping of state name → list of (guard_fn | None, next_state) tuples evaluated in order after the state's agent finishes. initial_state: Name of the starting state. terminal_states: Set of state names that end the workflow. state_store: Optional store to persist current state across restarts. max_transitions: Safety cap on total state transitions (default: 50). logger, metrics, tracer: Observability (optional).

Example::

wf = StateMachineWorkflow(
    states={"gather": search_agent, "summarise": summarise_agent},
    transitions={
        "gather": [(None, "summarise")],
        "summarise": [(None, "END")],
    },
    initial_state="gather",
    terminal_states={"END"},
)
result = await wf.run({"query": "AI news"})
StateMachineWorkflow( states: Optional[Dict[str, gmf_forge_ai_orchestration.BaseAgent]] = None, transitions: Optional[Dict[str, List[Tuple[Optional[Callable[[gmf_forge_ai_orchestration.AgentResult], bool]], str]]]] = None, initial_state: str = '', terminal_states: Optional[set] = None, state_store: Optional[gmf_forge_ai_orchestration.BaseStateStore] = None, max_transitions: int = 50, logger: Optional[gmf_forge_ai_shared_core.observability.BasicLogger] = None, metrics: Optional[gmf_forge_ai_shared_core.observability.BasicMetricsCollector] = None, tracer: Optional[gmf_forge_ai_shared_core.observability.TracingProvider] = None)
52    def __init__(
53        self,
54        states: Optional[Dict[str, BaseAgent]] = None,
55        transitions: Optional[Dict[str, List[TransitionRule]]] = None,
56        initial_state: str = "",
57        terminal_states: Optional[set] = None,
58        state_store: Optional[BaseStateStore] = None,
59        max_transitions: int = 50,
60        logger: Optional[BasicLogger] = None,
61        metrics: Optional[BasicMetricsCollector] = None,
62        tracer: Optional[TracingProvider] = None,
63    ) -> None:
64        super().__init__(logger=logger, metrics=metrics, tracer=tracer)
65        self._states: Dict[str, BaseAgent] = states or {}
66        self._transitions: Dict[str, List[TransitionRule]] = transitions or {}
67        self.initial_state = initial_state
68        self.terminal_states: set = terminal_states or set()
69        self._state_store = state_store
70        self.max_transitions = max_transitions
initial_state
terminal_states: set
max_transitions
async def run( self, initial_input: Dict[str, Any]) -> WorkflowResult:
 79    async def run(self, initial_input: Dict[str, Any]) -> WorkflowResult:
 80        self._logger.info(
 81            "StateMachineWorkflow started", initial_state=self.initial_state
 82        )
 83
 84        current_state = self.initial_state
 85        # Resume from persisted state if available
 86        if self._state_store:
 87            persisted = await self._state_store.get("__sm_current_state__")
 88            if persisted:
 89                current_state = persisted
 90
 91        node_results: Dict[str, AgentResult] = {}
 92        context = dict(initial_input)
 93        final_output = ""
 94
 95        with self._tracer.trace(
 96            "state_machine.run", input=str(initial_input), metadata={"initial": current_state}
 97        ) as trace:
 98            for _ in range(self.max_transitions):
 99                if current_state in self.terminal_states or current_state is None:
100                    break
101
102                agent = self._states.get(current_state)
103                if agent is None:
104                    self._logger.error(
105                        "No agent for state", state=current_state
106                    )
107                    break
108
109                self._log_node_start(current_state)
110                with trace.span(f"state.{current_state}", input=str(context)) as span:
111                    try:
112                        task = context.get("task") or context.get("query", str(context))
113                        result = await agent.execute(task, context=context)
114                        self._log_node_end(current_state, success=result.success)
115                        node_results[current_state] = result
116                        final_output = result.output
117                        context["last_output"] = result.output
118                        span.set_output(result.output)
119
120                        if self._state_store:
121                            next_s = self._next_state(current_state, result)
122                            await self._state_store.set("__sm_current_state__", next_s or "__terminal__")
123
124                        if self._metrics:
125                            self._metrics.increment(
126                                "workflow.state_transitions",
127                                from_state=current_state,
128                            )
129
130                        current_state = self._next_state(current_state, result) or self._TERMINAL
131
132                    except Exception as exc:
133                        self._log_node_end(current_state, success=False)
134                        span.set_error(exc)
135                        self._logger.error(
136                            "StateMachineWorkflow state error",
137                            state=current_state,
138                            error=str(exc),
139                        )
140                        trace.set_error(exc)
141                        return WorkflowResult(
142                            outputs=node_results, success=False, error=str(exc)
143                        )
144
145            trace.set_output(final_output)
146        return WorkflowResult(outputs=node_results, final_output=final_output, success=True)

Execute the workflow and return the aggregated result.

class EventDrivenWorkflow(gmf_forge_ai_orchestration.workflows.BaseWorkflow):
 28class EventDrivenWorkflow(BaseWorkflow):
 29    """
 30    Processes events through registered async handler functions.
 31
 32    Handlers are registered per event name. When an event is emitted, all
 33    registered handlers for that name run concurrently. If a handler returns
 34    a new :class:`WorkflowEvent`, it is automatically enqueued for further
 35    processing (event chaining).
 36
 37    The workflow terminates when:
 38    - The event queue is empty and no handlers are running, OR
 39    - ``max_events`` total events have been processed, OR
 40    - A ``terminal_event`` is emitted.
 41
 42    Args:
 43        max_events: Safety cap on total events processed (default: 100).
 44        terminal_event: Event name that signals completion (optional).
 45        logger, metrics, tracer: Observability (optional).
 46
 47    Example::
 48
 49        wf = EventDrivenWorkflow(terminal_event="done")
 50        wf.on("query_received", handle_query)
 51        wf.on("results_ready", handle_results)
 52        await wf.emit(WorkflowEvent(name="query_received", payload={"q": "AI"}))
 53        result = await wf.run({})
 54    """
 55
 56    def __init__(
 57        self,
 58        max_events: int = 100,
 59        terminal_event: Optional[str] = None,
 60        logger: Optional[BasicLogger] = None,
 61        metrics: Optional[BasicMetricsCollector] = None,
 62        tracer: Optional[TracingProvider] = None,
 63    ) -> None:
 64        super().__init__(logger=logger, metrics=metrics, tracer=tracer)
 65        self.max_events = max_events
 66        self.terminal_event = terminal_event
 67        self._handlers: Dict[str, List[EventHandler]] = {}
 68        self._queue: asyncio.Queue = asyncio.Queue()
 69        self._processed: List[WorkflowEvent] = []
 70
 71    def on(self, event_name: str, handler: EventHandler) -> None:
 72        """Register an async handler for a given event name."""
 73        self._handlers.setdefault(event_name, []).append(handler)
 74
 75    async def emit(self, event: WorkflowEvent) -> None:
 76        """Enqueue an event for processing."""
 77        await self._queue.put(event)
 78
 79    async def run(self, initial_input: Dict[str, Any]) -> WorkflowResult:
 80        self._logger.info("EventDrivenWorkflow started")
 81
 82        # Optionally seed with an initial event from input
 83        if "event" in initial_input:
 84            event_data = initial_input["event"]
 85            if isinstance(event_data, dict):
 86                await self.emit(WorkflowEvent(**event_data))
 87            elif isinstance(event_data, WorkflowEvent):
 88                await self.emit(event_data)
 89
 90        total_processed = 0
 91        final_output = ""
 92
 93        with self._tracer.trace("event_workflow.run", input=str(initial_input)) as trace:
 94            while total_processed < self.max_events:
 95                try:
 96                    event: WorkflowEvent = self._queue.get_nowait()
 97                except asyncio.QueueEmpty:
 98                    break
 99
100                self._logger.info(
101                    "Processing event",
102                    event_name=event.name,
103                    source=event.source,
104                )
105                self._processed.append(event)
106                total_processed += 1
107
108                if self._metrics:
109                    self._metrics.increment("workflow.events_processed", event=event.name)
110
111                with trace.span(f"event.{event.name}", input=str(event.payload)) as span:
112                    handlers = self._handlers.get(event.name, [])
113                    if not handlers:
114                        self._logger.warning("No handler for event", event_name=event.name)
115                        span.set_output("no_handler")
116                        self._queue.task_done()
117                        continue
118
119                    async def _run_handler(h: EventHandler, ev: WorkflowEvent) -> Optional[WorkflowEvent]:
120                        return await h(ev)
121
122                    follow_ups = await asyncio.gather(
123                        *[_run_handler(h, event) for h in handlers],
124                        return_exceptions=False,
125                    )
126
127                    for follow_up in follow_ups:
128                        if isinstance(follow_up, WorkflowEvent):
129                            follow_up.source = event.name
130                            await self.emit(follow_up)
131                            if self.terminal_event and follow_up.name == self.terminal_event:
132                                final_output = str(follow_up.payload.get("output", ""))
133                                span.set_output(final_output)
134                                self._queue.task_done()
135                                trace.set_output(final_output)
136                                return WorkflowResult(
137                                    final_output=final_output,
138                                    success=True,
139                                    metadata={"events_processed": total_processed},
140                                )
141
142                    span.set_output(f"processed {len(handlers)} handler(s)")
143                    self._queue.task_done()
144
145            self._logger.info(
146                "EventDrivenWorkflow completed", events_processed=total_processed
147            )
148            trace.set_output(final_output)
149        return WorkflowResult(
150            final_output=final_output,
151            success=True,
152            metadata={"events_processed": total_processed},
153        )

Processes events through registered async handler functions.

Handlers are registered per event name. When an event is emitted, all registered handlers for that name run concurrently. If a handler returns a new WorkflowEvent, it is automatically enqueued for further processing (event chaining).

The workflow terminates when:

  • The event queue is empty and no handlers are running, OR
  • max_events total events have been processed, OR
  • A terminal_event is emitted.

Args: max_events: Safety cap on total events processed (default: 100). terminal_event: Event name that signals completion (optional). logger, metrics, tracer: Observability (optional).

Example::

wf = EventDrivenWorkflow(terminal_event="done")
wf.on("query_received", handle_query)
wf.on("results_ready", handle_results)
await wf.emit(WorkflowEvent(name="query_received", payload={"q": "AI"}))
result = await wf.run({})
EventDrivenWorkflow( max_events: int = 100, terminal_event: Optional[str] = None, logger: Optional[gmf_forge_ai_shared_core.observability.BasicLogger] = None, metrics: Optional[gmf_forge_ai_shared_core.observability.BasicMetricsCollector] = None, tracer: Optional[gmf_forge_ai_shared_core.observability.TracingProvider] = None)
56    def __init__(
57        self,
58        max_events: int = 100,
59        terminal_event: Optional[str] = None,
60        logger: Optional[BasicLogger] = None,
61        metrics: Optional[BasicMetricsCollector] = None,
62        tracer: Optional[TracingProvider] = None,
63    ) -> None:
64        super().__init__(logger=logger, metrics=metrics, tracer=tracer)
65        self.max_events = max_events
66        self.terminal_event = terminal_event
67        self._handlers: Dict[str, List[EventHandler]] = {}
68        self._queue: asyncio.Queue = asyncio.Queue()
69        self._processed: List[WorkflowEvent] = []
max_events
terminal_event
def on( self, event_name: str, handler: Callable[[WorkflowEvent], Awaitable[Optional[WorkflowEvent]]]) -> None:
71    def on(self, event_name: str, handler: EventHandler) -> None:
72        """Register an async handler for a given event name."""
73        self._handlers.setdefault(event_name, []).append(handler)

Register an async handler for a given event name.

async def emit( self, event: WorkflowEvent) -> None:
75    async def emit(self, event: WorkflowEvent) -> None:
76        """Enqueue an event for processing."""
77        await self._queue.put(event)

Enqueue an event for processing.

async def run( self, initial_input: Dict[str, Any]) -> WorkflowResult:
 79    async def run(self, initial_input: Dict[str, Any]) -> WorkflowResult:
 80        self._logger.info("EventDrivenWorkflow started")
 81
 82        # Optionally seed with an initial event from input
 83        if "event" in initial_input:
 84            event_data = initial_input["event"]
 85            if isinstance(event_data, dict):
 86                await self.emit(WorkflowEvent(**event_data))
 87            elif isinstance(event_data, WorkflowEvent):
 88                await self.emit(event_data)
 89
 90        total_processed = 0
 91        final_output = ""
 92
 93        with self._tracer.trace("event_workflow.run", input=str(initial_input)) as trace:
 94            while total_processed < self.max_events:
 95                try:
 96                    event: WorkflowEvent = self._queue.get_nowait()
 97                except asyncio.QueueEmpty:
 98                    break
 99
100                self._logger.info(
101                    "Processing event",
102                    event_name=event.name,
103                    source=event.source,
104                )
105                self._processed.append(event)
106                total_processed += 1
107
108                if self._metrics:
109                    self._metrics.increment("workflow.events_processed", event=event.name)
110
111                with trace.span(f"event.{event.name}", input=str(event.payload)) as span:
112                    handlers = self._handlers.get(event.name, [])
113                    if not handlers:
114                        self._logger.warning("No handler for event", event_name=event.name)
115                        span.set_output("no_handler")
116                        self._queue.task_done()
117                        continue
118
119                    async def _run_handler(h: EventHandler, ev: WorkflowEvent) -> Optional[WorkflowEvent]:
120                        return await h(ev)
121
122                    follow_ups = await asyncio.gather(
123                        *[_run_handler(h, event) for h in handlers],
124                        return_exceptions=False,
125                    )
126
127                    for follow_up in follow_ups:
128                        if isinstance(follow_up, WorkflowEvent):
129                            follow_up.source = event.name
130                            await self.emit(follow_up)
131                            if self.terminal_event and follow_up.name == self.terminal_event:
132                                final_output = str(follow_up.payload.get("output", ""))
133                                span.set_output(final_output)
134                                self._queue.task_done()
135                                trace.set_output(final_output)
136                                return WorkflowResult(
137                                    final_output=final_output,
138                                    success=True,
139                                    metadata={"events_processed": total_processed},
140                                )
141
142                    span.set_output(f"processed {len(handlers)} handler(s)")
143                    self._queue.task_done()
144
145            self._logger.info(
146                "EventDrivenWorkflow completed", events_processed=total_processed
147            )
148            trace.set_output(final_output)
149        return WorkflowResult(
150            final_output=final_output,
151            success=True,
152            metadata={"events_processed": total_processed},
153        )

Execute the workflow and return the aggregated result.

@dataclass
class WorkflowEvent:
14@dataclass
15class WorkflowEvent:
16    """An event emitted into the workflow event bus."""
17
18    name: str
19    payload: Dict[str, Any] = field(default_factory=dict)
20    source: str = "external"

An event emitted into the workflow event bus.

WorkflowEvent( name: str, payload: Dict[str, Any] = <factory>, source: str = 'external')
name: str
payload: Dict[str, Any]
source: str = 'external'