目录

pregel node

1. PregelNode

1.1 PregelNode 属性

前面我们已经了解到,PregelNode 没有抽象方法,且有如下属性:

属性名 类型 语义作用 使用场景 / 示例
channels str | list[str] 输入通道名称(单个或多个)。决定从哪些通道读取输入并传给 bound channels="user_input"channels=["query", "history"]
triggers list[str] 当这些通道被写入时,当前节点在下一轮中被激活执行。 通常用来响应其他节点的输出或外部信号
mapper Callable[[Any], Any] | None 在传给 bound 之前对输入值进行转换或预处理。 如将多个输入组合成特定格式
writers list[Runnable] 在节点计算完成后,接管输出结果并写入对应通道。可自定义写入行为。 比如将 bound 的输出写入到多个通道
bound Runnable[Any, Any] 节点的核心执行逻辑,会接收来自 channels 的输入并返回输出。 可以是函数、链、LLM 调用等
retry_policy Sequence[RetryPolicy] | None 节点执行失败时的重试策略。支持如指数退避、最大次数等。 提高健壮性,处理临时失败
cache_policy CachePolicy | None 节点的缓存策略。用于跳过相同输入的重复执行。 性能优化,节省 LLM token 或 API 请求
tags Sequence[str] | None 附加在节点上的标记,用于 tracing、debug、日志等用途。 ["llm", "retriever"]
metadata Mapping[str, Any] | None 附加在节点上的任意键值对信息,用于 tracing 或运行时识别。 比如记录节点版本、模型 ID
subgraphs Sequence[PregelProtocol] 嵌套图,表示该节点内部可递归执行子图。 实现复杂控制流或组件级封装

1.2 PregelNode 方法

PregelNode 实现了 Runnable 协议,具体的实现方法是向下面这样,委托给 bound 属性:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
    def invoke(
        self,
        input: Any,
        config: RunnableConfig | None = None,
        **kwargs: Any | None,
    ) -> Any:
        return self.bound.invoke(
            input,
            merge_configs({"metadata": self.metadata, "tags": self.tags}, config),
            **kwargs,
        )

除了 Runnable 的接口方法,还有如下几个方法需要我们关注:

