Mybatis攔截器之數據權限過濾與分頁集成

需求場景

最近項目有個數據權限的業務需求,要求大體爲每一個單位只能查看本級單位及下屬單位的數據,例如:一個集團軍下屬十二個旅,那麼軍級用戶能夠看到全部數據,而每一個旅則只能看到本旅部的數據,以此類推;java

解決方案之改SQL

原sql

SELECT
	a.id AS "id",
	a.NAME AS "name",
	a.sex_cd AS "sexCd",
	a.org_id AS "orgId",
	a.STATUS AS "status",
	a.create_org_id AS "createOrgId"
FROM
	pty_person a
WHERE
	a. STATUS = 0

org_id是單位的標識,也就是where條件裏再加個單位標識的過濾。mysql

改後sql

SELECT
	a.id AS "id",
	a.NAME AS "name",
	a.sex_cd AS "sexCd",
	a.org_id AS "orgId",
	a.STATUS AS "status",
	a.create_org_id AS "createOrgId"
FROM
	pty_person a
WHERE
	a. STATUS = 0
	and a.org_id LIKE concat(710701070102, '%')

固然經過這個辦法也能夠實現數據的過濾,但這樣的話相比你們也都有同感,那就是每一個業務模塊 每一個人都要進行SQL改動,此次是根據單位過濾、明天又再根據其餘的屬性過濾,意味着要不停的改來改去,可謂是場面壯觀也,並且這種集體改造耗費了時間精力不說,還會有不少不肯定因素,好比SQL寫錯,存在漏網之魚等等。所以這個解決方案確定是直接PASS掉咯;spring

解決方案之攔截器

因爲項目大部分採用的持久層框架是Mybatis,也是使用的Mybatis進行分頁攔截處理,所以直接採用了Mybatis攔截器實現數據權限過濾。sql

一、自定義數據權限過濾註解 PermissionAop,負責過濾的開關 

package com.raising.framework.annotation;

import java.lang.annotation.*;

/**
 * 數據權限過濾自定義註解
 * @author lihaoshan
 * @date 2018-07-19
 * */
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface PermissionAop {

    String value() default "";
}

二、定義全局配置 PermissionConfig 類加載 權限過濾配置文件

package com.raising.framework.config;

import com.raising.utils.PropertiesLoader;
import org.apache.commons.lang.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.HashMap;
import java.util.Map;

/**
 * 全局配置
 * 對應 permission.properties
 * @author lihaoshan
 */
public class PermissionConfig {
    private static Logger logger = LoggerFactory.getLogger(PropertiesLoader.class);

    /**
     * 保存全局屬性值
     */
    private static Map<String, String> map = new HashMap<>(16);

    /**
     * 屬性文件加載對象
     */
    private static PropertiesLoader loader = new PropertiesLoader(
            "permission.properties");

    /**
     * 獲取配置
     */
    public static String getConfig(String key) {
        if(loader == null){
            logger.info("缺失配置文件 - permission.properties");
            return null;
        }
        String value = map.get(key);
        if (value == null) {
            value = loader.getProperty(key);
            map.put(key, value != null ? value : StringUtils.EMPTY);
        }
        return value;
    }

}

三、建立權限過濾的配置文件 permission.properties,用於配置須要攔截的DAO的 namespace

(因爲註解@PermissionAop是加在DAO層某個接口上的,而咱們分頁接口爲封裝的公共BaseDAO,因此若是僅僅使用註解方式開關攔截的話,會影響到全部的業務模塊,所以須要結合額外的配置文件)數據庫

# 須要進行攔截的SQL所屬namespace
permission.intercept.namespace=com.raising.modules.pty.dao.PtyGroupDao,com.raising.modules.pty.dao.PtyPersonDao

四、自定義權限工具類

根據 StatementHandler 獲取Permission註解對象:

package com.raising.utils.permission;

import com.raising.framework.annotation.PermissionAop;
import org.apache.ibatis.mapping.MappedStatement;

import java.lang.reflect.Method;

/**
 * 自定義權限相關工具類
 * @author lihaoshan
 * @date 2018-07-20
 * */
public class PermissionUtils {

