目录

langgraph api 流程总结

本节我们总结回顾一下 Langgraph API 相关重要的流程,包括

  1. StateGraph API 如何映射为 Pregel
  2. Function API 如何映射为 Pregel

1. StateGraph API 如何映射为 Pregel

在之前的总结里,我们已经学习了 Pregel 执行的过程:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
                      input
                        |
                        | 初始化
                        |
                  updated_channels
                              ^
                   /          \ \
             更新 /        触发 \ \ 生成
                 /              \ \
            channel --- 数据 ---> node

在 Pregel 的抽象里只包含 channel、node(PregelNode) 两个核心组件。

而在 StateGraph 里有如下抽象:

  1. nodes: dict[str, StateNodeSpec[Any, ContextT]]
  2. edges: set[tuple[str, str]]
  3. branches: defaultdict[str, dict[str, BranchSpec]]

除了这些抽象,StateGraph 提供的另一个重要功能通过类型注解,声明:

  1. node 包含的 channel、ManagedValue

我们需要关注的重点就是 nodes、edges、branches 如何映射为 channel、node

1.1 StateNodeSpec -> PregelNode

StateNodeSpec -> PregelNode 的映射位于 CompiledStateGraph.attach_node

  1. 每个 node 都会单独定义一个 branch_channel,命名为 branch:to:{node_key}
  2. node 会被自己的 branch_channel 触发,这个 branch_channel 只起到触发作用,不传递值。

在 PregelNode 的初始化中:

  1. channels: 是从 node 的 input_schema 中解析的 channel
  2. mapper: 后续执行时,会将 input 绑定到 input_schema
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
_CHANNEL_BRANCH_TO = "branch:to:{}"

branch_channel = _CHANNEL_BRANCH_TO.format(key)
# 1. 添加 channel
self.channels[branch_channel] = (
    LastValueAfterFinish(Any)
    if node.defer
    else EphemeralValue(Any, guard=False)
)
# 2. 添加 node
self.nodes[key] = PregelNode(
    triggers=[branch_channel],
    # read state keys and managed values
    channels=("__root__" if is_single_input else input_channels),
    # coerce state dict to schema class (eg. pydantic model)
    mapper=mapper,
    # publish to state keys
    writers=[ChannelWrite(write_entries)],
    metadata=node.metadata,
    retry_policy=node.retry_policy,
    cache_policy=node.cache_policy,
    bound=node.runnable,  # type: ignore[arg-type]
)

PregelNode writers 参数的初始化最为复杂:

  1. StateGraph.add_node 添加节点时,会从输入函数的返回值解析出 node.ends,ChannelWriteEntry 会根据 node.ends 生成提示信息
  2. PregelNode 会填两个 ChannelWriteEntry,ChannelWriteEntry 接收一个 mapper 函数,这个 mapper 函数接收的是 bound_return=PregelNode.bound(input) 执行结果
    • output_keys 是 StateGraph 中收集的所有 channel,output_keys == ["__root__"] 表示无法解析 input_schema
    • _get_updates 会从 bound_return 中过滤出所有在 output_keys 的返回值,输出为对 channel 的更改,对于嵌套结构会递归处理
    • _get_root 类似 _get_updates,但是因为 output_keys == ["__root__"] ,只有一个 channel 默认会把 bound_return 整体作为 __root__ channel 的值,不回去判断 bound_return 内部是有 __root__
    • _control_branch 用于从 bond_return 中提取出 Send 和 Command.goto,生成节点跳转的任务
  3. 总结一下 PregelNode.writers writers=[ChannelWrite(List[ChannelWriteEntry|ChannelWriteTupleEntry])]
    • Pregel 的原始 API
      • PregelNode.writers 包含的是 ChannelWriteEntry
      • ChannelWriteEntry 包含 channel,PregelNode.bound 只返回值,ChannelWrite 返回 (channel, mapper(value))
    • StateGraph 中
      • PregelNode.writers 包含的是 ChannelWriteTupleEntry
      • ChannelWriteTupleEntry 包含的是 mapper 函数,返回 mapper(value)
      • mapper 有 _get_updates,_control_branch,正是因为这些 mapper 函数,StateGraph 的 node 函数,可以返回 dict|Command

