diff --git a/source/handy_http_primitives/request.d b/source/handy_http_primitives/request.d index 30f4f07..31d3a3f 100644 --- a/source/handy_http_primitives/request.d +++ b/source/handy_http_primitives/request.d @@ -1,6 +1,6 @@ module handy_http_primitives.request; -import streams : InputStream; +import streams; import std.traits : EnumMembers; import handy_http_primitives.optional; @@ -14,15 +14,139 @@ struct ServerHttpRequest { /// The HTTP version of the request. const HttpVersion httpVersion = HttpVersion.V1; /// The remote address of the client that sent this request. - const ClientAddress clientAddress; + const ClientAddress clientAddress = ClientAddress.unknown; /// The HTTP verb used in the request. const string method = HttpMethod.GET; - /// The URL that was requested. + /// The URL that was requested, excluding any query parameters. const string url = ""; /// A case-insensitive map of all request headers. const(string[][string]) headers; + /// A list of all URL query parameters. + const QueryParameter[] queryParams; /// The underlying stream used to read the body from the request. InputStream!ubyte inputStream; + + /** + * Gets a header as the specified type, or returns the default value if the + * header doesn't exist or cannot be converted to the desired type. + * Params: + * headerName = The name of the header to get, case-sensitive. + * defaultValue = The default value to return if the header doesn't exist + * or is invalid. + * Returns: The header value. + */ + T getHeaderAs(T)(string headerName, T defaultValue = T.init) const { + import std.conv : to, ConvException; + if (headerName !in headers || headers[headerName].length == 0) return defaultValue; + try { + return to!T(headers[headerName][0]); + } catch (ConvException e) { + return defaultValue; + } + } + + /** + * Gets a query parameter with a given name, as the specified type, or + * returns the default value if the parameter doesn't exist. + * Params: + * paramName = The name of the parameter to get. + * defaultValue = The default value to return if the parameter doesn't + * exist or is invalid. + * Returns: The parameter value. + */ + T getParamAs(T)(string paramName, T defaultValue = T.init) const { + import std.conv : to, ConvException; + foreach (ref param; queryParams) { + if (param.key == paramName) { + foreach (string value; param.values) { + try { + return value.to!T; + } catch (ConvException e) { + continue; + } + } + // No value could be converted, short-circuit now. + return defaultValue; + } + } + return defaultValue; + } + + /** + * Reads the body of this request and transfers it to the given output + * stream, limited by the request's "Content-Length" unless you choose to + * allow infinite reading. If the request includes a header for + * "Transfer-Encoding: chunked", then it will wrap the input stream in one + * which decodes HTTP chunked-encoding first. + * Params: + * outputStream = The output stream to transfer data to. + * allowInfiniteRead = Whether to allow reading the request even if the + * Content-Length header is missing or invalid. Use + * with caution! + * Returns: Either the number of bytes read, or a stream error. + */ + StreamResult readBody(S)(ref S outputStream, bool allowInfiniteRead = false) if (isByteOutputStream!S) { + import std.algorithm : min; + import std.string : toLower; + const long contentLength = getHeaderAs!long("Content-Length", -1); + if (contentLength < 0 && !allowInfiniteRead) { + return StreamResult(0); + } + InputStream!ubyte sIn; + if ("Transfer-Encoding" in headers && toLower(headers["Transfer-Encoding"][0]) == "chunked") { + sIn = inputStreamObjectFor(chunkedEncodingInputStreamFor(inputStream)); + } else { + sIn = inputStream; + } + ulong bytesRead = 0; + ubyte[8192] buffer; + while (contentLength == -1 || bytesRead < contentLength) { + const ulong bytesToRead = (contentLength == -1) + ? buffer.length + : min(contentLength - bytesRead, buffer.length); + StreamResult readResult = sIn.readFromStream(buffer[0 .. bytesToRead]); + if (readResult.hasError) { + return readResult; + } + if (readResult.count == 0) break; + + StreamResult writeResult = outputStream.writeToStream(buffer[0 .. readResult.count]); + if (writeResult.hasError) { + return writeResult; + } + if (writeResult.count != readResult.count) { + return StreamResult(StreamError("Failed to write all bytes that were read to the output stream.", 1)); + } + bytesRead += writeResult.count; + } + + return StreamResult(cast(uint) bytesRead); + } + + /** + * Reads the request's body into a new byte array. + * Params: + * allowInfiniteRead = Whether to allow reading even without a valid + * Content-Length header. + * Returns: The byte array. + */ + ubyte[] readBodyAsBytes(bool allowInfiniteRead = false) { + auto sOut = byteArrayOutputStream(); + StreamResult r = readBody(sOut, allowInfiniteRead); + if (r.hasError) throw new Exception(cast(string) r.error.message); + return sOut.toArray(); + } + + /** + * Reads the request's body into a new string. + * Params: + * allowInfiniteRead = Whether to allow reading even without a valid + * Content-Length header. + * Returns: The string content. + */ + string readBodyAsString(bool allowInfiniteRead = false) { + return cast(string) readBodyAsBytes(allowInfiniteRead); + } } /** @@ -54,35 +178,6 @@ public enum HttpMethod : string { PATCH = "PATCH" } -/** - * Attempts to parse an HttpMethod from a string. - * Params: - * s = The string to parse. - * Returns: An optional which may contain an HttpMethod, if one was parsed. - */ -Optional!HttpMethod parseHttpMethod(string s) { - // TODO: Remove this function now that we're using plain string HTTP methods. - import std.uni : toUpper; - import std.string : strip; - static foreach (m; EnumMembers!HttpMethod) { - if (s == m) return Optional!HttpMethod.of(m); - } - const cleanStr = strip(toUpper(s)); - static foreach (m; EnumMembers!HttpMethod) { - if (cleanStr == m) return Optional!HttpMethod.of(m); - } - return Optional!HttpMethod.empty; -} - -unittest { - assert(parseHttpMethod("GET") == Optional!HttpMethod.of(HttpMethod.GET)); - assert(parseHttpMethod("get") == Optional!HttpMethod.of(HttpMethod.GET)); - assert(parseHttpMethod(" geT ") == Optional!HttpMethod.of(HttpMethod.GET)); - assert(parseHttpMethod("PATCH") == Optional!HttpMethod.of(HttpMethod.PATCH)); - assert(parseHttpMethod(" not a method!") == Optional!HttpMethod.empty); - assert(parseHttpMethod("") == Optional!HttpMethod.empty); -} - /// Stores a single query parameter's key and values. struct QueryParameter { string key; @@ -104,6 +199,8 @@ QueryParameter[] parseQueryParameters(string url) { string paramsStr = url[paramsStartIdx + 1 .. $]; QueryParameter[] params; + import std.array : RefAppender, appender; // TODO: Get rid of stdlib usage of std.array! + RefAppender!(QueryParameter[]) app = appender(¶ms); size_t idx = 0; while (idx < paramsStr.length) { // First, isolate the text up to the next '&' separator. @@ -145,7 +242,7 @@ QueryParameter[] parseQueryParameters(string url) { } // Otherwise, add a new query parameter. if (!keyExists) { - params ~= QueryParameter(key, [val]); + app ~= QueryParameter(key, [val]); } // Advance our current index pointer to the start of the next query parameter. // (past the '&' character separating query parameters)