1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
|
# 大模型厂商,以及他们的模型提供的能力
class LLMFactoriesService(CommonService):
model = LLMFactories
# 租户添加了哪些厂商的哪些大模型
class TenantLLMService(CommonService):
model = TenantLLM
@classmethod
@DB.connection_context()
def get_model_config(cls, tenant_id, llm_type, llm_name=None):
"""
根据 tenant_id(租户)、llm_type(模型类型)、llm_name(可选模型名),
获取该租户对应模型的配置信息(包含 api_key、llm_factory 等)。
"""
from api.db.services.llm_service import LLMService
# 从 Tenant 表获取用户配置的模型
e, tenant = TenantService.get_by_id(tenant_id)
if not e:
raise LookupError("Tenant not found") # 没有这个租户,抛异常
# llm_type 是用户要实例化的模型类型
# 除了 speech2text,优先使用 llm_name 传入的模型
# embedding: BAAI/bge-large-zh-v1.5@BAAI
if llm_type == LLMType.EMBEDDING.value:
mdlnm = tenant.embd_id if not llm_name else llm_name
# speech2text: qwen-audio-asr@Tongyi-Qianwen
elif llm_type == LLMType.SPEECH2TEXT.value:
mdlnm = tenant.asr_id
elif llm_type == LLMType.IMAGE2TEXT.value:
mdlnm = tenant.img2txt_id if not llm_name else llm_name
elif llm_type == LLMType.CHAT.value:
mdlnm = tenant.llm_id if not llm_name else llm_name
elif llm_type == LLMType.RERANK:
mdlnm = tenant.rerank_id if not llm_name else llm_name
elif llm_type == LLMType.TTS:
mdlnm = tenant.tts_id if not llm_name else llm_name
else:
# 类型错误,直接断言失败
assert False, "LLM type error"
# 从 TenantLLM 查询,模型配置
# TenantLLM主键: ("tenant_id", "llm_factory", "llm_name")
# model_config 保存了 api_key, max_token 等信息
model_config = cls.get_api_key(tenant_id, mdlnm)
# 按照 @ 分割模型名和工厂名
mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm)
# 如果没查到(可能是工厂 id 不匹配),再尝试仅用 mdlnm 重新查
if not model_config:
model_config = cls.get_api_key(tenant_id, mdlnm)
if model_config:
# 转为字典,便于后续操作
model_config = model_config.to_dict()
# 查 LLMService 表,补充模型信息(is_tools 是否支持使用工具)
# LLMService 主键: ("fid", "llm_name")
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
if not llm and fid: # 如果查不到(可能 fid 不一致),再试一次不带 fid
llm = LLMService.query(llm_name=mdlnm)
if llm:
model_config["is_tools"] = llm[0].is_tools
# 如果还是没查到 model_config,进入 fallback(兜底逻辑)
if not model_config:
# 特殊情况:embedding / rerank 类模型,可能使用第三方无密钥服务
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
# 针对 Youdao, FastEmbed, BAAI,这些工厂可能不需要 api_key
if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]:
model_config = {"llm_factory": llm[0].fid, "api_key": "", "llm_name": mdlnm, "api_base": ""}
# 如果还没找到,再兜底 flag-embedding 特殊处理(映射到 Tongyi-Qianwen)
if not model_config:
if mdlnm == "flag-embedding":
# 为什么这里传入的是 llm_name?
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "", "llm_name": llm_name, "api_base": ""}
else:
# 如果模型名为空,说明租户没设置
if not mdlnm:
raise LookupError(f"Type of {llm_type} model is not set.")
# 否则说明该模型没有授权
raise LookupError("Model({}) not authorized".format(mdlnm))
return model_config
@classmethod
@DB.connection_context()
def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
# model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "", "llm_name": llm_name, "api_base": "", "is_tools": 1}
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
kwargs.update({"provider": model_config["llm_factory"]})
# 从模型的全局映射中,按照厂商名索引模型类
if llm_type == LLMType.EMBEDDING.value:
if model_config["llm_factory"] not in EmbeddingModel:
return
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
if llm_type == LLMType.RERANK:
if model_config["llm_factory"] not in RerankModel:
return
return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
if llm_type == LLMType.IMAGE2TEXT.value:
if model_config["llm_factory"] not in CvModel:
return
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs)
if llm_type == LLMType.CHAT.value:
if model_config["llm_factory"] not in ChatModel:
return
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"], **kwargs)
if llm_type == LLMType.SPEECH2TEXT:
if model_config["llm_factory"] not in Seq2txtModel:
return
return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"])
if llm_type == LLMType.TTS:
if model_config["llm_factory"] not in TTSModel:
return
return TTSModel[model_config["llm_factory"]](
model_config["api_key"],
model_config["llm_name"],
base_url=model_config["api_base"],
)
|