Merge branch 'gitee-master-jdk17' into feature-project

# Conflicts:
#	README.md
#	yudao-framework/yudao-spring-boot-starter-web/src/main/java/cn/iocoder/yudao/framework/banner/core/BannerApplicationRunner.java
#	yudao-server/pom.xml
#	yudao-server/src/main/resources/application-dev.yaml
This commit is contained in:
2024-07-18 10:01:15 +08:00
269 changed files with 20598 additions and 1556 deletions

View File

@@ -43,4 +43,6 @@ public class ServiceErrorCodeRange {
// 模块 crm 错误码区间 [1-020-000-000 ~ 1-021-000-000)
// 模块 ai 错误码区间 [1-022-000-000 ~ 1-023-000-000)
}

View File

@@ -1,12 +1,13 @@
package cn.iocoder.yudao.framework.datapermission.config;
import cn.iocoder.yudao.framework.datapermission.core.aop.DataPermissionAnnotationAdvisor;
import cn.iocoder.yudao.framework.datapermission.core.db.DataPermissionDatabaseInterceptor;
import cn.iocoder.yudao.framework.datapermission.core.db.DataPermissionRuleHandler;
import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRule;
import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRuleFactory;
import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRuleFactoryImpl;
import cn.iocoder.yudao.framework.mybatis.core.util.MyBatisUtils;
import com.baomidou.mybatisplus.extension.plugins.MybatisPlusInterceptor;
import com.baomidou.mybatisplus.extension.plugins.inner.DataPermissionInterceptor;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.context.annotation.Bean;
@@ -26,14 +27,15 @@ public class YudaoDataPermissionAutoConfiguration {
}
@Bean
public DataPermissionDatabaseInterceptor dataPermissionDatabaseInterceptor(MybatisPlusInterceptor interceptor,
DataPermissionRuleFactory ruleFactory) {
// 创建 DataPermissionDatabaseInterceptor 拦截器
DataPermissionDatabaseInterceptor inner = new DataPermissionDatabaseInterceptor(ruleFactory);
public DataPermissionRuleHandler dataPermissionRuleHandler(MybatisPlusInterceptor interceptor,
DataPermissionRuleFactory ruleFactory) {
// 创建 DataPermissionInterceptor 拦截器
DataPermissionRuleHandler handler = new DataPermissionRuleHandler(ruleFactory);
DataPermissionInterceptor inner = new DataPermissionInterceptor(handler);
// 添加到 interceptor 中
// 需要加在首个,主要是为了在分页插件前面。这个是 MyBatis Plus 的规定
MyBatisUtils.addInterceptor(interceptor, inner, 0);
return inner;
return handler;
}
@Bean

View File

@@ -1,641 +0,0 @@
package cn.iocoder.yudao.framework.datapermission.core.db;
import cn.hutool.core.collection.CollUtil;
import cn.iocoder.yudao.framework.common.util.collection.SetUtils;
import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRule;
import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRuleFactory;
import cn.iocoder.yudao.framework.mybatis.core.util.MyBatisUtils;
import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
import com.baomidou.mybatisplus.extension.parser.JsqlParserSupport;
import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import net.sf.jsqlparser.expression.*;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
import net.sf.jsqlparser.expression.operators.relational.ExistsExpression;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.delete.Delete;
import net.sf.jsqlparser.statement.select.*;
import net.sf.jsqlparser.statement.update.Update;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import java.sql.Connection;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
/**
* 数据权限拦截器,通过 {@link DataPermissionRule} 数据权限规则,重写 SQL 的方式来实现
* 主要的 SQL 重写方法,可见 {@link #builderExpression(Expression, List)} 方法
*
* 整体的代码实现上,参考 {@link com.baomidou.mybatisplus.extension.plugins.inner.TenantLineInnerInterceptor} 实现。
* 所以每次 MyBatis Plus 升级时,需要 Review 下其具体的实现是否有变更!
*
* @author 芋道源码
*/
@RequiredArgsConstructor
public class DataPermissionDatabaseInterceptor extends JsqlParserSupport implements InnerInterceptor {
private final DataPermissionRuleFactory ruleFactory;
@Getter
private final MappedStatementCache mappedStatementCache = new MappedStatementCache();
@Override // SELECT 场景
public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {
// 获得 Mapper 对应的数据权限的规则
List<DataPermissionRule> rules = ruleFactory.getDataPermissionRule(ms.getId());
if (mappedStatementCache.noRewritable(ms, rules)) { // 如果无需重写,则跳过
return;
}
PluginUtils.MPBoundSql mpBs = PluginUtils.mpBoundSql(boundSql);
try {
// 初始化上下文
ContextHolder.init(rules);
// 处理 SQL
mpBs.sql(parserSingle(mpBs.sql(), null));
} finally {
// 添加是否需要重写的缓存
addMappedStatementCache(ms);
// 清空上下文
ContextHolder.clear();
}
}
@Override // 只处理 UPDATE / DELETE 场景,不处理 INSERT 场景(因为 INSERT 不需要数据权限)
public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
PluginUtils.MPStatementHandler mpSh = PluginUtils.mpStatementHandler(sh);
MappedStatement ms = mpSh.mappedStatement();
SqlCommandType sct = ms.getSqlCommandType();
if (sct == SqlCommandType.UPDATE || sct == SqlCommandType.DELETE) {
// 获得 Mapper 对应的数据权限的规则
List<DataPermissionRule> rules = ruleFactory.getDataPermissionRule(ms.getId());
if (mappedStatementCache.noRewritable(ms, rules)) { // 如果无需重写,则跳过
return;
}
PluginUtils.MPBoundSql mpBs = mpSh.mPBoundSql();
try {
// 初始化上下文
ContextHolder.init(rules);
// 处理 SQL
mpBs.sql(parserMulti(mpBs.sql(), null));
} finally {
// 添加是否需要重写的缓存
addMappedStatementCache(ms);
// 清空上下文
ContextHolder.clear();
}
}
}
@Override
protected void processSelect(Select select, int index, String sql, Object obj) {
processSelectBody(select.getSelectBody());
List<WithItem> withItemsList = select.getWithItemsList();
if (!CollectionUtils.isEmpty(withItemsList)) {
withItemsList.forEach(this::processSelectBody);
}
}
/**
* update 语句处理
*/
@Override
protected void processUpdate(Update update, int index, String sql, Object obj) {
final Table table = update.getTable();
update.setWhere(this.builderExpression(update.getWhere(), table));
}
/**
* delete 语句处理
*/
@Override
protected void processDelete(Delete delete, int index, String sql, Object obj) {
delete.setWhere(this.builderExpression(delete.getWhere(), delete.getTable()));
}
// ========== 和 TenantLineInnerInterceptor 一致的逻辑 ==========
protected void processSelectBody(SelectBody selectBody) {
if (selectBody == null) {
return;
}
if (selectBody instanceof PlainSelect) {
processPlainSelect((PlainSelect) selectBody);
} else if (selectBody instanceof WithItem) {
WithItem withItem = (WithItem) selectBody;
processSelectBody(withItem.getSubSelect().getSelectBody());
} else {
SetOperationList operationList = (SetOperationList) selectBody;
List<SelectBody> selectBodyList = operationList.getSelects();
if (CollectionUtils.isNotEmpty(selectBodyList)) {
selectBodyList.forEach(this::processSelectBody);
}
}
}
/**
* 处理 PlainSelect
*/
protected void processPlainSelect(PlainSelect plainSelect) {
//#3087 github
List<SelectItem> selectItems = plainSelect.getSelectItems();
if (CollectionUtils.isNotEmpty(selectItems)) {
selectItems.forEach(this::processSelectItem);
}
// 处理 where 中的子查询
Expression where = plainSelect.getWhere();
processWhereSubSelect(where);
// 处理 fromItem
FromItem fromItem = plainSelect.getFromItem();
List<Table> list = processFromItem(fromItem);
List<Table> mainTables = new ArrayList<>(list);
// 处理 join
List<Join> joins = plainSelect.getJoins();
if (CollectionUtils.isNotEmpty(joins)) {
mainTables = processJoins(mainTables, joins);
}
// 当有 mainTable 时,进行 where 条件追加
if (CollectionUtils.isNotEmpty(mainTables)) {
plainSelect.setWhere(builderExpression(where, mainTables));
}
}
private List<Table> processFromItem(FromItem fromItem) {
// 处理括号括起来的表达式
while (fromItem instanceof ParenthesisFromItem) {
fromItem = ((ParenthesisFromItem) fromItem).getFromItem();
}
List<Table> mainTables = new ArrayList<>();
// 无 join 时的处理逻辑
if (fromItem instanceof Table) {
Table fromTable = (Table) fromItem;
mainTables.add(fromTable);
} else if (fromItem instanceof SubJoin) {
// SubJoin 类型则还需要添加上 where 条件
List<Table> tables = processSubJoin((SubJoin) fromItem);
mainTables.addAll(tables);
} else {
// 处理下 fromItem
processOtherFromItem(fromItem);
}
return mainTables;
}
/**
* 处理where条件内的子查询
* <p>
* 支持如下:
* 1. in
* 2. =
* 3. >
* 4. <
* 5. >=
* 6. <=
* 7. <>
* 8. EXISTS
* 9. NOT EXISTS
* <p>
* 前提条件:
* 1. 子查询必须放在小括号中
* 2. 子查询一般放在比较操作符的右边
*
* @param where where 条件
*/
protected void processWhereSubSelect(Expression where) {
if (where == null) {
return;
}
if (where instanceof FromItem) {
processOtherFromItem((FromItem) where);
return;
}
if (where.toString().indexOf("SELECT") > 0) {
// 有子查询
if (where instanceof BinaryExpression) {
// 比较符号 , and , or , 等等
BinaryExpression expression = (BinaryExpression) where;
processWhereSubSelect(expression.getLeftExpression());
processWhereSubSelect(expression.getRightExpression());
} else if (where instanceof InExpression) {
// in
InExpression expression = (InExpression) where;
Expression inExpression = expression.getRightExpression();
if (inExpression instanceof SubSelect) {
processSelectBody(((SubSelect) inExpression).getSelectBody());
}
} else if (where instanceof ExistsExpression) {
// exists
ExistsExpression expression = (ExistsExpression) where;
processWhereSubSelect(expression.getRightExpression());
} else if (where instanceof NotExpression) {
// not exists
NotExpression expression = (NotExpression) where;
processWhereSubSelect(expression.getExpression());
} else if (where instanceof Parenthesis) {
Parenthesis expression = (Parenthesis) where;
processWhereSubSelect(expression.getExpression());
}
}
}
protected void processSelectItem(SelectItem selectItem) {
if (selectItem instanceof SelectExpressionItem) {
SelectExpressionItem selectExpressionItem = (SelectExpressionItem) selectItem;
if (selectExpressionItem.getExpression() instanceof SubSelect) {
processSelectBody(((SubSelect) selectExpressionItem.getExpression()).getSelectBody());
} else if (selectExpressionItem.getExpression() instanceof Function) {
processFunction((Function) selectExpressionItem.getExpression());
}
}
}
/**
* 处理函数
* <p>支持: 1. select fun(args..) 2. select fun1(fun2(args..),args..)<p>
* <p> fixed gitee pulls/141</p>
*
* @param function
*/
protected void processFunction(Function function) {
ExpressionList parameters = function.getParameters();
if (parameters != null) {
parameters.getExpressions().forEach(expression -> {
if (expression instanceof SubSelect) {
processSelectBody(((SubSelect) expression).getSelectBody());
} else if (expression instanceof Function) {
processFunction((Function) expression);
}
});
}
}
/**
* 处理子查询等
*/
protected void processOtherFromItem(FromItem fromItem) {
// 去除括号
while (fromItem instanceof ParenthesisFromItem) {
fromItem = ((ParenthesisFromItem) fromItem).getFromItem();
}
if (fromItem instanceof SubSelect) {
SubSelect subSelect = (SubSelect) fromItem;
if (subSelect.getSelectBody() != null) {
processSelectBody(subSelect.getSelectBody());
}
} else if (fromItem instanceof ValuesList) {
logger.debug("Perform a subQuery, if you do not give us feedback");
} else if (fromItem instanceof LateralSubSelect) {
LateralSubSelect lateralSubSelect = (LateralSubSelect) fromItem;
if (lateralSubSelect.getSubSelect() != null) {
SubSelect subSelect = lateralSubSelect.getSubSelect();
if (subSelect.getSelectBody() != null) {
processSelectBody(subSelect.getSelectBody());
}
}
}
}
/**
* 处理 sub join
*
* @param subJoin subJoin
* @return Table subJoin 中的主表
*/
private List<Table> processSubJoin(SubJoin subJoin) {
List<Table> mainTables = new ArrayList<>();
if (subJoin.getJoinList() != null) {
List<Table> list = processFromItem(subJoin.getLeft());
mainTables.addAll(list);
mainTables = processJoins(mainTables, subJoin.getJoinList());
}
return mainTables;
}
/**
* 处理 joins
*
* @param mainTables 可以为 null
* @param joins join 集合
* @return List<Table> 右连接查询的 Table 列表
*/
private List<Table> processJoins(List<Table> mainTables, List<Join> joins) {
// join 表达式中最终的主表
Table mainTable = null;
// 当前 join 的左表
Table leftTable = null;
if (mainTables == null) {
mainTables = new ArrayList<>();
} else if (mainTables.size() == 1) {
mainTable = mainTables.get(0);
leftTable = mainTable;
}
//对于 on 表达式写在最后的 join需要记录下前面多个 on 的表名
Deque<List<Table>> onTableDeque = new LinkedList<>();
for (Join join : joins) {
// 处理 on 表达式
FromItem joinItem = join.getRightItem();
// 获取当前 join 的表subJoint 可以看作是一张表
List<Table> joinTables = null;
if (joinItem instanceof Table) {
joinTables = new ArrayList<>();
joinTables.add((Table) joinItem);
} else if (joinItem instanceof SubJoin) {
joinTables = processSubJoin((SubJoin) joinItem);
}
if (joinTables != null) {
// 如果是隐式内连接
if (join.isSimple()) {
mainTables.addAll(joinTables);
continue;
}
// 当前表是否忽略
Table joinTable = joinTables.get(0);
List<Table> onTables = null;
// 如果不要忽略,且是右连接,则记录下当前表
if (join.isRight()) {
mainTable = joinTable;
if (leftTable != null) {
onTables = Collections.singletonList(leftTable);
}
} else if (join.isLeft()) {
onTables = Collections.singletonList(joinTable);
} else if (join.isInner()) {
if (mainTable == null) {
onTables = Collections.singletonList(joinTable);
} else {
onTables = Arrays.asList(mainTable, joinTable);
}
mainTable = null;
}
mainTables = new ArrayList<>();
if (mainTable != null) {
mainTables.add(mainTable);
}
// 获取 join 尾缀的 on 表达式列表
Collection<Expression> originOnExpressions = join.getOnExpressions();
// 正常 join on 表达式只有一个,立刻处理
if (originOnExpressions.size() == 1 && onTables != null) {
List<Expression> onExpressions = new LinkedList<>();
onExpressions.add(builderExpression(originOnExpressions.iterator().next(), onTables));
join.setOnExpressions(onExpressions);
leftTable = joinTable;
continue;
}
// 表名压栈,忽略的表压入 null以便后续不处理
onTableDeque.push(onTables);
// 尾缀多个 on 表达式的时候统一处理
if (originOnExpressions.size() > 1) {
Collection<Expression> onExpressions = new LinkedList<>();
for (Expression originOnExpression : originOnExpressions) {
List<Table> currentTableList = onTableDeque.poll();
if (CollectionUtils.isEmpty(currentTableList)) {
onExpressions.add(originOnExpression);
} else {
onExpressions.add(builderExpression(originOnExpression, currentTableList));
}
}
join.setOnExpressions(onExpressions);
}
leftTable = joinTable;
} else {
processOtherFromItem(joinItem);
leftTable = null;
}
}
return mainTables;
}
// ========== 和 TenantLineInnerInterceptor 存在差异的逻辑:关键,实现权限条件的拼接 ==========
/**
* 处理条件
*
* @param currentExpression 当前 where 条件
* @param table 单个表
*/
protected Expression builderExpression(Expression currentExpression, Table table) {
return this.builderExpression(currentExpression, Collections.singletonList(table));
}
/**
* 处理条件
*
* @param currentExpression 当前 where 条件
* @param tables 多个表
*/
protected Expression builderExpression(Expression currentExpression, List<Table> tables) {
// 没有表需要处理直接返回
if (CollectionUtils.isEmpty(tables)) {
return currentExpression;
}
// 第一步,获得 Table 对应的数据权限条件
Expression dataPermissionExpression = null;
for (Table table : tables) {
// 构建每个表的权限 Expression 条件
Expression expression = buildDataPermissionExpression(table);
if (expression == null) {
continue;
}
// 合并到 dataPermissionExpression 中
dataPermissionExpression = dataPermissionExpression == null ? expression
: new AndExpression(dataPermissionExpression, expression);
}
// 第二步,合并多个 Expression 条件
if (dataPermissionExpression == null) {
return currentExpression;
}
if (currentExpression == null) {
return dataPermissionExpression;
}
// ① 如果表达式为 Or则需要 (currentExpression) AND dataPermissionExpression
if (currentExpression instanceof OrExpression) {
return new AndExpression(new Parenthesis(currentExpression), dataPermissionExpression);
}
// ② 如果表达式为 And则直接返回 where AND dataPermissionExpression
return new AndExpression(currentExpression, dataPermissionExpression);
}
/**
* 构建指定表的数据权限的 Expression 过滤条件
*
* @param table 表
* @return Expression 过滤条件
*/
private Expression buildDataPermissionExpression(Table table) {
// 生成条件
Expression allExpression = null;
for (DataPermissionRule rule : ContextHolder.getRules()) {
// 判断表名是否匹配
String tableName = MyBatisUtils.getTableName(table);
if (!rule.getTableNames().contains(tableName)) {
continue;
}
// 如果有匹配的规则,说明可重写。
// 为什么不是有 allExpression 非空才重写呢?在生成 column = value 过滤条件时,会因为 value 不存在,导致未重写。
// 这样导致第一次无 value被标记成无需重写但是第二次有 value此时会需要重写。
ContextHolder.setRewrite(true);
// 单条规则的条件
Expression oneExpress = rule.getExpression(tableName, table.getAlias());
if (oneExpress == null){
continue;
}
// 拼接到 allExpression 中
allExpression = allExpression == null ? oneExpress
: new AndExpression(allExpression, oneExpress);
}
return allExpression;
}
/**
* 判断 SQL 是否重写。如果没有重写,则添加到 {@link MappedStatementCache} 中
*
* @param ms MappedStatement
*/
private void addMappedStatementCache(MappedStatement ms) {
if (ContextHolder.getRewrite()) {
return;
}
// 无重写,进行添加
mappedStatementCache.addNoRewritable(ms, ContextHolder.getRules());
}
/**
* SQL 解析上下文,方便透传 {@link DataPermissionRule} 规则
*
* @author 芋道源码
*/
static final class ContextHolder {
/**
* 该 {@link MappedStatement} 对应的规则
*/
private static final ThreadLocal<List<DataPermissionRule>> RULES = ThreadLocal.withInitial(Collections::emptyList);
/**
* SQL 是否进行重写
*/
private static final ThreadLocal<Boolean> REWRITE = ThreadLocal.withInitial(() -> Boolean.FALSE);
public static void init(List<DataPermissionRule> rules) {
RULES.set(rules);
REWRITE.set(false);
}
public static void clear() {
RULES.remove();
REWRITE.remove();
}
public static boolean getRewrite() {
return REWRITE.get();
}
public static void setRewrite(boolean rewrite) {
REWRITE.set(rewrite);
}
public static List<DataPermissionRule> getRules() {
return RULES.get();
}
}
/**
* {@link MappedStatement} 缓存
* 目前主要用于,记录 {@link DataPermissionRule} 是否对指定 {@link MappedStatement} 无效
* 如果无效,则可以避免 SQL 的解析,加快速度
*
* @author 芋道源码
*/
static final class MappedStatementCache {
/**
* 指定数据权限规则,对指定 MappedStatement 无需重写(不生效)的缓存
*
* value{@link MappedStatement#getId()} 编号
*/
@Getter
private final Map<Class<? extends DataPermissionRule>, Set<String>> noRewritableMappedStatements = new ConcurrentHashMap<>();
/**
* 判断是否无需重写
* ps虽然有点中文式英语但是容易读懂即可
*
* @param ms MappedStatement
* @param rules 数据权限规则数组
* @return 是否无需重写
*/
public boolean noRewritable(MappedStatement ms, List<DataPermissionRule> rules) {
// 如果规则为空,说明无需重写
if (CollUtil.isEmpty(rules)) {
return true;
}
// 任一规则不在 noRewritableMap 中,则说明可能需要重写
for (DataPermissionRule rule : rules) {
Set<String> mappedStatementIds = noRewritableMappedStatements.get(rule.getClass());
if (!CollUtil.contains(mappedStatementIds, ms.getId())) {
return false;
}
}
return true;
}
/**
* 添加无需重写的 MappedStatement
*
* @param ms MappedStatement
* @param rules 数据权限规则数组
*/
public void addNoRewritable(MappedStatement ms, List<DataPermissionRule> rules) {
for (DataPermissionRule rule : rules) {
Set<String> mappedStatementIds = noRewritableMappedStatements.get(rule.getClass());
if (CollUtil.isNotEmpty(mappedStatementIds)) {
mappedStatementIds.add(ms.getId());
} else {
noRewritableMappedStatements.put(rule.getClass(), SetUtils.asSet(ms.getId()));
}
}
}
/**
* 清空缓存
* 目前主要提供给单元测试
*/
public void clear() {
noRewritableMappedStatements.clear();
}
}
}

