基於Mybatis的代碼生成器

package com.demo.utils;

import org.apache.poi.hssf.usermodel.HSSFWorkbook;
import org.apache.poi.ss.usermodel.Row;
import org.apache.poi.ss.usermodel.Sheet;

import java.io.*;
import java.sql.*;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * EntityUtil
 */
public class EntityUtil {

    //數據庫和java類型
    private final String type_char = "char";
    private final String type_date = "date";
    private final String type_timestamp = "timestamp";
    private final String type_int = "int";
    private final String type_bigint = "bigint";
    private final String type_text = "text";
    private final String type_bit = "bit";
    private final String type_decimal = "decimal";
    private final String type_blob = "blob";
    private final String type_Double = "double";

    //模塊名稱
    private final String moduleName = "demo"; // 對應模塊名稱(根據本身模塊作相應調整!!!務必修改^_^)

    //路徑配置
    private final String basePath = "D:\\demo\\work\\mybatis\\"; // 基礎路徑
    private final String bean_path = basePath + moduleName + "/dtos"; // dto存放路徑
    private final String mapper_path = basePath + moduleName + "/mappers"; // mapper存放路徑
    private final String xml_path = basePath + moduleName + "/mappers";
    private final String service_path = basePath + moduleName + "/services";
    private final String mybatis_path = basePath + moduleName + "/mybatis";

    //包路徑配置
    private final String basePackage = "com."; // 基礎包路徑
    private final String bean_package = basePackage + moduleName + ".dtos";
    private final String mapper_package = basePackage + moduleName + ".mappers";
    private final String service_package = basePackage + moduleName + ".services";

    //數據庫信息
    private final String url = "jdbc:mysql://127.0.0.1:3306/gred?characterEncoding=utf8"; // 數據庫鏈接串
    private final String user = "root"; // 數據庫名
    private final String password = "root"; // 數據庫密碼
    private final String driverName = "com.mysql.jdbc.Driver"; // 數據庫驅動

    //其餘變量

    private final String author = "BoomGred";

    private String tableName = null;
    private String beanName = null;
    private String mapperName = null;
    private String XMLMapperName = null;
    private String serviceName = null;
    private String tableComment = null;

    private Connection conn = null;

