p2p: connection patches

- Make sure the server sends a complete response when the client
includes the "Connection: close" header.
- Make sure the server terminates in `m_strand` to avoid
concurrent socket closure and ops processing.
This commit is contained in:
j-berman
2026-01-08 16:10:36 -08:00
parent 4ce39e0c14
commit ee9e4a49ba
3 changed files with 173 additions and 62 deletions
@@ -128,6 +128,7 @@ namespace net_utils
void start_handshake();
void start_read();
void finish_read(size_t bytes_transferred);
void start_write();
void start_shutdown();
void cancel_socket();
@@ -139,6 +140,7 @@ namespace net_utils
void terminate();
void on_terminating();
void terminate_async();
bool send(epee::byte_slice message);
bool start_internal(
@@ -192,6 +194,7 @@ namespace net_utils
bool wait_read;
bool handle_read;
bool cancel_read;
bool shutdown_read;
bool wait_write;
bool handle_write;
+102 -62
View File
@@ -171,7 +171,7 @@ namespace net_utils
return;
m_state.timers.general.wait_expire = true;
auto self = connection<T>::shared_from_this();
m_timers.general.async_wait([this, self](const ec_t & ec){
auto on_wait = [this, self] {
std::lock_guard<std::mutex> guard(m_state.lock);
m_state.timers.general.wait_expire = false;
if (m_state.timers.general.cancel_expire) {
@@ -189,6 +189,9 @@ namespace net_utils
interrupt();
else if (m_state.status == status_t::INTERRUPTED)
terminate();
};
m_timers.general.async_wait([this, self, on_wait](const ec_t & ec){
boost::asio::post(m_strand, on_wait);
});
}
@@ -242,27 +245,7 @@ namespace net_utils
)
) {
m_state.ssl.enabled = false;
m_state.socket.handle_read = true;
boost::asio::post(
connection_basic::strand_,
[this, self, bytes_transferred]{
bool success = m_handler.handle_recv(
reinterpret_cast<char *>(m_state.data.read.buffer.data()),
bytes_transferred
);
std::lock_guard<std::mutex> guard(m_state.lock);
m_state.socket.handle_read = false;
if (m_state.status == status_t::INTERRUPTED)
on_interrupted();
else if (m_state.status == status_t::TERMINATING)
on_terminating();
else if (!success)
interrupt();
else {
start_read();
}
}
);
finish_read(bytes_transferred);
}
else {
m_state.ssl.detected = true;
@@ -322,7 +305,7 @@ namespace net_utils
void connection<T>::start_read()
{
if (m_state.timers.throttle.in.wait_expire || m_state.socket.wait_read ||
m_state.socket.handle_read
m_state.socket.handle_read || m_state.socket.shutdown_read
) {
return;
}
@@ -346,7 +329,7 @@ namespace net_utils
if (duration > duration_t{}) {
m_timers.throttle.in.expires_after(duration);
m_state.timers.throttle.in.wait_expire = true;
m_timers.throttle.in.async_wait([this, self](const ec_t &ec){
auto on_wait = [this, self](const ec_t &ec){
std::lock_guard<std::mutex> guard(m_state.lock);
m_state.timers.throttle.in.wait_expire = false;
if (m_state.timers.throttle.in.cancel_expire) {
@@ -355,8 +338,16 @@ namespace net_utils
}
else if (ec.value())
interrupt();
else
};
m_timers.throttle.in.async_wait([this, self, on_wait](const ec_t &ec){
std::lock_guard<std::mutex> guard(m_state.lock);
const bool error_status = m_state.timers.throttle.in.cancel_expire || ec.value();
if (error_status)
boost::asio::post(m_strand, std::bind(on_wait, ec));
else {
m_state.timers.throttle.in.wait_expire = false;
start_read();
}
});
return;
}
@@ -392,33 +383,7 @@ namespace net_utils
m_conn_context.m_recv_cnt += bytes_transferred;
start_timer(get_timeout_from_bytes_read(bytes_transferred), true);
}
// Post handle_recv to a separate `strand_`, distinct from `m_strand`
// which is listening for reads/writes. This avoids a circular dep.
// handle_recv can queue many writes, and `m_strand` will process those
// writes until the connection terminates without deadlocking waiting
// for handle_recv.
m_state.socket.handle_read = true;
boost::asio::post(
connection_basic::strand_,
[this, self, bytes_transferred]{
bool success = m_handler.handle_recv(
reinterpret_cast<char *>(m_state.data.read.buffer.data()),
bytes_transferred
);
std::lock_guard<std::mutex> guard(m_state.lock);
m_state.socket.handle_read = false;
if (m_state.status == status_t::INTERRUPTED)
on_interrupted();
else if (m_state.status == status_t::TERMINATING)
on_terminating();
else if (!success)
interrupt();
else {
start_read();
}
}
);
finish_read(bytes_transferred);
}
};
if (!m_state.ssl.enabled)
@@ -444,6 +409,62 @@ namespace net_utils
);
}
template<typename T>
void connection<T>::finish_read(size_t bytes_transferred)
{
// Post handle_recv to a separate `strand_`, distinct from `m_strand`
// which is listening for reads/writes. This avoids a circular dep.
// handle_recv can queue many writes, and `m_strand` will process those
// writes until the connection terminates without deadlocking waiting
// for handle_recv.
m_state.socket.handle_read = true;
auto self = connection<T>::shared_from_this();
boost::asio::post(
connection_basic::strand_,
[this, self, bytes_transferred]{
bool success = m_handler.handle_recv(
reinterpret_cast<char *>(m_state.data.read.buffer.data()),
bytes_transferred
);
std::lock_guard<std::mutex> guard(m_state.lock);
const bool error_status = m_state.status == status_t::INTERRUPTED
|| m_state.status == status_t::TERMINATING
|| !success;
if (!error_status) {
m_state.socket.handle_read = false;
start_read();
return;
}
boost::asio::post(
m_strand,
[this, self, success]{
// expect error_status == true
std::lock_guard<std::mutex> guard(m_state.lock);
m_state.socket.handle_read = false;
if (m_state.status == status_t::INTERRUPTED)
on_interrupted();
else if (m_state.status == status_t::TERMINATING)
on_terminating();
else if (!success) {
ec_t ec;
if (m_state.socket.wait_write) {
// Allow the already queued writes time to finish, but no more new reads
connection_basic::socket_.next_layer().shutdown(
socket_t::shutdown_receive,
ec
);
m_state.socket.shutdown_read = true;
}
if (!m_state.socket.wait_write || ec.value()) {
interrupt();
}
}
}
);
}
);
}
template<typename T>
void connection<T>::start_write()
{
@@ -475,7 +496,7 @@ namespace net_utils
if (duration > duration_t{}) {
m_timers.throttle.out.expires_after(duration);
m_state.timers.throttle.out.wait_expire = true;
m_timers.throttle.out.async_wait([this, self](const ec_t &ec){
auto on_wait = [this, self](const ec_t &ec){
std::lock_guard<std::mutex> guard(m_state.lock);
m_state.timers.throttle.out.wait_expire = false;
if (m_state.timers.throttle.out.cancel_expire) {
@@ -484,8 +505,16 @@ namespace net_utils
}
else if (ec.value())
interrupt();
else
};
m_timers.throttle.out.async_wait([this, self, on_wait](const ec_t &ec){
std::lock_guard<std::mutex> guard(m_state.lock);
const bool error_status = m_state.timers.throttle.out.cancel_expire || ec.value();
if (error_status)
boost::asio::post(m_strand, std::bind(on_wait, ec));
else {
m_state.timers.throttle.out.wait_expire = false;
start_write();
}
});
}
}
@@ -533,7 +562,12 @@ namespace net_utils
m_state.data.write.total_bytes -=
std::min(m_state.data.write.total_bytes, byte_count);
m_state.condition.notify_all();
start_write();
if (m_state.data.write.queue.empty() && m_state.socket.shutdown_read) {
// All writes have been sent and reads shutdown already, connection can be closed
interrupt();
} else {
start_write();
}
}
};
if (!m_state.ssl.enabled)
@@ -762,6 +796,17 @@ namespace net_utils
m_state.status = status_t::WASTED;
}
template<typename T>
void connection<T>::terminate_async()
{
// synchronize with intermediate writes on `m_strand`
auto self = connection<T>::shared_from_this();
boost::asio::post(m_strand, [this, self] {
std::lock_guard<std::mutex> guard(m_state.lock);
terminate();
});
}
template<typename T>
bool connection<T>::send(epee::byte_slice message)
{
@@ -814,12 +859,7 @@ namespace net_utils
);
m_state.data.write.wait_consume = false;
if (!success) {
// synchronize with intermediate writes on `m_strand`
auto self = connection<T>::shared_from_this();
boost::asio::post(m_strand, [this, self] {
std::lock_guard<std::mutex> guard(m_state.lock);
terminate();
});
terminate_async();
return false;
}
else
@@ -1093,7 +1133,7 @@ namespace net_utils
std::lock_guard<std::mutex> guard(m_state.lock);
if (m_state.status != status_t::RUNNING)
return false;
terminate();
terminate_async();
return true;
}
+68
View File
@@ -198,3 +198,71 @@ TEST(http_server, private_ip_limit)
failed |= bool(error);
EXPECT_TRUE(failed);
}
TEST(http_server, read_then_close)
{
namespace http = boost::beast::http;
http_server server{};
server.dummy_size = 200000;
server.init(nullptr, "8080");
server.run(2, false); // need at least 2 threads to trigger issues
bool failed_read = false;
bool closed_all_connections = true;
for (std::size_t j = 0; j < 1000; ++j)
{
boost::system::error_code error{};
boost::asio::io_context context{};
boost::asio::ip::tcp::socket stream{context};
stream.connect(
boost::asio::ip::tcp::endpoint{
boost::asio::ip::make_address("127.0.0.1"), 8080
},
error
);
EXPECT_FALSE(bool(error));
http::request<http::string_body> req{http::verb::get, "/dummy", 11};
req.set(http::field::host, "127.0.0.1");
req.set(http::field::user_agent, BOOST_BEAST_VERSION_STRING);
req.set(http::field::connection, "close"); // tell server to close connection after sending all data to the client
req.body() = make_payload();
req.prepare_payload();
dummy::response payload{};
boost::beast::flat_buffer buffer;
http::response_parser<http::basic_string_body<char>> parser;
parser.body_limit(server.dummy_size + 1024);
http::write(stream, req, error);
EXPECT_FALSE(bool(error));
http::read(stream, buffer, parser, error);
// If the read fails, continue the loop still just to make sure the server can handle it
failed_read |= bool(error);
if (failed_read)
continue;
failed_read |= !(parser.is_done());
if (failed_read)
continue;
const auto res = parser.release();
failed_read |= res.result_int() != 200u
|| !(epee::serialization::load_t_from_binary(payload, res.body()))
|| (server.dummy_size != std::count(payload.payload.begin(), payload.payload.end(), 'f'));
// See if the server closes the connection after handling the resp
char buf[1];
stream.read_some(boost::asio::buffer(buf), error);
closed_all_connections &= error == boost::asio::error::eof;
}
// The client should have been able to read all data sent by the server across all requests
EXPECT_FALSE(failed_read);
// The server should have closed all connections
EXPECT_TRUE(closed_all_connections);
server.send_stop_signal();
}