方法名 作用描述 输出值类型
flat_writers 返回合并优化后的 writers 列表(合并连续的 ChannelWrite 操作,避免冗余写操作) list[Runnable]
node 构建一个复合 Runnable,将 boundwriters 合并为一个节点运行体,作为该节点的主逻辑 `Runnable[Any, Any] None`
input_cache_key 返回该节点输入的缓存键(由 mapper 和 channels 决定),用于避免重复计算输入 INPUT_CACHE_KEY_TYPE(即 tuple[Callable, tuple[Channel, ...]]

init

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
    def __init__(
        self,
        *,
        channels: str | list[str],
        triggers: Sequence[str],
        mapper: Callable[[Any], Any] | None = None,
        writers: list[Runnable] | None = None,
        tags: list[str] | None = None,
        metadata: Mapping[str, Any] | None = None,
        bound: Runnable[Any, Any] | None = None,
        retry_policy: RetryPolicy | Sequence[RetryPolicy] | None = None,
        cache_policy: CachePolicy | None = None,
        subgraphs: Sequence[PregelProtocol] | None = None,
    ) -> None:
        self.channels = channels
        self.triggers = list(triggers)
        self.mapper = mapper
        self.writers = writers or []
        self.bound = bound if bound is not None else DEFAULT_BOUND
        self.cache_policy = cache_policy
        if isinstance(retry_policy, RetryPolicy):
            self.retry_policy = (retry_policy,)
        else:
            self.retry_policy = retry_policy
        self.tags = tags
        self.metadata = metadata
        if subgraphs is not None:
            self.subgraphs = subgraphs
        elif self.bound is not DEFAULT_BOUND:
            try:
                subgraph = find_subgraph_pregel(self.bound)
            except Exception:
                subgraph = None
            if subgraph:
                self.subgraphs = [subgraph]
            else:
                self.subgraphs = []
        else:
            self.subgraphs = []

PregelNode 初始化代码中比较难理解的是 subgraphs 的处理。未传入 subgraphs 时,会调用 find_subgraph_pregel(self.bound)

find_subgraph_pregel 的作用是在一个复杂的 Runnable 链(可能嵌套、封装)中递归查找第一个合格的 PregelProtocol 实例(即符合要求的子图 Pregel 实例)。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def find_subgraph_pregel(candidate: Runnable) -> PregelProtocol | None:
    from langgraph.pregel import Pregel

    candidates: list[Runnable] = [candidate]

    for c in candidates:
        # 第一个合格 Pregel
        if (
            isinstance(c, PregelProtocol)
            # subgraphs that disabled checkpointing are not considered
            and (not isinstance(c, Pregel) or c.checkpointer is not False)
        ):
            return c
        elif isinstance(c, RunnableSequence) or isinstance(c, RunnableSeq):
            candidates.extend(c.steps)
        # 获取 RunnableLambda 或 RunnableCallable 中的依赖
        elif isinstance(c, RunnableLambda):
            candidates.extend(c.deps)
        elif isinstance(c, RunnableCallable):
            if c.func is not None:
                candidates.extend(
                    # 如果非局部变量有 __self__,说明是绑定方法,取出其实例。
                    nl.__self__ if hasattr(nl, "__self__") else nl
                    for nl in get_function_nonlocals(c.func)
                )
            elif c.afunc is not None:
                candidates.extend(
                    nl.__self__ if hasattr(nl, "__self__") else nl
                    for nl in get_function_nonlocals(c.afunc)
                )

    return None

get_function_nonlocals 用于获取函数依赖(依赖的全局变量、闭包变量)。我先看下面这个例子:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from langgraph.pregel._utils import get_function_nonlocals


class A:
    def __init__(self):
        self.a = 1


db_config = {"host": "localhost", "port": 3306}
A_CONFIG = A()


def outer():
    factor = 10

    def inner(x):
        return x * factor + db_config["port"] + A_CONFIG.a

    return inner


inner_fn = outer()
print(get_function_nonlocals(inner_fn))
# [1, {'host': 'localhost', 'port': 3306}, 10]

具体的我们看看 get_function_nonlocals 的实现:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def get_function_nonlocals(func: Callable) -> list[Any]:
    """Get the nonlocal variables accessed by a function.

    Args:
        func: The function to check.

    Returns:
        List[Any]: The nonlocal variables accessed by the function.
    获取某个函数中访问到的非局部变量(nonlocal variables)的实际值列表。
    """
    try:
        # 获取函数源码并解析 AST
        code = inspect.getsource(func)
        tree = ast.parse(textwrap.dedent(code))
        # 使用 AST visitor 收集非局部变量名
        visitor = FunctionNonLocals()
        visitor.visit(tree)
        # visitor.nonlocals 示例返回:  {'A_CONFIG.a', 'factor', 'db_config', 'x'}
        closure = (
            # inspect.getclosurevars 会返回一个包含闭包变量的命名空间:包括 globals, nonlocals, builtins, unbound
            inspect.getclosurevars(func.__wrapped__)
            # 函数是被装饰器包装的(即 __wrapped__ 存在),则获取原始函数
            if hasattr(func, "__wrapped__") and callable(func.__wrapped__)
            else inspect.getclosurevars(func)
        )
        # 获取函数依赖的变量
        # 示例返回: {'A_CONFIG': <__main__.A object at 0x00000191C6C70E30>, 'db_config': {'host': 'localhost', 'port': 3306}, 'factor': 10}
        candidates = {**closure.globals, **closure.nonlocals}

        for k, v in candidates.items():
            if k in visitor.nonlocals:
                values.append(v)
            for kk in visitor.nonlocals:
                # 支持嵌套属性访问 eg: A_CONFIG.a
                if "." in kk and kk.startswith(k):
                    vv = v
                    for part in kk.split(".")[1:]:
                        if vv is None:
                            break
                        else:
                            try:
                                vv = getattr(vv, part)
                            except AttributeError:
                                break
                    else:
                        values.append(vv)
    except (SyntaxError, TypeError, OSError, SystemError):
        return []

    return values

flat_writers

writers 在节点计算完成后,接管输出结果并写入对应通道。flat_writers 返回合并优化后的 writers 列表(合并连续的 ChannelWrite 操作,避免冗余写操作)。关于 ChannelWrite 我们学习完 Channel 之后再看。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
    @cached_property
    def flat_writers(self) -> list[Runnable]:
        """Get writers with optimizations applied. Dedupes consecutive ChannelWrites."""
        writers = self.writers.copy()
        while (
            len(writers) > 1
            and isinstance(writers[-1], ChannelWrite)
            and isinstance(writers[-2], ChannelWrite)
        ):
            # we can combine writes if they are consecutive
            # careful to not modify the original writers list or ChannelWrite
            writers[-2] = ChannelWrite(
                writes=writers[-2].writes + writers[-1].writes,
            )
            writers.pop()
        return writers

node

node 方法用于将 boundwriters 合并为一个 Runnable。代码中使用到的 RunnableSeq 是 langgraph 实现的类似 RunnableSequence 的对象。目前不太清楚,为什么 langgraph 要定义 RunnableSeq。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
    @cached_property
    def node(self) -> Runnable[Any, Any] | None:
        """Get a runnable that combines `bound` and `writers`."""
        writers = self.flat_writers
        if self.bound is DEFAULT_BOUND and not writers:
            return None
        elif self.bound is DEFAULT_BOUND and len(writers) == 1:
            return writers[0]
        elif self.bound is DEFAULT_BOUND:
            return RunnableSeq(*writers)
        elif writers:
            return RunnableSeq(self.bound, *writers)
        else:
            return self.bound

input_cache_key

mapper 用于在传给 bound 之前对输入值进行转换或预处理,input_cache_key 用于获取缓存 key,避免重复计算。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
    @cached_property
    def input_cache_key(self) -> INPUT_CACHE_KEY_TYPE:
        """Get a cache key for the input to the node.
        This is used to avoid calculating the same input multiple times."""
        return (
            self.mapper,
            (
                tuple(self.channels)
                if isinstance(self.channels, list)
                else (self.channels,)
            ),
        )