diff --git a/include/HttpServer.h b/include/HttpServer.h index baf0f93..7c8b411 100644 --- a/include/HttpServer.h +++ b/include/HttpServer.h @@ -14,7 +14,7 @@ class HttpServer { ~HttpServer(); - std::atomic stop_flag; + std::atomic stop_flag {}; struct Request { std::string_view route; diff --git a/src/HttpServer.cpp b/src/HttpServer.cpp index 069e8e9..82e7a04 100644 --- a/src/HttpServer.cpp +++ b/src/HttpServer.cpp @@ -52,7 +52,7 @@ void HttpServer::listen(int port) { struct sockaddr_storage incoming_addr {}; socklen_t addr_size {sizeof(incoming_addr)}; - int conn_file_descriptor = accept(listener_fd, (struct sockaddr*)&incoming_addr, &addr_size); + int conn_file_descriptor = accept(listener_fd, reinterpret_cast(&incoming_addr), &addr_size); if (conn_file_descriptor == -1) { // If we're stopping, accept failures are expected; don't spam logs. if (stop_flag.load()) break; @@ -143,12 +143,14 @@ void HttpServer::handle_client() { // std::cout << "route: " << route << '\n'; // get body - size_t req_body_start = path.find("\r\n\r\n") + 4; - if (req_body_start == std::string_view::npos) { + size_t req_body_delimiter = path.find("\r\n\r\n"); + if (req_body_delimiter == std::string_view::npos) { close (conn_fd); std::cerr << "Invalid request formatting: the start of the request body is malformed\n"; + continue; } + size_t req_body_start = req_body_delimiter + 4; std::string_view req_body = path.substr(req_body_start, path.size() - req_body_start); // std::cout << "body: " << req_body << '\n'; diff --git a/test/HttpServerTest.cpp b/test/HttpServerTest.cpp index beb4dc2..2af9e60 100644 --- a/test/HttpServerTest.cpp +++ b/test/HttpServerTest.cpp @@ -4,27 +4,21 @@ #include #include -class HttpServerTest : public ::testing::Test {}; +class HttpServerTest : public ::testing::Test { +public: + static constexpr int port {8081}; +}; TEST(ServerTest, ConstructorDestructorTest) { HttpServer server {}; } -TEST(HttpServerTest, ServerStartsAndAcceptsRequests) { - HttpServer server {}; - - // Start server in non-blocking mode - server.start_listening(8080); - - // Send real HTTP request using curl - int result = system("curl -s http://localhost:8080 > /dev/null"); - - EXPECT_EQ(result, 0); -} - TEST(HttpServerTest, AcceptsHttpRequest) { - HttpServer server; - server.start_listening(8081); + HttpServer server {}; + server.get_mapping("/", [](const HttpServer::Request&, HttpServer::Response& res) { + res.body = "test"; + }); + server.start_listening(HttpServerTest::port); std::this_thread::sleep_for(std::chrono::milliseconds(1)); @@ -32,12 +26,16 @@ TEST(HttpServerTest, AcceptsHttpRequest) { sockaddr_in addr{}; addr.sin_family = AF_INET; - addr.sin_port = htons(8081); + addr.sin_port = htons(HttpServerTest::port); addr.sin_addr.s_addr = inet_addr("127.0.0.1"); - ASSERT_EQ(connect(sock, (sockaddr*)&addr, sizeof(addr)), 0); + ASSERT_EQ(connect(sock, reinterpret_cast(&addr), sizeof(addr)), 0); - const char* request = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n"; + const char* request = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n" + "Host: localhost\r\n" + "Connection: keep-alive\r\n" + "Content-Length: 0\r\n" + "\r\n"; send(sock, request, strlen(request), 0); char buffer[1024]; @@ -53,7 +51,7 @@ TEST(HttpServerTest, AcceptGetRequest) { server.get_mapping("/hello", [](const HttpServer::Request&, HttpServer::Response& res){ res.body = "hello, world"; }); - server.start_listening(8082); + server.start_listening(HttpServerTest::port); std::this_thread::sleep_for(std::chrono::milliseconds(1)); @@ -61,10 +59,10 @@ TEST(HttpServerTest, AcceptGetRequest) { sockaddr_in addr{}; addr.sin_family = AF_INET; - addr.sin_port = htons(8082); + addr.sin_port = htons(HttpServerTest::port); addr.sin_addr.s_addr = inet_addr("127.0.0.1"); - ASSERT_EQ(connect(sock, (sockaddr*)&addr, sizeof(addr)), 0); + ASSERT_EQ(connect(sock, reinterpret_cast(&addr), sizeof(addr)), 0); const char* request = "GET /hello HTTP/1.1\r\n\r\n"; send(sock, request, strlen(request), 0); @@ -84,7 +82,7 @@ TEST(HttpServerTest, IgnoreGetReqBody) { server.get_mapping("/hello", [](const HttpServer::Request& req, HttpServer::Response& res){ res.body = req.body; }); - server.start_listening(8082); + server.start_listening(HttpServerTest::port); std::this_thread::sleep_for(std::chrono::milliseconds(1)); @@ -92,10 +90,10 @@ TEST(HttpServerTest, IgnoreGetReqBody) { sockaddr_in addr{}; addr.sin_family = AF_INET; - addr.sin_port = htons(8082); + addr.sin_port = htons(HttpServerTest::port); addr.sin_addr.s_addr = inet_addr("127.0.0.1"); - ASSERT_EQ(connect(sock, (sockaddr*)&addr, sizeof(addr)), 0); + ASSERT_EQ(connect(sock, reinterpret_cast(&addr), sizeof(addr)), 0); const char* request = "GET /hello HTTP/1.1\r\n\r\nhello, world"; send(sock, request, strlen(request), 0); @@ -113,36 +111,44 @@ TEST(HttpServerTest, IgnoreGetReqBody) { } TEST(HttpServerTest, DoesntIgnorePostReqBody) { - HttpServer server {}; - server.post_mapping("/post-foo", [](const HttpServer::Request& req, HttpServer::Response& res){ - res.body = req.body; - }); - server.start_listening(8082); + try { + HttpServer server {}; + server.post_mapping("/foo", [](const HttpServer::Request& req, HttpServer::Response& res){ + res.body = req.body; + }); + server.start_listening(HttpServerTest::port); - std::this_thread::sleep_for(std::chrono::milliseconds(1)); + std::this_thread::sleep_for(std::chrono::milliseconds(1)); - int sock = socket(AF_INET, SOCK_STREAM, 0); + int sock = socket(AF_INET, SOCK_STREAM, 0); - sockaddr_in addr{}; - addr.sin_family = AF_INET; - addr.sin_port = htons(8082); - addr.sin_addr.s_addr = inet_addr("127.0.0.1"); + sockaddr_in addr{}; + addr.sin_family = AF_INET; + addr.sin_port = htons(HttpServerTest::port); + addr.sin_addr.s_addr = inet_addr("127.0.0.1"); - ASSERT_EQ(connect(sock, (sockaddr*)&addr, sizeof(addr)), 0); + ASSERT_EQ(connect(sock, reinterpret_cast(&addr), sizeof(addr)), 0); - const char* request = "POST /post-foo HTTP/1.1\r\n\r\nhello, world"; - send(sock, request, strlen(request), 0); + std::string request = "POST /foo HTTP/1.1\r\n" + "Host: localhost\r\n" + "Connection: keep-alive\r\n" + "Content-Length: 5\r\n" + "\r\n" + "hello"; - char buffer[1024] {}; - int bytes = recv(sock, buffer, sizeof(buffer), 0); - std::string result = std::string(buffer); + send(sock, request.c_str(), request.size(), 0); - EXPECT_GT(bytes, 0); + char buffer[1024] {}; + int bytes = recv(sock, buffer, sizeof(buffer), 0); + std::string result = std::string(buffer); - // Should find "hello, world" as setting the request body - ASSERT_TRUE(result.find("hello, world") != std::string::npos); + EXPECT_GT(bytes, 0); + ASSERT_TRUE(result.find("hello") != std::string::npos); - close(sock); + close(sock); + } catch (const std::exception& e) { + FAIL() << "Exception occurred: " << e.what(); + } } TEST(HttpServerTest, AllUniqueReqMethods) { @@ -177,13 +183,13 @@ TEST(HttpServerTest, AllUniqueReqMethods) { res.body = "8"; }); - server.start_listening(8083); + server.start_listening(HttpServerTest::port); std::this_thread::sleep_for(std::chrono::milliseconds(1)); sockaddr_in addr{}; addr.sin_family = AF_INET; - addr.sin_port = htons(8083); + addr.sin_port = htons(HttpServerTest::port); addr.sin_addr.s_addr = inet_addr("127.0.0.1"); const std::string methods[9] = { "GET", "POST", "PUT", "PATCH", "OPTIONS", "HEAD", "DELETE", "CONNECT", "TRACE" }; @@ -195,7 +201,7 @@ TEST(HttpServerTest, AllUniqueReqMethods) { "\r\n"; int listener_fd = socket(AF_INET, SOCK_STREAM, 0); - ASSERT_EQ(connect(listener_fd, (sockaddr*)&addr, sizeof(addr)), 0); + ASSERT_EQ(connect(listener_fd, reinterpret_cast(&addr), sizeof(addr)), 0); send(listener_fd, request.c_str(), request.size(), 0); char buffer[1024] {}; @@ -207,3 +213,99 @@ TEST(HttpServerTest, AllUniqueReqMethods) { ASSERT_TRUE(close(listener_fd) != -1); } } + +TEST(HttpServerTest, HandleNonExistentGetRoute) { + HttpServer server {}; + server.start_listening(HttpServerTest::port); + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + + sockaddr_in addr{}; + addr.sin_family = AF_INET; + addr.sin_port = htons(HttpServerTest::port); + addr.sin_addr.s_addr = inet_addr("127.0.0.1"); + + std::string request = "GET /foo HTTP/1.1\r\n" + "Host: localhost\r\n" + "Connection: keep-alive\r\n" + "Content-Length: 0\r\n" + "\r\n"; + + int listener_fd = socket(AF_INET, SOCK_STREAM, 0); + ASSERT_EQ(connect(listener_fd, reinterpret_cast(&addr), sizeof(addr)), 0); + send(listener_fd, request.c_str(), request.size(), 0); + + char buffer[1024] {}; + int bytes = recv(listener_fd, buffer, sizeof(buffer), 0); + std::string result = std::string(buffer); + + EXPECT_GT(bytes, 0); + ASSERT_TRUE(result.find("404 Not Found") != std::string::npos); + ASSERT_TRUE(close(listener_fd) != -1); +} + +/* + * This test covers a different branch than HandleNonExistentGetRoute + * because POST requests can handle the request body + */ +TEST(HttpServerTest, HandleNonExistentPostRoute) { + HttpServer server {}; + server.start_listening(HttpServerTest::port); + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + + sockaddr_in addr{}; + addr.sin_family = AF_INET; + addr.sin_port = htons(HttpServerTest::port); + addr.sin_addr.s_addr = inet_addr("127.0.0.1"); + + std::string request = "POST /foo HTTP/1.1\r\n" + "Host: localhost\r\n" + "Connection: keep-alive\r\n" + "Content-Length: 0\r\n" + "\r\n"; + + int listener_fd = socket(AF_INET, SOCK_STREAM, 0); + ASSERT_EQ(connect(listener_fd, reinterpret_cast(&addr), sizeof(addr)), 0); + send(listener_fd, request.c_str(), request.size(), 0); + + char buffer[1024] {}; + int bytes = recv(listener_fd, buffer, sizeof(buffer), 0); + std::string result = std::string(buffer); + + EXPECT_GT(bytes, 0); + ASSERT_TRUE(result.find("404 Not Found") != std::string::npos); + ASSERT_TRUE(close(listener_fd) != -1); +} + +TEST(HttpServerTest, HandleNonExistentHttpMethod) { + HttpServer server {}; + server.start_listening(HttpServerTest::port); + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + + sockaddr_in addr{}; + addr.sin_family = AF_INET; + addr.sin_port = htons(HttpServerTest::port); + addr.sin_addr.s_addr = inet_addr("127.0.0.1"); + + std::string request = "FOO /foo HTTP/1.1\r\n" + "Host: localhost\r\n" + "Connection: keep-alive\r\n" + "Content-Length: 0\r\n" + "\r\n"; + + int listener_fd = socket(AF_INET, SOCK_STREAM, 0); + ASSERT_EQ(connect(listener_fd, reinterpret_cast(&addr), sizeof(addr)), 0); + send(listener_fd, request.c_str(), request.size(), 0); + + char buffer[1024] {}; + int bytes = recv(listener_fd, buffer, sizeof(buffer), 0); + std::string result = std::string(buffer); + + EXPECT_GT(bytes, 0); + ASSERT_TRUE(result.find("500 Error") != std::string::npos); + ASSERT_TRUE(close(listener_fd) != -1); +} + +TEST(HttpServerTest, ListenThrowsIfSocketInvalid) { + HttpServer server {}; + EXPECT_THROW(server.listen(-1), std::runtime_error); +}