这里解释了 node 函数的返回值是如何处理。作为对比,我们额外说明一下 tool 函数的返回值是如何处理的。

  1. 首先 tool 函数被包装在 Tool 类中,Tool 会调用 _format_output 对函数返回结果进行包装:
    • 如果 content 就是一个类 ToolOutputMixin 的实例,直接返回
    • Command、ToolMessage 都是 ToolOutputMixin 的子类,会直接返回
    • 其他类型,会包装为 ToolMessage
    • 正因为如此,如果要在 Tool 中,返回对 channel 的修改,必须使用 Command
  2. tools 会被包装在 ToolNode 中,ToolNode 会调用 _combine_tool_outputs 对多个 tool 的返回结果进行合并
    • tool 的返回值经过 _format_output 包装后,只会返回 ToolMessage 和 Command
    • command 被收集后返回,ToolMessage 会包装成 {message_key: [ToolMessage]} 返回
  3. 回到上面 StateGraph 中对 node 函数返回值的处理流程
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
    def attach_node(self, key: str, node: StateNodeSpec[Any, ContextT] | None) -> None:
        if key == START:
            # 从 input_schema 获取 input 可能输入的 key
            output_keys = [
                k
                for k, v in self.builder.schemas[self.builder.input_schema].items()
                if not is_managed_value(v)
            ]
        else:
            # 非 start 节点,获取所有 channel 作为 可能输入的 key
            output_keys = list(self.builder.channels) + [
                k for k, v in self.builder.managed.items()
            ]

        # 获取对 channel 的更新
        def _get_updates(
            input: None | dict | Any,
        ) -> Sequence[tuple[str, Any]] | None:
            if input is None:
                return None
            elif isinstance(input, dict):
                return [(k, v) for k, v in input.items() if k in output_keys]
            elif isinstance(input, Command):
                # 这个更新是父图要处理的,不归当前图管,所以直接返回空元组 ()
                if input.graph == Command.PARENT:
                    return None
                return [
                    # _update_as_tuples 获取 Command 对 channel 的更新值
                    (k, v)
                    for k, v in input._update_as_tuples()
                    if k in output_keys
                ]
            elif (
                isinstance(input, (list, tuple))
                and input
                and any(isinstance(i, Command) for i in input)
            ):
                updates: list[tuple[str, Any]] = []
                for i in input:
                    if isinstance(i, Command):
                        if i.graph == Command.PARENT:
                            continue
                        updates.extend(
                            (k, v) for k, v in i._update_as_tuples() if k in output_keys
                        )
                    else:
                        updates.extend(_get_updates(i) or ())
                return updates
            # 如果是其他类型,比如 BaseModel ,判断其默认值,跟当前值,判断是否发生了更新
            elif (t := type(input)) and get_cached_annotated_keys(t):
                return get_update_as_tuples(input, output_keys)
            else:
                msg = create_error_message(
                    message=f"Expected dict, got {input}",
                    error_code=ErrorCode.INVALID_GRAPH_NODE_RETURN_VALUE,
                )
                raise InvalidUpdateError(msg)

        # state updaters
        # 调用 mapper(input) 获取 [(channel, value)]
        write_entries: tuple[ChannelWriteEntry | ChannelWriteTupleEntry, ...] = (
            ChannelWriteTupleEntry(
                # 处理 channel 值设置任务
                # 输入没有注解,统一放到 __root__ 的默认channel 中
                mapper=_get_root if output_keys == ["__root__"] else _get_updates
            ),
            ChannelWriteTupleEntry(
                # 处理跨节点跳转任务
                mapper=_control_branch,
                static=(
                    _control_static(node.ends)
                    if node is not None and node.ends is not None
                    else None
                ),
            ),
        )

        # add node and output channel
        if key == START:
            self.nodes[key] = PregelNode(
                tags=[TAG_HIDDEN],
                triggers=[START],  # 被 Start 节点出发
                channels=START,  # 从 Start 节点读取值
                # 输出的 channel
                writers=[ChannelWrite(write_entries)],
            )
        elif node is not None:
            input_schema = node.input_schema if node else self.builder.state_schema
            input_channels = list(self.builder.schemas[input_schema])
            is_single_input = len(input_channels) == 1 and "__root__" in input_channels
            if input_schema in self.schema_to_mapper:
                mapper = self.schema_to_mapper[input_schema]
            else:
                # 用 input_schema 实例化
                mapper = _pick_mapper(input_channels, input_schema)
                self.schema_to_mapper[input_schema] = mapper

            # 所有对节点的触发的 channel 都统一命名为 "branch:to:{}"
            branch_channel = _CHANNEL_BRANCH_TO.format(key)
            self.channels[branch_channel] = (
                LastValueAfterFinish(Any)
                if node.defer
                else EphemeralValue(Any, guard=False)
            )
            self.nodes[key] = PregelNode(
                triggers=[branch_channel],
                # read state keys and managed values
                channels=("__root__" if is_single_input else input_channels),
                # coerce state dict to schema class (eg. pydantic model)
                mapper=mapper,
                # publish to state keys
                writers=[ChannelWrite(write_entries)],
                metadata=node.metadata,
                retry_policy=node.retry_policy,
                cache_policy=node.cache_policy,
                bound=node.runnable,  # type: ignore[arg-type]
            )
        else:
            raise RuntimeError