    /**
     * 根據 StatementHandler 獲取 註解對象
     * @author lihaoshan
     * @date 2018-07-20
     */
    public static PermissionAop getPermissionByDelegate(MappedStatement mappedStatement){
        PermissionAop permissionAop = null;
        try {
            String id = mappedStatement.getId();
            String className = id.substring(0, id.lastIndexOf("."));
            String methodName = id.substring(id.lastIndexOf(".") + 1, id.length());
            final Class cls = Class.forName(className);
            final Method[] method = cls.getMethods();
            for (Method me : method) {
                if (me.getName().equals(methodName) && me.isAnnotationPresent(PermissionAop.class)) {
                    permissionAop = me.getAnnotation(PermissionAop.class);
                }
            }
        }catch (Exception e){
            e.printStackTrace();
        }
        return permissionAop;
    }
}

五、建立分頁攔截器 MybatisSpringPageInterceptor 或進行改造(本文是在Mybatis分頁攔截器基礎上進行的數據權限攔截改造,SQL包裝必定要在執行分頁以前,也就是獲取到原始SQL後就進行數據過濾包裝) 

首先看數據權限攔截核心代碼:apache

  • 獲取須要進行攔截的DAO層namespace拼接串;
  • 獲取當前mapped所屬namespace;
  • 判斷配置文件中的namespace是否包含當前的mapped所屬的namespace,若是包含則繼續,不然直接放行;
  • 獲取數據權限註解對象,及註解的值;
  • 判斷註解值是否爲DATA_PERMISSION_INTERCEPT,是則攔截、並進行過濾SQL包裝,不然放行;
  • 根據包裝後的SQL查分頁總數,不能使用原始SQL進行查詢;
  • 執行請求方法,獲取攔截後的分頁結果;

執行流程圖:微信

攔截器源碼:

package com.raising.framework.interceptor;

import com.raising.StaticParam;
import com.raising.framework.annotation.PermissionAop;
import com.raising.framework.config.PermissionConfig;
import com.raising.modules.sys.entity.User;
import com.raising.utils.JStringUtils;
import com.raising.utils.UserUtils;
import com.raising.utils.permission.PermissionUtils;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.executor.parameter.ParameterHandler;
import org.apache.ibatis.executor.statement.RoutingStatementHandler;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.lang.reflect.Field;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.List;
import java.util.Map;
import java.util.Properties;

/**
 * 分頁攔截器
 * @author GaoYuan
 * @author lihaoshan 增長了數據權限的攔截過濾
 * @datetime 2017/12/1 下午5:43
 */
@Intercepts({ @Signature(method = "prepare", type = StatementHandler.class, args = { Connection.class }),
        @Signature(method = "query", type = Executor.class, args = { MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class }) })
public class MybatisSpringPageInterceptor implements Interceptor {
    private static final Logger log = LoggerFactory.getLogger(MybatisSpringPageInterceptor.class);

    public static final String MYSQL = "mysql";
    public static final String ORACLE = "oracle";
    /**數據庫類型,不一樣的數據庫有不一樣的分頁方法*/
    protected String databaseType;

    @SuppressWarnings("rawtypes")
    protected ThreadLocal<Page> pageThreadLocal = new ThreadLocal<Page>();

    public String getDatabaseType() {
        return databaseType;
    }

    public void setDatabaseType(String databaseType) {
        if (!databaseType.equalsIgnoreCase(MYSQL) && !databaseType.equalsIgnoreCase(ORACLE)) {
            throw new PageNotSupportException("Page not support for the type of database, database type [" + databaseType + "]");
        }
        this.databaseType = databaseType;
    }

    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    @Override
    public void setProperties(Properties properties) {
        String databaseType = properties.getProperty("databaseType");
        if (databaseType != null) {
            setDatabaseType(databaseType);
        }
    }

