前言

现因公司业务出现一张表数据增长快,在进行关联计算的时候速度太慢,添加了分表的业务,根据日期进行分表,遇到没有的表自动进行创建

具体实现

引入配置

<!-- Sharding-JDBC -->
<dependency>
  <groupId>org.apache.shardingsphere</groupId>
  <artifactId>shardingsphere-jdbc-core</artifactId>
  <version>5.3.0</version>
</dependency>

基本配置

server:
  port: 8081

spring:
  ### 处理连接池冲突 #####
  main:
    allow-bean-definition-overriding: true
  datasource:
    driver-class-name: org.apache.shardingsphere.driver.ShardingSphereDriver
    url: jdbc:shardingsphere:classpath:sharding.yaml

# mybatis-plus
mybatis-plus:
  mapper-locations: classpath*:/mapper/*Mapper.xml
  # 实体扫描,多个package用逗号或者分号分隔
  typeAliasesPackage: com.demo.*.entity
  # 测试环境打印sql
  configuration:
    log-impl: org.apache.ibatis.logging.stdout.StdOutImpl

pagehelper:
  helperDialect: postgresql

my:
  sharding:
    create-table:
      url: jdbc:mysql://localhost:3306/mybd?useUnicode=true&characterEncoding=UTF-8&serverTimezone=Asia/Shanghai
      username: root
      password: 123456
dataSources:
  ds:
    dataSourceClassName: com.zaxxer.hikari.HikariDataSource
    driverClassName: com.mysql.jdbc.Driver
    jdbcUrl: jdbc:mysql://localhost:3306/mybd?useUnicode=true&characterEncoding=UTF-8&serverTimezone=Asia/Shanghai
    username: root
    password: 123456

rules:
- !SHARDING
  tables:
    t_user:
      actualDataNodes: ds.t_user
      tableStrategy:
        standard:
          shardingColumn: create_time
          shardingAlgorithmName: auto-custom

  shardingAlgorithms:
    auto-custom:
# 基于类进行分表
      type: CLASS_BASED
      props:
        strategy: standard
        algorithmClassName: com.demo.module.config.sharding.TimeShardingAlgorithm

分片配置类

TimeShardingAlgorithm(实现接口,重写方法,自定义分片方法,和表的自动创建)

/**
 * <p> @Title TimeShardingAlgorithm
 * <p> @Description 分片算法,按月分片
 *
 * @author Tia.chen
 * @date 2023/12/20 11:33
 */
@Slf4j
public final class TimeShardingAlgorithm implements StandardShardingAlgorithm<LocalDateTime>, ShardingAutoTableAlgorithm {

    /**
     * 分片时间格式
     */
    private static final DateTimeFormatter TABLE_SHARD_TIME_FORMATTER = DateTimeFormatter.ofPattern("yyyyMM");

    /**
     * 完整时间格式
     */
    private static final DateTimeFormatter DATE_TIME_FORMATTER = DateTimeFormatter.ofPattern("yyyyMMdd HH:mm:ss");

    /**
     * 表分片符号,例:t_contract_202201 中,分片符号为 "_"
     */
    private final String TABLE_SPLIT_SYMBOL = "_";

    @Getter
    private Properties props;

    @Getter
    private int autoTablesAmount;

    @Override
    public void init(final Properties props) {
        this.props = props;
    }

    @Override
    public String doSharding(final Collection<String> availableTargetNames, final PreciseShardingValue<LocalDateTime> preciseShardingValue) {
        String logicTableName = preciseShardingValue.getLogicTableName();

        /// 打印分片信息
        log.info(">>>>>>>>>> 【INFO】精确分片,节点配置表名:{}", availableTargetNames);

        LocalDateTime dateTime = preciseShardingValue.getValue();
        String resultTableName = logicTableName + "_" + dateTime.format(TABLE_SHARD_TIME_FORMATTER);

        // 检查是否需要初始化
        if (availableTargetNames.size() == 1) {
            // 如果只有一个表,说明需要获取所有表名
            List<String> allTableNameBySchema = ShardingAlgorithmTool.getAllTableNameBySchema(logicTableName);
            availableTargetNames.clear();
            availableTargetNames.addAll(allTableNameBySchema);
            autoTablesAmount = allTableNameBySchema.size();
            return getShardingTableAndCreate(logicTableName, resultTableName, availableTargetNames);
        }

        return getShardingTableAndCreate(logicTableName, resultTableName, availableTargetNames);
    }

