上一节我们介绍了 pregel loop 的初始化、tick 函数。tick 函数其实也没有介绍完,并标记了我们要理解的其他方法,包括:
_emit
_put_checkpoint
put_writes
这一节我们来学习这些函数的实现。
1. _emit
1.1 _emit 执行逻辑
_emit
函数的作用是外部传递事件。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
|
def _emit(
self,
mode: StreamMode,
values: Callable[P, Iterator[Any]],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
# stream 是 loop 初始化是传入的 `stream=StreamProtocol(stream.put, stream_modes),`
# stream = SyncQueue()
if self.stream is None:
return
debug_remap = mode in ("checkpoints", "tasks") and "debug" in self.stream.modes
if mode not in self.stream.modes and not debug_remap:
return
for v in values(*args, **kwargs):
if mode in self.stream.modes:
# 向 queue 发送事件
self.stream((self.checkpoint_ns, mode, v))
# "debug" mode is "checkpoints" or "tasks" with a wrapper dict
if debug_remap:
self.stream(
(
self.checkpoint_ns,
"debug",
{
"step": self.step - 1
if mode == "checkpoints"
else self.step,
"timestamp": datetime.now(timezone.utc).isoformat(),
"type": "checkpoint"
if mode == "checkpoints"
else "task_result"
if "result" in v
else "task",
"payload": v,
},
)
)
|
1.2 tick 中调用的 _emit
tick 调用 _emit
时传入了一个 map_debug_checkpoint 函数。
1
2
3
4
5
6
|
if self._checkpointer_put_after_previous is not None:
self._emit(
"checkpoints",
map_debug_checkpoint,
...
)
|
map_debug_checkpoint(…) 是 LangGraph 中用于构建调试用 checkpoint 数据结构 的函数,输出格式为 CheckpointPayload。
2. _put_checkpoint
_put_checkpoint(...)
是 LangGraph 中用于 保存执行中间状态(checkpoint) 的关键函数之一。它负责将当前图的执行状态、通道状态等通过后端存储器(checkpointer)保存下来。
_put_checkpoint 代码的注释一下,重点关注 self._put_checkpoint_fut = self.submit(
。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
|
def _put_checkpoint(self, metadata: CheckpointMetadata) -> None:
# assign step and parents
# 判断是否为当前退出点的 checkpoint
# `exiting == True` 表示这是图执行准备结束时(`exit`)要保存的 checkpoint。
exiting = metadata is self.checkpoint_metadata
# 如果已经保存,直接退出
if exiting and self.checkpoint["id"] == self.checkpoint_id_saved:
# checkpoint already saved
return
# 如果是新 checkpoint,补充元数据(step、parents)
if not exiting:
metadata["step"] = self.step
metadata["parents"] = self.config[CONF].get(CONFIG_KEY_CHECKPOINT_MAP, {})
self.checkpoint_metadata = metadata
# do checkpoint?
# 是否要真正执行保存?
do_checkpoint = self._checkpointer_put_after_previous is not None and (
exiting or self.durability != "exit"
)
# create new checkpoint
# 构造新的 checkpoint 数据结构
self.checkpoint = create_checkpoint(
self.checkpoint,
self.channels if do_checkpoint else None,
self.step,
id=self.checkpoint["id"] if exiting else None,
)
# bail if no checkpointer
if do_checkpoint and self._checkpointer_put_after_previous is not None:
self.prev_checkpoint_config = (
self.checkpoint_config
if CONFIG_KEY_CHECKPOINT_ID in self.checkpoint_config[CONF]
and self.checkpoint_config[CONF][CONFIG_KEY_CHECKPOINT_ID]
else None
)
self.checkpoint_config = {
**self.checkpoint_config,
CONF: {
**self.checkpoint_config[CONF],
CONFIG_KEY_CHECKPOINT_NS: self.config[CONF].get(
CONFIG_KEY_CHECKPOINT_NS, ""
),
},
}
channel_versions = self.checkpoint["channel_versions"].copy()
new_versions = get_new_channel_versions(
self.checkpoint_previous_versions, channel_versions
)
self.checkpoint_previous_versions = channel_versions
# save it, without blocking
# if there's a previous checkpoint save in progress, wait for it
# ensuring checkpointers receive checkpoints in order
# 异步保存
self._put_checkpoint_fut = self.submit(
self._checkpointer_put_after_previous,
# 如果是第一次执行,这个属性为 None
# 第二次执行,这个属性是第一次 submit 返回的 Future
# checkpoint 的保存会延时一个提交?
getattr(self, "_put_checkpoint_fut", None),
self.checkpoint_config,
copy_checkpoint(self.checkpoint),
self.checkpoint_metadata,
new_versions,
)
# 更新 `checkpoint_config` 中的 checkpoint id
self.checkpoint_config = {
**self.checkpoint_config,
CONF: {
**self.checkpoint_config[CONF],
CONFIG_KEY_CHECKPOINT_ID: self.checkpoint["id"],
},
}
if not exiting:
# increment step
self.step += 1
|
3. put_writes
put_writes
,主要功能是 将某个 task 的写入信息 writes
存储到 checkpoint_pending_writes
中,供下一个 tick(调度轮次)使用,同时调用 checkpointer 实现持久化。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
|
class PregelLoop:
def put_writes(self, task_id: str, writes: WritesT) -> None:
"""Put writes for a task, to be read by the next tick."""
# WritesT = Sequence[tuple[str, Any]]
if not writes:
return
# deduplicate writes to special channels, last write wins
# channel 都是 `WRITES_IDX_MAP` 中的特殊通道(常用于状态同步或调度控制)
# 去重(以 channel 为键,保留最后一次写入
if all(w[0] in WRITES_IDX_MAP for w in writes):
writes = list({w[0]: w for w in writes}.values())
if task_id == NULL_TASK_ID:
# writes for the null task are accumulated
# 过滤出非 special channel 写入
self.checkpoint_pending_writes = [
w
# w 保存的是 (task_id, channel, value)
for w in self.checkpoint_pending_writes
if w[0] != task_id or w[1] not in WRITES_IDX_MAP
]
writes_to_save: WritesT = [
w[1:] for w in self.checkpoint_pending_writes if w[0] == task_id
] + list(writes)
else:
# remove existing writes for this task
self.checkpoint_pending_writes = [
w for w in self.checkpoint_pending_writes if w[0] != task_id
]
writes_to_save = writes
# save writes
self.checkpoint_pending_writes.extend((task_id, c, v) for c, v in writes)
# checkpointer_put_writes = BaseCheckpointSaver.put_writes 方法
if self.durability != "exit" and self.checkpointer_put_writes is not None:
config = patch_configurable(
self.checkpoint_config,
{
CONFIG_KEY_CHECKPOINT_NS: self.config[CONF].get(
CONFIG_KEY_CHECKPOINT_NS, ""
),
CONFIG_KEY_CHECKPOINT_ID: self.checkpoint["id"],
},
)
if self.checkpointer_put_writes_accepts_task_path:
if hasattr(self, "tasks"):
task = self.tasks.get(task_id)
else:
task = None
self.submit(
self.checkpointer_put_writes,
config,
writes_to_save,
task_id,
task_path_str(task.path) if task else "",
)
else:
self.submit(
self.checkpointer_put_writes,
config,
writes_to_save,
task_id,
)
# output writes
if hasattr(self, "tasks"):
self.output_writes(task_id, writes)
|
4. output_writes
output_writes
用于 对外发出任务写入结果的机制,作用是:
根据 task 的写入结果 writes
,向外部系统或调试工具发送 update(更新)、interrupt(中断)或 task 结果信息。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
|
def output_writes(
self, task_id: str, writes: WritesT, *, cached: bool = False
) -> None:
if task := self.tasks.get(task_id):
# 过滤隐藏任务
if task.config is not None and TAG_HIDDEN in task.config.get(
"tags", EMPTY_SEQ
):
return
# 处理终端
if writes[0][0] == INTERRUPT:
# in loop.py we append a bool to the PUSH task paths to indicate
# whether or not a call was present. If so,
# we don't emit the interrupt as it'll be emitted by the parent
# 说明父图已经处理了中断,当前图不需要再处理
if task.path[0] == PUSH and task.path[-1] is True:
return
interrupts = [
{
INTERRUPT: tuple(
v
for w in writes
if w[0] == INTERRUPT
for v in (w[1] if isinstance(w[1], Sequence) else (w[1],))
)
}
]
# 生成中断事件并发送
self._emit("updates", lambda: iter(interrupts))
elif writes[0][0] != ERROR:
# 普通写入
self._emit(
"updates",
# 通过 `map_output_updates` 映射成 update 格式
map_output_updates,
self.output_keys,
[(task, writes)],
cached,
)
if not cached:
# 如果不是来自缓存(即是新执行产生的 writes)
self._emit(
"tasks",
# 使用 `map_debug_task_results` 转换为调试格式
map_debug_task_results,
(task, writes),
self.stream_keys,
)
|