001package gudusoft.gsqlparser.demos.sqlguard; 002 003import com.sun.net.httpserver.HttpExchange; 004import com.sun.net.httpserver.HttpHandler; 005import com.sun.net.httpserver.HttpServer; 006 007import java.io.ByteArrayOutputStream; 008import java.io.IOException; 009import java.io.InputStream; 010import java.io.OutputStream; 011import java.net.HttpURLConnection; 012import java.net.InetAddress; 013import java.net.InetSocketAddress; 014import java.net.URL; 015import java.nio.charset.StandardCharsets; 016import java.util.concurrent.ArrayBlockingQueue; 017import java.util.concurrent.ThreadPoolExecutor; 018import java.util.concurrent.TimeUnit; 019 020public class SqlGuardHttpServer { 021 public static final int MAX_BODY_BYTES = 1572864; 022 private final String host; 023 private final int port; 024 private final SqlGuardService service; 025 private HttpServer server; 026 027 public SqlGuardHttpServer(String host, int port, SqlGuardService service) { 028 this.host = host == null ? "127.0.0.1" : host; 029 this.port = port; 030 this.service = service; 031 } 032 033 public void start() throws IOException { 034 InetAddress bindAddress = InetAddress.getByName(host); 035 if (!bindAddress.isLoopbackAddress()) { 036 throw new IOException("SQL Guard worker refuses non-loopback bind address: " + host); 037 } 038 server = HttpServer.create(new InetSocketAddress(bindAddress, port), 16); 039 server.createContext("/healthz", new Health()); 040 server.createContext("/check", new Check()); 041 server.setExecutor(new ThreadPoolExecutor( 042 4, 4, 0L, TimeUnit.MILLISECONDS, new ArrayBlockingQueue<Runnable>(32))); 043 server.start(); 044 } 045 046 public void stop(int delay) { 047 if (server != null) server.stop(delay); 048 } 049 050 public InetSocketAddress address() { 051 return server.getAddress(); 052 } 053 054 class Health implements HttpHandler { 055 public void handle(HttpExchange x) throws IOException { 056 if (!"GET".equals(x.getRequestMethod())) { 057 send(x, 405, "{\"ok\":false}"); 058 return; 059 } 060 send(x, 200, "{\"ok\":true,\"service\":\"sql-guard-worker\"}"); 061 } 062 } 063 064 class Check implements HttpHandler { 065 public void handle(HttpExchange x) throws IOException { 066 if (!"POST".equals(x.getRequestMethod())) { 067 send(x, 405, "{\"ok\":false}"); 068 return; 069 } 070 String body; 071 try { 072 body = readLimited(x.getRequestBody(), MAX_BODY_BYTES); 073 } catch (IOException e) { 074 send(x, 413, SqlGuardResponse.error(null, "REQUEST_TOO_LARGE", "Request body too large.").toJson()); 075 return; 076 } 077 SqlGuardResponse r; 078 try { 079 r = service.check(SqlGuardRequest.fromJson(body)); 080 } catch (Exception e) { 081 r = SqlGuardResponse.error(null, "INVALID_REQUEST", "Invalid JSON request."); 082 } 083 send(x, r.ok ? 200 : 400, r.toJson()); 084 } 085 } 086 087 static String readLimited(InputStream in, int max) throws IOException { 088 ByteArrayOutputStream out = new ByteArrayOutputStream(); 089 byte[] buf = new byte[8192]; 090 int total = 0; 091 int n; 092 while ((n = in.read(buf)) != -1) { 093 if (total + n > max) { 094 throw new IOException("too large"); 095 } 096 total += n; 097 out.write(buf, 0, n); 098 } 099 return new String(out.toByteArray(), StandardCharsets.UTF_8); 100 } 101 102 static void send(HttpExchange x, int status, String body) throws IOException { 103 byte[] b = body.getBytes(StandardCharsets.UTF_8); 104 x.getResponseHeaders().set("Content-Type", "application/json; charset=utf-8"); 105 x.sendResponseHeaders(status, b.length); 106 OutputStream os = x.getResponseBody(); 107 try { 108 os.write(b); 109 } finally { 110 os.close(); 111 } 112 } 113 114 public static String httpGet(String url) throws IOException { 115 HttpURLConnection c = (HttpURLConnection) new URL(url).openConnection(); 116 c.setRequestMethod("GET"); 117 return readAll(c); 118 } 119 120 public static String httpPost(String url, String body) throws IOException { 121 HttpURLConnection c = (HttpURLConnection) new URL(url).openConnection(); 122 c.setRequestMethod("POST"); 123 c.setDoOutput(true); 124 c.setRequestProperty("Content-Type", "application/json"); 125 OutputStream os = c.getOutputStream(); 126 try { 127 os.write(body.getBytes(StandardCharsets.UTF_8)); 128 } finally { 129 os.close(); 130 } 131 return readAll(c); 132 } 133 134 private static String readAll(HttpURLConnection c) throws IOException { 135 InputStream in = c.getResponseCode() >= 400 ? c.getErrorStream() : c.getInputStream(); 136 if (in == null) { 137 return ""; 138 } 139 try { 140 ByteArrayOutputStream out = new ByteArrayOutputStream(); 141 byte[] buf = new byte[4096]; 142 int n; 143 while ((n = in.read(buf)) != -1) out.write(buf, 0, n); 144 return new String(out.toByteArray(), StandardCharsets.UTF_8); 145 } finally { 146 in.close(); 147 } 148 } 149}