1. BaseChannel
channel 表示节点之间的数据传输管道,有点类似 golang 中的 channel。BaseChannel 是一个抽象基类。下面是 channel 的 UML 类图:

但是与 golang 中的 channel 不同,channel 本身只定义数据如何存储,不定义如何读写。channel 的读写由 ChannelRead、ChannelWrite 定义。
1.1 BaseChannel 的定义
BaseChannel 定义了两个属性:
- key: 通道的唯一标识,用于在图中定位通道。
- typ: 通道的数据类型,用于类型检查和序列化。
1
2
3
4
5
6
7
8
9
10
|
Value = TypeVar("Value")
Update = TypeVar("Update")
Checkpoint = TypeVar("Checkpoint")
class BaseChannel(Generic[Value, Update, Checkpoint], ABC):
"""Base class for all channels."""
__slots__ = ("key", "typ")
# 完整代码省略
|
泛型定义中,给channel 定义了三种类型的输入:
- Value: 通道中存储的值的类型。
- Update: 可以写入通道的更新类型。
- Checkpoint: 通道当前状态的可序列化快照类型。
1.2 BaseChannel 抽象方法
方法名 |
返回类型 |
作用描述 |
ValueType |
Any |
声明通道中存储的值的类型(抽象属性) |
UpdateType |
Any |
声明可以写入通道的更新类型(抽象属性) |
copy() |
Self |
返回当前通道的副本,默认基于 checkpoint() 和 from_checkpoint() 实现 |
checkpoint() |
Checkpoint | Any |
返回通道当前状态的可序列化快照;如通道为空则返回 MISSING |
from_checkpoint(checkpoint) |
Self |
基于 checkpoint 创建一个新的相同通道实例(抽象方法) |
get() |
Value |
读取当前通道的值,若尚未写入则抛出 EmptyChannelError (抽象方法) |
is_available() |
bool |
判断当前通道是否已有值,默认通过调用 get() 实现 |
update(values) |
bool |
使用给定的一组更新更新通道状态,返回是否发生变化(抽象方法) |
consume() |
bool |
表示某个任务已经使用了此通道的值,默认无操作;用于控制消费行为 |
finish() |
bool |
表示整个图执行即将结束,通道可据此清理或更改状态,默认无操作 |
1.3 ✅ BaseChannel
子类
BaseChannel 有如下子类:
子类名 |
核心语义 |
特点 / 使用场景 |
AnyValue |
存储任意值,无聚合、无验证 |
最通用的通道,适用于默认存储单值场景 |
BinaryOperatorAggregate |
使用二元操作聚合多个更新为单个值 |
适用于数值求和、列表合并等场景(如 x + y ) |
EphemeralValue |
值仅在一次任务中可见,用后即焚 |
适用于临时中间变量,不参与状态快照 |
LastValue |
始终保留最近一次写入的值(覆盖写) |
常用于最新状态存储,如最终响应、变量覆盖 |
LastValueAfterFinish |
与 LastValue 类似,但仅在图结束时才可见 |
用于隐藏中间状态,只暴露最终结果 |
NamedBarrierValue |
所有指定 writers 写入后,才可读取 |
多个节点同步的“屏障”,用于协调依赖 |
NamedBarrierValueAfterFinish |
与上类似,但只在图执行完后暴露 |
延迟暴露聚合值(如日志、结果统计) |
Topic |
发布/订阅消息通道,支持消息广播 |
多节点触发、多输出场景(如事件流) |
UntrackedValue |
值可读写,但不会被记录进状态快照 |
适用于只用于推理、不希望持久化的中间值 |
按照行为,channel 的子类可以分为如下几类:
分类 |
子类名 |
特征 |
单值通道 |
AnyValue , LastValue |
存储最近一次写入 |
聚合通道 |
BinaryOperatorAggregate , NamedBarrierValue |
多值合并或协调 |
只读一次 |
EphemeralValue , UntrackedValue |
用完即失,不记录状态 |
最终值通道 |
LastValueAfterFinish , NamedBarrierValueAfterFinish |
仅在流程结束后暴露 |
事件广播通道 |
Topic |
多订阅者触发,适合多节点 fan-out |
现在我们来看这些子类的实现。
2. AnyValue
AnyValue 存储任意值,无聚合、无验证,单值通道。AnyValue 比较好理解,源码如下。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
|
class AnyValue(Generic[Value], BaseChannel[Value, Value, Value]):
"""Stores the last value received, assumes that if multiple values are
received, they are all equal."""
__slots__ = ("typ", "value")
value: Value | Any
def __init__(self, typ: Any, key: str = "") -> None:
super().__init__(typ, key)
self.value = MISSING
def copy(self) -> Self:
"""Return a copy of the channel."""
empty = self.__class__(self.typ, self.key)
empty.value = self.value
return empty
def from_checkpoint(self, checkpoint: Value) -> Self:
empty = self.__class__(self.typ, self.key)
if checkpoint is not MISSING:
empty.value = checkpoint
return empty
def update(self, values: Sequence[Value]) -> bool:
if len(values) == 0:
if self.value is MISSING:
return False
else:
self.value = MISSING
return True
self.value = values[-1]
return True
def get(self) -> Value:
if self.value is MISSING:
raise EmptyChannelError()
return self.value
def checkpoint(self) -> Value:
return self.value
|
3. LastValue
LastValue 始终保留最近一次写入的值(覆盖写)。与 AnyValue 不同的是 update 方法,每次限制只能接受一个值。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
|
class LastValue(Generic[Value], BaseChannel[Value, Value, Value]):
"""Stores the last value received, can receive at most one value per step."""
__slots__ = ("value",)
value: Value | Any
def __init__(self, typ: Any, key: str = "") -> None:
super().__init__(typ, key)
self.value = MISSING
def update(self, values: Sequence[Value]) -> bool:
if len(values) == 0:
return False
if len(values) != 1:
msg = create_error_message(
message=f"At key '{self.key}': Can receive only one value per step. Use an Annotated key to handle multiple values.",
error_code=ErrorCode.INVALID_CONCURRENT_GRAPH_UPDATE,
)
raise InvalidUpdateError(msg)
self.value = values[-1]
return True
|
3. LastValueAfterFinish
LastValueAfterFinish 与 LastValue
类似,但仅在图结束时才可见。图是否结束是通过 finished 属性控制的。只有调用 finish() 方法之后,值才可见。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
|
class LastValueAfterFinish(
Generic[Value], BaseChannel[Value, Value, tuple[Value, bool]]
):
"""Stores the last value received, but only made available after finish().
Once made available, clears the value."""
__slots__ = ("value", "finished")
value: Value | Any
finished: bool
def __init__(self, typ: Any, key: str = "") -> None:
super().__init__(typ, key)
self.value = MISSING
self.finished = False
def checkpoint(self) -> tuple[Value | Any, bool] | Any:
if self.value is MISSING:
return MISSING
return (self.value, self.finished)
def from_checkpoint(self, checkpoint: tuple[Value | Any, bool] | Any) -> Self:
empty = self.__class__(self.typ)
empty.key = self.key
if checkpoint is not MISSING:
empty.value, empty.finished = checkpoint
return empty
def update(self, values: Sequence[Value | Any]) -> bool:
if len(values) == 0:
return False
self.finished = False
self.value = values[-1]
return True
def consume(self) -> bool:
if self.finished:
self.finished = False
self.value = MISSING
return True
return False
def finish(self) -> bool:
if not self.finished and self.value is not MISSING:
self.finished = True
return True
else:
return False
|
4. BinaryOperatorAggregate
BinaryOperatorAggregate 使用二元操作聚合多个更新为单个值。BinaryOperatorAggregate 的核心是接受一个二元操作符(operator),以聚合值。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
|
class BinaryOperatorAggregate(Generic[Value], BaseChannel[Value, Value, Value]):
"""Stores the result of applying a binary operator to the current value and each new value.
import operator
total = Channels.BinaryOperatorAggregate(int, operator.add)
"""
__slots__ = ("value", "operator")
def __init__(self, typ: type[Value], operator: Callable[[Value, Value], Value]):
super().__init__(typ)
self.operator = operator
# special forms from typing or collections.abc are not instantiable
# so we need to replace them with their concrete counterparts
# 还原基础类型
typ = _strip_extras(typ)
if typ in (collections.abc.Sequence, collections.abc.MutableSequence):
typ = list
if typ in (collections.abc.Set, collections.abc.MutableSet):
typ = set
if typ in (collections.abc.Mapping, collections.abc.MutableMapping):
typ = dict
try:
self.value = typ()
except Exception:
self.value = MISSING
def checkpoint(self) -> Value:
return self.value
def update(self, values: Sequence[Value]) -> bool:
if not values:
return False
if self.value is MISSING:
self.value = values[0]
values = values[1:]
for value in values:
self.value = self.operator(self.value, value)
return True
|
_strip_extras(t) 的作用是:递归地去除类型注解中的 Annotated、Required、NotRequired 等“包装类型”,还原出“基础类型”。
在 Python 的 typing 类型系统中,类型注解可能被多层“包装”:
Annotated[int, "some metadata"]
→ 包裹了 int
Required[str]
/ NotRequired[float]
→ 特别用于 TypedDict
的字段,表示是否必须
这些包装类型的本质是类型工厂,构建出来的是 __origin__
指向基础类型的对象。
1
2
3
4
|
from typing import Annotated, Required, NotRequired
Annotated[int, "meta"].__origin__ # <class 'int'>
Required[str].__origin__ # <class 'str'>
|
下面是其源码:
1
2
3
4
5
6
7
8
9
10
11
|
# Adapted from typing_extensions
def _strip_extras(t): # type: ignore[no-untyped-def]
"""Strips Annotated, Required and NotRequired from a given type."""
# 如果 `t` 是一个被包装的类型(如 `Annotated[int, ...]`),则提取它的 `__origin__`(即原始类型)继续递归。
if hasattr(t, "__origin__"):
return _strip_extras(t.__origin__)
# 不执行这个条件
if hasattr(t, "__origin__") and t.__origin__ in (Required, NotRequired):
return _strip_extras(t.__args__[0])
return t
|
5. NamedBarrierValue
NamedBarrierValue 在所有指定 writers 写入后,才可读取。NamedBarrierValue 不保存具体的值,只用于屏障,其 get 方法返回 None。只有 self.names == self.seen 时才可读。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
|
class NamedBarrierValue(Generic[Value], BaseChannel[Value, Value, set[Value]]):
"""A channel that waits until all named values are received before making the value available."""
__slots__ = ("names", "seen")
names: set[Value]
seen: set[Value]
def __init__(self, typ: type[Value], names: set[Value]) -> None:
super().__init__(typ)
self.names = names
self.seen: set[str] = set()
def checkpoint(self) -> set[Value]:
return self.seen
def from_checkpoint(self, checkpoint: set[Value]) -> Self:
empty = self.__class__(self.typ, self.names)
empty.key = self.key
if checkpoint is not MISSING:
empty.seen = checkpoint
return empty
def update(self, values: Sequence[Value]) -> bool:
updated = False
for value in values:
if value in self.names:
if value not in self.seen:
self.seen.add(value)
updated = True
else:
raise InvalidUpdateError(
f"At key '{self.key}': Value {value} not in {self.names}"
)
return updated
# get 方法不返回具体值
def get(self) -> Value:
if self.seen != self.names:
raise EmptyChannelError()
return None
def consume(self) -> bool:
if self.seen == self.names:
self.seen = set()
return True
return False
|
6. NamedBarrierValueAfterFinish
NamedBarrierValueAfterFinish 在 NamedBarrierValue 的基础上增加了 finish() 方法,只有在 finish() 方法被调用后,才能读取到具体的值。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
|
class NamedBarrierValueAfterFinish(
Generic[Value], BaseChannel[Value, Value, set[Value]]
):
"""A channel that waits until all named values are received before making the value ready to be made available. It is only made available after finish() is called."""
__slots__ = ("names", "seen", "finished")
names: set[Value]
seen: set[Value]
def __init__(self, typ: type[Value], names: set[Value]) -> None:
super().__init__(typ)
self.names = names
self.seen: set[str] = set()
self.finished = False
def get(self) -> Value:
if not self.finished or self.seen != self.names:
raise EmptyChannelError()
return None
def checkpoint(self) -> tuple[set[Value], bool]:
return (self.seen, self.finished)
|
7. EphemeralValue
EphemeralValue 值仅在一次任务中可见,用后即焚。用后即焚并没有类本身中体现。相比于 LastValue 多了一个 guard 参数,控制 update 接收的 values 数量是否允许多个。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
|
class EphemeralValue(Generic[Value], BaseChannel[Value, Value, Value]):
"""Stores the value received in the step immediately preceding, clears after."""
__slots__ = ("value", "guard")
value: Value | Any
guard: bool
def __init__(self, typ: Any, guard: bool = True) -> None:
super().__init__(typ)
self.guard = guard
self.value = MISSING
def checkpoint(self) -> Value:
return self.value
def update(self, values: Sequence[Value]) -> bool:
if len(values) == 0:
if self.value is not MISSING:
self.value = MISSING
return True
else:
return False
if len(values) != 1 and self.guard:
raise InvalidUpdateError(
f"At key '{self.key}': EphemeralValue(guard=True) can receive only one value per step. Use guard=False if you want to store any one of multiple values."
)
self.value = values[-1]
return True
def get(self) -> Value:
if self.value is MISSING:
raise EmptyChannelError()
return self.value
|
8. UntrackedValue
UntrackedValue 值可读写,但不会被记录进状态快照。体现在 checkpoint 方法上。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
|
class UntrackedValue(Generic[Value], BaseChannel[Value, Value, Value]):
"""Stores the last value received, never checkpointed."""
__slots__ = ("value", "guard")
guard: bool
value: Value | Any
def __init__(self, typ: type[Value], guard: bool = True) -> None:
super().__init__(typ)
self.guard = guard
self.value = MISSING
def checkpoint(self) -> Value | Any:
return MISSING
def from_checkpoint(self, checkpoint: Value) -> Self:
empty = self.__class__(self.typ, self.guard)
empty.key = self.key
return empty
def update(self, values: Sequence[Value]) -> bool:
if len(values) == 0:
return False
if len(values) != 1 and self.guard:
raise InvalidUpdateError(
f"At key '{self.key}': UntrackedValue(guard=True) can receive only one value per step. Use guard=False if you want to store any one of multiple values."
)
self.value = values[-1]
return True
def get(self) -> Value:
if self.value is MISSING:
raise EmptyChannelError()
return self.value
|
9. Topic
Topic 发布/订阅消息通道,支持消息广播。消息的广播也没有体现在自身的实现上。accumulate 累加体现在 update 方法上。如果非累加,先清空 self.values
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
|
class Topic(
Generic[Value],
BaseChannel[Sequence[Value], Union[Value, list[Value]], list[Value]],
):
"""A configurable PubSub Topic.
Args:
typ: The type of the value stored in the channel.
accumulate: Whether to accumulate values across steps. If False, the channel will be emptied after each step.
"""
__slots__ = ("values", "accumulate")
def __init__(self, typ: type[Value], accumulate: bool = False) -> None:
super().__init__(typ)
# attrs
self.accumulate = accumulate
# state
self.values = list[Value]()
@property
def ValueType(self) -> Any:
"""The type of the value stored in the channel."""
return Sequence[self.typ] # type: ignore[name-defined]
@property
def UpdateType(self) -> Any:
"""The type of the update received by the channel."""
return Union[self.typ, list[self.typ]] # type: ignore[name-defined]
def checkpoint(self) -> list[Value]:
return self.values
def from_checkpoint(self, checkpoint: list[Value]) -> Self:
empty = self.__class__(self.typ, self.accumulate)
empty.key = self.key
if checkpoint is not MISSING:
if isinstance(checkpoint, tuple):
# backwards compatibility
empty.values = checkpoint[1]
else:
empty.values = checkpoint
return empty
def update(self, values: Sequence[Value | list[Value]]) -> bool:
updated = False
if not self.accumulate:
updated = bool(self.values)
# 非累加状态,先清空旧值
self.values = list[Value]()
if flat_values := tuple(_flatten(values)):
updated = True
self.values.extend(flat_values)
return updated
def get(self) -> Sequence[Value]:
if self.values:
return list(self.values)
else:
raise EmptyChannelError
|