    @Override
    public Collection<String> doSharding(final Collection<String> availableTargetNames, final RangeShardingValue<LocalDateTime> rangeShardingValue) {
        String logicTableName = rangeShardingValue.getLogicTableName();

        /// 打印分片信息
        log.info(">>>>>>>>>> 【INFO】范围分片,节点配置表名:{}", availableTargetNames);

        // between and 的起始值
        Range<LocalDateTime> valueRange = rangeShardingValue.getValueRange();
        boolean hasLowerBound = valueRange.hasLowerBound();
        boolean hasUpperBound = valueRange.hasUpperBound();

        // 获取最大值和最小值
        LocalDateTime min = hasLowerBound ? valueRange.lowerEndpoint() :getLowerEndpoint(availableTargetNames);
        LocalDateTime max = hasUpperBound ? valueRange.upperEndpoint() :getUpperEndpoint(availableTargetNames);

        // 循环计算分表范围
        Set<String> resultTableNames = new LinkedHashSet<>();
        while (min.isBefore(max) || min.equals(max)) {
            String tableName = logicTableName + TABLE_SPLIT_SYMBOL + min.format(TABLE_SHARD_TIME_FORMATTER);
            resultTableNames.add(tableName);
            min = min.plusMonths(1);
        }
        return getShardingTablesAndCreate(logicTableName, resultTableNames, availableTargetNames);
    }

    @Override
    public String getType() {
        return "AUTO_CUSTOM";
    }

    // --------------------------------------------------------------------------------------------------------------
    // 私有方法
    // --------------------------------------------------------------------------------------------------------------


    /**
     * 检查分表获取的表名是否存在,不存在则自动建表
     *
     * @param logicTableName        逻辑表
     * @param resultTableNames     真实表名,例:t_user_202201
     * @param availableTargetNames 可用的数据库表名
     * @return 存在于数据库中的真实表名集合
     */
    public Set<String> getShardingTablesAndCreate(String logicTableName, Collection<String> resultTableNames, Collection<String> availableTargetNames) {
        return resultTableNames.stream().map(o -> getShardingTableAndCreate(logicTableName, o, availableTargetNames)).collect(Collectors.toSet());
    }

    /**
     * 检查分表获取的表名是否存在,不存在则自动建表
     * @param logicTableName   逻辑表
     * @param resultTableName 真实表名,例:t_user_202201
     * @return 确认存在于数据库中的真实表名
     */
    private String getShardingTableAndCreate(String logicTableName, String resultTableName, Collection<String> availableTargetNames) {
        // 缓存中有此表则返回,没有则判断创建
        if (availableTargetNames.contains(resultTableName)) {
            return resultTableName;
        } else {
            // 检查分表获取的表名不存在,需要自动建表
            boolean isSuccess = ShardingAlgorithmTool.createShardingTable(logicTableName, resultTableName);
            if (isSuccess) {
                // 如果建表成功,需要更新缓存
                availableTargetNames.add(resultTableName);
                autoTablesAmount++;
                return resultTableName;
            } else {
                // 如果建表失败,返回逻辑空表
                return logicTableName;
            }
        }
    }

    /**
     * 获取 最小分片值
     * @param tableNames 表名集合
     * @return 最小分片值
     */
    private LocalDateTime getLowerEndpoint(Collection<String> tableNames) {
        Optional<LocalDateTime> optional = tableNames.stream()
                .map(o -> LocalDateTime.parse(o.replace(TABLE_SPLIT_SYMBOL, "") + "01 00:00:00", DATE_TIME_FORMATTER))
                .min(Comparator.comparing(Function.identity()));
        if (optional.isPresent()) {
            return optional.get();
        } else {
            log.error(">>>>>>>>>> 【ERROR】获取数据最小分表失败,请稍后重试,tableName:{}", tableNames);
            throw new IllegalArgumentException("获取数据最小分表失败,请稍后重试");
        }
    }

