mirror of
https://github.com/monero-project/monero.git
synced 2026-04-28 11:53:17 -07:00
@@ -128,6 +128,7 @@ namespace net_utils
|
||||
|
||||
void start_handshake();
|
||||
void start_read();
|
||||
void finish_read(size_t bytes_transferred);
|
||||
void start_write();
|
||||
void start_shutdown();
|
||||
void cancel_socket();
|
||||
@@ -139,6 +140,7 @@ namespace net_utils
|
||||
|
||||
void terminate();
|
||||
void on_terminating();
|
||||
void terminate_async();
|
||||
|
||||
bool send(epee::byte_slice message);
|
||||
bool start_internal(
|
||||
@@ -192,6 +194,7 @@ namespace net_utils
|
||||
bool wait_read;
|
||||
bool handle_read;
|
||||
bool cancel_read;
|
||||
bool shutdown_read;
|
||||
|
||||
bool wait_write;
|
||||
bool handle_write;
|
||||
|
||||
@@ -173,7 +173,7 @@ namespace net_utils
|
||||
return;
|
||||
m_state.timers.general.wait_expire = true;
|
||||
auto self = connection<T>::shared_from_this();
|
||||
m_timers.general.async_wait([this, self](const ec_t & ec){
|
||||
auto on_wait = [this, self] {
|
||||
std::lock_guard<std::mutex> guard(m_state.lock);
|
||||
m_state.timers.general.wait_expire = false;
|
||||
if (m_state.timers.general.cancel_expire) {
|
||||
@@ -191,6 +191,9 @@ namespace net_utils
|
||||
interrupt();
|
||||
else if (m_state.status == status_t::INTERRUPTED)
|
||||
terminate();
|
||||
};
|
||||
m_timers.general.async_wait([this, self, on_wait](const ec_t & ec){
|
||||
boost::asio::post(m_strand, on_wait);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -244,27 +247,7 @@ namespace net_utils
|
||||
)
|
||||
) {
|
||||
m_state.ssl.enabled = false;
|
||||
m_state.socket.handle_read = true;
|
||||
boost::asio::post(
|
||||
connection_basic::strand_,
|
||||
[this, self, bytes_transferred]{
|
||||
bool success = m_handler.handle_recv(
|
||||
reinterpret_cast<char *>(m_state.data.read.buffer.data()),
|
||||
bytes_transferred
|
||||
);
|
||||
std::lock_guard<std::mutex> guard(m_state.lock);
|
||||
m_state.socket.handle_read = false;
|
||||
if (m_state.status == status_t::INTERRUPTED)
|
||||
on_interrupted();
|
||||
else if (m_state.status == status_t::TERMINATING)
|
||||
on_terminating();
|
||||
else if (!success)
|
||||
interrupt();
|
||||
else {
|
||||
start_read();
|
||||
}
|
||||
}
|
||||
);
|
||||
finish_read(bytes_transferred);
|
||||
}
|
||||
else {
|
||||
m_state.ssl.detected = true;
|
||||
@@ -324,7 +307,7 @@ namespace net_utils
|
||||
void connection<T>::start_read()
|
||||
{
|
||||
if (m_state.timers.throttle.in.wait_expire || m_state.socket.wait_read ||
|
||||
m_state.socket.handle_read
|
||||
m_state.socket.handle_read || m_state.socket.shutdown_read
|
||||
) {
|
||||
return;
|
||||
}
|
||||
@@ -348,7 +331,7 @@ namespace net_utils
|
||||
if (duration > duration_t{}) {
|
||||
m_timers.throttle.in.expires_after(duration);
|
||||
m_state.timers.throttle.in.wait_expire = true;
|
||||
m_timers.throttle.in.async_wait([this, self](const ec_t &ec){
|
||||
auto on_wait = [this, self](const ec_t &ec){
|
||||
std::lock_guard<std::mutex> guard(m_state.lock);
|
||||
m_state.timers.throttle.in.wait_expire = false;
|
||||
if (m_state.timers.throttle.in.cancel_expire) {
|
||||
@@ -357,8 +340,16 @@ namespace net_utils
|
||||
}
|
||||
else if (ec.value())
|
||||
interrupt();
|
||||
else
|
||||
};
|
||||
m_timers.throttle.in.async_wait([this, self, on_wait](const ec_t &ec){
|
||||
std::lock_guard<std::mutex> guard(m_state.lock);
|
||||
const bool error_status = m_state.timers.throttle.in.cancel_expire || ec.value();
|
||||
if (error_status)
|
||||
boost::asio::post(m_strand, std::bind(on_wait, ec));
|
||||
else {
|
||||
m_state.timers.throttle.in.wait_expire = false;
|
||||
start_read();
|
||||
}
|
||||
});
|
||||
return;
|
||||
}
|
||||
@@ -394,33 +385,7 @@ namespace net_utils
|
||||
m_conn_context.m_recv_cnt += bytes_transferred;
|
||||
start_timer(get_timeout_from_bytes_read(bytes_transferred), true);
|
||||
}
|
||||
|
||||
// Post handle_recv to a separate `strand_`, distinct from `m_strand`
|
||||
// which is listening for reads/writes. This avoids a circular dep.
|
||||
// handle_recv can queue many writes, and `m_strand` will process those
|
||||
// writes until the connection terminates without deadlocking waiting
|
||||
// for handle_recv.
|
||||
m_state.socket.handle_read = true;
|
||||
boost::asio::post(
|
||||
connection_basic::strand_,
|
||||
[this, self, bytes_transferred]{
|
||||
bool success = m_handler.handle_recv(
|
||||
reinterpret_cast<char *>(m_state.data.read.buffer.data()),
|
||||
bytes_transferred
|
||||
);
|
||||
std::lock_guard<std::mutex> guard(m_state.lock);
|
||||
m_state.socket.handle_read = false;
|
||||
if (m_state.status == status_t::INTERRUPTED)
|
||||
on_interrupted();
|
||||
else if (m_state.status == status_t::TERMINATING)
|
||||
on_terminating();
|
||||
else if (!success)
|
||||
interrupt();
|
||||
else {
|
||||
start_read();
|
||||
}
|
||||
}
|
||||
);
|
||||
finish_read(bytes_transferred);
|
||||
}
|
||||
};
|
||||
if (!m_state.ssl.enabled)
|
||||
@@ -446,6 +411,62 @@ namespace net_utils
|
||||
);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void connection<T>::finish_read(size_t bytes_transferred)
|
||||
{
|
||||
// Post handle_recv to a separate `strand_`, distinct from `m_strand`
|
||||
// which is listening for reads/writes. This avoids a circular dep.
|
||||
// handle_recv can queue many writes, and `m_strand` will process those
|
||||
// writes until the connection terminates without deadlocking waiting
|
||||
// for handle_recv.
|
||||
m_state.socket.handle_read = true;
|
||||
auto self = connection<T>::shared_from_this();
|
||||
boost::asio::post(
|
||||
connection_basic::strand_,
|
||||
[this, self, bytes_transferred]{
|
||||
bool success = m_handler.handle_recv(
|
||||
reinterpret_cast<char *>(m_state.data.read.buffer.data()),
|
||||
bytes_transferred
|
||||
);
|
||||
std::lock_guard<std::mutex> guard(m_state.lock);
|
||||
const bool error_status = m_state.status == status_t::INTERRUPTED
|
||||
|| m_state.status == status_t::TERMINATING
|
||||
|| !success;
|
||||
if (!error_status) {
|
||||
m_state.socket.handle_read = false;
|
||||
start_read();
|
||||
return;
|
||||
}
|
||||
boost::asio::post(
|
||||
m_strand,
|
||||
[this, self, success]{
|
||||
// expect error_status == true
|
||||
std::lock_guard<std::mutex> guard(m_state.lock);
|
||||
m_state.socket.handle_read = false;
|
||||
if (m_state.status == status_t::INTERRUPTED)
|
||||
on_interrupted();
|
||||
else if (m_state.status == status_t::TERMINATING)
|
||||
on_terminating();
|
||||
else if (!success) {
|
||||
ec_t ec;
|
||||
if (m_state.socket.wait_write) {
|
||||
// Allow the already queued writes time to finish, but no more new reads
|
||||
connection_basic::socket_.next_layer().shutdown(
|
||||
socket_t::shutdown_receive,
|
||||
ec
|
||||
);
|
||||
m_state.socket.shutdown_read = true;
|
||||
}
|
||||
if (!m_state.socket.wait_write || ec.value()) {
|
||||
interrupt();
|
||||
}
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void connection<T>::start_write()
|
||||
{
|
||||
@@ -477,7 +498,7 @@ namespace net_utils
|
||||
if (duration > duration_t{}) {
|
||||
m_timers.throttle.out.expires_after(duration);
|
||||
m_state.timers.throttle.out.wait_expire = true;
|
||||
m_timers.throttle.out.async_wait([this, self](const ec_t &ec){
|
||||
auto on_wait = [this, self](const ec_t &ec){
|
||||
std::lock_guard<std::mutex> guard(m_state.lock);
|
||||
m_state.timers.throttle.out.wait_expire = false;
|
||||
if (m_state.timers.throttle.out.cancel_expire) {
|
||||
@@ -486,8 +507,16 @@ namespace net_utils
|
||||
}
|
||||
else if (ec.value())
|
||||
interrupt();
|
||||
else
|
||||
};
|
||||
m_timers.throttle.out.async_wait([this, self, on_wait](const ec_t &ec){
|
||||
std::lock_guard<std::mutex> guard(m_state.lock);
|
||||
const bool error_status = m_state.timers.throttle.out.cancel_expire || ec.value();
|
||||
if (error_status)
|
||||
boost::asio::post(m_strand, std::bind(on_wait, ec));
|
||||
else {
|
||||
m_state.timers.throttle.out.wait_expire = false;
|
||||
start_write();
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -535,7 +564,12 @@ namespace net_utils
|
||||
m_state.data.write.total_bytes -=
|
||||
std::min(m_state.data.write.total_bytes, byte_count);
|
||||
m_state.condition.notify_all();
|
||||
start_write();
|
||||
if (m_state.data.write.queue.empty() && m_state.socket.shutdown_read) {
|
||||
// All writes have been sent and reads shutdown already, connection can be closed
|
||||
interrupt();
|
||||
} else {
|
||||
start_write();
|
||||
}
|
||||
}
|
||||
};
|
||||
if (!m_state.ssl.enabled)
|
||||
@@ -764,6 +798,17 @@ namespace net_utils
|
||||
m_state.status = status_t::WASTED;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void connection<T>::terminate_async()
|
||||
{
|
||||
// synchronize with intermediate writes on `m_strand`
|
||||
auto self = connection<T>::shared_from_this();
|
||||
boost::asio::post(m_strand, [this, self] {
|
||||
std::lock_guard<std::mutex> guard(m_state.lock);
|
||||
terminate();
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
bool connection<T>::send(epee::byte_slice message)
|
||||
{
|
||||
@@ -816,12 +861,7 @@ namespace net_utils
|
||||
);
|
||||
m_state.data.write.wait_consume = false;
|
||||
if (!success) {
|
||||
// synchronize with intermediate writes on `m_strand`
|
||||
auto self = connection<T>::shared_from_this();
|
||||
boost::asio::post(m_strand, [this, self] {
|
||||
std::lock_guard<std::mutex> guard(m_state.lock);
|
||||
terminate();
|
||||
});
|
||||
terminate_async();
|
||||
return false;
|
||||
}
|
||||
else
|
||||
@@ -1095,7 +1135,7 @@ namespace net_utils
|
||||
std::lock_guard<std::mutex> guard(m_state.lock);
|
||||
if (m_state.status != status_t::RUNNING)
|
||||
return false;
|
||||
terminate();
|
||||
terminate_async();
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@@ -202,3 +202,71 @@ TEST(http_server, private_ip_limit)
|
||||
failed |= bool(error);
|
||||
EXPECT_TRUE(failed);
|
||||
}
|
||||
|
||||
TEST(http_server, read_then_close)
|
||||
{
|
||||
namespace http = boost::beast::http;
|
||||
|
||||
http_server server{};
|
||||
server.dummy_size = 200000;
|
||||
server.init(nullptr, "8080");
|
||||
server.run(2, false); // need at least 2 threads to trigger issues
|
||||
|
||||
bool failed_read = false;
|
||||
bool closed_all_connections = true;
|
||||
for (std::size_t j = 0; j < 1000; ++j)
|
||||
{
|
||||
boost::system::error_code error{};
|
||||
boost::asio::io_context context{};
|
||||
boost::asio::ip::tcp::socket stream{context};
|
||||
stream.connect(
|
||||
boost::asio::ip::tcp::endpoint{
|
||||
boost::asio::ip::make_address("127.0.0.1"), 8080
|
||||
},
|
||||
error
|
||||
);
|
||||
EXPECT_FALSE(bool(error));
|
||||
|
||||
http::request<http::string_body> req{http::verb::get, "/dummy", 11};
|
||||
req.set(http::field::host, "127.0.0.1");
|
||||
req.set(http::field::user_agent, BOOST_BEAST_VERSION_STRING);
|
||||
req.set(http::field::connection, "close"); // tell server to close connection after sending all data to the client
|
||||
req.body() = make_payload();
|
||||
req.prepare_payload();
|
||||
|
||||
dummy::response payload{};
|
||||
boost::beast::flat_buffer buffer;
|
||||
http::response_parser<http::basic_string_body<char>> parser;
|
||||
parser.body_limit(server.dummy_size + 1024);
|
||||
|
||||
http::write(stream, req, error);
|
||||
EXPECT_FALSE(bool(error));
|
||||
|
||||
http::read(stream, buffer, parser, error);
|
||||
|
||||
// If the read fails, continue the loop still just to make sure the server can handle it
|
||||
failed_read |= bool(error);
|
||||
if (failed_read)
|
||||
continue;
|
||||
failed_read |= !(parser.is_done());
|
||||
if (failed_read)
|
||||
continue;
|
||||
const auto res = parser.release();
|
||||
failed_read |= res.result_int() != 200u
|
||||
|| !(epee::serialization::load_t_from_binary(payload, res.body()))
|
||||
|| (server.dummy_size != std::count(payload.payload.begin(), payload.payload.end(), 'f'));
|
||||
|
||||
// See if the server closes the connection after handling the resp
|
||||
char buf[1];
|
||||
stream.read_some(boost::asio::buffer(buf), error);
|
||||
closed_all_connections &= error == boost::asio::error::eof;
|
||||
}
|
||||
|
||||
// The client should have been able to read all data sent by the server across all requests
|
||||
EXPECT_FALSE(failed_read);
|
||||
|
||||
// The server should have closed all connections
|
||||
EXPECT_TRUE(closed_all_connections);
|
||||
|
||||
server.send_stop_signal();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user