    @Override
    @SuppressWarnings({ "unchecked", "rawtypes" })
    public Object intercept(Invocation invocation) throws Throwable {
        // 控制SQL和查詢總數的地方
        if (invocation.getTarget() instanceof StatementHandler) {
            Page page = pageThreadLocal.get();
            //不是分頁查詢
            if (page == null) {
                return invocation.proceed();
            }

            RoutingStatementHandler handler = (RoutingStatementHandler) invocation.getTarget();
            StatementHandler delegate = (StatementHandler) ReflectUtil.getFieldValue(handler, "delegate");
            BoundSql boundSql = delegate.getBoundSql();

            Connection connection = (Connection) invocation.getArgs()[0];
            // 準備數據庫類型
            prepareAndCheckDatabaseType(connection);
            MappedStatement mappedStatement = (MappedStatement) ReflectUtil.getFieldValue(delegate, "mappedStatement");

            String sql = boundSql.getSql();

            /** 單位數據權限攔截 begin */
            //獲取須要進行攔截的DAO層namespace拼接串
            String interceptNamespace = PermissionConfig.getConfig("permission.intercept.namespace");

            //獲取當前mapped的namespace
            String mappedStatementId = mappedStatement.getId();
            String className = mappedStatementId.substring(0, mappedStatementId.lastIndexOf("."));

            if(JStringUtils.isNotBlank(interceptNamespace)){
                //判斷配置文件中的namespace是否與當前的mapped namespace匹配,若是包含則進行攔截,不然放行
                if(interceptNamespace.contains(className)){
                    //獲取數據權限註解對象
                    PermissionAop permissionAop = PermissionUtils.getPermissionByDelegate(mappedStatement);
                    if (permissionAop != null){
                        //獲取註解的值
                        String permissionAopValue = permissionAop.value();
                        //判斷註解是否開啓攔截
                        if(StaticParam.DATA_PERMISSION_INTERCEPT.equals(permissionAopValue) ){
                            if(log.isInfoEnabled()){
                                log.info("數據權限攔截【拼接SQL】...");
                            }
                            //返回攔截包裝後的sql
                            sql = permissionSql(sql);
                            ReflectUtil.setFieldValue(boundSql, "sql", sql);
                        } else {
                            if(log.isInfoEnabled()){
                                log.info("數據權限放行...");
                            }
                        }
                    }

                }
            }
            /** 單位數據權限攔截 end */

            if (page.getTotalPage() > -1) {
                if (log.isTraceEnabled()) {
                    log.trace("已經設置了總頁數, 不須要再查詢總數.");
                }
            } else {
                Object parameterObj = boundSql.getParameterObject();
///                MappedStatement mappedStatement = (MappedStatement) ReflectUtil.getFieldValue(delegate, "mappedStatement");
                queryTotalRecord(page, parameterObj, mappedStatement, sql,connection);
            }

            String pageSql = buildPageSql(page, sql);
            if (log.isDebugEnabled()) {
                log.debug("分頁時, 生成分頁pageSql......");
            }
            ReflectUtil.setFieldValue(boundSql, "sql", pageSql);

            return invocation.proceed();
        } else { // 查詢結果的地方
            // 獲取是否有分頁Page對象
            Page<?> page = findPageObject(invocation.getArgs()[1]);
            if (page == null) {
                if (log.isTraceEnabled()) {
                    log.trace("沒有Page對象做爲參數, 不是分頁查詢.");
                }
                return invocation.proceed();
            } else {
                if (log.isTraceEnabled()) {
                    log.trace("檢測到分頁Page對象, 使用分頁查詢.");
                }
            }
            //設置真正的parameterObj
            invocation.getArgs()[1] = extractRealParameterObject(invocation.getArgs()[1]);

            pageThreadLocal.set(page);
            try {
                // Executor.query(..)
                Object resultObj = invocation.proceed();
                if (resultObj instanceof List) {
                    /* @SuppressWarnings({ "unchecked", "rawtypes" }) */
                    page.setResults((List) resultObj);
                }
                return resultObj;
            } finally {
                pageThreadLocal.remove();
            }
        }
    }

    protected Page<?> findPageObject(Object parameterObj) {
        if (parameterObj instanceof Page<?>) {
            return (Page<?>) parameterObj;
        } else if (parameterObj instanceof Map) {
            for (Object val : ((Map<?, ?>) parameterObj).values()) {
                if (val instanceof Page<?>) {
                    return (Page<?>) val;
                }
            }
        }
        return null;
    }

