MyBatis 动态分表

Subtable.java

package com.xxxxx.common.annotations;

import java.lang.annotation.*;

 /**
  * 分表
  * @author zeng026
  *
  */
@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Inherited
@Documented
public @interface Subtable {
    /**
     * 逻辑表名
     *
     * @return String
     */
    String[] tableName();

    /**
     * 分表策略
     *
     * @return 策略名
     */
    String strategy();

}

AbstractSplitTableStrategy.java

package com.xxxxx.common.mybatis;


import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.mapping.ParameterMapping;
import org.springframework.beans.factory.annotation.Autowired;

import javax.annotation.PostConstruct;
import java.util.List;

@Slf4j
public abstract class AbstractSplitTableStrategy {

    /**
     * 策略管理者
     */
    @Autowired
    private StrategyManager strategyManager;

    abstract String key();

    @PostConstruct
    public void init() {
        this.register();
    }

    /**
     * @param logicTableName 逻辑表名
     * @param list           映射
     * @param val            值
     * @return 实际表名
     */
    public abstract String doSharding(String logicTableName, List<ParameterMapping> list, Object val);

    protected final void register() {
        String name = key();
        strategyManager.registerStrategy(name, this);
    }

    /**
     * 从mybatis映射中取指定的值
     *
     * @param list        映射集
     * @param val         参数值
     * @param shardingKey 分片键
     * @return 分片键对应的值
     */
    protected String getShardingValue(List<ParameterMapping> list, Object val, String shardingKey) throws Exception {
        JSONObject obj;
        if (val.toString().contains("=")) { 
            //用变量传值
            String replaceAll = val.toString().replaceAll("=", ":");
            obj = (JSONObject) JSONObject.parse(replaceAll);
        } else {   
            //用对象传值
            obj = (JSONObject) JSONObject.parse(JSON.toJSONString(val));
        }
        for (ParameterMapping para : list) {
            String property = para.getProperty();
            log.info("abstract getShardingValue! shardingKey={} property={} value={}", shardingKey, property, obj.get(property));
            if (para.getProperty().equals(shardingKey)) {
                //获取制定sql参数
                return obj.getString(shardingKey); 
            }
        }
        throw new RuntimeException("Sharding value is null! shardingKey=" + shardingKey);
    }

}

BrandSplitTableStrategy.java

package com.xxxxx.common.mybatis;


import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.mapping.ParameterMapping;
import org.springframework.stereotype.Component;

import com.xxxxx.common.context.QualityApplicationContext;

import java.util.List;

 /**
  * 按品牌分表策略
  * @author zeng026
  *
  */
@Slf4j
@Component
public class BrandSplitTableStrategy extends AbstractSplitTableStrategy {
    public static final String BRAND_STRATEGY = "BRAND_STRATEGY";

    @Override
    public String key() {
        return BRAND_STRATEGY;
    }

    /**
     * logicTableName 原始表名
     * list 参数
     * val 入mybatis的参数对象
     */
    @Override
    public String doSharding(String logicTableName, List<ParameterMapping> list, Object val) {
        /**
         * 根据品牌code分表
         */
        String brand = "";
        try {
            brand = getShardingValue(list, val, "brand");
            //brand = QualityApplicationContext.getLoginUser().getBrandIdentify();
        }catch(Exception e) {
            brand = QualityApplicationContext.getLoginUser().getBrandIdentify();
        }
        return logicTableName + "_" + brand;
    }

}

StrategyManager.java

package com.xxxxx.common.mybatis;

import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

@Slf4j
@Component
public class StrategyManager {

    private final Map<String, AbstractSplitTableStrategy> strategies = new ConcurrentHashMap<>(10);

    public AbstractSplitTableStrategy getStrategy(String key) {
        return strategies.get(key);
    }

    public Map<String, AbstractSplitTableStrategy> getStrategies() {
        return strategies;
    }

    public void registerStrategy(String key, AbstractSplitTableStrategy strategy) {
        if (strategies.containsKey(key)) {
            log.error("Key is already in use! key={}", key);
            throw new RuntimeException("Key is already in use! key=" + key);
        }
        strategies.put(key, strategy);
    }
}

MyBatisInterceptor.java

package com.xxxx.system;


import lombok.extern.slf4j.Slf4j;

import org.apache.commons.lang3.StringUtils;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.reflection.DefaultReflectorFactory;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.springframework.stereotype.Component;

