Merge pull request #10278

7d1f779 p2p: connection patches (j-berman)
This commit is contained in:
tobtoht
2026-02-03 19:51:50 +00:00
3 changed files with 173 additions and 62 deletions

View File

@@ -128,6 +128,7 @@ namespace net_utils
void start_handshake();
void start_read();
void finish_read(size_t bytes_transferred);
void start_write();
void start_shutdown();
void cancel_socket();
@@ -139,6 +140,7 @@ namespace net_utils
void terminate();
void on_terminating();
void terminate_async();
bool send(epee::byte_slice message);
bool start_internal(
@@ -192,6 +194,7 @@ namespace net_utils
bool wait_read;
bool handle_read;
bool cancel_read;
bool shutdown_read;
bool wait_write;
bool handle_write;

View File

@@ -173,7 +173,7 @@ namespace net_utils
return;
m_state.timers.general.wait_expire = true;
auto self = connection<T>::shared_from_this();
m_timers.general.async_wait([this, self](const ec_t & ec){
auto on_wait = [this, self] {
std::lock_guard<std::mutex> guard(m_state.lock);
m_state.timers.general.wait_expire = false;
if (m_state.timers.general.cancel_expire) {
@@ -191,6 +191,9 @@ namespace net_utils
interrupt();
else if (m_state.status == status_t::INTERRUPTED)
terminate();
};
m_timers.general.async_wait([this, self, on_wait](const ec_t & ec){
boost::asio::post(m_strand, on_wait);
});
}
@@ -244,27 +247,7 @@ namespace net_utils
)
) {
m_state.ssl.enabled = false;
m_state.socket.handle_read = true;
boost::asio::post(
connection_basic::strand_,
[this, self, bytes_transferred]{
bool success = m_handler.handle_recv(
reinterpret_cast<char *>(m_state.data.read.buffer.data()),
bytes_transferred
);
std::lock_guard<std::mutex> guard(m_state.lock);
m_state.socket.handle_read = false;
if (m_state.status == status_t::INTERRUPTED)
on_interrupted();
else if (m_state.status == status_t::TERMINATING)
on_terminating();
else if (!success)
interrupt();
else {
start_read();
}
}
);
finish_read(bytes_transferred);
}
else {
m_state.ssl.detected = true;
@@ -324,7 +307,7 @@ namespace net_utils
void connection<T>::start_read()
{
if (m_state.timers.throttle.in.wait_expire || m_state.socket.wait_read ||
m_state.socket.handle_read
m_state.socket.handle_read || m_state.socket.shutdown_read
) {
return;
}
@@ -348,7 +331,7 @@ namespace net_utils
if (duration > duration_t{}) {
m_timers.throttle.in.expires_after(duration);
m_state.timers.throttle.in.wait_expire = true;
m_timers.throttle.in.async_wait([this, self](const ec_t &ec){
auto on_wait = [this, self](const ec_t &ec){
std::lock_guard<std::mutex> guard(m_state.lock);
m_state.timers.throttle.in.wait_expire = false;
if (m_state.timers.throttle.in.cancel_expire) {
@@ -357,8 +340,16 @@ namespace net_utils
}
else if (ec.value())
interrupt();
else
};
m_timers.throttle.in.async_wait([this, self, on_wait](const ec_t &ec){
std::lock_guard<std::mutex> guard(m_state.lock);
const bool error_status = m_state.timers.throttle.in.cancel_expire || ec.value();
if (error_status)
boost::asio::post(m_strand, std::bind(on_wait, ec));
else {
m_state.timers.throttle.in.wait_expire = false;
start_read();
}
});
return;
}
@@ -394,33 +385,7 @@ namespace net_utils
m_conn_context.m_recv_cnt += bytes_transferred;
start_timer(get_timeout_from_bytes_read(bytes_transferred), true);
}
// Post handle_recv to a separate `strand_`, distinct from `m_strand`
// which is listening for reads/writes. This avoids a circular dep.
// handle_recv can queue many writes, and `m_strand` will process those
// writes until the connection terminates without deadlocking waiting
// for handle_recv.
m_state.socket.handle_read = true;
boost::asio::post(
connection_basic::strand_,
[this, self, bytes_transferred]{
bool success = m_handler.handle_recv(
reinterpret_cast<char *>(m_state.data.read.buffer.data()),
bytes_transferred
);
std::lock_guard<std::mutex> guard(m_state.lock);
m_state.socket.handle_read = false;
if (m_state.status == status_t::INTERRUPTED)
on_interrupted();
else if (m_state.status == status_t::TERMINATING)
on_terminating();
else if (!success)
interrupt();
else {
start_read();
}
}
);
finish_read(bytes_transferred);
}
};
if (!m_state.ssl.enabled)
@@ -446,6 +411,62 @@ namespace net_utils
);
}
template<typename T>
void connection<T>::finish_read(size_t bytes_transferred)
{
// Post handle_recv to a separate `strand_`, distinct from `m_strand`
// which is listening for reads/writes. This avoids a circular dep.
// handle_recv can queue many writes, and `m_strand` will process those
// writes until the connection terminates without deadlocking waiting
// for handle_recv.
m_state.socket.handle_read = true;
auto self = connection<T>::shared_from_this();
boost::asio::post(
connection_basic::strand_,
[this, self, bytes_transferred]{
bool success = m_handler.handle_recv(
reinterpret_cast<char *>(m_state.data.read.buffer.data()),
bytes_transferred
);
std::lock_guard<std::mutex> guard(m_state.lock);
const bool error_status = m_state.status == status_t::INTERRUPTED
|| m_state.status == status_t::TERMINATING
|| !success;
if (!error_status) {
m_state.socket.handle_read = false;
start_read();
return;
}
boost::asio::post(
m_strand,
[this, self, success]{
// expect error_status == true
std::lock_guard<std::mutex> guard(m_state.lock);
m_state.socket.handle_read = false;
if (m_state.status == status_t::INTERRUPTED)
on_interrupted();
else if (m_state.status == status_t::TERMINATING)
on_terminating();
else if (!success) {
ec_t ec;
if (m_state.socket.wait_write) {
// Allow the already queued writes time to finish, but no more new reads
connection_basic::socket_.next_layer().shutdown(
socket_t::shutdown_receive,
ec
);
m_state.socket.shutdown_read = true;
}
if (!m_state.socket.wait_write || ec.value()) {
interrupt();
}
}
}
);
}
);
}
template<typename T>
void connection<T>::start_write()
{
@@ -477,7 +498,7 @@ namespace net_utils
if (duration > duration_t{}) {
m_timers.throttle.out.expires_after(duration);
m_state.timers.throttle.out.wait_expire = true;
m_timers.throttle.out.async_wait([this, self](const ec_t &ec){
auto on_wait = [this, self](const ec_t &ec){
std::lock_guard<std::mutex> guard(m_state.lock);
m_state.timers.throttle.out.wait_expire = false;
if (m_state.timers.throttle.out.cancel_expire) {
@@ -486,8 +507,16 @@ namespace net_utils
}
else if (ec.value())
interrupt();
else
};
m_timers.throttle.out.async_wait([this, self, on_wait](const ec_t &ec){
std::lock_guard<std::mutex> guard(m_state.lock);
const bool error_status = m_state.timers.throttle.out.cancel_expire || ec.value();
if (error_status)
boost::asio::post(m_strand, std::bind(on_wait, ec));
else {
m_state.timers.throttle.out.wait_expire = false;
start_write();
}
});
}
}
@@ -535,7 +564,12 @@ namespace net_utils
m_state.data.write.total_bytes -=
std::min(m_state.data.write.total_bytes, byte_count);
m_state.condition.notify_all();
start_write();
if (m_state.data.write.queue.empty() && m_state.socket.shutdown_read) {
// All writes have been sent and reads shutdown already, connection can be closed
interrupt();
} else {
start_write();
}
}
};
if (!m_state.ssl.enabled)
@@ -764,6 +798,17 @@ namespace net_utils
m_state.status = status_t::WASTED;
}
template<typename T>
void connection<T>::terminate_async()
{
// synchronize with intermediate writes on `m_strand`
auto self = connection<T>::shared_from_this();
boost::asio::post(m_strand, [this, self] {
std::lock_guard<std::mutex> guard(m_state.lock);
terminate();
});
}
template<typename T>
bool connection<T>::send(epee::byte_slice message)
{
@@ -816,12 +861,7 @@ namespace net_utils
);
m_state.data.write.wait_consume = false;
if (!success) {
// synchronize with intermediate writes on `m_strand`
auto self = connection<T>::shared_from_this();
boost::asio::post(m_strand, [this, self] {
std::lock_guard<std::mutex> guard(m_state.lock);
terminate();
});
terminate_async();
return false;
}
else
@@ -1095,7 +1135,7 @@ namespace net_utils
std::lock_guard<std::mutex> guard(m_state.lock);
if (m_state.status != status_t::RUNNING)
return false;
terminate();
terminate_async();
return true;
}