def _get_root(input: Any) -> Sequence[tuple[str, Any]] | None:
    if isinstance(input, Command):
        # 这个更新是父图要处理的,不归当前图管,所以直接返回空元组 ()
        if input.graph == Command.PARENT:
            return ()
        return input._update_as_tuples()
    elif (
        isinstance(input, (list, tuple))
        and input
        and any(isinstance(i, Command) for i in input)
    ):
        updates: list[tuple[str, Any]] = []
        for i in input:
            if isinstance(i, Command):
                if i.graph == Command.PARENT:
                    continue
                updates.extend(i._update_as_tuples())
            else:
                updates.append(("__root__", i))
        return updates
    elif input is not None:
        return [("__root__", input)]


def _control_branch(value: Any) -> Sequence[tuple[str, Any]]:
    if isinstance(value, Send):
        return ((TASKS, value),)
    commands: list[Command] = []
    if isinstance(value, Command):
        commands.append(value)
    elif isinstance(value, (list, tuple)):
        for cmd in value:
            if isinstance(cmd, Command):
                commands.append(cmd)
    rtn: list[tuple[str, Any]] = []
    for command in commands:
        if command.graph == Command.PARENT:
            raise ParentCommand(command)

        goto_targets = (
            [command.goto] if isinstance(command.goto, (Send, str)) else command.goto
        )

        for go in goto_targets:
            # Send 表示跳转到固定节点
            if isinstance(go, Send):
                rtn.append((TASKS, go))
            # 否则跳转到 branch 节点
            elif isinstance(go, str) and go != END:
                # END is a special case, it's not actually a node in a practical sense
                # but rather a special terminal node that we don't need to branch to
                rtn.append((_CHANNEL_BRANCH_TO.format(go), None))
    return rtn

1.2 edges

edge 的处理比较简单,edge 会被翻译成节点的 writer:

  1. 在 start.writers 添加对 end.branch_channel 的写入任务
  2. 前面我们说过,每个节点都会添加一个 branch_channel,并被这个 branch_channel 触发
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
    def attach_edge(self, starts: str | Sequence[str], end: str) -> None:
        if isinstance(starts, str):
            # subscribe to start channel
            if end != END:
                # edge 的含义表示,一个节点会触发另一节点的执行
                self.nodes[starts].writers.append(
                    ChannelWrite(
                        (ChannelWriteEntry(_CHANNEL_BRANCH_TO.format(end), None),)
                    )
                )

attach_edge 还需要处理一种特殊情况,多个 start 节点触发一个 end 任务。end 任务需要等所有 start 节点都更新后,才能被触发。处理方法是使用一个特殊的 channel。

  1. 创建一个 NamedBarrierValue channel,end 节点被这个 channel 触发
  2. 每个 start 节点添加对 NamedBarrierValue 更新的 writers。
  3. 当 start 节点更新时,会触发 NamedBarrierValue 记录对 start 节点的可见
  4. 当所有节点都可见时,NamedBarrierValue 就是 available
  5. apply_writes 会在 updated_channel 应用之后,在检查一下所有 channel 是否 available,这样就可以检测到可用的 NamedBarrierValue,并将其追加到 updated_channel,这样就可以触发 end 节点
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
    def attach_edge(self, starts: str | Sequence[str], end: str) -> None:
        elif end != END:
            channel_name = f"join:{'+'.join(starts)}:{end}"
            # register channel
            if self.builder.nodes[end].defer:
                # 只有所有 starts 都更新,才能从此 channel 获取到值
                self.channels[channel_name] = NamedBarrierValueAfterFinish(
                    str, set(starts)
                )
            else:
                self.channels[channel_name] = NamedBarrierValue(str, set(starts))
            # subscribe to channel
            # node 被合并channel 触发
            self.nodes[end].triggers.append(channel_name)
            # publish to channel
            for start in starts:
                # start 会触发合并 channel 对 start channel 的可见
                # 所有channel 都可见时,合并的 channel 可以获取到值,触发对 end 的更新
                self.nodes[start].writers.append(
                    ChannelWrite((ChannelWriteEntry(channel_name, start),))
                )