    public static void main(String[] args) {
        try {
            new EntityUtil().generate();
            // 自動打開生成文件的目錄
        } catch (ClassNotFoundException e) {
            e.printStackTrace();
        } catch (SQLException e) {
            e.printStackTrace();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    /**
     * 初始化獲取數據庫鏈接
     *
     * @throws ClassNotFoundException
     * @throws SQLException
     */
    private void init() throws ClassNotFoundException, SQLException {
        Class.forName(driverName);
        conn = DriverManager.getConnection(url, user, password);
    }

    /**
     * 獲取全部的表
     *
     * @return
     * @throws SQLException
     */
    private List<String> getTables() throws SQLException {
        List<String> tables = new ArrayList<String>();
        PreparedStatement pstate = conn.prepareStatement("show tables");
        ResultSet results = pstate.executeQuery();
        while (results.next()) {
            String tableName = results.getString(1);
            tables.add(tableName);
        }
        return tables;
    }

    /**
     * 處理表
     *
     * @param table
     */
    private void processTable(String table) {
        StringBuffer sb = new StringBuffer(table.length());
        String tableNew = table.toLowerCase();
        if (table.indexOf("_") != -1) {
            String[] tables = tableNew.split("_");
            String temp = null;
            for (int i = 0; i < tables.length; i++) {
                temp = tables[i].trim();
                sb.append(temp.substring(0, 1).toUpperCase()).append(temp.substring(1));
            }
        } else {
            sb.append(table.substring(0, 1).toUpperCase()).append(table.substring(1));
        }
        beanName = sb.toString();
        System.out.println(beanName);
        mapperName = beanName + "Mapper";
        XMLMapperName = beanName + "_mapper";
        serviceName = beanName + "Service";
    }

    /**
     * 處理字段類型
     *
     * @param type
     * @return
     */
    private String processType(String type) {
        if (type.indexOf(type_char) > -1) {
            return "String";
        } else if (type.indexOf(type_bigint) > -1) {
            return "Long";
        } else if (type.indexOf(type_int) > -1) {
            return "Integer";
        } else if (type.indexOf(type_date) > -1) {
            return "java.util.Date";
        } else if (type.indexOf(type_text) > -1) {
            return "String";
        } else if (type.indexOf(type_timestamp) > -1) {
            return "java.sql.Timestamp";
        } else if (type.indexOf(type_bit) > -1) {
            return "Boolean";
        } else if (type.indexOf(type_decimal) > -1) {
            return "java.math.BigDecimal";
        } else if (type.indexOf(type_blob) > -1) {
            return "byte[]";
        } else if (type.indexOf(type_Double) > -1) {
            return "Double";
        }
        return null;
    }

    /**
     * 處理字段名
     *
     * @param field
     * @return
     */
    private String processField(String field) {
        StringBuffer sb = new StringBuffer(field.length());
        //field = field.toLowerCase();
        String[] fields = field.split("_");
        String temp = null;
        sb.append(fields[0]);
        for (int i = 1; i < fields.length; i++) {
            temp = fields[i].trim();
            sb.append(temp.substring(0, 1).toUpperCase()).append(temp.substring(1));
        }
        return sb.toString();
    }

    /**
     * 將實體類名首字母改成小寫
     *
     * @param beanName
     * @return
     */
    private String processResultMapId(String beanName) {
        return beanName.substring(0, 1).toLowerCase() + beanName.substring(1);
    }

    /**
     * 構建類上面的註釋
     *
     * @param bw
     * @param text
     * @return
     * @throws IOException
     */
    private BufferedWriter buildClassComment(BufferedWriter bw, String text) throws IOException {
        bw.newLine();
        bw.newLine();
        bw.write("/**");
        bw.newLine();
        bw.write(" * ");
        bw.newLine();
        bw.write(" * " + text);
        bw.newLine();
        bw.write(" * @author " + author);
        bw.newLine();
        bw.write(" **/");
        return bw;
    }

    /**
     * 構建方法上面的註釋
     *
     * @param bw
     * @param text
     * @return
     * @throws IOException
     */
    private BufferedWriter buildMethodComment(BufferedWriter bw, String text) throws IOException {
        bw.newLine();
        bw.write("\t/**");
        bw.newLine();
        bw.write("\t * ");
        bw.newLine();
        bw.write("\t * " + text);
        bw.newLine();
        bw.write("\t * ");
        bw.newLine();
        bw.write("\t **/");
        return bw;
    }

    /**
     * 生成實體類
     *
     * @param columns
     * @param types
     * @param comments
     * @throws IOException
     */
    private void buildEntityBean(List<String> columns, List<String> types, List<String> comments, HSSFWorkbook excel)
            throws IOException {
        //excel操做
        Sheet sheet = excel.createSheet(beanName + "(" + tableComment + ")");//建立sheet
        Row row0 = sheet.createRow(0);//建立第一行
        row0.createCell(0).setCellValue("字段名");
        row0.createCell(1).setCellValue("類型");
        row0.createCell(2).setCellValue("字段註釋");
        //--------------------以上建立excel文檔

        File folder = new File(bean_path);
        if (!folder.exists()) {
            folder.mkdirs();
        }

        File beanFile = new File(bean_path, beanName + ".java");
        BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(beanFile)));
        bw.write("package " + bean_package + ";");
        bw.newLine();
        bw.write("import java.io.Serializable;");
        bw.newLine();
        bw = buildClassComment(bw, tableComment);
        bw.newLine();
        bw.write("@SuppressWarnings(\"serial\")");
        bw.newLine();
        bw.write("public class " + beanName + " implements Serializable {");
        bw.newLine();
        bw.newLine();
        int size = columns.size();
        for (int i = 0; i < size; i++) {

            bw.write("\tprivate " + processType(types.get(i)) + " " + processField(columns.get(i)) + ";");
            bw.write("  //" + comments.get(i));
            bw.newLine();
            //寫入excel
            Row row = sheet.createRow(i + 1);
            row.createCell(0).setCellValue(processField(columns.get(i)));//字段名
            row.createCell(1).setCellValue(processType(types.get(i)));//類型
            row.createCell(2).setCellValue(comments.get(i));//註釋
        }
        bw.write("\tprivate java.sql.Timestamp lastUpdTime; ");
        bw.write("  //上一次的更新時間");
        bw.newLine();

        // 生成get 和 set方法
        String tempField = null;
        String _tempField = null;
        String tempType = null;
        for (int i = 0; i < size; i++) {
            tempType = processType(types.get(i));
            _tempField = processField(columns.get(i));
            tempField = _tempField.substring(0, 1).toUpperCase() + _tempField.substring(1);
            bw.newLine();
            bw.write("\tpublic void set" + tempField + "(" + tempType + " " + _tempField + "){");
            bw.newLine();
            bw.write("\t\tthis." + _tempField + " = " + _tempField + ";");
            bw.newLine();
            bw.write("\t}");
            bw.newLine();
            bw.newLine();
            bw.write("\tpublic " + tempType + " get" + tempField + "(){");
            bw.newLine();
            bw.write("\t\treturn this." + _tempField + ";");
            bw.newLine();
            bw.write("\t}");
            bw.newLine();
        }

        bw.newLine();
        bw.write("\tpublic java.sql.Timestamp getLastUpdTime() {\n" +
                "        return lastUpdTime;\n" +
                "    }");
        bw.newLine();
        bw.newLine();
        bw.write("\tpublic void setLastUpdTime(java.sql.Timestamp lastUpdTime) {\n" +
                "        this.lastUpdTime = lastUpdTime;\n" +
                "    }");
        bw.newLine();
        bw.newLine();
        bw.write("}");
        bw.newLine();
        bw.flush();
        bw.close();

        //自動調整寬度
        sheet.autoSizeColumn((short) 0); //調整第一列寬度
        sheet.autoSizeColumn((short) 1); //調整第二列寬度
        sheet.autoSizeColumn((short) 2); //調整第三列寬度
    }