    /**
     * 获取 最大分片值
     * @param tableNames 表名集合
     * @return 最大分片值
     */
    private LocalDateTime getUpperEndpoint(Collection<String> tableNames) {
        Optional<LocalDateTime> optional = tableNames.stream()
                .map(o -> LocalDateTime.parse(o.replace(TABLE_SPLIT_SYMBOL, "") + "01 00:00:00", DATE_TIME_FORMATTER))
                .max(Comparator.comparing(Function.identity()));
        if (optional.isPresent()) {
            return optional.get();
        } else {
            log.error(">>>>>>>>>> 【ERROR】获取数据最大分表失败,请稍后重试,tableName:{}", tableNames);
            throw new IllegalArgumentException("获取数据最大分表失败,请稍后重试");
        }
    }
}

ShardingAlgorithmTool(工具类)

/**
 * <p> @Title ShardingAlgorithmTool
 * <p> @Description 按月分片算法工具
 *
 * @author Tia.chen
 * @date 2023/12/20 14:03
 */
@Slf4j
public class ShardingAlgorithmTool {

    /** 表分片符号,例:t_user_202201 中,分片符号为 "_" */
    private static final String TABLE_SPLIT_SYMBOL = "_";

    /** 数据库配置 */
    private static final Environment ENV = SpringUtil.getApplicationContext().getEnvironment();
    private static final String DATASOURCE_URL = ENV.getProperty("my.sharding.create-table.url");
    private static final String DATASOURCE_USERNAME = ENV.getProperty("my.sharding.create-table.username");
    private static final String DATASOURCE_PASSWORD = ENV.getProperty("my.sharding.create-table.password");

    /** 配置文件路径 */
    private static final String CONFIG_FILE = "sharding-tables.yaml";

    /**
     * 获取所有表名
     * @return 表名集合
     * @param logicTableName 逻辑表
     */
    public static List<String> getAllTableNameBySchema(String logicTableName) {
        List<String> tableNames = new ArrayList<>();
        if (StringUtils.isEmpty(DATASOURCE_URL) || StringUtils.isEmpty(DATASOURCE_USERNAME) || StringUtils.isEmpty(DATASOURCE_PASSWORD)) {
            log.error(">>>>>>>>>> 【ERROR】数据库连接配置有误,请稍后重试,URL:{}, username:{}, password:{}", DATASOURCE_URL, DATASOURCE_USERNAME, DATASOURCE_PASSWORD);
            throw new IllegalArgumentException("数据库连接配置有误,请稍后重试");
        }
        try (Connection conn = DriverManager.getConnection(DATASOURCE_URL, DATASOURCE_USERNAME, DATASOURCE_PASSWORD);
             Statement st = conn.createStatement()) {
            try (ResultSet rs = st.executeQuery("show TABLES like '" + logicTableName + TABLE_SPLIT_SYMBOL + "%'")) {
                while (rs.next()) {
                    String tableName = rs.getString(1);
                    // 匹配分表格式 例:^(t\_contract_\d{6})$
                    if (tableName != null && tableName.matches(String.format("^(%s\\d{6})$", logicTableName + TABLE_SPLIT_SYMBOL))) {
                        tableNames.add(rs.getString(1));
                    }
                }
            }
        } catch (SQLException e) {
            log.error(">>>>>>>>>> 【ERROR】数据库连接失败,请稍后重试,原因:{}", e.getMessage(), e);
            throw new IllegalArgumentException("数据库连接失败,请稍后重试");
        }
        return tableNames;
    }


    // --------------------------------------------------------------------------------------------------------------
    // 私有方法
    // --------------------------------------------------------------------------------------------------------------

    /**
     * 获取数据源
     */
//    private static Map<String, DataSource> getDataSourceMap() {
//        return getActualDataSources().entrySet().stream().filter(entry -> ACTUAL_DATA_SOURCE_NAMES.contains(entry.getKey())).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
//    }