View File

@@ -202,3 +202,71 @@ TEST(http_server, private_ip_limit)
failed |= bool(error);
EXPECT_TRUE(failed);
}
TEST(http_server, read_then_close)
{
namespace http = boost::beast::http;
http_server server{};
server.dummy_size = 200000;
server.init(nullptr, "8080");
server.run(2, false); // need at least 2 threads to trigger issues
bool failed_read = false;
bool closed_all_connections = true;
for (std::size_t j = 0; j < 1000; ++j)
{
boost::system::error_code error{};
boost::asio::io_context context{};
boost::asio::ip::tcp::socket stream{context};
stream.connect(
boost::asio::ip::tcp::endpoint{
boost::asio::ip::make_address("127.0.0.1"), 8080
},
error
);
EXPECT_FALSE(bool(error));
http::request<http::string_body> req{http::verb::get, "/dummy", 11};
req.set(http::field::host, "127.0.0.1");
req.set(http::field::user_agent, BOOST_BEAST_VERSION_STRING);
req.set(http::field::connection, "close"); // tell server to close connection after sending all data to the client
req.body() = make_payload();
req.prepare_payload();
dummy::response payload{};
boost::beast::flat_buffer buffer;
http::response_parser<http::basic_string_body<char>> parser;
parser.body_limit(server.dummy_size + 1024);
http::write(stream, req, error);
EXPECT_FALSE(bool(error));
http::read(stream, buffer, parser, error);
// If the read fails, continue the loop still just to make sure the server can handle it
failed_read |= bool(error);
if (failed_read)
continue;
failed_read |= !(parser.is_done());
if (failed_read)
continue;
const auto res = parser.release();
failed_read |= res.result_int() != 200u
|| !(epee::serialization::load_t_from_binary(payload, res.body()))
|| (server.dummy_size != std::count(payload.payload.begin(), payload.payload.end(), 'f'));
// See if the server closes the connection after handling the resp
char buf[1];
stream.read_some(boost::asio::buffer(buf), error);
closed_all_connections &= error == boost::asio::error::eof;
}
// The client should have been able to read all data sent by the server across all requests
EXPECT_FALSE(failed_read);
// The server should have closed all connections
EXPECT_TRUE(closed_all_connections);
server.send_stop_signal();
}