220 lines
8.3 KiB
Python
220 lines
8.3 KiB
Python
![]() |
from config import AgentId,Secret,corpid,token,encodingAESKey
|
|||
|
#import xml.etree.ElementTree as ET
|
|||
|
from flask import Flask,request
|
|||
|
from WXBizMsgCrypt import WXBizMsgCrypt
|
|||
|
import base64,hashlib
|
|||
|
from Crypto.Cipher import AES
|
|||
|
import time
|
|||
|
from xml.etree import ElementTree
|
|||
|
from keyword_config import KEYWORD_REPLIES
|
|||
|
from ai_service import OllamaService,MessageHandler,DifyService
|
|||
|
from typing import Dict
|
|||
|
from config import OPENAI_API_KEY, OPENAI_MODEL, OPENAI_BASE_URL,DIFY_API_KEY,DIFY_BASE_URL
|
|||
|
from config import FASTAPI_BASE_URL, FASTAPI_API_KEY
|
|||
|
import re
|
|||
|
|
|||
|
from ai_service import OpenAIService,FastGptService
|
|||
|
app =Flask(__name__)
|
|||
|
wxcpt = WXBizMsgCrypt(token,encodingAESKey,corpid)
|
|||
|
#ai_service = OllamaService()
|
|||
|
ai_service = OpenAIService(OPENAI_API_KEY, OPENAI_MODEL, OPENAI_BASE_URL)
|
|||
|
#ai_service = DifyService(DIFY_API_KEY,DIFY_BASE_URL)
|
|||
|
#ai_service = FastGptService(FASTAPI_API_KEY, FASTAPI_BASE_URL)
|
|||
|
|
|||
|
# 检查base64编码后数据位数是否正确
|
|||
|
def check_base64_len(base64_str):
|
|||
|
len_remainder = 4 - (len(base64_str) % 4)
|
|||
|
if len_remainder == 0:
|
|||
|
return base64_str
|
|||
|
else:
|
|||
|
for temp in range(0,len_remainder):
|
|||
|
base64_str = base64_str + "="
|
|||
|
return base64_str
|
|||
|
# 解密并提取消息正文
|
|||
|
def msg_base64_decrypt(ciphertext_base64,key_base64):
|
|||
|
# 处理密文、密钥和iv
|
|||
|
ciphertext_bytes = base64.b64decode(check_base64_len(ciphertext_base64))
|
|||
|
key_bytes = base64.b64decode(check_base64_len(key_base64))
|
|||
|
iv_bytes = key_bytes[:16]
|
|||
|
|
|||
|
# 解密
|
|||
|
decr = AES.new(key_bytes,AES.MODE_CBC,iv_bytes)
|
|||
|
plaintext_bytes = decr.decrypt(ciphertext_bytes)
|
|||
|
|
|||
|
# 截取数据,判断消息正文字节数
|
|||
|
msg_len_bytes = plaintext_bytes[16:20]
|
|||
|
msg_len = int.from_bytes(msg_len_bytes,byteorder='big', signed=False)
|
|||
|
|
|||
|
# 根据消息正文字节数截取消息正文,并转为字符串格式
|
|||
|
msg_bytes = plaintext_bytes[20:20+msg_len]
|
|||
|
msg = str(msg_bytes,encoding='utf-8')
|
|||
|
|
|||
|
return msg
|
|||
|
# 消息体签名校验
|
|||
|
def check_msg_signature(msg_signature,token,timestamp,nonce,echostr):
|
|||
|
# 使用sort()从小到大排序[].sort()是在原地址改值的,所以如果使用li_s = li.sort(),li_s是空的,li的值变为排序后的值]
|
|||
|
li = [token,timestamp,nonce,echostr]
|
|||
|
li.sort()
|
|||
|
# 将排序结果拼接
|
|||
|
li_str = li[0]+li[1]+li[2]+li[3]
|
|||
|
|
|||
|
# 计算SHA-1值
|
|||
|
sha1 = hashlib.sha1()
|
|||
|
# update()要指定加密字符串字符代码,不然要报错:
|
|||
|
# "Unicode-objects must be encoded before hashing"
|
|||
|
sha1.update(li_str.encode("utf8"))
|
|||
|
sha1_result = sha1.hexdigest()
|
|||
|
|
|||
|
# 比较并返回比较结果
|
|||
|
if sha1_result == msg_signature:
|
|||
|
return True
|
|||
|
else:
|
|||
|
return False
|
|||
|
@app.route('/hello', methods=['GET'])
|
|||
|
def hello():
|
|||
|
return "Hello, Flask!"
|
|||
|
|
|||
|
@app.route('/', methods=['GET', 'POST'])
|
|||
|
def reply():
|
|||
|
try:
|
|||
|
# 处理GET请求(验证URL)部分保持不变
|
|||
|
if request.method == 'GET':
|
|||
|
msg_signature = request.args.to_dict().get("msg_signature")
|
|||
|
timestamp = request.args.to_dict().get("timestamp")
|
|||
|
nonce = request.args.to_dict().get("nonce")
|
|||
|
echostr = request.args.to_dict().get("echostr")
|
|||
|
print(msg_signature,timestamp,nonce,echostr)
|
|||
|
# 获取消息体签名校验结果
|
|||
|
check_result = check_msg_signature(msg_signature,token,timestamp,nonce,echostr)
|
|||
|
if check_result:
|
|||
|
decrypt_result = msg_base64_decrypt(echostr,encodingAESKey)
|
|||
|
print("通过")
|
|||
|
return decrypt_result
|
|||
|
else:
|
|||
|
return ""
|
|||
|
|
|||
|
# 处理POST请求(消息处理)
|
|||
|
elif request.method == 'POST':
|
|||
|
try:
|
|||
|
# 获取并解析原始数据部分保持不变
|
|||
|
raw_data = request.get_data()
|
|||
|
if isinstance(raw_data, bytes):
|
|||
|
raw_data = raw_data.decode('utf-8')
|
|||
|
|
|||
|
# 获取参数验证部分保持不变
|
|||
|
msg_signature = request.args.get('msg_signature', '')
|
|||
|
timestamp = str(request.args.get('timestamp', ''))
|
|||
|
nonce = request.args.get('nonce', '')
|
|||
|
|
|||
|
if not all([msg_signature, timestamp, nonce]):
|
|||
|
return "缺少必要参数", 400
|
|||
|
|
|||
|
# 解密消息部分保持不变
|
|||
|
ret, xml_content = wxcpt.DecryptMsg(
|
|||
|
raw_data,
|
|||
|
msg_signature,
|
|||
|
timestamp,
|
|||
|
nonce
|
|||
|
)
|
|||
|
|
|||
|
if ret != 0:
|
|||
|
print(f"[ERROR] 解密失败,错误码: {ret}")
|
|||
|
return "消息解密失败", 500
|
|||
|
|
|||
|
# 解析XML
|
|||
|
xml_tree = ElementTree.fromstring(xml_content)
|
|||
|
|
|||
|
def get_text(element):
|
|||
|
if element is None:
|
|||
|
return ''
|
|||
|
text = element.text
|
|||
|
if text is None:
|
|||
|
return ''
|
|||
|
return text.decode('utf-8') if isinstance(text, bytes) else str(text)
|
|||
|
|
|||
|
# 获取用户ID和消息内容
|
|||
|
from_user_name = get_text(xml_tree.find('FromUserName'))
|
|||
|
msg_content = get_text(xml_tree.find('Content'))
|
|||
|
|
|||
|
# 处理刷新对话的关键词
|
|||
|
refresh_keywords = ["new", "refresh", "00", "restart", "刷新", "新话题", "退下", "结束", "over"]
|
|||
|
if msg_content.strip().lower() in refresh_keywords:
|
|||
|
if hasattr(ai_service, 'conversation_history') and from_user_name in ai_service.conversation_history:
|
|||
|
del ai_service.conversation_history[from_user_name]
|
|||
|
re_text = "会话已重置"
|
|||
|
else:
|
|||
|
# 使用用户ID生成回复
|
|||
|
msg_handler = MessageHandler(KEYWORD_REPLIES, ai_service)
|
|||
|
re_text = msg_handler.get_reply(msg_content)
|
|||
|
if isinstance(ai_service, OpenAIService):
|
|||
|
re_text = ai_service.generate_response(msg_content, user_id=from_user_name)
|
|||
|
|
|||
|
# 处理回复文本
|
|||
|
re_text = process_text(re_text)
|
|||
|
|
|||
|
# 构造回复消息
|
|||
|
reply_dict = {
|
|||
|
'ToUserName': from_user_name,
|
|||
|
'FromUserName': get_text(xml_tree.find('ToUserName')),
|
|||
|
'CreateTime': str(int(time.time())),
|
|||
|
'MsgType': 'text',
|
|||
|
'Content': re_text
|
|||
|
}
|
|||
|
|
|||
|
# 构造XML回复并加密
|
|||
|
reply_msg = ResponseMessage(reply_dict).xml
|
|||
|
|
|||
|
if isinstance(reply_msg, bytes):
|
|||
|
reply_msg = reply_msg.decode('utf-8')
|
|||
|
|
|||
|
ret, encrypt_xml = wxcpt.EncryptMsg(reply_msg, nonce, timestamp)
|
|||
|
|
|||
|
if ret != 0:
|
|||
|
print(f"[ERROR] 加密回复消息失败,错误码: {ret}")
|
|||
|
return "加密回复消息失败", 500
|
|||
|
|
|||
|
return encrypt_xml
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
print(f"[ERROR] 处理消息时发生异常: {str(e)}")
|
|||
|
return "服务器内部错误", 500
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
print(f"处理异常: {str(e)}")
|
|||
|
return "服务器错误", 500
|
|||
|
|
|||
|
return "success"
|
|||
|
|
|||
|
# ResponseMessage类修改
|
|||
|
class ResponseMessage(object):
|
|||
|
def __init__(self, dict_data):
|
|||
|
self.dict_data = {k: str(v) for k, v in dict_data.items()}
|
|||
|
|
|||
|
@property
|
|||
|
def xml(self):
|
|||
|
xml = "<xml>"
|
|||
|
for k, v in self.dict_data.items():
|
|||
|
xml += f"<{k}><![CDATA[{v}]]></{k}>"
|
|||
|
xml += "</xml>"
|
|||
|
return xml
|
|||
|
|
|||
|
|
|||
|
def process_text(text):
|
|||
|
# 1. 把所有的 "\n" 替换为 "__lineFeed__"
|
|||
|
text = text.replace('\n', '__lineFeed__')
|
|||
|
|
|||
|
# 2. 删除所有换行符(包括 \r 等)
|
|||
|
text = re.sub(r'[^\S]+', '', text)
|
|||
|
|
|||
|
# 3. 把所有的 "__lineFeed__" 替换回 "\n"
|
|||
|
text = text.replace('__lineFeed__', '\n')
|
|||
|
# 4. 删除多余的换行符
|
|||
|
text = re.sub(r'\n+', '\n', text)
|
|||
|
return text.strip()
|
|||
|
|
|||
|
|
|||
|
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
app.run(host='0.0.0.0', port=5000, debug=True)
|