import com.xxxx.common.annotations.Subtable;
import com.xxxx.common.mybatis.AbstractSplitTableStrategy;
import com.xxxx.common.mybatis.StrategyManager;
import com.xxxx.tools.SpringUtil;

import java.lang.reflect.Field;
import java.sql.Connection;
import java.util.List;
import java.util.Properties;


@Slf4j
@Component
@Intercepts({
        @Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})
})
public class MyBatisInterceptor implements Interceptor {
    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
        MetaObject metaObject = MetaObject.forObject(statementHandler, SystemMetaObject.DEFAULT_OBJECT_FACTORY,
                SystemMetaObject.DEFAULT_OBJECT_WRAPPER_FACTORY, new DefaultReflectorFactory());
        MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement");
        // id为执行的mapper方法的全路径名
        String id = mappedStatement.getId();
        BoundSql boundSql = statementHandler.getBoundSql();
        // 注解逻辑判断 添加注解了才拦截
        Class<?> classType = Class.forName(id.substring(0, mappedStatement.getId().lastIndexOf(".")));
        if (classType.isAnnotationPresent(Subtable.class)) {
            Subtable tableSeg = classType.getAnnotation(Subtable.class);
            String sql = rewriteSql(tableSeg, boundSql);
            //通过反射修改sql语句
            Field field = boundSql.getClass().getDeclaredField("sql");
            field.setAccessible(true);
            field.set(boundSql, sql);
        }
        return invocation.proceed();
    }

    @Override
    public Object plugin(Object o) {
        return Plugin.wrap(o, this);
    }

    @Override
    public void setProperties(Properties properties) {
        log.warn("MyInterceptor=======" + properties.toString());
    }

    /**
     * 重新sql
     *
     * @param tableSeg 注解
     * @param boundSql sql信息
     * @return 重写后的sql
     */
    private String rewriteSql(Subtable tableSeg, BoundSql boundSql) {
        String sql = boundSql.getSql();
        Object parameterObject = boundSql.getParameterObject();
        List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
        StrategyManager strategyManager = (StrategyManager)SpringUtil.getBean("strategyManager");
        AbstractSplitTableStrategy strategy = strategyManager.getStrategy(tableSeg.strategy());


        for(String tableName : tableSeg.tableName()) {
            String newTableName = strategy.doSharding(tableName, parameterMappings, parameterObject);
            sql = StringUtils.replace(sql,tableName+" ",newTableName+" ");
        }

        return sql;
    }
}

SessionFilter.java

package com.xxxxx.system;

import java.io.IOException;

import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.springframework.stereotype.Component;

import com.alibaba.fastjson.JSON;
import com.arvato.auth.dao.entity.UserEntity;
import com.xxxxx.common.context.QualityApplicationContext;
import com.xxxxx.common.exception.BusinessException;

@Component
public class SessionFilter implements Filter {

    @Override
    public void destroy() {

    }

    @Override
    public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain)
            throws IOException, ServletException {

        HttpServletRequest request = (HttpServletRequest) req;
        //HttpServletResponse response = (HttpServletResponse) res;

        try {
            UserEntity currentUser = JSON.parseObject(request.getHeader("session"), UserEntity.class);
            if (null == currentUser) {
                throw new BusinessException("校验权限错误: 请先登录");
            }
            QualityApplicationContext.setLoginUser(currentUser);
        }catch(Exception e) {

        }

        chain.doFilter(req, res);

    }

    @Override
    public void init(FilterConfig arg0) throws ServletException {

    }

}

使用

package com.xxxxx.quality.review.dao;

import java.util.List;

import org.springframework.stereotype.Repository;

import com.xxxxx.common.annotations.Subtable;
import com.xxxxx.quality.review.bean.Review;
import com.xxxxx.quality.review.dto.ReviewDto;
import com.xxxxx.common.mybatis.BrandSplitTableStrategy;



@Repository(value = "dataMapper") 
//@Subtable(tableName = "tbl_review_data", strategy = BrandSplitTableStrategy.BRAND_STRATEGY)
@Subtable(tableName = {"xxx1","xxxx2"}, strategy = BrandSplitTableStrategy.BRAND_STRATEGY)
public interface dataMapper {
    List<Review> selectList(ReviewDto reviewDto);
    Review selectByPrimaryKey(String id);
    int updateByPrimary(Review review);
    void deleteByPrimaryKey(String id);
    void insertData(ReviewDto reviewDto);
}
评论