    /**
     * 構建Mapper文件
     *
     * @throws IOException
     */
    private void buildMapper() throws IOException {
        File folder = new File(mapper_path);
        if (!folder.exists()) {
            folder.mkdirs();
        }

        File mapperFile = new File(mapper_path, mapperName + ".java");
        BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(mapperFile), "utf-8"));
        bw.write("package " + mapper_package + ";");
        bw.newLine();
        bw.newLine();
        bw.write("import " + bean_package + "." + beanName + ";");
        bw.newLine();
        bw.write("import com.demo.base.IMapper;");
        bw = buildClassComment(bw, mapperName + "數據庫操做接口類");
        bw.newLine();
        bw.newLine();bw.write("public interface " + mapperName + "<T extends " + beanName + "> extends IMapper<T> {");
        bw.write("}");
        bw.flush();
        bw.close();
    }

    /**
     * 構建service文件
     *
     * @throws IOException
     */
    private void buildService() throws IOException {
        File folder = new File(service_path);
        if (!folder.exists()) {
            folder.mkdirs();
        }
        File serviceFile = new File(service_path, serviceName + ".java");
        BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(serviceFile), "utf-8"));
        bw.write("package " + service_package + ";");
        bw.newLine();
        bw.newLine();
        bw.write("import " + bean_package + "." + beanName + ";");
        bw.newLine();
        bw.write("import " + mapper_package + "." + mapperName + ";");
        bw.newLine();

        bw.write("import com.demo.base.ResponseBean;");
        bw.newLine();
        bw.write("import com.demo.base.BaseServiceSupport;");
        bw.newLine();
        bw.write("import com.demo.base.IMapper;");
        bw.newLine();
        bw.write("import org.springframework.beans.factory.annotation.Autowired;");
        bw.newLine();
        bw.write("import org.springframework.stereotype.Service;");
        bw = buildClassComment(bw, serviceName + " 業務處理類");
        bw.newLine();
        bw.write("@Service");
        bw.newLine();
        bw.write("public class " + serviceName + " extends BaseServiceSupport<" + beanName + "> {");
        bw.newLine();
        bw.write("    @Autowired");
        bw.newLine();
        bw.write("    private " + mapperName + "<" + beanName + "> " + mapperName.substring(0, 1).toLowerCase() + mapperName.substring(1) + ";");
        bw.newLine();
        bw.newLine();
        bw.write("    @Override");
        bw.newLine();
        bw.write("    public IMapper<" + beanName + "> getMapper() {");
        bw.newLine();
        bw.write("        return " + mapperName.substring(0, 1).toLowerCase() + mapperName.substring(1) + ";");
        bw.newLine();
        bw.write("    }");

        bw.newLine();
        bw.newLine();
        bw.write("    @Override");
        bw.newLine();
        bw.write("    public String getPK() {");
        bw.newLine();
        bw.write("        return \"uuid\";");
        bw.newLine();
        bw.write("    }");
        bw.newLine();

        bw.write("}");
        bw.flush();
        bw.close();
    }

    /**
     * 構建實體類映射XML文件
     *
     * @param columns
     * @param types
     * @param comments
     * @throws IOException
     */
    private void buildMapperXml(List<String> columns, List<String> types, List<String> comments) throws IOException {
        File folder = new File(xml_path);
        if (!folder.exists()) {
            folder.mkdirs();
        }

        File mapperXmlFile = new File(xml_path, XMLMapperName + ".xml");
        BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(mapperXmlFile)));
        bw.write("<?xml version=\"1.0\" encoding=\"UTF-8\"?>");
        bw.newLine();
        bw.write("<!DOCTYPE mapper PUBLIC \"-//mybatis.org//DTD Mapper 3.0//EN\" ");
        bw.newLine();
        bw.write("    \"http://mybatis.org/dtd/mybatis-3-mapper.dtd\">");
        bw.newLine();
        bw.write("<mapper namespace=\"" + mapper_package + "." + mapperName + "\">");
        bw.newLine();
        bw.newLine();

        buildSQL(bw, columns, types);

        bw.write("</mapper>");
        bw.flush();
        bw.close();
    }

    /**
     * 構建實體類映射XML文件
     *
     * @param tables
     * @throws IOException
     */
    private void buildAliasXml(List<String> tables) throws IOException {
        File folder = new File(mybatis_path);
        if (!folder.exists()) {
            folder.mkdirs();
        }

        File mapperXmlFile = new File(mybatis_path, "mybatis.xml");
        BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(mapperXmlFile)));
        bw.write("<?xml version=\"1.0\" encoding=\"UTF-8\"?>");
        bw.newLine();
        bw.write("<!DOCTYPE mapper PUBLIC \"-//mybatis.org//DTD Mapper 3.0//EN\" ");
        bw.newLine();
        bw.write("    \"http://mybatis.org/dtd/mybatis-3-mapper.dtd\">");
        bw.newLine();

        for (String table : tables) {
            StringBuffer sb = new StringBuffer(table.length());
            String tableNew = table.toLowerCase();
            if (table.indexOf("_") != -1) {
                String[] tablesStr = tableNew.split("_");
                String temp = null;
                for (int i = 0; i < tablesStr.length; i++) {
                    temp = tablesStr[i].trim();
                    sb.append(temp.substring(0, 1).toUpperCase()).append(temp.substring(1));
                }
            } else {
                sb.append(table.substring(0, 1).toUpperCase()).append(table.substring(1));
            }
            String tableName = sb.toString();

            bw.write("<typeAlias type=\"" + bean_package + "." + tableName + "\" alias=\"" + processResultMapId(tableName) + "\"/>");
            bw.newLine();
        }
        bw.write("</mapper>");
        bw.flush();
        bw.close();
    }

    private void buildSQL(BufferedWriter bw, List<String> columns, List<String> types) throws IOException {
        int size = columns.size();
        // 通用結果列
        //bw.write("\t<!-- 通用查詢結果列-->");
        bw.newLine();

        // 添加insert方法
        bw.write("\t<!-- 添加 -->");
        bw.newLine();
        bw.write("\t<insert id=\"insert\" parameterType=\"" + processResultMapId(beanName) + "\" useGeneratedKeys=\"true\" keyProperty=\"uuid\">");
        bw.newLine();
        bw.write("\t\t INSERT INTO " + tableName);
        bw.newLine();
        bw.write(" \t\t\t(");
        for (int i = 0; i < size; i++) {
            bw.write(columns.get(i));
            if (i != size - 1) {
                bw.write(",");
            }
        }
        bw.write(") ");
        bw.newLine();
        bw.write("\t\t\t VALUES ");
        bw.newLine();
        bw.write(" \t\t\t(");
        for (int i = 0; i < size; i++) {
            bw.write("#{" + processField(columns.get(i)) + "}");
            if (i != size - 1) {
                bw.write(",");
            }
        }
        bw.write(") ");
        bw.newLine();
        bw.write("\t</insert>");
        bw.newLine();
        bw.newLine();
        // 添加insert完

        // ----- 修改(匹配有值的字段)
        bw.write("\t<!-- 修 改-->");
        bw.newLine();
        bw.write("\t<update id=\"update\" parameterType=\"" + processResultMapId(beanName) + "\">");
        bw.newLine();
        bw.write("\t\t UPDATE " + tableName);
        bw.newLine();
        bw.write("\t\t <trim prefix=\"SET\" suffixOverrides=\",\" suffix=\"WHERE uuid = #{uuid}\"> ");

        bw.newLine();
        String tempField = null;
        tempField = null;
        for (int i = 1; i < size; i++) {
            tempField = processField(columns.get(i));
            bw.write("\t\t\t<if test=\"" + tempField + " != null\">");
            bw.write(columns.get(i) + " = #{" + tempField + "},");
            bw.write("</if>");
            bw.newLine();
        }
        bw.write("\t\t </trim> ");
        bw.newLine();
        bw.write("\t</update>");
        bw.newLine();
        bw.newLine();

        // 刪除(根據主鍵ID刪除)
        bw.write("\t<!--刪除:根據主鍵ID刪除-->");
        bw.newLine();
        bw.write("\t<delete id=\"delete\" parameterType=\"" + processResultMapId(beanName) + "\">");
        bw.newLine();
        bw.write("\t\t DELETE FROM " + tableName);
        bw.newLine();
        bw.write("\t\t WHERE " + columns.get(0) + " = #{" + processField(columns.get(0)) + "}");
        bw.newLine();
        bw.write("\t</delete>");
        bw.newLine();
        bw.newLine();
        // 刪除完


        // 查詢(根據主鍵ID查詢)
        bw.write("\t<!-- 查詢(根據主鍵ID查詢) -->");
        bw.newLine();
        bw.write("\t<select id=\"getInfoByUuid\" resultType=\"" + processResultMapId(beanName) + "\">");
        bw.newLine();
        bw.write("\t\t SELECT");
        bw.newLine();
        bw.write("\t\t *");
        bw.newLine();
        bw.write("\t\t FROM " + tableName);
        bw.newLine();
        bw.write("\t\t WHERE " + columns.get(0) + " = #{" + processField(columns.get(0)) + "}");
        bw.newLine();
        bw.write("\t</select>");
        bw.newLine();
        bw.newLine();
        // 查詢完

        // 查詢(根據map查詢)
        bw.write("\t<!-- 查詢(根據map查詢) -->");
        bw.newLine();
        bw.write("\t<select id=\"getInfoByMap\" resultType=\"" + processResultMapId(beanName) + "\">");
        bw.newLine();
        bw.write("\t\t SELECT");
        bw.newLine();
        bw.write("\t\t *");
        bw.newLine();
        bw.write("\t\t <include refid=\"sqlForList\"/> ");
        bw.newLine();
        bw.write("\t\t order by create_time desc limit 1 ");
        bw.newLine();
        bw.write("\t</select>");
        bw.newLine();
        bw.newLine();
        // 查詢完

        bw.newLine();
        bw.write("\t<!-- 獲取列表的通用SQL-->");
        bw.newLine();
        bw.write("\t<sql id=\"sqlForList\">");
        bw.newLine();
        bw.write("\t\t FROM " + tableName);
        bw.newLine();
        bw.write("\t\t<where>");
        bw.newLine();

        for (int i = 1; i < size; i++) {
            tempField = processField(columns.get(i));
            bw.write("\t\t\t<if test=\"" + tempField + " != null\">");
            bw.write("\tand\t" + columns.get(i) + " = #{" + tempField + "}");
            bw.write("</if>");
            bw.newLine();
        }
        bw.write("\t\t</where>");
        bw.newLine();
        bw.write("\t</sql>");
        bw.newLine();

        bw.newLine();
        bw.write("\t<!-- 統計-->");
        bw.newLine();
        bw.write("\t<select id=\"count\" resultType=\"int\">");
        bw.newLine();
        bw.write("\t\tselect");
        bw.newLine();
        bw.write("\t\t\tcount(*)");
        bw.newLine();
        bw.write("\t\t<include refid=\"sqlForList\"/>");
        bw.newLine();
        bw.write("\t</select>");
        bw.newLine();

        bw.newLine();
        bw.write("\t<!-- 獲取list-->");
        bw.newLine();
        bw.write("\t<select id=\"query\" resultType=\"" + processResultMapId(beanName) + "\">");
        bw.newLine();
        bw.write("\t\tselect");
        bw.newLine();
        bw.write("\t\t\t*");
        bw.newLine();
        bw.write("\t\t<include refid=\"sqlForList\"/>");
        bw.newLine();
        bw.write("\t\t order by create_time desc ");
        bw.newLine();
        bw.write("\t</select>");
        bw.newLine();

        bw.newLine();
        bw.write("\t<!-- 分頁查詢-->");
        bw.newLine();
        bw.write("\t<select id=\"queryPage\" resultType=\"" + processResultMapId(beanName) + "\">");
        bw.newLine();
        bw.write("\t\tselect");
        bw.newLine();
        bw.write("\t\t\t*");
        bw.newLine();
        bw.write("\t\t<include refid=\"sqlForList\"/>");
        bw.newLine();
        bw.write("\t\t order by create_time desc ");
        bw.newLine();
        bw.write("\t</select>");
        bw.newLine();
    }

    /**
     * 獲取全部的數據庫表註釋
     *
     * @return
     * @throws SQLException
     */
    private Map<String, String> getTableComment() throws SQLException {
        Map<String, String> maps = new HashMap<String, String>();
        PreparedStatement pstate = conn.prepareStatement("show table status");
        ResultSet results = pstate.executeQuery();
        while (results.next()) {
            String tableName = results.getString("NAME");
            String comment = results.getString("COMMENT");
            maps.put(tableName, comment);
        }
        return maps;
    }

    /**
     * 生成
     *
     * @throws ClassNotFoundException
     * @throws SQLException
     * @throws IOException
     */
    public void generate() throws ClassNotFoundException, SQLException, IOException {
        //
        HSSFWorkbook excel = new HSSFWorkbook();//建立excel

        init();
        String prefix = "show full fields from ";
        List<String> columns = null;
        List<String> types = null;
        List<String> comments = null;
        PreparedStatement pstate = null;
        List<String> tables = getTables();
        Map<String, String> tableComments = getTableComment();

        for (String table : tables) {
            columns = new ArrayList<String>();
            types = new ArrayList<String>();
            comments = new ArrayList<String>();
            pstate = conn.prepareStatement(prefix + table);
            ResultSet results = pstate.executeQuery();
            while (results.next()) {
                columns.add(results.getString("FIELD"));
                types.add(results.getString("TYPE"));
                comments.add(results.getString("COMMENT"));
            }
            tableName = table;
            processTable(table);
            tableComment = tableComments.get(tableName);
            buildEntityBean(columns, types, comments, excel);
            buildMapper();
            buildService();
            buildMapperXml(columns, types, comments);
        }

        buildAliasXml(tables);
        conn.close();
    }
}
相關文章
相關標籤/搜索