wework_api/ai_service.py
2025-02-21 08:04:44 +08:00

452 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# ai_service.py 优化版本
import requests
import logging
from typing import Dict, Optional
from functools import lru_cache
from config import OLLAMA_MODEL, OPENAI_API_KEY, OPENAI_MODEL,OPENAI_BASE_URL # 确保config.py中有这些配置
import time
# 配置日志
logging.basicConfig(
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
level=logging.INFO
)
logger = logging.getLogger(__name__)
class AIService:
"""AI服务抽象基类"""
def generate_response(self, prompt: str) -> str:
"""
生成AI回复
:param prompt: 用户输入的提示文本
:return: 生成的回复文本
"""
raise NotImplementedError
class OllamaService(AIService):
"""Ollama本地模型服务实现"""
def __init__(
self,
endpoint: str = "http://localhost:11434/api/generate",
model: str = OLLAMA_MODEL,
timeout: int = 10
):
self.endpoint = endpoint
self.default_model = model
self.timeout = timeout
@lru_cache(maxsize=100)
def generate_response(self, prompt: str) -> str:
try:
response = requests.post(
self.endpoint,
json={
'model': self.default_model,
'prompt': prompt,
'stream': False
},
timeout=self.timeout
)
response.raise_for_status()
result = response.json()
return result.get('response', '收到您的消息')
except requests.exceptions.ConnectionError:
logger.error("无法连接Ollama服务请检查服务状态")
return "本地模型服务未启动"
except requests.exceptions.Timeout:
logger.warning("Ollama请求超时")
return "响应超时,请简化问题"
except Exception as e:
logger.error(f"Ollama处理异常: {str(e)}", exc_info=True)
return "本地模型服务异常"
class DifyService(AIService):
"""Dify API客户端封装"""
def __init__(
self,
api_key: str,
base_url: str = "http://localhost/v1",
timeout: int = 100,
default_user: str = "system"
):
"""
:param api_key: 应用API密钥
:param base_url: API基础地址 (默认: http://localhost/v1)
:param timeout: 请求超时时间 (秒)
:param default_user: 默认用户标识
"""
self._validate_config(api_key, base_url)
self.api_key = api_key
self.base_url = base_url.rstrip('/')
self.timeout = timeout
self.default_user = default_user
self.logger = logging.getLogger(self.__class__.__name__)
self.session = requests.Session()
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
def _validate_config(self, api_key: str, base_url: str):
"""配置校验"""
if not api_key.startswith('app-'):
raise ValueError("Invalid API key format")
if not base_url.startswith(('http://', 'https://')):
raise ValueError("Invalid base URL protocol")
@lru_cache(maxsize=100)
def generate_response(
self,
query: str,
response_mode: str = "blocking",
conversation_id: Optional[str] = None,
user: Optional[str] = None,
**additional_inputs
) -> str:
"""
生成对话响应
:param query: 用户查询内容
:param response_mode: 响应模式 (blocking/streaming)
:param conversation_id: 会话ID (为空时创建新会话)
:param user: 用户标识 (默认使用初始化参数"""
try:
response = requests.post(
f"{self.base_url}/chat-messages",
headers=self.headers,
json={
"inputs": {},
"query": query,
"response_mode": "blocking",
"conversation_id": "",
"user": "abc-123"
},
timeout=self.timeout
)
response.raise_for_status()
#response.json()["answer"]
return response.json()["answer"]
except requests.exceptions.ConnectionError:
logger.error("无法连接dify服务请检查服务状态")
return "本地模型服务未启动"
except requests.exceptions.Timeout:
logger.warning("dify请求超时")
return "响应超时,请简化问题"
except Exception as e:
logger.error(f"dify处理异常: {str(e)}", exc_info=True)
return "本地模型服务异常"
class OpenAIService(AIService):
"""OpenAI官方接口服务实现"""
def __init__(
self,
api_key: str = OPENAI_API_KEY,
model: str = OPENAI_MODEL,
base_url: str = OPENAI_BASE_URL,
timeout: int = 15,
temperature: float = 0.7,
max_conversation_length: int = 10,
max_time_gap: int = 30
):
self._validate_config(api_key, model)
self.api_key = api_key
self.default_model = model
self.base_url = base_url
self.timeout = timeout
self.temperature = temperature
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
# 新增会话管理相关属性
self.conversation_history = {}
self.max_conversation_length = max_conversation_length
self.max_time_gap = max_time_gap
self.system_prompt = '''你是路桥设计院智能助手。你的使命是尽可能地用详尽的、温暖的、友善的话语帮助企业员工,在各种方面提供帮助和支持。无论我需要什么帮助或建议,你都会尽力提供详尽信息。'''
def _validate_config(self, api_key: str, model: str) -> None:
"""
验证OpenAI配置参数
:param api_key: OpenAI API密钥
:param model: 模型名称
:raises ValueError: 当配置参数无效时抛出
"""
if not api_key:
raise ValueError("OpenAI API密钥不能为空")
if not model:
raise ValueError("模型名称不能为空")
if not isinstance(api_key, str) or not isinstance(model, str):
raise ValueError("API密钥和模型名称必须是字符串类型")
# 可选验证API密钥格式
# if not api_key.startswith('sk-'):
# raise ValueError("无效的OpenAI API密钥格式")
def _manage_conversation_history(self, user_id: str, message: str):
"""管理会话历史"""
current_timestamp = int(time.time())
# 检查会话是否超时
if (user_id in self.conversation_history and
current_timestamp - self.conversation_history[user_id]["last_timestamp"] >= self.max_time_gap * 60):
del self.conversation_history[user_id]
# 初始化或更新会话历史
if user_id not in self.conversation_history:
self.conversation_history[user_id] = {
"messages": [],
"last_timestamp": current_timestamp
}
else:
self.conversation_history[user_id]["last_timestamp"] = current_timestamp
# 限制会话历史长度
if len(self.conversation_history[user_id]["messages"]) > self.max_conversation_length:
self.conversation_history[user_id]["messages"] = (
self.conversation_history[user_id]["messages"][-self.max_conversation_length:]
)
# 添加新消息
self.conversation_history[user_id]["messages"].append({
"role": "user",
"content": message
})
def generate_response(self, prompt: str, user_id: str = "default_user") -> str:
"""
生成带有会话历史的回复
:param prompt: 用户输入的提示文本
:param user_id: 用户标识符
:return: 生成的回复文本
"""
try:
self._manage_conversation_history(user_id, prompt)
# 构建完整的消息历史
messages = [{"role": "system", "content": self.system_prompt}]
messages.extend(self.conversation_history[user_id]["messages"])
response = requests.post(
f"{self.base_url}/chat/completions",
headers=self.headers,
json={
"model": self.default_model,
"messages": messages,
"temperature": self.temperature
},
timeout=self.timeout
)
response.raise_for_status()
result = response.json()
if 'choices' not in result:
logger.error(f"OpenAI响应格式异常: {result}")
return "响应解析失败"
response_text = result['choices'][0]['message']['content']
# 保存助手的回复到会话历史
self.conversation_history[user_id]["messages"].append({
"role": "assistant",
"content": response_text
})
return response_text
except Exception as e:
logger.error(f"OpenAI处理异常: {str(e)}", exc_info=True)
return "服务暂时不可用"
class FastGptService(AIService):
"""FastGPT API客户端封装"""
def __init__(
self,
api_key: str,
base_url: str = "http://localhost:3000/api/v1",
timeout: int = 30,
max_conversation_length: int = 10,
max_time_gap: int = 30
):
"""
初始化FastGPT服务
:param api_key: FastGPT API密钥
:param base_url: API基础地址
:param timeout: 请求超时时间(秒)
:param max_conversation_length: 最大会话长度
:param max_time_gap: 会话超时时间(分钟)
"""
self._validate_config(api_key, base_url)
self.api_key = api_key
self.base_url = base_url.rstrip('/')
self.timeout = timeout
self.headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
# 会话管理相关属性
self.conversation_history = {}
self.max_conversation_length = max_conversation_length
self.max_time_gap = max_time_gap
def _validate_config(self, api_key: str, base_url: str):
"""配置校验"""
if not api_key:
raise ValueError("FastGPT API密钥不能为空")
if not base_url.startswith(('http://', 'https://')):
raise ValueError("无效的基础URL协议")
def _manage_conversation_history(self, user_id: str, message: str):
"""管理会话历史"""
current_timestamp = int(time.time())
# 检查会话是否超时
if (user_id in self.conversation_history and
current_timestamp - self.conversation_history[user_id]["last_timestamp"] >= self.max_time_gap * 60):
del self.conversation_history[user_id]
# 初始化或更新会话历史
if user_id not in self.conversation_history:
self.conversation_history[user_id] = {
"chat_id": f"chat_{user_id}_{current_timestamp}",
"messages": [],
"last_timestamp": current_timestamp
}
else:
self.conversation_history[user_id]["last_timestamp"] = current_timestamp
# 限制会话历史长度
if len(self.conversation_history[user_id]["messages"]) >= self.max_conversation_length:
self.conversation_history[user_id]["messages"] = (
self.conversation_history[user_id]["messages"][-self.max_conversation_length:]
)
# 添加新消息
msg_id = f"msg_{current_timestamp}"
self.conversation_history[user_id]["messages"].append({
"role": "user",
"content": message
})
return msg_id
def generate_response(
self,
prompt: str,
user_id: str = "default_user",
variables: dict = None,
detail: bool = False
) -> str:
"""
生成回复
:param prompt: 用户输入
:param user_id: 用户标识
:param variables: 模块变量
:param detail: 是否返回详细信息
:return: 生成的回复文本
"""
try:
msg_id = self._manage_conversation_history(user_id, prompt)
chat_info = self.conversation_history.get(user_id, {})
payload = {
"chatId": chat_info.get("chat_id"),
"stream": False,
"detail": detail,
"responseChatItemId": msg_id,
"variables": variables or {},
"messages": [{"role": "user", "content": prompt}]
}
response = requests.post(
f"{self.base_url}/chat/completions",
headers=self.headers,
json=payload,
timeout=self.timeout
)
response.raise_for_status()
result = response.json()
# 处理响应
if detail:
response_text = result.get("responseData", {}).get("content", "响应解析失败")
else:
response_text = result.get("content", "响应解析失败")
# 保存助手回复到会话历史
if user_id in self.conversation_history:
self.conversation_history[user_id]["messages"].append({
"role": "assistant",
"content": response_text
})
return response_text
except requests.exceptions.ConnectionError:
logger.error("无法连接FastGPT服务")
return "服务连接失败"
except requests.exceptions.Timeout:
logger.warning("FastGPT请求超时")
return "响应超时"
except Exception as e:
logger.error(f"FastGPT处理异常: {str(e)}", exc_info=True)
return "服务暂时不可用"
class HybridAIService(AIService):
"""混合AI服务故障转移模式"""
def __init__(self, services: list[AIService]):
self.services = services
def generate_response(self, prompt: str) -> str:
for service in self.services:
try:
return service.generate_response(prompt)
except Exception as e:
logger.warning(f"{type(service).__name__} 服务失败: {str(e)}")
continue
return "所有AI服务不可用"
class MessageHandler:
"""智能消息处理器"""
def __init__(self, keyword_config: Dict, ai_service: AIService):
"""
:param keyword_config: 关键词配置字典
:param ai_service: AI服务实例
"""
self.keyword_config = keyword_config
self.ai_service = ai_service
def get_reply(self, content: str) -> str:
# 优先全匹配关键词
for rule in self.keyword_config.values():
if any(kw == content.strip() for kw in rule['keywords']):
return rule['reply']
# 其次模糊匹配
for rule in self.keyword_config.values():
if any(kw in content for kw in rule['keywords']):
return rule['reply']
# 无匹配时调用AI
return self.ai_service.generate_response(content)