def apply_writes():
    # Channels that weren't updated in this step are notified of a new step
    if bump_step:
        for chan in channels:
            if channels[chan].is_available() and chan not in updated_channels:
                if channels[chan].update(EMPTY_SEQ) and next_version is not None:
                    checkpoint["channel_versions"][chan] = next_version
                    # unavailable channels can't trigger tasks, so don't add them
                    if channels[chan].is_available():
                        updated_channels.add(chan)

1.3 Branch

Branch 处理的是 condition_edge,接收一个处理函数 path,执行完之后输出要跳转的节点。condition_edge 只确定 start 节点,并且不能是多个 start 节点。

Branch 的处理比较复杂,我们先看 attach_branch 提供的两个函数

  1. get_writes:
    • 前面我们提到,每个 node 都有一个触发它的 branch_channel。所以对 node 的触发,必须格式化为 _CHANNEL_BRANCH_TO
    • get_writes 用于标准化 channel,并将对 channel 的写入转换为 ChannelWriteEntry
    • ChannelWriteEntry 默认的 value 是 PASSTHROUGH,get_writes 中设置成了 None
  2. reader:
    • path 的入参可能需要从多个 channel 读取值
    • 从哪些 channel 读取值,由 path 的 input_schema 定义
    • reader 用于实现从 channel 读取值,并生成 input_schema 的值
    • reader 调用 ChannelRead.do_read 内部,会从 config[CONF][CONFIG_KEY_READ] 获取一个 read 函数,这个 read 函数正是 prepare_single_task 在执行时配置的
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
    def attach_branch(
        self, start: str, name: str, branch: BranchSpec, *, with_reader: bool = True
    ) -> None:
        def get_writes(
            packets: Sequence[str | Send], static: bool = False
        ) -> Sequence[ChannelWriteEntry | Send]:
            writes = [
                (
                    ChannelWriteEntry(
                        # 标准化channel,并设置 value=None
                        p if p == END else _CHANNEL_BRANCH_TO.format(p), None
                    )
                    if not isinstance(p, Send)
                    else p
                )
                for p in packets
                if (True if static else p != END)
            ]
            if not writes:
                return []
            return writes

        if with_reader:
            # get schema
            # branch.input_schema 是从 brach.path 推测的 schema
            schema = branch.input_schema or (
                self.builder.nodes[start].input_schema
                if start in self.builder.nodes # start 还能不在么?
                else self.builder.state_schema
            )
            # 读取数据的 channel
            channels = list(self.builder.schemas[schema])
            # get mapper
            if schema in self.schema_to_mapper:
                mapper = self.schema_to_mapper[schema]
            else:
                # 将输入与 schema 绑定
                mapper = _pick_mapper(channels, schema)
                self.schema_to_mapper[schema] = mapper
            # create reader
            reader: Callable[[RunnableConfig], Any] | None = partial(
                # read: READ_TYPE = config[CONF][CONFIG_KEY_READ] 读取函数
                ChannelRead.do_read,
                # 读取的 channel
                select=channels[0] if channels == ["__root__"] else channels,
                fresh=True,
                # coerce state dict to schema class (eg. pydantic model)
                mapper=mapper,
            )
        else:
            reader = None

        # attach branch publisher
        # 条件分支,表示为 start 节点,写入时,会根据 branch 的函数,动态输出 [(channel,value)]
        self.nodes[start].writers.append(branch.run(get_writes, reader))


class ChannelRead(RunnableCallable):
    @staticmethod
    def do_read(
        config: RunnableConfig,
        *,
        select: str | list[str],
        fresh: bool = False,
        mapper: Callable[[Any], Any] | None = None,
    ) -> Any:
        try:
            read: READ_TYPE = config[CONF][CONFIG_KEY_READ]
        except KeyError:
            raise RuntimeError(
                "Not configured with a read function"
                "Make sure to call in the context of a Pregel process"
            )
        if mapper:
            return mapper(read(select, fresh))
        else:
            return read(select, fresh)


def prepare_single_task():
            return PregelExecutableTask(
                patch_config(
                    
                    configurable={
                        CONFIG_KEY_SEND: writes.extend,
                        CONFIG_KEY_READ: partial(
                            local_read,
                            scratchpad,
                            channels,
                            managed,
                            PregelTaskWrites(task_path, name, writes, triggers),
                        ),
                        
                    },
                ),
            )