View File

@@ -0,0 +1,57 @@
package cn.iocoder.yudao.framework.datapermission.core.db;
import cn.hutool.core.collection.CollUtil;
import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRule;
import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRuleFactory;
import cn.iocoder.yudao.framework.mybatis.core.util.MyBatisUtils;
import com.baomidou.mybatisplus.extension.plugins.handler.MultiDataPermissionHandler;
import lombok.RequiredArgsConstructor;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.schema.Table;
import java.util.List;
/**
* 基于 {@link DataPermissionRule} 的数据权限处理器
*
* 它的底层,是基于 MyBatis Plus 的 <a href="https://baomidou.com/plugins/data-permission/">数据权限插件</a>
* 核心原理:它会在 SQL 执行前拦截 SQL 语句,并根据用户权限动态添加权限相关的 SQL 片段。这样,只有用户有权限访问的数据才会被查询出来
*
* @author 芋道源码
*/
@RequiredArgsConstructor
public class DataPermissionRuleHandler implements MultiDataPermissionHandler {
private final DataPermissionRuleFactory ruleFactory;
@Override
public Expression getSqlSegment(Table table, Expression where, String mappedStatementId) {
// 获得 Mapper 对应的数据权限的规则
List<DataPermissionRule> rules = ruleFactory.getDataPermissionRule(mappedStatementId);
if (CollUtil.isEmpty(rules)) {
return null;
}
// 生成条件
Expression allExpression = null;
for (DataPermissionRule rule : rules) {
// 判断表名是否匹配
String tableName = MyBatisUtils.getTableName(table);
if (!rule.getTableNames().contains(tableName)) {
continue;
}
// 单条规则的条件
Expression oneExpress = rule.getExpression(tableName, table.getAlias());
if (oneExpress == null) {
continue;
}
// 拼接到 allExpression 中
allExpression = allExpression == null ? oneExpress
: new AndExpression(allExpression, oneExpress);
}
return allExpression;
}
}

