001package gudusoft.gsqlparser.demos.sqlguard;
002
003import com.sun.net.httpserver.HttpExchange;
004import com.sun.net.httpserver.HttpHandler;
005import com.sun.net.httpserver.HttpServer;
006
007import java.io.ByteArrayOutputStream;
008import java.io.IOException;
009import java.io.InputStream;
010import java.io.OutputStream;
011import java.net.HttpURLConnection;
012import java.net.InetAddress;
013import java.net.InetSocketAddress;
014import java.net.URL;
015import java.nio.charset.StandardCharsets;
016import java.util.concurrent.ArrayBlockingQueue;
017import java.util.concurrent.ThreadPoolExecutor;
018import java.util.concurrent.TimeUnit;
019
020public class SqlGuardHttpServer {
021    public static final int MAX_BODY_BYTES = 1572864;
022    private final String host;
023    private final int port;
024    private final SqlGuardService service;
025    private HttpServer server;
026
027    public SqlGuardHttpServer(String host, int port, SqlGuardService service) {
028        this.host = host == null ? "127.0.0.1" : host;
029        this.port = port;
030        this.service = service;
031    }
032
033    public void start() throws IOException {
034        InetAddress bindAddress = InetAddress.getByName(host);
035        if (!bindAddress.isLoopbackAddress()) {
036            throw new IOException("SQL Guard worker refuses non-loopback bind address: " + host);
037        }
038        server = HttpServer.create(new InetSocketAddress(bindAddress, port), 16);
039        server.createContext("/healthz", new Health());
040        server.createContext("/check", new Check());
041        server.setExecutor(new ThreadPoolExecutor(
042                4, 4, 0L, TimeUnit.MILLISECONDS, new ArrayBlockingQueue<Runnable>(32)));
043        server.start();
044    }
045
046    public void stop(int delay) {
047        if (server != null) server.stop(delay);
048    }
049
050    public InetSocketAddress address() {
051        return server.getAddress();
052    }
053
054    class Health implements HttpHandler {
055        public void handle(HttpExchange x) throws IOException {
056            if (!"GET".equals(x.getRequestMethod())) {
057                send(x, 405, "{\"ok\":false}");
058                return;
059            }
060            send(x, 200, "{\"ok\":true,\"service\":\"sql-guard-worker\"}");
061        }
062    }
063
064    class Check implements HttpHandler {
065        public void handle(HttpExchange x) throws IOException {
066            if (!"POST".equals(x.getRequestMethod())) {
067                send(x, 405, "{\"ok\":false}");
068                return;
069            }
070            String body;
071            try {
072                body = readLimited(x.getRequestBody(), MAX_BODY_BYTES);
073            } catch (IOException e) {
074                send(x, 413, SqlGuardResponse.error(null, "REQUEST_TOO_LARGE", "Request body too large.").toJson());
075                return;
076            }
077            SqlGuardResponse r;
078            try {
079                r = service.check(SqlGuardRequest.fromJson(body));
080            } catch (Exception e) {
081                r = SqlGuardResponse.error(null, "INVALID_REQUEST", "Invalid JSON request.");
082            }
083            send(x, r.ok ? 200 : 400, r.toJson());
084        }
085    }
086
087    static String readLimited(InputStream in, int max) throws IOException {
088        ByteArrayOutputStream out = new ByteArrayOutputStream();
089        byte[] buf = new byte[8192];
090        int total = 0;
091        int n;
092        while ((n = in.read(buf)) != -1) {
093            if (total + n > max) {
094                throw new IOException("too large");
095            }
096            total += n;
097            out.write(buf, 0, n);
098        }
099        return new String(out.toByteArray(), StandardCharsets.UTF_8);
100    }
101
102    static void send(HttpExchange x, int status, String body) throws IOException {
103        byte[] b = body.getBytes(StandardCharsets.UTF_8);
104        x.getResponseHeaders().set("Content-Type", "application/json; charset=utf-8");
105        x.sendResponseHeaders(status, b.length);
106        OutputStream os = x.getResponseBody();
107        try {
108            os.write(b);
109        } finally {
110            os.close();
111        }
112    }
113
114    public static String httpGet(String url) throws IOException {
115        HttpURLConnection c = (HttpURLConnection) new URL(url).openConnection();
116        c.setRequestMethod("GET");
117        return readAll(c);
118    }
119
120    public static String httpPost(String url, String body) throws IOException {
121        HttpURLConnection c = (HttpURLConnection) new URL(url).openConnection();
122        c.setRequestMethod("POST");
123        c.setDoOutput(true);
124        c.setRequestProperty("Content-Type", "application/json");
125        OutputStream os = c.getOutputStream();
126        try {
127            os.write(body.getBytes(StandardCharsets.UTF_8));
128        } finally {
129            os.close();
130        }
131        return readAll(c);
132    }
133
134    private static String readAll(HttpURLConnection c) throws IOException {
135        InputStream in = c.getResponseCode() >= 400 ? c.getErrorStream() : c.getInputStream();
136        if (in == null) {
137            return "";
138        }
139        try {
140            ByteArrayOutputStream out = new ByteArrayOutputStream();
141            byte[] buf = new byte[4096];
142            int n;
143            while ((n = in.read(buf)) != -1) out.write(buf, 0, n);
144            return new String(out.toByteArray(), StandardCharsets.UTF_8);
145        } finally {
146            in.close();
147        }
148    }
149}