Files
ragflow_api_test/src/add_chunk_cli_pdf_img.py

410 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from ragflow_sdk import RAGFlow
import os
import re
# 在文件顶部添加新依赖
import requests
#from urllib.parse import urlparse
import tempfile
from elasticsearch import Elasticsearch
from minio import Minio
from minio.error import S3Error
from dotenv import load_dotenv # 新增
# 加载 .env 文件中的环境变量
load_dotenv()
# 从环境变量初始化配置
api_key = os.getenv("RAGFLOW_API_KEY")
base_url = os.getenv("RAGFLOW_BASE_URL")
elastic_tenant_id = os.getenv("ELASTIC_TENANT_ID")
# 初始化 RAGFlow
rag_object = RAGFlow(api_key=api_key, base_url=base_url)
# 初始化 Elasticsearch
es = Elasticsearch(
[{
'host': os.getenv("ELASTIC_HOST"),
'port': int(os.getenv("ELASTIC_PORT")),
'scheme': 'http'
}],
basic_auth=(
os.getenv("ELASTIC_USERNAME"),
os.getenv("ELASTIC_PASSWORD")
)
)
# MinIO 配置
MINIO_CONFIG = {
"endpoint": f"{os.getenv('MINIO_HOST')}:{os.getenv('MINIO_PORT')}",
"access_key": os.getenv("MINIO_USER"),
"secret_key": os.getenv("MINIO_PASSWORD"),
"secure": False
}
def update_img_id_in_elasticsearch(tenant_id, doc_id, chunk_id, new_img_id):
"""
在 Elasticsearch 中更新指定文档块的 img_id。
:param tenant_id: 租户 ID
:param doc_id: 文档 ID
:param chunk_id: 文档块 ID
:param new_img_id: 新的 img_id
:return: 更新结果
"""
try:
# 构建索引名称
index_name = f"ragflow_{tenant_id}"
# 构建查询条件
query = {
"bool": {
"must": [
{"term": {"doc_id": doc_id}},
{"term": {"_id": chunk_id}}
]
}
}
# 搜索目标文档
result = es.search(index=index_name, body={"query": query})
# 检查是否找到目标文档
if result['hits']['total']['value'] == 0:
print(f"在 Elasticsearch 中未找到文档: index={index_name}, doc_id={doc_id}, chunk_id={chunk_id}")
return {"code": 102, "message": f"Can't find this chunk {chunk_id}"}
# 获取目标文档的 ID
hit = result['hits']['hits'][0]
doc_id_in_es = hit['_id']
# 构建更新请求
update_body = {
"doc": {
"img_id": new_img_id
}
}
# 更新文档
update_result = es.update(
index=index_name,
id=doc_id_in_es,
body=update_body,
refresh=True # 确保更新立即可见
)
print(f"Elasticsearch 更新结果: index={index_name}, id={doc_id_in_es}, result={update_result}")
# 验证更新
verify_doc = es.get(index=index_name, id=doc_id_in_es)
if verify_doc['_source'].get('img_id') == new_img_id:
print(f"成功更新 img_id 为: {new_img_id}")
return {"code": 0, "message": ""}
else:
print(f"更新验证失败,当前 img_id: {verify_doc['_source'].get('img_id')}")
return {"code": 100, "message": "Failed to verify img_id update"}
except Exception as e:
print(f"更新 Elasticsearch 时发生错误: {str(e)}")
return {"code": 101, "message": f"Error updating img_id: {str(e)}"}
def get_minio_client():
"""创建MinIO客户端"""
return Minio(
endpoint=MINIO_CONFIG["endpoint"],
access_key=MINIO_CONFIG["access_key"],
secret_key=MINIO_CONFIG["secret_key"],
secure=MINIO_CONFIG["secure"]
)
def upload_file2minio(bucket_name, object_name, file_path):
"""上传文件到MinIO
# 通过fput_object上传时
# 如果object_name为image\image.jpg则上传后的名字就是image\image.jpg
# 如果object_name为image/image.jpg则上传后image为文件夹文件名为image.jpg
"""
minio_client= get_minio_client()
try:
# 检查存储桶是否存在,如果不存在则创建(可选)
if not minio_client.bucket_exists(bucket_name):
minio_client.make_bucket(bucket_name)
print(f"Bucket '{bucket_name}' created")
# 上传文件
minio_client.fput_object(
bucket_name=bucket_name,
object_name=object_name,
file_path=file_path
)
# 获取文件的预签名URL可选
#res = minio_client.get_presigned_url("GET", bucket_name, object_name, expires=timedelta(days=7))
#res = "http://127.0.0.1:9000" + "/"+bucket_name+"/" + object_name
#print(res)
print(f"文件 '{file_path}' 成功上传到存储桶 '{bucket_name}''{object_name}'")
return True
except S3Error as exc:
print("MinIO错误:", exc)
return False
except Exception as e:
print("发生错误:", e)
return False
def choose_from_list(options, prompt):
for idx, item in enumerate(options):
print(f"{idx + 1}. {item}")
while True:
choice = input(prompt)
if choice.isdigit() and 1 <= int(choice) <= len(options):
return options[int(choice) - 1]
else:
print("输入无效,请重新输入编号。")
def select_files(file_path, file_type="pdf"):
"""
选择file_path中的所有指定类型文件默认pdf
返回文件路径列表
"""
file_list = []
for root, dirs, files in os.walk(file_path):
for file in files:
if file.lower().endswith(f".{file_type.lower()}"):
file_list.append(os.path.join(root, file))
return file_list
def pair_pdf_and_txt(pdf_path, txt_path):
"""
将pdf和txt文件对齐
返回对齐pdf_dict和txt_dict
pdf_dict和txt_dict的key为文件名不含后缀value为文件路径
txt_dict仅收入与pdf_dict中存在的文件
如果pdf_dict中有文件名没有对应的txt文件则不收入txt_dict
"""
pdf_files = select_files(pdf_path, "pdf")
txt_files = select_files(txt_path, "txt")
# 构建文件名到路径的映射
pdf_dict = {os.path.splitext(os.path.basename(f))[0]: f for f in pdf_files}
txt_dict_all = {os.path.splitext(os.path.basename(f))[0]: f for f in txt_files}
# 只保留有对应txt的pdf
pdf_dict_aligned = {}
txt_dict_aligned = {}
for name in pdf_dict:
if name in txt_dict_all:
pdf_dict_aligned[name] = pdf_dict[name]
txt_dict_aligned[name] = txt_dict_all[name]
return pdf_dict_aligned, txt_dict_aligned
def select_dataset(rag_object):
"""选择可用数据集"""
datasets = rag_object.list_datasets()
if not datasets:
print("没有可用的数据集。")
return None
dataset_names = [ds.name for ds in datasets]
dataset_name = choose_from_list(dataset_names, "请选择数据集编号:")
return [ds for ds in datasets if ds.name == dataset_name][0]
def upload_or_get_document(dataset, pdf_path, display_name):
"""上传或获取已存在的文档"""
try:
document = dataset.list_documents(name=display_name)[0]
print(f"文档已存在: {display_name},跳过上传。")
return document
except Exception:
try:
with open(pdf_path, "rb") as f:
blob = f.read()
dataset.upload_documents([{"display_name": display_name, "blob": blob}])
return dataset.list_documents(name=display_name)[0]
except Exception as e:
print(f"上传PDF失败: {pdf_path},错误: {e}")
return None
def divid_txt_chunk_img(txt_chunk):
"""分离文本块中的图片链接和纯文本内容
输入格式示例:
"这是文本内容![image](路径/IMAGE1.png)更多文本![image](路径/IMAGE2.png)"
返回:
clean_text: 移除所有图片链接后的纯文本内容
image_paths: 提取到的图片路径列表
"""
# 正则表达式匹配Markdown图片格式: ![alt_text](path)
pattern = r'!\[.*?\]\((.*?)\)'
# 提取所有图片路径
image_paths = re.findall(pattern, txt_chunk)
# 移除所有图片标记
clean_text = re.sub(pattern, '', txt_chunk)
# 移除多余空行并清理首尾空白
clean_text = re.sub(r'\n\s*\n', '\n\n', clean_text).strip()
return clean_text, image_paths
def extract_images_from_chunk( content):
"""从chunk内容中提取图片链接"""
img_pattern = r'!\[.*?\]\((.*?)\)'
return re.findall(img_pattern, content)
def remove_images_from_content( content):
"""从内容中移除图片链接"""
# 移除markdown图片语法 ![alt](url)
content = re.sub(r'!\[.*?\]\(.*?\)', '', content)
# 清理多余的空行
content = re.sub(r'\n\s*\n\s*\n', '\n\n', content)
return content.strip()
# 修改 process_txt_chunks 函数中的图片处理逻辑
def process_txt_chunks(dataset_id, document, txt_path):
try:
with open(txt_path, 'r', encoding='utf-8') as file:
file_content = file.read()
img_chunk_ids = []
for num, txt_chunk in enumerate(file_content.split('\n\n')):
if txt_chunk.strip():
print(f"处理文本块: {txt_chunk[:30]}...")
img_urls = extract_images_from_chunk(txt_chunk)
img_url = img_urls[0] if img_urls else None
if img_url:
print(f"检测到图片链接: {img_url}")
clean_chunk = remove_images_from_content(txt_chunk)
chunk = document.add_chunk(content=clean_chunk)
# 判断是否为网络图片 (新增逻辑)
if img_url.startswith(('http://', 'https://')):
# 下载网络图片到临时文件
try:
response = requests.get(img_url)
response.raise_for_status()
# 创建临时文件
with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp_file:
tmp_file.write(response.content)
tmp_path = tmp_file.name
# 上传临时文件
if upload_file2minio(dataset_id, chunk.id, tmp_path):
img_chunk_ids.append(chunk.id)
# new_img_id = f"{dataset_id}-{chunk.id}"
# print(f"网络图片 {img_url} 已下载并上传,新的 img_id: {new_img_id}")
# update_img_id_in_elasticsearch(elastic_tenant_id, document.id, chunk.id, new_img_id)
# 删除临时文件
os.unlink(tmp_path)
except Exception as e:
print(f"下载网络图片失败: {e}")
else:
# 处理本地图片 (原逻辑)
if not os.path.isabs(img_url):
img_abs_path = os.path.join(os.path.dirname(txt_path), img_url)
else:
img_abs_path = img_url
print(f"图片绝对路径: {img_abs_path}")
if os.path.exists(img_abs_path):
if upload_file2minio(dataset_id, chunk.id, img_abs_path):
img_chunk_ids.append(chunk.id)
# new_img_id = f"{dataset_id}-{chunk.id}"
# print(f"图片 {img_abs_path} 已上传,新的 img_id: {new_img_id}")
# update_img_id_in_elasticsearch(elastic_tenant_id, document.id, chunk.id, new_img_id)
else:
print(f"图片未找到: {img_abs_path},跳过。")
else:
print("未检测到图片链接,直接添加文本块。")
chunk = document.add_chunk(content=txt_chunk)
print(f"{num+1} Chunk添加成功! ID: {chunk.id}")
for img_chunk_id in img_chunk_ids:
update_img_id_in_elasticsearch(elastic_tenant_id, document.id, img_chunk_id, f"{dataset_id}-{img_chunk_id}")
except Exception as e:
print(f"处理文本文件时出错: {txt_path},错误: {e}")
def process_pdf_txt_pairs(pdf_dict, txt_dict, dataset):
"""处理PDF-TXT文件对"""
for name, pdf_path in pdf_dict.items():
display_name = os.path.basename(pdf_path)
document = upload_or_get_document(dataset, pdf_path, display_name)
print(f"选择的文档: {document.name}ID: {document.id}")
if not document:
continue
txt_path = txt_dict.get(name)
if txt_path:
process_txt_chunks(dataset.id,document, txt_path)
def main():
"""主函数处理PDF和TXT文件对
dataset.id = bucket_name
chunk_id = object_name
"""
file_path = os.getenv("FILE_PATH")
pdf_dict, txt_dict = pair_pdf_and_txt(file_path, file_path)
if not pdf_dict:
print("未选择任何文件。")
return
dataset = select_dataset(rag_object)
print(f"选择的数据集: {dataset.name}")
print(f"选择的数据集id: {dataset.id}")
if not dataset:
print("未选择数据集。")
return
process_pdf_txt_pairs(pdf_dict, txt_dict, dataset)
if __name__ == "__main__":
main()