View File

@@ -156,7 +156,8 @@ public class DeptDataPermissionRule implements DataPermissionRule {
}
// 拼接条件
return new InExpression(MyBatisUtils.buildColumn(tableName, tableAlias, columnName),
new ExpressionList(CollectionUtils.convertList(deptIds, LongValue::new)));
// Parenthesis 的目的,是提供 (1,2,3) 的 () 左右括号
new Parenthesis(new ExpressionList<>(CollectionUtils.convertList(deptIds, LongValue::new))));
}
private Expression buildUserExpression(String tableName, Alias tableAlias, Boolean self, Long userId) {

View File

@@ -1,190 +0,0 @@
package cn.iocoder.yudao.framework.datapermission.core.db;
import cn.iocoder.yudao.framework.common.util.collection.SetUtils;
import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRule;
import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRuleFactory;
import cn.iocoder.yudao.framework.mybatis.core.util.MyBatisUtils;
import cn.iocoder.yudao.framework.test.core.ut.BaseMockitoUnitTest;
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.schema.Column;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.MockedStatic;
import java.sql.Connection;
import java.util.*;
import static java.util.Collections.singletonList;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.*;
/**
* {@link DataPermissionDatabaseInterceptor} 的单元测试
* 主要测试 {@link DataPermissionDatabaseInterceptor#beforePrepare(StatementHandler, Connection, Integer)}
* 和 {@link DataPermissionDatabaseInterceptor#beforeUpdate(Executor, MappedStatement, Object)}
* 以及在这个过程中ContextHolder 和 MappedStatementCache
*
* @author 芋道源码
*/
public class DataPermissionDatabaseInterceptorTest extends BaseMockitoUnitTest {
@InjectMocks
private DataPermissionDatabaseInterceptor interceptor;
@Mock
private DataPermissionRuleFactory ruleFactory;
@BeforeEach
public void setUp() {
// 清理上下文
DataPermissionDatabaseInterceptor.ContextHolder.clear();
// 清空缓存
interceptor.getMappedStatementCache().clear();
}
@Test // 不存在规则,且不匹配
public void testBeforeQuery_withoutRule() {
try (MockedStatic<PluginUtils> pluginUtilsMock = mockStatic(PluginUtils.class)) {
// 准备参数
MappedStatement mappedStatement = mock(MappedStatement.class);
BoundSql boundSql = mock(BoundSql.class);
// 调用
interceptor.beforeQuery(null, mappedStatement, null, null, null, boundSql);
// 断言
pluginUtilsMock.verify(() -> PluginUtils.mpBoundSql(boundSql), never());
}
}
@Test // 存在规则,且不匹配
public void testBeforeQuery_withMatchRule() {
try (MockedStatic<PluginUtils> pluginUtilsMock = mockStatic(PluginUtils.class)) {
// 准备参数
MappedStatement mappedStatement = mock(MappedStatement.class);
BoundSql boundSql = mock(BoundSql.class);
// mock 方法(数据权限)
when(ruleFactory.getDataPermissionRule(same(mappedStatement.getId())))
.thenReturn(singletonList(new DeptDataPermissionRule()));
// mock 方法(MPBoundSql)
PluginUtils.MPBoundSql mpBs = mock(PluginUtils.MPBoundSql.class);
pluginUtilsMock.when(() -> PluginUtils.mpBoundSql(same(boundSql))).thenReturn(mpBs);
// mock 方法(SQL)
String sql = "select * from t_user where id = 1";
when(mpBs.sql()).thenReturn(sql);
// 针对 ContextHolder 和 MappedStatementCache 暂时不 mock主要想校验过程中数据是否正确
// 调用
interceptor.beforeQuery(null, mappedStatement, null, null, null, boundSql);
// 断言
verify(mpBs, times(1)).sql(
eq("SELECT * FROM t_user WHERE id = 1 AND t_user.dept_id = 100"));
// 断言缓存
assertTrue(interceptor.getMappedStatementCache().getNoRewritableMappedStatements().isEmpty());
}
}
@Test // 存在规则,但不匹配
public void testBeforeQuery_withoutMatchRule() {
try (MockedStatic<PluginUtils> pluginUtilsMock = mockStatic(PluginUtils.class)) {
// 准备参数
MappedStatement mappedStatement = mock(MappedStatement.class);
BoundSql boundSql = mock(BoundSql.class);
// mock 方法(数据权限)
when(ruleFactory.getDataPermissionRule(same(mappedStatement.getId())))
.thenReturn(singletonList(new DeptDataPermissionRule()));
// mock 方法(MPBoundSql)
PluginUtils.MPBoundSql mpBs = mock(PluginUtils.MPBoundSql.class);
pluginUtilsMock.when(() -> PluginUtils.mpBoundSql(same(boundSql))).thenReturn(mpBs);
// mock 方法(SQL)
String sql = "select * from t_role where id = 1";
when(mpBs.sql()).thenReturn(sql);
// 针对 ContextHolder 和 MappedStatementCache 暂时不 mock主要想校验过程中数据是否正确
// 调用
interceptor.beforeQuery(null, mappedStatement, null, null, null, boundSql);
// 断言
verify(mpBs, times(1)).sql(
eq("SELECT * FROM t_role WHERE id = 1"));
// 断言缓存
assertFalse(interceptor.getMappedStatementCache().getNoRewritableMappedStatements().isEmpty());
}
}
@Test
public void testAddNoRewritable() {
// 准备参数
MappedStatement ms = mock(MappedStatement.class);
List<DataPermissionRule> rules = singletonList(new DeptDataPermissionRule());
// mock 方法
when(ms.getId()).thenReturn("selectById");
// 调用
interceptor.getMappedStatementCache().addNoRewritable(ms, rules);
// 断言
Map<Class<? extends DataPermissionRule>, Set<String>> noRewritableMappedStatements =
interceptor.getMappedStatementCache().getNoRewritableMappedStatements();
assertEquals(1, noRewritableMappedStatements.size());
assertEquals(SetUtils.asSet("selectById"), noRewritableMappedStatements.get(DeptDataPermissionRule.class));
}
@Test
public void testNoRewritable() {
// 准备参数
MappedStatement ms = mock(MappedStatement.class);
// mock 方法
when(ms.getId()).thenReturn("selectById");
// mock 数据
List<DataPermissionRule> rules = singletonList(new DeptDataPermissionRule());
interceptor.getMappedStatementCache().addNoRewritable(ms, rules);
// 场景一rules 为空
assertTrue(interceptor.getMappedStatementCache().noRewritable(ms, null));
// 场景二rules 非空,可重写
assertFalse(interceptor.getMappedStatementCache().noRewritable(ms, singletonList(new EmptyDataPermissionRule())));
// 场景三rule 非空,不可重写
assertTrue(interceptor.getMappedStatementCache().noRewritable(ms, rules));
}
private static class DeptDataPermissionRule implements DataPermissionRule {
private static final String COLUMN = "dept_id";
@Override
public Set<String> getTableNames() {
return SetUtils.asSet("t_user");
}
@Override
public Expression getExpression(String tableName, Alias tableAlias) {
Column column = MyBatisUtils.buildColumn(tableName, tableAlias, COLUMN);
LongValue value = new LongValue(100L);
return new EqualsTo(column, value);
}
}
private static class EmptyDataPermissionRule implements DataPermissionRule {
@Override
public Set<String> getTableNames() {
return Collections.emptySet();
}
@Override
public Expression getExpression(String tableName, Alias tableAlias) {
return null;
}
}
}

View File

@@ -4,9 +4,11 @@ import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRule;
import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRuleFactory;
import cn.iocoder.yudao.framework.mybatis.core.util.MyBatisUtils;
import cn.iocoder.yudao.framework.test.core.ut.BaseMockitoUnitTest;
import com.baomidou.mybatisplus.extension.plugins.inner.DataPermissionInterceptor;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.Parenthesis;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
@@ -21,24 +23,30 @@ import java.util.Set;
import static cn.iocoder.yudao.framework.common.util.collection.SetUtils.asSet;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
/**
* {@link DataPermissionDatabaseInterceptor} 的单元测试
* {@link DataPermissionRuleHandler} 的单元测试
* 主要复用了 MyBatis Plus TenantLineInnerInterceptorTest 的单元测试
* 不过它的单元测试不是很规范考虑到是复用的所以暂时不进行修改~
*
* @author 芋道源码
*/
public class DataPermissionDatabaseInterceptorTest2 extends BaseMockitoUnitTest {
public class DataPermissionRuleHandlerTest extends BaseMockitoUnitTest {
@InjectMocks
private DataPermissionDatabaseInterceptor interceptor;
private DataPermissionRuleHandler handler;
@Mock
private DataPermissionRuleFactory ruleFactory;
private DataPermissionInterceptor interceptor;
@BeforeEach
public void setUp() {
interceptor = new DataPermissionInterceptor(handler);
// 租户的数据权限规则
DataPermissionRule tenantRule = new DataPermissionRule() {
@@ -71,14 +79,14 @@ public class DataPermissionDatabaseInterceptorTest2 extends BaseMockitoUnitTest
@Override
public Expression getExpression(String tableName, Alias tableAlias) {
Column column = MyBatisUtils.buildColumn(tableName, tableAlias, COLUMN);
ExpressionList values = new ExpressionList(new LongValue(10L),
ExpressionList<LongValue> values = new ExpressionList<>(new LongValue(10L),
new LongValue(20L));
return new InExpression(column, values);
return new InExpression(column, new Parenthesis((values)));
}
};
// 设置到上下文保证
DataPermissionDatabaseInterceptor.ContextHolder.init(Arrays.asList(tenantRule, deptRule));
// 设置到上下文
when(ruleFactory.getDataPermissionRule(any())).thenReturn(Arrays.asList(tenantRule, deptRule));
}
@Test
@@ -262,7 +270,7 @@ public class DataPermissionDatabaseInterceptorTest2 extends BaseMockitoUnitTest
"right join entity2 e2 on e1.id = e2.id",
"SELECT * FROM entity e " +
"LEFT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
"RIGHT JOIN entity2 e2 ON e1.id = e2.id AND e1.tenant_id = 1 " +
"RIGHT JOIN entity2 e2 ON e1.id = e2.id AND e.tenant_id = 1 " +
"WHERE e2.tenant_id = 1");
assertSql("SELECT * FROM entity e " +
@@ -447,7 +455,6 @@ public class DataPermissionDatabaseInterceptorTest2 extends BaseMockitoUnitTest
assertEquals(targetSql, interceptor.parserSingle(sql, null));
}
// ========== 额外的测试 ==========
@Test

View File

@@ -28,7 +28,7 @@ public class AreaUtilsTest {
@Test
public void testFormat() {
assertEquals(AreaUtils.format(110105), "北京 北京市 朝阳区");
assertEquals(AreaUtils.format(110105), "北京 北京市 朝阳区");
assertEquals(AreaUtils.format(1), "中国");
assertEquals(AreaUtils.format(2), "蒙古");
}

View File

@@ -53,6 +53,16 @@
<artifactId>DmJdbcDriver18</artifactId>
<optional>true</optional>
</dependency>
<dependency>
<groupId>cn.com.kingbase</groupId>
<artifactId>kingbase8</artifactId>
<optional>true</optional>
</dependency>
<dependency>
<groupId>org.opengauss</groupId>
<artifactId>opengauss-jdbc</artifactId>
<optional>true</optional>
</dependency>
<dependency>
<groupId>com.alibaba</groupId>

View File

@@ -0,0 +1,84 @@
package cn.iocoder.yudao.framework.mybatis.core.enums;
import cn.hutool.core.util.StrUtil;
import com.baomidou.mybatisplus.annotation.DbType;
import lombok.AllArgsConstructor;
import lombok.Getter;
import java.util.Arrays;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;
/**
* 针对 MyBatis Plus 的 {@link DbType} 增强,补充更多信息
*/
@Getter
@AllArgsConstructor
public enum DbTypeEnum {
/**
* MySQL
*/
MY_SQL( DbType.MYSQL, "MySQL", "FIND_IN_SET('#{value}', #{column}) <> 0"),
/**
* Oracle
*/
ORACLE(DbType.ORACLE, "Oracle", "FIND_IN_SET('#{value}', #{column}) <> 0"),
/**
* PostgreSQL
*
* 华为 openGauss 使用 ProductName 与 PostgreSQL 相同
*/
POSTGRE_SQL(DbType.POSTGRE_SQL,"PostgreSQL", "POSITION('#{value}' IN #{column}) <> 0"),
/**
* SQL Server
*/
SQL_SERVER(DbType.SQL_SERVER, "Microsoft SQL Server", "CHARINDEX(',' + #{value} + ',', ',' + #{column} + ',') <> 0"),
/**
* 达梦
*/
DM(DbType.DM, "DM DBMS", "FIND_IN_SET('#{value}', #{column}) <> 0"),
/**
* 人大金仓
*/
KINGBASE_ES(DbType.KINGBASE_ES, "KingbaseES", "POSITION('#{value}' IN #{column}) <> 0"),
;
public static final Map<String, DbTypeEnum> MAP_BY_NAME = Arrays.stream(values())
.collect(Collectors.toMap(DbTypeEnum::getProductName, Function.identity()));
public static final Map<DbType, DbTypeEnum> MAP_BY_MP = Arrays.stream(values())
.collect(Collectors.toMap(DbTypeEnum::getMpDbType, Function.identity()));
/**
* MyBatis Plus 类型
*/
private final DbType mpDbType;
/**
* 数据库产品名
*/
private final String productName;
/**
* SQL FIND_IN_SET 模板
*/
private final String findInSetTemplate;
public static DbType find(String databaseProductName) {
if (StrUtil.isBlank(databaseProductName)) {
return null;
}
return MAP_BY_NAME.get(databaseProductName).getMpDbType();
}
public static String getFindInSetTemplate(DbType dbType) {
return Optional.of(MAP_BY_MP.get(dbType).getFindInSetTemplate())
.orElseThrow(() -> new IllegalArgumentException("FIND_IN_SET not supported"));
}
}

View File

@@ -185,8 +185,8 @@ public interface BaseMapperX<T> extends MPJBaseMapper<T> {
return Db.updateBatchById(entities, size);
}
default Boolean insertOrUpdate(T entity) {
return Db.saveOrUpdate(entity);
default boolean insertOrUpdate(T entity) {
return Db.saveOrUpdate(entity);
}
default Boolean insertOrUpdateBatch(Collection<T> collection) {

View File

@@ -1,31 +0,0 @@
package cn.iocoder.yudao.framework.mybatis.core.type;
import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
import com.baomidou.mybatisplus.extension.handlers.AbstractJsonTypeHandler;
import com.fasterxml.jackson.core.type.TypeReference;
import java.util.Set;
/**
* 参考 {@link com.baomidou.mybatisplus.extension.handlers.JacksonTypeHandler} 实现
* 在我们将字符串反序列化为 Set 并且泛型为 Long 时,如果每个元素的数值太小,会被处理成 Integer 类型,导致可能存在隐性的 BUG。
*
* 例如说哦SysUserDO 的 postIds 属性
*
* @author 芋道源码
*/
public class JsonLongSetTypeHandler extends AbstractJsonTypeHandler<Object> {
private static final TypeReference<Set<Long>> TYPE_REFERENCE = new TypeReference<Set<Long>>(){};
@Override
protected Object parse(String json) {
return JsonUtils.parseObject(json, TYPE_REFERENCE);
}
@Override
protected String toJson(Object obj) {
return JsonUtils.toJsonString(obj);
}
}

View File

@@ -1,9 +1,14 @@
package cn.iocoder.yudao.framework.mybatis.core.util;
import cn.iocoder.yudao.framework.common.util.spring.SpringUtils;
import cn.iocoder.yudao.framework.mybatis.core.enums.DbTypeEnum;
import com.baomidou.dynamic.datasource.DynamicRoutingDataSource;
import com.baomidou.mybatisplus.annotation.DbType;
import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.SQLException;
/**
* JDBC 工具类
@@ -35,8 +40,22 @@ public class JdbcUtils {
* @return DB 类型
*/
public static DbType getDbType(String url) {
String name = com.alibaba.druid.util.JdbcUtils.getDbType(url, null);
return DbType.getDbType(name);
return com.baomidou.mybatisplus.extension.toolkit.JdbcUtils.getDbType(url);
}
/**
* 通过当前数据库连接获得对应的 DB 类型
*
* @return DB 类型
*/
public static DbType getDbType() {
DynamicRoutingDataSource dynamicRoutingDataSource = SpringUtils.getBean(DynamicRoutingDataSource.class);
DataSource dataSource = dynamicRoutingDataSource.determineDataSource();
try (Connection conn = dataSource.getConnection()) {
return DbTypeEnum.find(conn.getMetaData().getDatabaseProductName());
} catch (SQLException e) {
throw new IllegalArgumentException(e.getMessage());
}
}
}

View File

@@ -1,8 +1,11 @@
package cn.iocoder.yudao.framework.mybatis.core.util;
import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.common.pojo.PageParam;
import cn.iocoder.yudao.framework.common.pojo.SortingField;
import cn.iocoder.yudao.framework.mybatis.core.enums.DbTypeEnum;
import com.baomidou.mybatisplus.annotation.DbType;
import com.baomidou.mybatisplus.core.metadata.OrderItem;
import com.baomidou.mybatisplus.core.toolkit.StringPool;
import com.baomidou.mybatisplus.extension.plugins.MybatisPlusInterceptor;
@@ -34,7 +37,7 @@ public class MyBatisUtils {
// 排序字段
if (!CollectionUtil.isEmpty(sortingFields)) {
page.addOrder(sortingFields.stream().map(sortingField -> SortingField.ORDER_ASC.equals(sortingField.getOrder()) ?
OrderItem.asc(sortingField.getField()) : OrderItem.desc(sortingField.getField()))
OrderItem.asc(sortingField.getField()) : OrderItem.desc(sortingField.getField()))
.collect(Collectors.toList()));
}
return page;
@@ -56,7 +59,7 @@ public class MyBatisUtils {
/**
* 获得 Table 对应的表名
*
* <p>
* 兼容 MySQL 转义表名 `t_xxx`
*
* @param table 表
@@ -85,4 +88,19 @@ public class MyBatisUtils {
return new Column(tableName + StringPool.DOT + column);
}
/**
* 跨数据库的 find_in_set 实现
*
* @param column 字段名称
* @param value 查询值(不带单引号)
* @return sql
*/
public static String findInSet(String column, Object value) {
// 这里不用SqlConstants.DB_TYPE因为它是使用 primary 数据源的 url 推断出来的类型
DbType dbType = JdbcUtils.getDbType();
return DbTypeEnum.getFindInSetTemplate(dbType)
.replace("#{column}", column)
.replace("#{value}", StrUtil.toString(value));
}
}