现在我们再来理解 branch.run(get_writes, reader)

  1. branch 有三个参数
    • path: branch_func 路由函数
    • ends: 路由函数的输出,映射到哪个节点
    • input_schema: 路由函数的输入 schema
  2. run 方法返回的是一个 RunnableCallable
    • 执行入口是 branch._route
    • writer=get_writes
    • reader=reader
    • writer、reader 都会作为 kwargs 最终传递给 branch._route
    • 可以忽略 ChannelWrite.register_writer 只是将返回的 Runnable 标识为一个 ChannelWrite,并添加 static
  3. branch._route
    • brach.run 的返回值 是添加到 node.writers,他会像 ChannelWrite 一样被调用,即 invoke(input=bound_return)
    • value = reader(config) 获取 path 函数需要读取的 channel 的值
    • result = self.path.invoke(value, config) 返回的路由到哪些 node,或者是 Send
  4. branch._finish
    • writer=get_writes,entries = writer(destinations, False) 标准化 branch_channel,并将 str 的 branch_channel 转换为 ChannelWriteEntry
    • if need_passthrough: return ChannelWrite(entries): 暂时不知道这个分支何时被调用。
    • ChannelWrite.do_write(config, entries): get_writes 中将 value 设置成了 None,所以会直接调用这个分支,生成对 branch_channel 的写入,进而触发对应 node 的执行
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
class BranchSpec(NamedTuple):
    path: Runnable[Any, Hashable | list[Hashable]]
    ends: dict[Hashable, str] | None
    input_schema: type[Any] | None = None


    def run(
        self,
        writer: _Writer,
        reader: Callable[[RunnableConfig], Any] | None = None,
    ) -> RunnableCallable:
        return ChannelWrite.register_writer(
            RunnableCallable(
                func=self._route,
                afunc=self._aroute,
                writer=writer,
                reader=reader,
                name=None,
                trace=False,
            ),
            list(
                zip_longest(
                    writer([e for e in self.ends.values()], True),
                    [str(la) for la, e in self.ends.items()],
                )
            )
            if self.ends
            else None,
        )


    def _route(
        self,
        input: Any,
        config: RunnableConfig,
        *,
        reader: Callable[[RunnableConfig], Any] | None,
        writer: _Writer,
    ) -> Runnable:
        if reader:
            value = reader(config)
            # passthrough additional keys from node to branch
            # only doable when using dict states
            if (
                isinstance(value, dict)
                and isinstance(input, dict)
                # 如果 input_schema 不是 None,input 返回的对 channel 的修改,应该就已经在 value 中了
                and self.input_schema is None
            ):
                value = {**input, **value}
        else:
            value = input
        result = self.path.invoke(value, config)
        return self._finish(writer, input, result, config)


    def _finish(
        self,
        writer: _Writer,
        input: Any,
        result: Any,
        config: RunnableConfig,
    ) -> Runnable | Any:
        if not isinstance(result, (list, tuple)):
            result = [result]
        if self.ends:
            # 过滤 ends,所以 branch 中的 ends 是强限制
            destinations: Sequence[Send | str] = [
                r if isinstance(r, Send) else self.ends[r] for r in result
            ]
        else:
            destinations = cast(Sequence[Union[Send, str]], result)
        if any(dest is None or dest == START for dest in destinations):
            raise ValueError("Branch did not return a valid destination")
        if any(p.node == END for p in destinations if isinstance(p, Send)):
            raise InvalidUpdateError("Cannot send a packet to the END node")
        entries = writer(destinations, False)
        if not entries:
            return input
        else:
            need_passthrough = False
            for e in entries:
                if isinstance(e, ChannelWriteEntry):
                    if e.value is PASSTHROUGH:
                        need_passthrough = True
                        break
            if need_passthrough:
                return ChannelWrite(entries)
            else:
                ChannelWrite.do_write(config, entries)
                return input

1.4 不同节点的返回

至此我们总结一下,不同节点的返回值:

  1. node: 可以返回 dict|command,dict 表示对应 channel 的更改
  2. tool: 可以返回 any|command,tool 内如果想实现对 channel 的更新和节点跳转,必须使用 command
  3. branch: 可以返回 str|Send,str 会转换为 ChannelWriteEntry(channel, None) 表示对 channel 的更新,Send 表示节点跳转

2. Function API 如何映射为 Pregel