    /**
     * 获取数据源配置
     */
    private static File getShardingYAMLFile() {
        return new File(Objects.requireNonNull(
                ShardingAlgorithmTool.class.getClassLoader().getResource(CONFIG_FILE), String.format("File `%s` is not existed.", CONFIG_FILE)).getFile());
    }

    /**
     * 刷新ActualDataNodes
     */
    private static void updateShardRuleActualDataNodes(ShardingSphereDataSource dataSource, String logicTableName, String newActualDataNodes) throws NoSuchFieldException {
        // Context manager.
        Field contextManagerField = dataSource.getClass().getDeclaredField("contextManager");
        ContextManager contextManager = (ContextManager) ReflectionUtils.getField(contextManagerField, dataSource);

        // Rule configuration.
        ShardingSphereMetaData shardingSphereMetaData = contextManager
                .getMetaDataContexts()
                .getMetaData();
        Collection<RuleConfiguration> newRuleConfigList = new LinkedList<>();
        Collection<RuleConfiguration> oldRuleConfigList = shardingSphereMetaData
                .getGlobalRuleMetaData()
                .getConfigurations();

        for (RuleConfiguration oldRuleConfig : oldRuleConfigList) {
            if (oldRuleConfig instanceof ShardingRuleConfiguration) {

                // Algorithm provided sharding rule configuration
                ShardingRuleConfiguration oldAlgorithmConfig = (ShardingRuleConfiguration) oldRuleConfig;
                ShardingRuleConfiguration newAlgorithmConfig = new ShardingRuleConfiguration();

                // Sharding table rule configuration Collection
                Collection<ShardingTableRuleConfiguration> newTableRuleConfigList = new LinkedList<>();
                Collection<ShardingTableRuleConfiguration> oldTableRuleConfigList = oldAlgorithmConfig.getTables();

                oldTableRuleConfigList.forEach(oldTableRuleConfig -> {
                    if (logicTableName.equals(oldTableRuleConfig.getLogicTable())) {
                        ShardingTableRuleConfiguration newTableRuleConfig = new ShardingTableRuleConfiguration(oldTableRuleConfig.getLogicTable(), newActualDataNodes);
                        newTableRuleConfig.setTableShardingStrategy(oldTableRuleConfig.getTableShardingStrategy());
                        newTableRuleConfig.setDatabaseShardingStrategy(oldTableRuleConfig.getDatabaseShardingStrategy());
                        newTableRuleConfig.setKeyGenerateStrategy(oldTableRuleConfig.getKeyGenerateStrategy());

                        newTableRuleConfigList.add(newTableRuleConfig);
                    } else {
                        newTableRuleConfigList.add(oldTableRuleConfig);
                    }
                });

                newAlgorithmConfig.setTables(newTableRuleConfigList);
                newAlgorithmConfig.setAutoTables(oldAlgorithmConfig.getAutoTables());
                newAlgorithmConfig.setBindingTableGroups(oldAlgorithmConfig.getBindingTableGroups());
                newAlgorithmConfig.setBroadcastTables(oldAlgorithmConfig.getBroadcastTables());
                newAlgorithmConfig.setDefaultDatabaseShardingStrategy(oldAlgorithmConfig.getDefaultDatabaseShardingStrategy());
                newAlgorithmConfig.setDefaultTableShardingStrategy(oldAlgorithmConfig.getDefaultTableShardingStrategy());
                newAlgorithmConfig.setDefaultKeyGenerateStrategy(oldAlgorithmConfig.getDefaultKeyGenerateStrategy());
                newAlgorithmConfig.setDefaultShardingColumn(oldAlgorithmConfig.getDefaultShardingColumn());
                newAlgorithmConfig.setShardingAlgorithms(oldAlgorithmConfig.getShardingAlgorithms());
                newAlgorithmConfig.setKeyGenerators(oldAlgorithmConfig.getKeyGenerators());

                newRuleConfigList.add(newAlgorithmConfig);
            }
        }

        // update context
        String schemaName = "logic_db";
        contextManager.alterRuleConfiguration(schemaName, newRuleConfigList);
    }

