1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
|
import redis
import json
import time
import argparse
import logging
# 配置日志
logging.basicConfig(level=logging.INFO)
class NodeMessage:
def __init__(self, msg_id, node, stamp):
self.msg_id = msg_id
self.node = node
self.stamp = stamp
class RedisBarrier:
def __init__(self, name, nodes, timeout=10):
"""
:param name: barrier 名字
:param nodes: 所有需要参与的节点名列表
:param timeout: 每一轮 barrier 超时时间(秒)
"""
self.REDIS = None
self.name = name
self.chan = f"barrier:{name}:chan"
self.nodes = nodes
self.timeout = timeout
self.__open__()
def __open__(self):
self.REDIS = redis.StrictRedis(
host="192.168.2.41",
port=6379,
db=1,
password="infini_rag_flow",
decode_responses=True,
)
def send_message(self, queue_name, node_name, timestamp=None):
"""
向Redis Stream发送节点到达消息
:param queue_name: stream名称
:param node_name: 节点名称
:param timestamp: 时间戳,不提供则使用当前时间
:return: 消息ID
"""
if timestamp is None:
timestamp = int(time.time())
try:
# 使用XADD命令向stream添加消息
msg_id = self.REDIS.xadd(
name=queue_name,
fields={
"node": node_name,
"stamp": timestamp
}
)
print(f"发送消息: {node_name} -> {queue_name}, 消息ID: {msg_id}")
return msg_id
except Exception as e:
print(f"发送消息失败: {e}")
return None
def ack_message(self, queue_name, group_name, msg_id):
"""
确认处理完成的消息
:param queue_name: stream名称
:param group_name: 消费者组名
:param msg_id: 消息ID
:return: 成功确认的消息数量
"""
try:
count = self.REDIS.xack(queue_name, group_name, msg_id)
print(f"ACK消息: {msg_id} -> {queue_name}:{group_name}")
return count
except Exception as e:
print(f"ACK消息失败: {e}")
return 0
def queue_consumer(self, queue_name, group_name, consumer_name, timeout, msg_id=b">"):
"""https://redis.io/docs/latest/commands/xreadgroup/"""
try:
try:
group_info = self.REDIS.xinfo_groups(queue_name)
if not any(gi["name"] == group_name for gi in group_info):
self.REDIS.xgroup_create(queue_name, group_name, id="0", mkstream=True)
except redis.exceptions.ResponseError as e:
if "no such key" in str(e).lower():
self.REDIS.xgroup_create(queue_name, group_name, id="0", mkstream=True)
elif "busygroup" in str(e).lower():
logging.warning("Group already exists, continue.")
pass
else:
raise
args = {
"groupname": group_name,
"consumername": consumer_name,
"count": 1,
"block": int(timeout * 1000) ,
"streams": {queue_name: msg_id},
}
messages = self.REDIS.xreadgroup(**args)
if not messages:
return None
stream, element_list = messages[0]
if not element_list:
return None
msg_id, payload = element_list[0]
res: NodeMessage = NodeMessage(msg_id, payload["node"], payload["stamp"])
return res
except Exception as e:
if str(e) == 'no such key':
pass
else:
print(
"RedisDB.queue_consumer "
+ str(queue_name)
+ " got exception: "
+ str(e)
)
self.__open__()
return None
def cleanup_stream(self, queue_name, max_len=1000):
"""
清理stream,保持最大长度
:param queue_name: stream名称
:param max_len: 最大保留消息数
"""
try:
# 使用XTRIM命令修剪stream
trimmed = self.REDIS.xtrim(queue_name, maxlen=max_len, approximate=True)
if trimmed > 0:
print(f"清理stream {queue_name}: 删除了 {trimmed} 条消息")
return trimmed
except Exception as e:
print(f"清理stream失败: {e}")
return 0
def get_stream_info(self, queue_name):
"""
获取stream信息
:param queue_name: stream名称
:return: stream信息字典
"""
try:
info = self.REDIS.xinfo_stream(queue_name)
print(f"Stream {queue_name} 信息:")
print(f" 长度: {info.get('length', 0)}")
print(f" 第一条消息: {info.get('first-entry', 'N/A')}")
print(f" 最后一条消息: {info.get('last-entry', 'N/A')}")
return info
except Exception as e:
print(f"获取stream信息失败: {e}")
return {}
def wait(self, node_name):
"""节点等待 barrier"""
# 生成本节点到达消息
stamp = int(time.time())
expert = set(self.nodes)
# 向stream发送节点到达消息
self.send_message(self.chan, node_name, stamp)
# 立即处理自己的消息
expert.discard(node_name)
print(f"节点 {node_name} 已到达 (自己)")
end = time.time() + self.timeout
start = time.time() - self.timeout
while True:
remain = end - time.time()
if remain <= 0:
break
# 等待剩余时间,直到消息或超时
m = self.queue_consumer(queue_name=self.chan, group_name=node_name, consumer_name=node_name,timeout=remain)
if m:
try:
recv_node_name = m.node
node_stamp = int(m.stamp) # 确保转换为整数
print(f"收到消息: {recv_node_name} {node_stamp}")
# 检查时间戳是否在有效范围内
if start <= node_stamp <= end:
expert.discard(recv_node_name)
print(f"节点 {recv_node_name} 已到达")
self.ack_message(self.chan, node_name, m.msg_id)
# 如果所有节点都到齐
if len(expert) == 0:
return {"status": "OK", "lose": []}
else:
print(f"节点 {recv_node_name} 时间戳过期: {node_stamp} (范围: {start}-{end})")
except (ValueError, KeyError) as e:
print(f"解析消息失败: {e}")
# 超时:返回未到达的节点
lose = list(expert)
return {"status": "FAIL", "lose": lose}
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--node", required=True, help="当前节点名称")
parser.add_argument("--nodes", nargs="+", default=["nodeA", "nodeB", "nodeC"],
help="所有参与的节点列表")
parser.add_argument("--timeout", type=int, default=10, help="barrier 超时时间(秒)")
parser.add_argument("--info", action="store_true", help="显示stream信息")
parser.add_argument("--cleanup", action="store_true", help="清理stream")
args = parser.parse_args()
barrier = RedisBarrier("test", args.nodes, args.timeout)
# 如果请求显示信息
if args.info:
barrier.get_stream_info(barrier.chan)
return
# 如果请求清理
if args.cleanup:
barrier.cleanup_stream(barrier.chan)
return
print(f"[{args.node}] 等待 barrier...")
print(f"[{args.node}] 参与节点: {args.nodes}")
print(f"[{args.node}] 超时时间: {args.timeout}秒")
start_time = time.time()
result = barrier.wait(args.node)
end_time = time.time()
print(f"[{args.node}] 结果: {result}")
print(f"[{args.node}] 耗时: {end_time - start_time:.2f}秒")
if __name__ == "__main__":
main()
|