001package gudusoft.gsqlparser.dlineage.graph.utils;
002
003import gudusoft.gsqlparser.util.Logger;
004import gudusoft.gsqlparser.util.LoggerFactory;
005import org.w3c.dom.Document;
006import org.w3c.dom.Element;
007import org.w3c.dom.Node;
008import org.w3c.dom.NodeList;
009import org.xml.sax.InputSource;
010
011import javax.xml.parsers.DocumentBuilder;
012import javax.xml.parsers.DocumentBuilderFactory;
013import javax.xml.transform.OutputKeys;
014import javax.xml.transform.Transformer;
015import javax.xml.transform.TransformerFactory;
016import javax.xml.transform.dom.DOMSource;
017import javax.xml.transform.stream.StreamResult;
018import java.io.StringReader;
019import java.io.StringWriter;
020
021public class XMLUtil {
022    private static final Logger logger = LoggerFactory.getLogger(XMLUtil.class);
023    private static final TransformerFactory transformerFactory = TransformerFactory.newInstance();
024
025    public static Document parseXmlString(String xmlString) {
026        if (xmlString == null) return null;
027
028        try (StringReader reader = new StringReader(xmlString.replaceAll("<data key=\"d\\d+\"/>", ""))) {
029            DocumentBuilderFactory factory = DocumentBuilderFactory.newInstance();
030            DocumentBuilder builder = factory.newDocumentBuilder();
031            return builder.parse(new InputSource(reader));
032        } catch (Exception e) {
033            logger.error("Parse XML string failed.", e);
034            return null;
035        }
036    }
037
038    public static boolean containsNode(Document document, String nodeName) {
039        if (document == null || nodeName == null) return false;
040        Element root = document.getDocumentElement();
041        return containsNodeRecursive(root, nodeName);
042    }
043
044    private static boolean containsNodeRecursive(Element element, String nodeName) {
045        if (element.getNodeName().equals(nodeName)) return true;
046
047        NodeList children = element.getChildNodes();
048        for (int i = 0; i < children.getLength(); i++) {
049            Node node = children.item(i);
050            if (node.getNodeType() == Node.ELEMENT_NODE) {
051                if (containsNodeRecursive((Element) node, nodeName)) return true;
052            }
053        }
054        return false;
055    }
056
057    public static String documentToString(Document document) {
058        if (document == null) return null;
059        try (StringWriter writer = new StringWriter()) {
060            Transformer transformer = transformerFactory.newTransformer();
061            transformer.setOutputProperty(OutputKeys.OMIT_XML_DECLARATION, "no");
062            transformer.setOutputProperty(OutputKeys.ENCODING, "UTF-8");
063            transformer.setOutputProperty(OutputKeys.INDENT, "yes");
064            transformer.transform(new DOMSource(document), new StreamResult(writer));
065            return writer.toString();
066        } catch (Exception e) {
067            logger.error("Convert document to string failed.", e);
068            return null;
069        }
070    }
071
072    public static boolean containsText(Document document, String searchText) {
073        if (document == null || searchText == null) return false;
074        String xml = documentToString(document);
075        return xml != null && xml.contains(searchText);
076    }
077
078    public static String getFirstNodeLabelTextRecursive(Node node) {
079        if (node == null || node.getNodeType() != Node.ELEMENT_NODE) return null;
080
081        Element element = (Element) node;
082        NodeList nodeLabels = element.getElementsByTagName("NodeLabel");
083        if (nodeLabels.getLength() > 0) {
084            return nodeLabels.item(0).getTextContent().trim();
085        }
086
087        NodeList children = element.getChildNodes();
088        for (int i = 0; i < children.getLength(); i++) {
089            String result = getFirstNodeLabelTextRecursive(children.item(i));
090            if (result != null) return result;
091        }
092        return null;
093    }
094}