    /**
     * <pre>
     * 把真正的參數對象解析出來
     * Spring會自動封裝對個參數對象爲Map<String, Object>對象
     * 對於經過@Param指定key值參數咱們不作處理,由於XML文件須要該KEY值
     * 而對於沒有@Param指定時,Spring會使用0,1做爲主鍵
     * 對於沒有@Param指定名稱的參數,通常XML文件會直接對真正的參數對象解析,
     * 此時解析出真正的參數做爲根對象
     * </pre>
     * @param parameterObj
     * @return
     */
    protected Object extractRealParameterObject(Object parameterObj) {
        if (parameterObj instanceof Map<?, ?>) {
            Map<?, ?> parameterMap = (Map<?, ?>) parameterObj;
            if (parameterMap.size() == 2) {
                boolean springMapWithNoParamName = true;
                for (Object key : parameterMap.keySet()) {
                    if (!(key instanceof String)) {
                        springMapWithNoParamName = false;
                        break;
                    }
                    String keyStr = (String) key;
                    if (!"0".equals(keyStr) && !"1".equals(keyStr)) {
                        springMapWithNoParamName = false;
                        break;
                    }
                }
                if (springMapWithNoParamName) {
                    for (Object value : parameterMap.values()) {
                        if (!(value instanceof Page<?>)) {
                            return value;
                        }
                    }
                }
            }
        }
        return parameterObj;
    }

    protected void prepareAndCheckDatabaseType(Connection connection) throws SQLException {
        if (databaseType == null) {
            String productName = connection.getMetaData().getDatabaseProductName();
            if (log.isTraceEnabled()) {
                log.trace("Database productName: " + productName);
            }
            productName = productName.toLowerCase();
            if (productName.indexOf(MYSQL) != -1) {
                databaseType = MYSQL;
            } else if (productName.indexOf(ORACLE) != -1) {
                databaseType = ORACLE;
            } else {
                throw new PageNotSupportException("Page not support for the type of database, database product name [" + productName + "]");
            }
            if (log.isInfoEnabled()) {
                log.info("自動檢測到的數據庫類型爲: " + databaseType);
            }
        }
    }

    /**
     * <pre>
     * 生成分頁SQL
     * </pre>
     *
     * @param page
     * @param sql
     * @return
     */
    protected String buildPageSql(Page<?> page, String sql) {
        if (MYSQL.equalsIgnoreCase(databaseType)) {
            return buildMysqlPageSql(page, sql);
        } else if (ORACLE.equalsIgnoreCase(databaseType)) {
            return buildOraclePageSql(page, sql);
        }
        return sql;
    }

    /**
     * <pre>
     * 生成Mysql分頁查詢SQL
     * </pre>
     *
     * @param page
     * @param sql
     * @return
     */
    protected String buildMysqlPageSql(Page<?> page, String sql) {
        // 計算第一條記錄的位置,Mysql中記錄的位置是從0開始的。
        int offset = (page.getPageNo() - 1) * page.getPageSize();
        if(offset<0){
            return " limit 0 ";
        }
        return new StringBuilder(sql).append(" limit ").append(offset).append(",").append(page.getPageSize()).toString();
    }

    /**
     * <pre>
     * 生成Oracle分頁查詢SQL
     * </pre>
     *
     * @param page
     * @param sql
     * @return
     */
    protected String buildOraclePageSql(Page<?> page, String sql) {
        // 計算第一條記錄的位置,Oracle分頁是經過rownum進行的,而rownum是從1開始的
        int offset = (page.getPageNo() - 1) * page.getPageSize() + 1;
        StringBuilder sb = new StringBuilder(sql);
        sb.insert(0, "select u.*, rownum r from (").append(") u where rownum < ").append(offset + page.getPageSize());
        sb.insert(0, "select * from (").append(") where r >= ").append(offset);
        return sb.toString();
    }