    /**
     * 创建分表2
     * @param logicTableName  逻辑表
     * @param resultTableName 真实表名,例:t_user_202401
     * @return 创建结果(true创建成功,false未创建)
     */
    public static boolean createShardingTable(String logicTableName, String resultTableName) {
        // 根据日期判断,当前月份之后分表不提前创建
        String month = resultTableName.replace(logicTableName + TABLE_SPLIT_SYMBOL,"");
        YearMonth shardingMonth = YearMonth.parse(month, DateTimeFormatter.ofPattern("yyyyMM"));
        if (shardingMonth.isAfter(YearMonth.now())) {
            return false;
        }

        synchronized (logicTableName.intern()) {
            // 缓存中无此表,则建表并添加缓存
            executeSql(Collections.singletonList("CREATE TABLE IF NOT EXISTS `" + resultTableName + "` LIKE `" + logicTableName + "`;"));
        }
        return true;
    }

    /**
     * 执行SQL
     * @param sqlList SQL集合
     */
    private static void executeSql(List<String> sqlList) {
        if (StringUtils.isEmpty(DATASOURCE_URL) || StringUtils.isEmpty(DATASOURCE_USERNAME) || StringUtils.isEmpty(DATASOURCE_PASSWORD)) {
            log.error(">>>>>>>>>> 【ERROR】数据库连接配置有误,请稍后重试,URL:{}, username:{}, password:{}", DATASOURCE_URL, DATASOURCE_USERNAME, DATASOURCE_PASSWORD);
            throw new IllegalArgumentException("数据库连接配置有误,请稍后重试");
        }
        try (Connection conn = DriverManager.getConnection(DATASOURCE_URL, DATASOURCE_USERNAME, DATASOURCE_PASSWORD)) {
            try (Statement st = conn.createStatement()) {
                conn.setAutoCommit(false);
                for (String sql : sqlList) {
                    st.execute(sql);
                }
            } catch (Exception e) {
                conn.rollback();
                log.error(">>>>>>>>>> 【ERROR】数据表创建执行失败,请稍后重试,原因:{}", e.getMessage(), e);
                throw new IllegalArgumentException("数据表创建执行失败,请稍后重试");
            }
        } catch (SQLException e) {
            log.error(">>>>>>>>>> 【ERROR】数据库连接失败,请稍后重试,原因:{}", e.getMessage(), e);
            throw new IllegalArgumentException("数据库连接失败,请稍后重试");
        }
    }

}

使用

测试类

@SpringBootTest
class SpringbootDemoApplicationTests {

    private final DateTimeFormatter DATE_TIME_FORMATTER = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss");

    @Autowired
    private UserService userService;

    @Test
    void saveTest() {
        List<User> user1s = new ArrayList<>(3);
        LocalDateTime time1 = LocalDateTime.parse("2022-01-01 00:00:00", DATE_TIME_FORMATTER);
        LocalDateTime time2 = LocalDateTime.parse("2022-02-01 00:00:00", DATE_TIME_FORMATTER);
        user1s.add(new User("Tia.chen_1", "123456", 10, time1, time1));
        user1s.add(new User("Tia.chen_2", "123456", 11, time2, time2));
        userService.saveBatch(user1s);
    }

    @Test
    void listTest() {
        PageHelper.startPage(1, 100);
        List<User> user1s = userService.list(
                new LambdaQueryWrapper<User>()
                        .between(User::getCreateTime,
                                 LocalDateTime.parse("2023-01-01 00:00:00", DATE_TIME_FORMATTER),
                                 LocalDateTime.parse("2025-07-28 23:59:59", DATE_TIME_FORMATTER))
        );
        PageInfo<User> pageInfo = new PageInfo<>(user1s);
        System.out.println(">>>>>>>>>> 【Result】<<<<<<<<<< ");
        System.out.println(pageInfo);
    }
}

代码仓库 https://gitee.com/cbin0526/springboot-sharding-dome.git