001package gudusoft.gsqlparser.dlineage;
002
003import gudusoft.gsqlparser.*;
004import gudusoft.gsqlparser.dlineage.dataflow.model.Option;
005import gudusoft.gsqlparser.dlineage.dataflow.model.SqlInfo;
006import gudusoft.gsqlparser.dlineage.statistics.*;
007import gudusoft.gsqlparser.nodes.TJoin;
008import gudusoft.gsqlparser.nodes.TJoinItem;
009import gudusoft.gsqlparser.nodes.TJoinItemList;
010import gudusoft.gsqlparser.nodes.TJoinList;
011import gudusoft.gsqlparser.stmt.*;
012import gudusoft.gsqlparser.util.Logger;
013import gudusoft.gsqlparser.util.LoggerFactory;
014import gudusoft.gsqlparser.util.SQLUtil;
015import gudusoft.gsqlparser.util.json.JSON;
016
017import java.io.File;
018import java.util.*;
019import java.util.regex.Matcher;
020import java.util.regex.Pattern;
021import java.util.stream.Collectors;
022
023public class SQLFileStatistics {
024    private static final Logger logger = LoggerFactory.getLogger(SQLFileStatistics.class);
025
026    private SqlInfo[] sqlInfos;
027    private Option option;
028    private List<FileStatistics> fileStatisticsList = new ArrayList<>();
029
030    public SQLFileStatistics(SqlInfo[] sqlInfos, Option option) {
031        if (sqlInfos == null) {
032            this.sqlInfos = new SqlInfo[0];
033        } else {
034            this.sqlInfos = new SqlInfo[sqlInfos.length];
035            for (int i = 0; i < sqlInfos.length; i++) {
036                if (sqlInfos[i] == null) {
037                    this.sqlInfos[i] = new SqlInfo();
038                    this.sqlInfos[i].setSql("");
039                    this.sqlInfos[i].setOriginIndex(i);
040                } else {
041                    this.sqlInfos[i] = sqlInfos[i];
042                    this.sqlInfos[i].setOriginIndex(i);
043                }
044            }
045        }
046        this.option = option;
047    }
048
049    /**
050     * 合并另一个SQLFileStatistics对象的分析结果到当前对象
051     *
052     * @param other 要合并的SQLFileStatistics对象
053     */
054    public void merge(SQLFileStatistics other) {
055        if (other == null || other.fileStatisticsList == null || other.fileStatisticsList.isEmpty()) {
056            return;
057        }
058
059        // 创建一个当前文件统计的映射,便于快速查找
060        Map<String, FileStatistics> currentFileStatsMap = new HashMap<>();
061        for (FileStatistics stats : this.fileStatisticsList) {
062            currentFileStatsMap.put(stats.getFile(), stats);
063        }
064
065        // 合并另一个对象的文件统计结果
066        for (FileStatistics otherStats : other.fileStatisticsList) {
067            if (currentFileStatsMap.containsKey(otherStats.getFile())) {
068                // 如果当前对象已经有该文件的统计结果,合并它们
069                FileStatistics currentStats = currentFileStatsMap.get(otherStats.getFile());
070                mergeFileStatistics(currentStats, otherStats);
071            } else {
072                // 如果当前对象没有该文件的统计结果,直接添加
073                this.fileStatisticsList.add(otherStats);
074                currentFileStatsMap.put(otherStats.getFile(), otherStats);
075            }
076        }
077    }
078
079    /**
080     * 合并两个FileStatistics对象
081     *
082     * @param target 目标FileStatistics对象,合并后结果存储在这里
083     * @param source 源FileStatistics对象,合并后不会改变
084     */
085    private void mergeFileStatistics(FileStatistics target, FileStatistics source) {
086        // 合并计数类型的统计
087        target.setTotal_statements(target.getTotal_statements() + source.getTotal_statements());
088        target.setSelect_count(target.getSelect_count() + source.getSelect_count());
089        target.setInsert_count(target.getInsert_count() + source.getInsert_count());
090        target.setUpdate_count(target.getUpdate_count() + source.getUpdate_count());
091        target.setDelete_count(target.getDelete_count() + source.getDelete_count());
092        target.setMerge_count(target.getMerge_count() + source.getMerge_count());
093        target.setCreate_table_count(target.getCreate_table_count() + source.getCreate_table_count());
094        target.setCreate_temp_table_count(target.getCreate_temp_table_count() + source.getCreate_temp_table_count());
095        target.setCtas_count(target.getCtas_count() + source.getCtas_count());
096        target.setCreate_view_count(target.getCreate_view_count() + source.getCreate_view_count());
097        target.setCreate_temp_view_count(target.getCreate_temp_view_count() + source.getCreate_temp_view_count());
098        target.setDrop_count(target.getDrop_count() + source.getDrop_count());
099        target.setTruncate_count(target.getTruncate_count() + source.getTruncate_count());
100        target.setParse_error_count(target.getParse_error_count() + source.getParse_error_count());
101        target.setJoin_count(target.getJoin_count() + source.getJoin_count());
102
103        // 合并子查询深度(取最大值)
104        if (source.getSubquery_depth() > target.getSubquery_depth()) {
105            target.setSubquery_depth(source.getSubquery_depth());
106        }
107
108        // 合并其他计数
109        target.setCte_count(target.getCte_count() + source.getCte_count());
110        target.setCase_count(target.getCase_count() + source.getCase_count());
111        target.setUnion_count(target.getUnion_count() + source.getUnion_count());
112        target.setWindow_function_count(target.getWindow_function_count() + source.getWindow_function_count());
113        target.setAggregate_function_count(target.getAggregate_function_count() + source.getAggregate_function_count());
114        target.setWhere_predicate_count(target.getWhere_predicate_count() + source.getWhere_predicate_count());
115
116        // 合并JOIN类型统计
117        for (Map.Entry<String, Integer> entry : source.getJoin_types().entrySet()) {
118            String joinType = entry.getKey();
119            int count = entry.getValue();
120            target.getJoin_types().put(joinType, target.getJoin_types().getOrDefault(joinType, 0) + count);
121        }
122
123        // 合并表引用
124        target.addTableReferences(source.getDistinct_table_references());
125
126        // 合并创建的对象
127        for (String object : source.getObjects_created()) {
128            if (!target.getObjects_created().contains(object)) {
129                target.addObject_created(object);
130            }
131        }
132
133        // 合并读取的对象
134        for (String object : source.getObjects_read()) {
135            target.addObject_read(object);
136        }
137    }
138
139    public SQLFileStatistics(String sqlContent, Option option) {
140        SqlInfo[] sqlInfos = new SqlInfo[1];
141        SqlInfo info = new SqlInfo();
142        info.setSql(sqlContent);
143        info.setOriginIndex(0);
144        sqlInfos[0] = info;
145        this.sqlInfos = sqlInfos;
146        this.option = option;
147    }
148
149    public SQLFileStatistics(File[] sqlFiles, Option option) {
150        SqlInfo[] sqlInfos = new SqlInfo[sqlFiles.length];
151        for (int i = 0; i < sqlFiles.length; i++) {
152            SqlInfo info = new SqlInfo();
153            info.setSql(SQLUtil.getFileContent(sqlFiles[i]));
154            info.setFileName(sqlFiles[i].getName());
155            info.setFilePath(sqlFiles[i].getAbsolutePath());
156            info.setOriginIndex(0);
157            sqlInfos[i] = info;
158        }
159        this.sqlInfos = sqlInfos;
160        this.option = option;
161    }
162
163    /**
164     * 从拆分后的文件名中提取原始文件名
165     * 拆分文件名格式:原文件名_文件索引_起始行号_结束行号.扩展名
166     *
167     * @param filename 拆分后的文件名
168     * @return 原始文件名,如果不是拆分文件则返回原文件名
169     */
170    private String getOriginalFileName(String filename) {
171        if (filename == null || filename.isEmpty()) {
172            return filename;
173        }
174
175        // 使用正则表达式匹配拆分文件名格式
176        // 格式:原文件名_数字_数字_数字.扩展名
177        String regex = "^(.*?)_\\d+_\\d+_\\d+\\.(.*)$";
178        Pattern pattern = Pattern.compile(regex);
179        Matcher matcher = pattern.matcher(filename);
180
181        if (matcher.matches()) {
182            // 如果匹配成功,返回原始文件名(原文件名.扩展名)
183            return matcher.group(1) + "." + matcher.group(2);
184        }
185
186        // 如果不是拆分文件,返回原文件名
187        return filename;
188    }
189
190    /**
191     * 从文件路径中提取原始文件名
192     *
193     * @param filePath 文件路径
194     * @return 原始文件名,如果不是拆分文件则返回原文件路径
195     */
196    private String getOriginalFilePath(String filePath) {
197        if (filePath == null || filePath.isEmpty()) {
198            return filePath;
199        }
200
201        File file = new File(filePath);
202        String filename = file.getName();
203        String originalFilename = getOriginalFileName(filename);
204
205        if (originalFilename.equals(filename)) {
206            // 如果不是拆分文件,返回原文件路径
207            return filePath;
208        }
209
210        // 如果是拆分文件,返回原始文件路径
211        return new File(file.getParent(), originalFilename).getAbsolutePath();
212    }
213
214    public synchronized String generateFileStatistics() {
215        if (sqlInfos == null) {
216            return JSON.toJSONString(Collections.emptyMap());
217        }
218
219        // 按文件分组统计 - 识别并合并拆分文件
220        Map<String, List<SqlInfo>> fileSqlInfoMap = Arrays.stream(sqlInfos)
221                .filter(info -> info != null && info.getSql() != null && !info.getSql().trim().isEmpty())
222                .collect(Collectors.groupingBy(info -> {
223                    if (info.getFilePath() != null && !info.getFilePath().isEmpty()) {
224                        // 如果有文件路径,检查是否是拆分文件
225                        return getOriginalFilePath(info.getFilePath());
226                    } else if (info.getFileName() != null && !info.getFileName().isEmpty()) {
227                        // 如果只有文件名,检查是否是拆分文件
228                        return getOriginalFileName(info.getFileName());
229                    } else {
230                        return "anonymous.sql";
231                    }
232                }));
233
234        // 对每个文件进行统计
235        // 缓存供应商类型,避免重复调用
236        EDbVendor vendor = option.getVendor();
237
238        // 预创建访问者对象,避免重复创建
239        SubqueryDepthVisitor subqueryDepthVisitor = new SubqueryDepthVisitor();
240        CaseVisitor caseVisitor = new CaseVisitor();
241        WindowFunctionVisitor windowFunctionVisitor = new WindowFunctionVisitor();
242        AggregateFunctionVisitor aggregateFunctionVisitor = new AggregateFunctionVisitor();
243        TableReferenceVisitor tableReferenceVisitor = new TableReferenceVisitor();
244        WherePredicateVisitor wherePredicateVisitor = new WherePredicateVisitor();
245
246        // 预创建SQL解析器,避免重复创建
247        TGSqlParser sqlparser = new TGSqlParser(vendor);
248
249        // 遍历每个文件
250        for (Map.Entry<String, List<SqlInfo>> entry : fileSqlInfoMap.entrySet()) {
251            String filePath = entry.getKey();
252            List<SqlInfo> fileSqlInfos = entry.getValue();
253
254            // 预分配FileStatistics对象,设置合理的初始容量
255            FileStatistics fileStats = new FileStatistics(filePath);
256            
257            // 统计文件大小和总行数
258            File file = new File(filePath);
259            collectSQLFileInfo(file, fileStats, filePath);
260
261            // 解析文件中的所有SQL语句
262            for (SqlInfo sqlInfo : fileSqlInfos) {
263                // 重置解析器状态并设置新的SQL文本
264                sqlparser.sqltext = sqlInfo.getSql();
265
266                int parseResult = sqlparser.parse();
267                if (parseResult != 0) {
268                    logger.warn("文件: " + filePath + " SQL解析错误: " + sqlparser.getErrormessage());
269                    // 使用sqlparser.getErrorCount()获取解析错误数量
270                    fileStats.setParse_error_count(fileStats.getParse_error_count() + sqlparser.getErrorCount());
271                }
272                // 获取语句列表,避免重复调用
273                int statementCount = sqlparser.sqlstatements.size();
274
275                // 统计每个SQL语句
276                for (int i = 0; i < statementCount; i++) {
277                    TCustomSqlStatement stmt = sqlparser.sqlstatements.get(i);
278                    fileStats.incrementTotal_statements();
279
280                    // 统计语句类型
281                    countStatementType(fileStats, stmt);
282
283                    // 统计JOIN相关指标
284                    countJoins(fileStats, stmt);
285
286                    // 统计子查询深度 - 重置访问者状态
287                    subqueryDepthVisitor.reset();
288                    stmt.acceptChildren(subqueryDepthVisitor);
289                    int subqueryDepth = subqueryDepthVisitor.getMaxDepth();
290                    if (subqueryDepth > fileStats.getSubquery_depth()) {
291                        fileStats.setSubquery_depth(subqueryDepth);
292                    }
293
294                    // 统计CTE数量
295                    countCTEs(fileStats, stmt);
296
297                    // 统计CASE表达式数量 - 重置访问者状态
298                    caseVisitor.reset();
299                    stmt.acceptChildren(caseVisitor);
300                    fileStats.addCase_count(caseVisitor.getCaseCount());
301
302                    // 统计UNION数量
303                    countUnions(fileStats, stmt);
304
305                    // 统计窗口函数数量 - 重置访问者状态
306                    windowFunctionVisitor.reset();
307                    stmt.acceptChildren(windowFunctionVisitor);
308                    fileStats.addWindow_function_count(windowFunctionVisitor.getWindowFunctionCount());
309
310                    // 统计聚合函数数量 - 重置访问者状态
311                    aggregateFunctionVisitor.reset();
312                    stmt.acceptChildren(aggregateFunctionVisitor);
313                    fileStats.addAggregate_function_count(aggregateFunctionVisitor.getAggregateFunctionCount());
314
315                    // 统计表引用 - 重置访问者状态
316                    tableReferenceVisitor.reset();
317                    stmt.acceptChildren(tableReferenceVisitor);
318                    fileStats.addTableReferences(tableReferenceVisitor.getTableReferences());
319
320                    // 对于SELECT语句,收集读取的表
321                    if (stmt instanceof TSelectSqlStatement) {
322                        tableReferenceVisitor.getTableReferences().forEach(table -> {
323                            fileStats.addObject_read(table);
324                        });
325                    }
326
327                    // 统计WHERE/HAVING条件数 - 重置访问者状态
328                    if (stmt.getWhereClause() != null) {
329                        wherePredicateVisitor.reset();
330                        stmt.getWhereClause().acceptChildren(wherePredicateVisitor);
331                        fileStats.addWhere_predicate_count(wherePredicateVisitor.getPredicateCount());
332                    }
333                    if (stmt.getStatements() != null) {
334                        for (int j = 0; j < stmt.getStatements().size(); j++) {
335                            TCustomSqlStatement subStmt = stmt.getStatements().get(j);
336                            if (subStmt.getWhereClause() == null) {
337                                continue;
338                            }
339                            wherePredicateVisitor.reset();
340                            subStmt.getWhereClause().acceptChildren(wherePredicateVisitor);
341                            fileStats.addWhere_predicate_count(wherePredicateVisitor.getPredicateCount());
342                        }
343                    }
344                }
345            }
346
347            fileStatisticsList.add(fileStats);
348        }
349
350        // 生成JSON输出
351        Map result = new LinkedHashMap();
352        List fileStatsArray = new ArrayList();
353
354        for (FileStatistics stat : fileStatisticsList) {
355            fileStatsArray.add(stat.toJSON());
356        }
357
358        result.put("file_statistics", fileStatsArray);
359        return JSON.toJSONString(result);
360    }
361
362    private void collectSQLFileInfo(File file, FileStatistics fileStats, String filePath) {
363        if (file.exists()) {
364            // 统计文件大小
365            fileStats.setFile_size(file.length());
366
367            // 统计总行数
368            try (java.io.BufferedReader reader = new java.io.BufferedReader(new java.io.FileReader(file))) {
369                int lineCount = 0;
370                while (reader.readLine() != null) {
371                    lineCount++;
372                }
373                fileStats.setLine_count(lineCount);
374            } catch (java.io.IOException e) {
375                logger.warn("统计文件总行数错误: " + filePath + ", 错误信息: " + e.getMessage());
376            }
377        }
378    }
379
380    private void countStatementType(FileStatistics fileStats, TCustomSqlStatement stmt) {
381        if (stmt instanceof TSelectSqlStatement) {
382            fileStats.incrementSelect_count();
383        } else if (stmt instanceof TInsertSqlStatement) {
384            fileStats.incrementInsert_count();
385        } else if (stmt instanceof TUpdateSqlStatement) {
386            fileStats.incrementUpdate_count();
387        } else if (stmt instanceof TDeleteSqlStatement) {
388            fileStats.incrementDelete_count();
389        } else if (stmt instanceof TMergeSqlStatement) {
390            fileStats.incrementMerge_count();
391        } else if (stmt instanceof TCreateTableSqlStatement) {
392            TCreateTableSqlStatement createTableStmt = (TCreateTableSqlStatement) stmt;
393            // 检查是否是临时表
394            if (createTableStmt.getTableKinds() != null && (
395                    createTableStmt.getTableKinds().contains(ETableKind.etkTemporary) ||
396                            createTableStmt.getTableKinds().contains(ETableKind.etkTemp) ||
397                            createTableStmt.getTableKinds().contains(ETableKind.etkLocalTemporary) ||
398                            createTableStmt.getTableKinds().contains(ETableKind.etkLocalTemp) ||
399                            createTableStmt.getTableKinds().contains(ETableKind.etkGlobalTemporary) ||
400                            createTableStmt.getTableKinds().contains(ETableKind.etkGlobalTemp))) {
401                fileStats.incrementCreate_temp_table_count();
402            } else {
403                fileStats.incrementCreate_table_count();
404            }
405            // 检查是否是CTAS
406            if (createTableStmt.getSubQuery() != null) {
407                fileStats.incrementCtas_count();
408            }
409
410            fileStats.addObject_created(createTableStmt.getTargetTable().toString());
411        } else if (stmt instanceof TCreateViewSqlStatement) {
412            TCreateViewSqlStatement createViewStmt = (TCreateViewSqlStatement) stmt;
413            // 检查是否是临时视图
414            boolean isTempView = false;
415            if (createViewStmt.getTableKind() != null) {
416                ETableKind tableKind = createViewStmt.getTableKind();
417                isTempView = (tableKind == ETableKind.etkTemporary ||
418                        tableKind == ETableKind.etkTemp ||
419                        tableKind == ETableKind.etkLocalTemporary ||
420                        tableKind == ETableKind.etkLocalTemp ||
421                        tableKind == ETableKind.etkGlobalTemporary ||
422                        tableKind == ETableKind.etkGlobalTemp);
423            }
424
425            if (isTempView) {
426                fileStats.incrementCreate_temp_view_count();
427            } else {
428                fileStats.incrementCreate_view_count();
429            }
430
431            fileStats.addObject_created(createViewStmt.getViewName().toString());
432        } else if (stmt.sqlstatementtype.toString().toLowerCase().startsWith("sstdrop")) {
433            fileStats.incrementDrop_count();
434        } else if (stmt instanceof TTruncateStatement) {
435            fileStats.incrementTruncate_count();
436        }
437    }
438
439    private void countJoins(FileStatistics fileStats, TCustomSqlStatement stmt) {
440        // 使用集合跟踪已处理的连接,避免重复计数
441        Set<TJoin> processedJoins = new HashSet<>();
442        // 递归处理所有语句及其子语句
443        processStatementsForJoins(fileStats, stmt, processedJoins);
444    }
445
446    /**
447     * 递归处理所有语句及其子语句,统计连接信息
448     *
449     * @param fileStats      统计信息对象
450     * @param stmt           当前处理的语句
451     * @param processedJoins 已处理的连接集合,用于去重
452     */
453    private void processStatementsForJoins(FileStatistics fileStats, TCustomSqlStatement stmt, Set<TJoin> processedJoins) {
454        if (stmt == null) {
455            return;
456        }
457
458        // 处理当前语句的连接
459        TJoinList joins = stmt.getJoins();
460        if (joins != null && joins.size() > 0) {
461            for (int i = 0; i < joins.size(); i++) {
462                TJoin join = joins.getJoin(i);
463                // 检查连接是否已经处理过
464                if (!processedJoins.contains(join)) {
465                    processedJoins.add(join);
466                    TJoinItemList joinItems = join.getJoinItems();
467                    if (joinItems != null && joinItems.size() > 0) {
468                        for (int j = 0; j < joinItems.size(); j++) {
469                            TJoinItem joinItem = joinItems.getJoinItem(j);
470                            EJoinType joinType = joinItem.getJoinType();
471                            if (joinType != null) {
472                                fileStats.incrementJoin_count();
473
474                                // 统计JOIN类型,将具体类型归类到五种主要类型:INNER、LEFT、RIGHT、FULL、CROSS
475                                String joinTypeStr = joinType.toString().toLowerCase();
476                                String mainJoinType = "";
477
478                                // 根据具体JOIN类型确定主要类型
479                                switch (joinType) {
480                                    case inner:
481                                    case join:
482                                    case natural_inner:
483                                        mainJoinType = "inner";
484                                        break;
485                                    case left:
486                                    case leftouter:
487                                    case natural_left:
488                                    case natural_leftouter:
489                                    case leftsemi:
490                                    case leftanti:
491                                        mainJoinType = "left";
492                                        break;
493                                    case right:
494                                    case rightouter:
495                                    case natural_right:
496                                    case natural_rightouter:
497                                        mainJoinType = "right";
498                                        break;
499                                    case full:
500                                    case fullouter:
501                                    case natural_full:
502                                    case natural_fullouter:
503                                        mainJoinType = "full";
504                                        break;
505                                    case cross:
506                                        mainJoinType = "cross";
507                                        break;
508                                    default:
509                                        // 对于其他类型,保留原始类型
510                                        mainJoinType = joinTypeStr;
511                                        break;
512                                }
513
514                                fileStats.addJoin_type(mainJoinType);
515                            }
516                        }
517                    }
518                }
519            }
520        }
521
522        // 处理子语句
523        TStatementList statements = stmt.getStatements();
524        if (statements != null && statements.size() > 0) {
525            for (TCustomSqlStatement subStmt : statements) {
526                processStatementsForJoins(fileStats, subStmt, processedJoins);
527            }
528        }
529    }
530
531    private void countCTEs(FileStatistics fileStats, TCustomSqlStatement stmt) {
532        if (stmt.getCteList() != null) {
533            fileStats.addCte_count(stmt.getCteList().size());
534        }
535    }
536
537    private void countUnions(FileStatistics fileStats, TCustomSqlStatement stmt) {
538        if (stmt instanceof TSelectSqlStatement) {
539            TSelectSqlStatement selectStmt = (TSelectSqlStatement) stmt;
540            if (selectStmt.getSetOperatorType() != ESetOperatorType.none) {
541                // Count this union operation
542                fileStats.incrementUnion_count();
543
544                // Recursively process left and right statements
545                countUnions(fileStats, selectStmt.getLeftStmt());
546                countUnions(fileStats, selectStmt.getRightStmt());
547            }
548        }
549    }
550
551
552}