    /**
     * <pre>
     * 查詢總數
     * </pre>
     *
     * @param page
     * @param parameterObject
     * @param mappedStatement
     * @param sql
     * @param connection
     * @throws SQLException
     */
    protected void queryTotalRecord(Page<?> page, Object parameterObject, MappedStatement mappedStatement, String sql, Connection connection) throws SQLException {
        BoundSql boundSql = mappedStatement.getBoundSql(page);
///        String sql = boundSql.getSql();

        String countSql = this.buildCountSql(sql);
        if (log.isDebugEnabled()) {
            log.debug("分頁時, 生成countSql......");
        }

        List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
        BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(), countSql, parameterMappings, parameterObject);
        ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, parameterObject, countBoundSql);
        PreparedStatement pstmt = null;
        ResultSet rs = null;
        try {
            pstmt = connection.prepareStatement(countSql);
            parameterHandler.setParameters(pstmt);
            rs = pstmt.executeQuery();
            if (rs.next()) {
                long totalRecord = rs.getLong(1);
                page.setTotalRecord(totalRecord);
            }
        } finally {
            if (rs != null) {
                try {
                    rs.close();
                } catch (Exception e) {
                    if (log.isWarnEnabled()) {
                        log.warn("關閉ResultSet時異常.", e);
                    }
                }
            }
            if (pstmt != null) {
                try {
                    pstmt.close();
                } catch (Exception e) {
                    if (log.isWarnEnabled()) {
                        log.warn("關閉PreparedStatement時異常.", e);
                    }
                }
            }
        }
    }

    /**
     * 根據原Sql語句獲取對應的查詢總記錄數的Sql語句
     *
     * @param sql
     * @return
     */
    protected String buildCountSql(String sql) {
        //查出第一個from,先轉成小寫
        sql = sql.toLowerCase();
        int index = sql.indexOf("from");
        return "select count(0) " + sql.substring(index);
    }

    /**
     * 利用反射進行操做的一個工具類
     *
     */
    private static class ReflectUtil {
        /**
         * 利用反射獲取指定對象的指定屬性
         *
         * @param obj 目標對象
         * @param fieldName 目標屬性
         * @return 目標屬性的值
         */
        public static Object getFieldValue(Object obj, String fieldName) {
            Object result = null;
            Field field = ReflectUtil.getField(obj, fieldName);
            if (field != null) {
                field.setAccessible(true);
                try {
                    result = field.get(obj);
                } catch (IllegalArgumentException e) {
                    // TODO Auto-generated catch block
                    e.printStackTrace();
                } catch (IllegalAccessException e) {
                    // TODO Auto-generated catch block
                    e.printStackTrace();
                }
            }
            return result;
        }

        /**
         * 利用反射獲取指定對象裏面的指定屬性
         *
         * @param obj 目標對象
         * @param fieldName 目標屬性
         * @return 目標字段
         */
        private static Field getField(Object obj, String fieldName) {
            Field field = null;
            for (Class<?> clazz = obj.getClass(); clazz != Object.class; clazz = clazz.getSuperclass()) {
                try {
                    field = clazz.getDeclaredField(fieldName);
                    break;
                } catch (NoSuchFieldException e) {
                    // 這裏不用作處理,子類沒有該字段可能對應的父類有,都沒有就返回null。
                }
            }
            return field;
        }

        /**
         * 利用反射設置指定對象的指定屬性爲指定的值
         *
         * @param obj 目標對象
         * @param fieldName 目標屬性
         * @param fieldValue 目標值
         */
        public static void setFieldValue(Object obj, String fieldName, String fieldValue) {
            Field field = ReflectUtil.getField(obj, fieldName);
            if (field != null) {
                try {
                    field.setAccessible(true);
                    field.set(obj, fieldValue);
                } catch (IllegalArgumentException e) {
                    // TODO Auto-generated catch block
                    e.printStackTrace();
                } catch (IllegalAccessException e) {
                    // TODO Auto-generated catch block
                    e.printStackTrace();
                }
            }
        }
    }

    public static class PageNotSupportException extends RuntimeException {

        /** serialVersionUID*/
        private static final long serialVersionUID = 1L;

        public PageNotSupportException() {
            super();
        }

        public PageNotSupportException(String message, Throwable cause) {
            super(message, cause);
        }

        public PageNotSupportException(String message) {
            super(message);
        }

        public PageNotSupportException(Throwable cause) {
            super(cause);
        }
    }

    /**
     * 數據權限sql包裝【只能查看本級單位及下屬單位的數據】
     * @author lihaoshan
     * @date 2018-07-19
     */
    protected String permissionSql(String sql) {
        StringBuilder sbSql = new StringBuilder(sql);
        //獲取當前登陸人
        User user = UserUtils.getLoginUser();
        String orgId =null;
        if (user != null) {
            //獲取當前登陸人所屬單位標識
            orgId = user.getOrganizationId();
        }
        //若是有動態參數 orgId
        if(orgId != null){
            sbSql = new StringBuilder("select * from (")
                    .append(sbSql)
                    .append(" ) s ")
                    .append(" where s.createOrgId like concat("+ orgId +",'%') ");
        }
        return sbSql.toString();
    }
}

至此,Mybatis攔截器改造已完成,感謝各位大佬的耐性閱讀,有什麼問題和建議歡迎各位大佬留言,以此互相借鑑學習!

歡迎各位大佬關注個人我的微信訂閱號:session

相關文章
相關標籤/搜索