137 lines
4.7 KiB
Python
137 lines
4.7 KiB
Python
|
import os
|
|||
|
import re
|
|||
|
import argparse
|
|||
|
import openai
|
|||
|
from typing import List, Dict, Any
|
|||
|
|
|||
|
# 设置OpenAI API
|
|||
|
openai.api_key = "sk-no-key-required"
|
|||
|
openai.api_base = "http://localhost:1234/v1" # LM-Studio默认地址
|
|||
|
|
|||
|
def read_markdown_file(file_path: str) -> str:
|
|||
|
"""读取Markdown文件内容"""
|
|||
|
try:
|
|||
|
with open(file_path, 'r', encoding='utf-8') as file:
|
|||
|
return file.read()
|
|||
|
except Exception as e:
|
|||
|
print(f"读取文件时出错: {e}")
|
|||
|
return ""
|
|||
|
|
|||
|
def split_by_headers(content: str, max_length: int) -> List[str]:
|
|||
|
"""根据一级标题(#)将内容分块,确保每个块小于max_length"""
|
|||
|
# 查找所有一级标题的位置
|
|||
|
pattern = r'^# .+$'
|
|||
|
headers = [(m.start(), m.group()) for m in re.finditer(pattern, content, re.MULTILINE)]
|
|||
|
|
|||
|
if not headers:
|
|||
|
# 如果没有一级标题,将整个内容作为一个块
|
|||
|
return [content] if len(content) <= max_length else chunk_content(content, max_length)
|
|||
|
|
|||
|
chunks = []
|
|||
|
current_chunk = ""
|
|||
|
last_pos = 0
|
|||
|
|
|||
|
# 遍历所有标题
|
|||
|
for i, (pos, header) in enumerate(headers):
|
|||
|
# 获取当前标题到下一个标题之间的内容
|
|||
|
if i == 0 and pos > 0:
|
|||
|
# 处理文件开头到第一个标题之间的内容
|
|||
|
current_chunk = content[:pos]
|
|||
|
|
|||
|
# 当前标题的内容
|
|||
|
section_end = headers[i+1][0] if i+1 < len(headers) else len(content)
|
|||
|
section = content[pos:section_end]
|
|||
|
|
|||
|
# 检查添加当前部分是否会超过最大长度
|
|||
|
if len(current_chunk) + len(section) <= max_length:
|
|||
|
current_chunk += section
|
|||
|
else:
|
|||
|
# 如果当前块不为空,添加到chunks
|
|||
|
if current_chunk:
|
|||
|
chunks.append(current_chunk)
|
|||
|
|
|||
|
# 如果单个部分超过最大长度,需要进一步分割
|
|||
|
if len(section) > max_length:
|
|||
|
sub_chunks = chunk_content(section, max_length)
|
|||
|
chunks.extend(sub_chunks)
|
|||
|
current_chunk = ""
|
|||
|
else:
|
|||
|
current_chunk = section
|
|||
|
|
|||
|
last_pos = section_end
|
|||
|
|
|||
|
# 添加最后一个块
|
|||
|
if current_chunk:
|
|||
|
chunks.append(current_chunk)
|
|||
|
|
|||
|
return chunks
|
|||
|
|
|||
|
def chunk_content(content: str, max_length: int) -> List[str]:
|
|||
|
"""将内容分割成固定大小的块"""
|
|||
|
chunks = []
|
|||
|
for i in range(0, len(content), max_length):
|
|||
|
chunks.append(content[i:i + max_length])
|
|||
|
return chunks
|
|||
|
|
|||
|
def process_chunk_with_llm(chunk: str, model: str = "gpt-3.5-turbo") -> str:
|
|||
|
"""使用LLM处理每个块"""
|
|||
|
try:
|
|||
|
response = openai.ChatCompletion.create(
|
|||
|
model=model,
|
|||
|
messages=[
|
|||
|
{"role": "system", "content": "你是一个有用的助手,请处理以下Markdown内容。"},
|
|||
|
{"role": "user", "content": chunk}
|
|||
|
],
|
|||
|
temperature=0.7,
|
|||
|
)
|
|||
|
return response.choices[0].message.content
|
|||
|
except Exception as e:
|
|||
|
print(f"处理块时出错: {e}")
|
|||
|
return chunk # 出错时返回原始内容
|
|||
|
|
|||
|
def save_markdown_file(content: str, output_path: str) -> None:
|
|||
|
"""保存处理后的Markdown内容到文件"""
|
|||
|
try:
|
|||
|
with open(output_path, 'w', encoding='utf-8') as file:
|
|||
|
file.write(content)
|
|||
|
print(f"已保存处理后的文件到: {output_path}")
|
|||
|
except Exception as e:
|
|||
|
print(f"保存文件时出错: {e}")
|
|||
|
|
|||
|
def main():
|
|||
|
parser = argparse.ArgumentParser(description='处理Markdown文件')
|
|||
|
parser.add_argument('input_file', help='输入的Markdown文件路径')
|
|||
|
parser.add_argument('output_file', help='输出的Markdown文件路径')
|
|||
|
parser.add_argument('--max_length', type=int, default=4000, help='每个块的最大长度')
|
|||
|
parser.add_argument('--model', default='gpt-3.5-turbo', help='使用的LLM模型名称')
|
|||
|
parser.add_argument('--api_base', default='http://localhost:1234/v1', help='API基础URL')
|
|||
|
|
|||
|
args = parser.parse_args()
|
|||
|
|
|||
|
# 设置API基础URL
|
|||
|
openai.api_base = args.api_base
|
|||
|
|
|||
|
# 读取文件
|
|||
|
content = read_markdown_file(args.input_file)
|
|||
|
if not content:
|
|||
|
return
|
|||
|
|
|||
|
# 分块
|
|||
|
chunks = split_by_headers(content, args.max_length)
|
|||
|
print(f"文件已分成 {len(chunks)} 个块")
|
|||
|
|
|||
|
# 处理每个块
|
|||
|
processed_chunks = []
|
|||
|
for i, chunk in enumerate(chunks):
|
|||
|
print(f"处理块 {i+1}/{len(chunks)}...")
|
|||
|
processed_chunk = process_chunk_with_llm(chunk, args.model)
|
|||
|
processed_chunks.append(processed_chunk)
|
|||
|
|
|||
|
# 合并处理后的内容
|
|||
|
final_content = '\n'.join(processed_chunks)
|
|||
|
|
|||
|
# 保存结果
|
|||
|
save_markdown_file(final_content, args.output_file)
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
main()
|