wework_api/callback1.py
2025-02-21 08:04:44 +08:00

220 lines
8.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from config import AgentId,Secret,corpid,token,encodingAESKey
#import xml.etree.ElementTree as ET
from flask import Flask,request
from WXBizMsgCrypt import WXBizMsgCrypt
import base64,hashlib
from Crypto.Cipher import AES
import time
from xml.etree import ElementTree
from keyword_config import KEYWORD_REPLIES
from ai_service import OllamaService,MessageHandler,DifyService
from typing import Dict
from config import OPENAI_API_KEY, OPENAI_MODEL, OPENAI_BASE_URL,DIFY_API_KEY,DIFY_BASE_URL
from config import FASTAPI_BASE_URL, FASTAPI_API_KEY
import re
from ai_service import OpenAIService,FastGptService
app =Flask(__name__)
wxcpt = WXBizMsgCrypt(token,encodingAESKey,corpid)
#ai_service = OllamaService()
ai_service = OpenAIService(OPENAI_API_KEY, OPENAI_MODEL, OPENAI_BASE_URL)
#ai_service = DifyService(DIFY_API_KEY,DIFY_BASE_URL)
#ai_service = FastGptService(FASTAPI_API_KEY, FASTAPI_BASE_URL)
# 检查base64编码后数据位数是否正确
def check_base64_len(base64_str):
len_remainder = 4 - (len(base64_str) % 4)
if len_remainder == 0:
return base64_str
else:
for temp in range(0,len_remainder):
base64_str = base64_str + "="
return base64_str
# 解密并提取消息正文
def msg_base64_decrypt(ciphertext_base64,key_base64):
# 处理密文、密钥和iv
ciphertext_bytes = base64.b64decode(check_base64_len(ciphertext_base64))
key_bytes = base64.b64decode(check_base64_len(key_base64))
iv_bytes = key_bytes[:16]
# 解密
decr = AES.new(key_bytes,AES.MODE_CBC,iv_bytes)
plaintext_bytes = decr.decrypt(ciphertext_bytes)
# 截取数据,判断消息正文字节数
msg_len_bytes = plaintext_bytes[16:20]
msg_len = int.from_bytes(msg_len_bytes,byteorder='big', signed=False)
# 根据消息正文字节数截取消息正文,并转为字符串格式
msg_bytes = plaintext_bytes[20:20+msg_len]
msg = str(msg_bytes,encoding='utf-8')
return msg
# 消息体签名校验
def check_msg_signature(msg_signature,token,timestamp,nonce,echostr):
# 使用sort()从小到大排序[].sort()是在原地址改值的所以如果使用li_s = li.sort()li_s是空的li的值变为排序后的值]
li = [token,timestamp,nonce,echostr]
li.sort()
# 将排序结果拼接
li_str = li[0]+li[1]+li[2]+li[3]
# 计算SHA-1值
sha1 = hashlib.sha1()
# update()要指定加密字符串字符代码,不然要报错:
# "Unicode-objects must be encoded before hashing"
sha1.update(li_str.encode("utf8"))
sha1_result = sha1.hexdigest()
# 比较并返回比较结果
if sha1_result == msg_signature:
return True
else:
return False
@app.route('/hello', methods=['GET'])
def hello():
return "Hello, Flask!"
@app.route('/', methods=['GET', 'POST'])
def reply():
try:
# 处理GET请求(验证URL)部分保持不变
if request.method == 'GET':
msg_signature = request.args.to_dict().get("msg_signature")
timestamp = request.args.to_dict().get("timestamp")
nonce = request.args.to_dict().get("nonce")
echostr = request.args.to_dict().get("echostr")
print(msg_signature,timestamp,nonce,echostr)
# 获取消息体签名校验结果
check_result = check_msg_signature(msg_signature,token,timestamp,nonce,echostr)
if check_result:
decrypt_result = msg_base64_decrypt(echostr,encodingAESKey)
print("通过")
return decrypt_result
else:
return ""
# 处理POST请求(消息处理)
elif request.method == 'POST':
try:
# 获取并解析原始数据部分保持不变
raw_data = request.get_data()
if isinstance(raw_data, bytes):
raw_data = raw_data.decode('utf-8')
# 获取参数验证部分保持不变
msg_signature = request.args.get('msg_signature', '')
timestamp = str(request.args.get('timestamp', ''))
nonce = request.args.get('nonce', '')
if not all([msg_signature, timestamp, nonce]):
return "缺少必要参数", 400
# 解密消息部分保持不变
ret, xml_content = wxcpt.DecryptMsg(
raw_data,
msg_signature,
timestamp,
nonce
)
if ret != 0:
print(f"[ERROR] 解密失败,错误码: {ret}")
return "消息解密失败", 500
# 解析XML
xml_tree = ElementTree.fromstring(xml_content)
def get_text(element):
if element is None:
return ''
text = element.text
if text is None:
return ''
return text.decode('utf-8') if isinstance(text, bytes) else str(text)
# 获取用户ID和消息内容
from_user_name = get_text(xml_tree.find('FromUserName'))
msg_content = get_text(xml_tree.find('Content'))
# 处理刷新对话的关键词
refresh_keywords = ["new", "refresh", "00", "restart", "刷新", "新话题", "退下", "结束", "over"]
if msg_content.strip().lower() in refresh_keywords:
if hasattr(ai_service, 'conversation_history') and from_user_name in ai_service.conversation_history:
del ai_service.conversation_history[from_user_name]
re_text = "会话已重置"
else:
# 使用用户ID生成回复
msg_handler = MessageHandler(KEYWORD_REPLIES, ai_service)
re_text = msg_handler.get_reply(msg_content)
if isinstance(ai_service, OpenAIService):
re_text = ai_service.generate_response(msg_content, user_id=from_user_name)
# 处理回复文本
re_text = process_text(re_text)
# 构造回复消息
reply_dict = {
'ToUserName': from_user_name,
'FromUserName': get_text(xml_tree.find('ToUserName')),
'CreateTime': str(int(time.time())),
'MsgType': 'text',
'Content': re_text
}
# 构造XML回复并加密
reply_msg = ResponseMessage(reply_dict).xml
if isinstance(reply_msg, bytes):
reply_msg = reply_msg.decode('utf-8')
ret, encrypt_xml = wxcpt.EncryptMsg(reply_msg, nonce, timestamp)
if ret != 0:
print(f"[ERROR] 加密回复消息失败,错误码: {ret}")
return "加密回复消息失败", 500
return encrypt_xml
except Exception as e:
print(f"[ERROR] 处理消息时发生异常: {str(e)}")
return "服务器内部错误", 500
except Exception as e:
print(f"处理异常: {str(e)}")
return "服务器错误", 500
return "success"
# ResponseMessage类修改
class ResponseMessage(object):
def __init__(self, dict_data):
self.dict_data = {k: str(v) for k, v in dict_data.items()}
@property
def xml(self):
xml = "<xml>"
for k, v in self.dict_data.items():
xml += f"<{k}><![CDATA[{v}]]></{k}>"
xml += "</xml>"
return xml
def process_text(text):
# 1. 把所有的 "\n" 替换为 "__lineFeed__"
text = text.replace('\n', '__lineFeed__')
# 2. 删除所有换行符(包括 \r 等)
text = re.sub(r'[^\S]+', '', text)
# 3. 把所有的 "__lineFeed__" 替换回 "\n"
text = text.replace('__lineFeed__', '\n')
# 4. 删除多余的换行符
text = re.sub(r'\n+', '\n', text)
return text.strip()
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000, debug=True)