diff options
author | George Hazan <ghazan@miranda.im> | 2022-11-30 17:48:47 +0300 |
---|---|---|
committer | George Hazan <ghazan@miranda.im> | 2022-11-30 17:48:47 +0300 |
commit | 0ece30dc7c0e34b4c5911969b8fa99c33c6d023c (patch) | |
tree | 671325d3fec09b999411e4e3ab84ef8259261818 /protocols/Telegram/tdlib/td/tdnet | |
parent | 46c53ffc6809c67e4607e99951a2846c382b63b2 (diff) |
Telegram: update for TDLIB
Diffstat (limited to 'protocols/Telegram/tdlib/td/tdnet')
39 files changed, 2153 insertions, 1074 deletions
diff --git a/protocols/Telegram/tdlib/td/tdnet/CMakeLists.txt b/protocols/Telegram/tdlib/td/tdnet/CMakeLists.txt index 823ed027d6..e14f3500f9 100644 --- a/protocols/Telegram/tdlib/td/tdnet/CMakeLists.txt +++ b/protocols/Telegram/tdlib/td/tdnet/CMakeLists.txt @@ -1,4 +1,10 @@ -cmake_minimum_required(VERSION 3.0.2 FATAL_ERROR) +if ((CMAKE_MAJOR_VERSION LESS 3) OR (CMAKE_VERSION VERSION_LESS "3.0.2")) + message(FATAL_ERROR "CMake >= 3.0.2 is required") +endif() + +if (NOT DEFINED CMAKE_INSTALL_LIBDIR) + set(CMAKE_INSTALL_LIBDIR "lib") +endif() if (NOT OPENSSL_FOUND) find_package(OpenSSL REQUIRED) @@ -14,11 +20,14 @@ set(TDNET_SOURCE td/net/HttpFile.cpp td/net/HttpInboundConnection.cpp td/net/HttpOutboundConnection.cpp + td/net/HttpProxy.cpp td/net/HttpQuery.cpp td/net/HttpReader.cpp td/net/Socks5.cpp - td/net/SslFd.cpp + td/net/SslCtx.cpp + td/net/SslStream.cpp td/net/TcpListener.cpp + td/net/TransparentProxy.cpp td/net/Wget.cpp td/net/GetHostByNameActor.h @@ -29,26 +38,53 @@ set(TDNET_SOURCE td/net/HttpHeaderCreator.h td/net/HttpInboundConnection.h td/net/HttpOutboundConnection.h + td/net/HttpProxy.h td/net/HttpQuery.h td/net/HttpReader.h td/net/NetStats.h td/net/Socks5.h - td/net/SslFd.h + td/net/SslCtx.h + td/net/SslStream.h td/net/TcpListener.h + td/net/TransparentProxy.h td/net/Wget.h ) +if (APPLE_WATCH) + set(TDNET_SOURCE + ${TDNET_SOURCE} + td/net/DarwinHttp.mm + td/net/DarwinHttp.h + ) + set_source_files_properties(td/net/DarwinHttp.mm PROPERTIES COMPILE_FLAGS -fobjc-arc) +endif() + #RULES #LIBRARIES add_library(tdnet STATIC ${TDNET_SOURCE}) target_include_directories(tdnet PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>) -target_include_directories(tdnet SYSTEM PUBLIC $<BUILD_INTERFACE:${OPENSSL_INCLUDE_DIR}>) -target_link_libraries(tdnet PUBLIC tdutils tdactor ${OPENSSL_LIBRARIES} PRIVATE ${CMAKE_DL_LIBS} ${ZLIB_LIBRARIES}) +target_include_directories(tdnet SYSTEM PRIVATE $<BUILD_INTERFACE:${OPENSSL_INCLUDE_DIR}>) +target_link_libraries(tdnet PUBLIC tdutils tdactor) +if (NOT EMSCRIPTEN) + target_link_libraries(tdnet PRIVATE ${OPENSSL_SSL_LIBRARY}) +endif() +target_link_libraries(tdnet PRIVATE ${OPENSSL_CRYPTO_LIBRARY} ${CMAKE_DL_LIBS} ${ZLIB_LIBRARIES}) + +if (WIN32) + if (MINGW) + target_link_libraries(tdnet PRIVATE ws2_32 mswsock crypt32) + else() + target_link_libraries(tdnet PRIVATE ws2_32 Mswsock Crypt32) + endif() +endif() + +if (APPLE_WATCH) + find_library(FOUNDATION_LIBRARY Foundation REQUIRED) + target_link_libraries(tdnet PRIVATE ${FOUNDATION_LIBRARY}) +endif() install(TARGETS tdnet EXPORT TdTargets - LIBRARY DESTINATION lib - ARCHIVE DESTINATION lib - RUNTIME DESTINATION bin - INCLUDES DESTINATION include + LIBRARY DESTINATION "${CMAKE_INSTALL_LIBDIR}" + ARCHIVE DESTINATION "${CMAKE_INSTALL_LIBDIR}" ) diff --git a/protocols/Telegram/tdlib/td/tdnet/td/net/DarwinHttp.h b/protocols/Telegram/tdlib/td/tdnet/td/net/DarwinHttp.h new file mode 100644 index 0000000000..42c0a5b000 --- /dev/null +++ b/protocols/Telegram/tdlib/td/tdnet/td/net/DarwinHttp.h @@ -0,0 +1,21 @@ +// +// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022 +// +// Distributed under the Boost Software License, Version 1.0. (See accompanying +// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) +// +#pragma once + +#include "td/utils/buffer.h" +#include "td/utils/Promise.h" +#include "td/utils/Slice.h" + +namespace td { + +class DarwinHttp { + public: + static void get(CSlice url, Promise<BufferSlice> promise); + static void post(CSlice url, Slice data, Promise<BufferSlice> promise); +}; + +} // namespace td diff --git a/protocols/Telegram/tdlib/td/tdnet/td/net/DarwinHttp.mm b/protocols/Telegram/tdlib/td/tdnet/td/net/DarwinHttp.mm new file mode 100644 index 0000000000..f8e285433f --- /dev/null +++ b/protocols/Telegram/tdlib/td/tdnet/td/net/DarwinHttp.mm @@ -0,0 +1,68 @@ +// +// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022 +// +// Distributed under the Boost Software License, Version 1.0. (See accompanying +// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) +// +#include "td/net/DarwinHttp.h" + +#include "td/utils/logging.h" +#include "td/utils/SliceBuilder.h" + +#import <Foundation/Foundation.h> + +namespace td { + +namespace { +NSString *to_ns_string(CSlice slice) { + return [NSString stringWithUTF8String:slice.c_str()]; +} + +NSData *to_ns_data(Slice data) { + return [NSData dataWithBytes:static_cast<const void *>(data.data()) length:data.size()]; +} + +auto http_get(CSlice url) { + auto nsurl = [NSURL URLWithString:to_ns_string(url)]; + auto request = [NSURLRequest requestWithURL:nsurl]; + return request; +} + +auto http_post(CSlice url, Slice data) { + auto nsurl = [NSURL URLWithString:to_ns_string(url)]; + auto request = [NSMutableURLRequest requestWithURL:nsurl]; + [request setHTTPMethod:@"POST"]; + [request setHTTPBody:to_ns_data(data)]; + [request setValue:@"keep-alive" forHTTPHeaderField:@"Connection"]; + [request setValue:@"" forHTTPHeaderField:@"Host"]; + [request setValue:to_ns_string(PSLICE() << data.size()) forHTTPHeaderField:@"Content-Length"]; + [request setValue:@"application/x-www-form-urlencoded" forHTTPHeaderField:@"Content-Type"]; + return request; +} + +void http_send(NSURLRequest *request, Promise<BufferSlice> promise) { + __block auto callback = std::move(promise); + NSURLSessionDataTask *dataTask = + [NSURLSession.sharedSession + dataTaskWithRequest:request + completionHandler: + ^(NSData *data, NSURLResponse *response, NSError *error) { + if (error == nil) { + callback.set_value(BufferSlice(Slice((const char *)([data bytes]), [data length]))); + } else { + callback.set_error(Status::Error(static_cast<int32>([error code]), "HTTP request failed")); + } + }]; + [dataTask resume]; +} +} // namespace + +void DarwinHttp::get(CSlice url, Promise<BufferSlice> promise) { + return http_send(http_get(url), std::move(promise)); +} + +void DarwinHttp::post(CSlice url, Slice data, Promise<BufferSlice> promise) { + return http_send(http_post(url, data), std::move(promise)); +} + +} // namespace td diff --git a/protocols/Telegram/tdlib/td/tdnet/td/net/GetHostByNameActor.cpp b/protocols/Telegram/tdlib/td/tdnet/td/net/GetHostByNameActor.cpp index b6cdcca0f0..8a720083c0 100644 --- a/protocols/Telegram/tdlib/td/tdnet/td/net/GetHostByNameActor.cpp +++ b/protocols/Telegram/tdlib/td/tdnet/td/net/GetHostByNameActor.cpp @@ -1,48 +1,216 @@ // -// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018 +// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022 // // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) // #include "td/net/GetHostByNameActor.h" +#include "td/net/HttpQuery.h" +#include "td/net/SslCtx.h" +#include "td/net/Wget.h" + +#include "td/utils/common.h" +#include "td/utils/JsonBuilder.h" #include "td/utils/logging.h" +#include "td/utils/misc.h" +#include "td/utils/Slice.h" +#include "td/utils/SliceBuilder.h" #include "td/utils/Time.h" namespace td { -GetHostByNameActor::GetHostByNameActor(int32 ok_timeout, int32 error_timeout) - : ok_timeout_(ok_timeout), error_timeout_(error_timeout) { +namespace detail { + +class GoogleDnsResolver final : public Actor { + public: + GoogleDnsResolver(std::string host, bool prefer_ipv6, Promise<IPAddress> promise) + : host_(std::move(host)), prefer_ipv6_(prefer_ipv6), promise_(std::move(promise)) { + } + + private: + std::string host_; + bool prefer_ipv6_; + Promise<IPAddress> promise_; + ActorOwn<Wget> wget_; + double begin_time_ = 0; + + void start_up() final { + auto r_address = IPAddress::get_ip_address(host_); + if (r_address.is_ok()) { + promise_.set_value(r_address.move_as_ok()); + return stop(); + } + + const int timeout = 10; + const int ttl = 3; + begin_time_ = Time::now(); + auto wget_promise = PromiseCreator::lambda([actor_id = actor_id(this)](Result<unique_ptr<HttpQuery>> r_http_query) { + send_closure(actor_id, &GoogleDnsResolver::on_result, std::move(r_http_query)); + }); + wget_ = create_actor<Wget>( + "GoogleDnsResolver", std::move(wget_promise), + PSTRING() << "https://dns.google/resolve?name=" << url_encode(host_) << "&type=" << (prefer_ipv6_ ? 28 : 1), + std::vector<std::pair<string, string>>({{"Host", "dns.google"}}), timeout, ttl, prefer_ipv6_, + SslCtx::VerifyPeer::Off); + } + + static Result<IPAddress> get_ip_address(Result<unique_ptr<HttpQuery>> r_http_query) { + TRY_RESULT(http_query, std::move(r_http_query)); + + auto get_ip_address = [](JsonValue &answer) -> Result<IPAddress> { + auto &array = answer.get_array(); + if (array.empty()) { + return Status::Error("Failed to parse DNS result: Answer is an empty array"); + } + if (array[0].type() != JsonValue::Type::Object) { + return Status::Error("Failed to parse DNS result: Answer[0] is not an object"); + } + auto &answer_0 = array[0].get_object(); + TRY_RESULT(ip_str, get_json_object_string_field(answer_0, "data", false)); + IPAddress ip; + TRY_STATUS(ip.init_host_port(ip_str, 0)); + return ip; + }; + if (!http_query->get_arg("Answer").empty()) { + TRY_RESULT(answer, json_decode(http_query->get_arg("Answer"))); + if (answer.type() != JsonValue::Type::Array) { + return Status::Error("Expected JSON array"); + } + return get_ip_address(answer); + } else { + TRY_RESULT(json_value, json_decode(http_query->content_)); + if (json_value.type() != JsonValue::Type::Object) { + return Status::Error("Failed to parse DNS result: not an object"); + } + TRY_RESULT(answer, get_json_object_field(json_value.get_object(), "Answer", JsonValue::Type::Array, false)); + return get_ip_address(answer); + } + } + + void on_result(Result<unique_ptr<HttpQuery>> r_http_query) { + auto end_time = Time::now(); + auto result = get_ip_address(std::move(r_http_query)); + VLOG(dns_resolver) << "Init IPv" << (prefer_ipv6_ ? "6" : "4") << " host = " << host_ << " in " + << end_time - begin_time_ << " seconds to " + << (result.is_ok() ? (PSLICE() << result.ok()) : CSlice("[invalid]")); + promise_.set_result(std::move(result)); + stop(); + } +}; + +class NativeDnsResolver final : public Actor { + public: + NativeDnsResolver(std::string host, bool prefer_ipv6, Promise<IPAddress> promise) + : host_(std::move(host)), prefer_ipv6_(prefer_ipv6), promise_(std::move(promise)) { + } + + private: + std::string host_; + bool prefer_ipv6_; + Promise<IPAddress> promise_; + + void start_up() final { + IPAddress ip; + auto begin_time = Time::now(); + auto status = ip.init_host_port(host_, 0, prefer_ipv6_); + auto end_time = Time::now(); + VLOG(dns_resolver) << "Init host = " << host_ << " in " << end_time - begin_time << " seconds to " << ip; + if (status.is_error()) { + promise_.set_error(std::move(status)); + } else { + promise_.set_value(std::move(ip)); + } + stop(); + } +}; + +} // namespace detail + +int VERBOSITY_NAME(dns_resolver) = VERBOSITY_NAME(DEBUG); + +GetHostByNameActor::GetHostByNameActor(Options options) : options_(std::move(options)) { + CHECK(!options_.resolver_types.empty()); } -void GetHostByNameActor::run(std::string host, int port, td::Promise<td::IPAddress> promise) { - auto r_ip = load_ip(std::move(host), port); - promise.set_result(std::move(r_ip)); +void GetHostByNameActor::run(string host, int port, bool prefer_ipv6, Promise<IPAddress> promise) { + auto r_ascii_host = idn_to_ascii(host); + if (r_ascii_host.is_error()) { + return promise.set_error(r_ascii_host.move_as_error()); + } + auto ascii_host = r_ascii_host.move_as_ok(); + if (ascii_host.empty()) { + return promise.set_error(Status::Error("Host is empty")); + } + + auto begin_time = Time::now(); + auto &value = cache_[prefer_ipv6].emplace(ascii_host, Value{{}, begin_time - 1.0}).first->second; + if (value.expires_at > begin_time) { + return promise.set_result(value.get_ip_port(port)); + } + + auto &query_ptr = active_queries_[prefer_ipv6][ascii_host]; + if (query_ptr == nullptr) { + query_ptr = make_unique<Query>(); + } + auto &query = *query_ptr; + query.promises.emplace_back(port, std::move(promise)); + if (query.query.empty()) { + CHECK(query.promises.size() == 1); + query.real_host = std::move(host); + query.begin_time = Time::now(); + run_query(std::move(ascii_host), prefer_ipv6, query); + } } -Result<td::IPAddress> GetHostByNameActor::load_ip(string host, int port) { - auto &value = cache_.emplace(host, Value{{}, 0}).first->second; - auto begin_time = td::Time::now(); - if (value.expire_at > begin_time) { - auto ip = value.ip.clone(); - if (ip.is_ok()) { - ip.ok_ref().set_port(port); - CHECK(ip.ok().get_port() == port); +void GetHostByNameActor::run_query(std::string host, bool prefer_ipv6, Query &query) { + auto promise = PromiseCreator::lambda([actor_id = actor_id(this), host, prefer_ipv6](Result<IPAddress> res) mutable { + send_closure(actor_id, &GetHostByNameActor::on_query_result, std::move(host), prefer_ipv6, std::move(res)); + }); + + CHECK(query.query.empty()); + CHECK(query.pos < options_.resolver_types.size()); + auto resolver_type = options_.resolver_types[query.pos++]; + query.query = [&] { + switch (resolver_type) { + case ResolverType::Native: + return ActorOwn<>(create_actor_on_scheduler<detail::NativeDnsResolver>( + "NativeDnsResolver", options_.scheduler_id, std::move(host), prefer_ipv6, std::move(promise))); + case ResolverType::Google: + return ActorOwn<>(create_actor_on_scheduler<detail::GoogleDnsResolver>( + "GoogleDnsResolver", options_.scheduler_id, std::move(host), prefer_ipv6, std::move(promise))); + default: + UNREACHABLE(); + return ActorOwn<>(); } - return ip; + }(); +} + +void GetHostByNameActor::on_query_result(std::string host, bool prefer_ipv6, Result<IPAddress> result) { + auto query_it = active_queries_[prefer_ipv6].find(host); + CHECK(query_it != active_queries_[prefer_ipv6].end()); + auto &query = *query_it->second; + CHECK(!query.promises.empty()); + CHECK(!query.query.empty()); + + if (result.is_error() && query.pos < options_.resolver_types.size()) { + query.query.reset(); + return run_query(std::move(host), prefer_ipv6, query); } - td::IPAddress ip; - auto status = ip.init_host_port(host, port); - auto end_time = td::Time::now(); - LOG(WARNING) << "Init host = " << host << ", port = " << port << " in " << end_time - begin_time << " seconds to " - << ip; + auto end_time = Time::now(); + VLOG(dns_resolver) << "Init host = " << query.real_host << " in total of " << end_time - query.begin_time + << " seconds to " << (result.is_ok() ? (PSLICE() << result.ok()) : CSlice("[invalid]")); - if (status.is_ok()) { - value = Value{ip, end_time + ok_timeout_}; - return ip; - } else { - value = Value{status.clone(), end_time + error_timeout_}; - return std::move(status); + auto promises = std::move(query.promises); + auto value_it = cache_[prefer_ipv6].find(host); + CHECK(value_it != cache_[prefer_ipv6].end()); + auto cache_timeout = result.is_ok() ? options_.ok_timeout : options_.error_timeout; + value_it->second = Value{std::move(result), end_time + cache_timeout}; + active_queries_[prefer_ipv6].erase(query_it); + + for (auto &promise : promises) { + promise.second.set_result(value_it->second.get_ip_port(promise.first)); } } + } // namespace td diff --git a/protocols/Telegram/tdlib/td/tdnet/td/net/GetHostByNameActor.h b/protocols/Telegram/tdlib/td/tdnet/td/net/GetHostByNameActor.h index b352a05d18..f864b13e3c 100644 --- a/protocols/Telegram/tdlib/td/tdnet/td/net/GetHostByNameActor.h +++ b/protocols/Telegram/tdlib/td/tdnet/td/net/GetHostByNameActor.h @@ -1,35 +1,75 @@ // -// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018 +// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022 // // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) // #pragma once + #include "td/actor/actor.h" -#include "td/actor/PromiseFuture.h" +#include "td/utils/FlatHashMap.h" +#include "td/utils/logging.h" #include "td/utils/port/IPAddress.h" +#include "td/utils/Promise.h" #include "td/utils/Status.h" -#include <unordered_map> +#include <utility> namespace td { -class GetHostByNameActor final : public td::Actor { + +extern int VERBOSITY_NAME(dns_resolver); + +class GetHostByNameActor final : public Actor { public: - explicit GetHostByNameActor(int32 ok_timeout = CACHE_TIME, int32 error_timeout = ERROR_CACHE_TIME); - void run(std::string host, int port, td::Promise<td::IPAddress> promise); + enum class ResolverType { Native, Google }; + + struct Options { + static constexpr int32 DEFAULT_CACHE_TIME = 60 * 29; // 29 minutes + static constexpr int32 DEFAULT_ERROR_CACHE_TIME = 60 * 5; // 5 minutes + + vector<ResolverType> resolver_types{ResolverType::Native}; + int32 scheduler_id{-1}; + int32 ok_timeout{DEFAULT_CACHE_TIME}; + int32 error_timeout{DEFAULT_ERROR_CACHE_TIME}; + }; + + explicit GetHostByNameActor(Options options); + + void run(std::string host, int port, bool prefer_ipv6, Promise<IPAddress> promise); private: + void on_query_result(std::string host, bool prefer_ipv6, Result<IPAddress> result); + struct Value { - Result<td::IPAddress> ip; - double expire_at; + Result<IPAddress> ip; + double expires_at; + + Value(Result<IPAddress> ip, double expires_at) : ip(std::move(ip)), expires_at(expires_at) { + } + + Result<IPAddress> get_ip_port(int port) const { + auto result = ip.clone(); + if (result.is_ok()) { + result.ok_ref().set_port(port); + } + return result; + } }; - std::unordered_map<string, Value> cache_; - static constexpr int32 CACHE_TIME = 60 * 29; // 29 minutes - static constexpr int32 ERROR_CACHE_TIME = 60 * 5; // 5 minutes + FlatHashMap<string, Value> cache_[2]; - int32 ok_timeout_; - int32 error_timeout_; - Result<td::IPAddress> load_ip(string host, int port) TD_WARN_UNUSED_RESULT; + struct Query { + ActorOwn<> query; + size_t pos = 0; + string real_host; + double begin_time = 0.0; + std::vector<std::pair<int, Promise<IPAddress>>> promises; + }; + FlatHashMap<string, unique_ptr<Query>> active_queries_[2]; + + Options options_; + + void run_query(std::string host, bool prefer_ipv6, Query &query); }; + } // namespace td diff --git a/protocols/Telegram/tdlib/td/tdnet/td/net/HttpChunkedByteFlow.cpp b/protocols/Telegram/tdlib/td/tdnet/td/net/HttpChunkedByteFlow.cpp index 2edd225bfa..259e1fafbd 100644 --- a/protocols/Telegram/tdlib/td/tdnet/td/net/HttpChunkedByteFlow.cpp +++ b/protocols/Telegram/tdlib/td/tdnet/td/net/HttpChunkedByteFlow.cpp @@ -1,5 +1,5 @@ // -// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018 +// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022 // // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) @@ -8,76 +8,77 @@ #include "td/utils/find_boundary.h" #include "td/utils/format.h" -#include "td/utils/logging.h" #include "td/utils/misc.h" +#include "td/utils/SliceBuilder.h" #include "td/utils/Status.h" namespace td { -void HttpChunkedByteFlow::loop() { - bool was_updated = false; - size_t need_size; - while (true) { - if (state_ == ReadChunkLength) { +bool HttpChunkedByteFlow::loop() { + bool result = false; + do { + if (state_ == State::ReadChunkLength) { bool ok = find_boundary(input_->clone(), "\r\n", len_); if (len_ > 10) { - return finish(Status::Error(PSLICE() << "Too long length in chunked " - << input_->cut_head(len_).move_as_buffer_slice().as_slice())); + finish(Status::Error(PSLICE() << "Too long length in chunked " + << input_->cut_head(len_).move_as_buffer_slice().as_slice())); + return false; } if (!ok) { - need_size = input_->size() + 1; + set_need_size(input_->size() + 1); break; } auto s_len = input_->cut_head(len_).move_as_buffer_slice(); input_->advance(2); len_ = hex_to_integer<size_t>(s_len.as_slice()); if (len_ > MAX_CHUNK_SIZE) { - return finish(Status::Error(PSLICE() << "Invalid chunk size " << tag("size", len_))); + finish(Status::Error(PSLICE() << "Invalid chunk size " << tag("size", len_))); + return false; } save_len_ = len_; - state_ = ReadChunkContent; + state_ = State::ReadChunkContent; } auto size = input_->size(); auto ready = min(len_, size); - need_size = min(MIN_UPDATE_SIZE, len_ + 2); + auto need_size = min(MIN_UPDATE_SIZE, len_ + 2); if (size < need_size) { + set_need_size(need_size); break; } - total_size_ += ready; - uncommited_size_ += ready; - if (total_size_ > MAX_SIZE) { - return finish(Status::Error(PSLICE() << "Too big query " << tag("size", input_->size()))); + if (total_size_ > MAX_SIZE - ready) { + finish(Status::Error(PSLICE() << "Too big query " << tag("size", input_->size()))); + return false; } + total_size_ += ready; + uncommitted_size_ += ready; output_.append(input_->cut_head(ready)); + result = true; len_ -= ready; - if (uncommited_size_ >= MIN_UPDATE_SIZE) { - uncommited_size_ = 0; - was_updated = true; + if (uncommitted_size_ >= MIN_UPDATE_SIZE) { + uncommitted_size_ = 0; } if (len_ == 0) { if (input_->size() < 2) { - need_size = 2; + set_need_size(2); break; } - input_->cut_head(2); + input_->advance(2); total_size_ += 2; if (save_len_ == 0) { - return finish(Status::OK()); + finish(Status::OK()); + return false; } - state_ = ReadChunkLength; + state_ = State::ReadChunkLength; len_ = 0; } + } while (false); + if (!is_input_active_ && !result) { + finish(Status::Error("Unexpected end of stream")); } - if (was_updated) { - on_output_updated(); - } - if (!is_input_active_) { - return finish(Status::Error("Unexpected end of stream")); - } - set_need_size(need_size); + return result; } } // namespace td diff --git a/protocols/Telegram/tdlib/td/tdnet/td/net/HttpChunkedByteFlow.h b/protocols/Telegram/tdlib/td/tdnet/td/net/HttpChunkedByteFlow.h index 9c62c3368e..8e248ebb73 100644 --- a/protocols/Telegram/tdlib/td/tdnet/td/net/HttpChunkedByteFlow.h +++ b/protocols/Telegram/tdlib/td/tdnet/td/net/HttpChunkedByteFlow.h @@ -1,5 +1,5 @@ // -// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018 +// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022 // // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) @@ -8,21 +8,24 @@ #include "td/utils/ByteFlow.h" +#include <limits> + namespace td { class HttpChunkedByteFlow final : public ByteFlowBase { public: - void loop() override; + bool loop() final; private: - static constexpr int MAX_CHUNK_SIZE = 15 << 20; // some reasonable limit - static constexpr int MAX_SIZE = 150 << 20; // some reasonable limit + static constexpr size_t MAX_CHUNK_SIZE = 15 << 20; // some reasonable limit + static constexpr size_t MAX_SIZE = std::numeric_limits<uint32>::max(); // some reasonable limit static constexpr size_t MIN_UPDATE_SIZE = 1 << 14; - enum { ReadChunkLength, ReadChunkContent, OK } state_ = ReadChunkLength; + enum class State { ReadChunkLength, ReadChunkContent, OK }; + State state_ = State::ReadChunkLength; size_t len_ = 0; - size_t save_len_; + size_t save_len_ = 0; size_t total_size_ = 0; - size_t uncommited_size_ = 0; + size_t uncommitted_size_ = 0; }; } // namespace td diff --git a/protocols/Telegram/tdlib/td/tdnet/td/net/HttpConnectionBase.cpp b/protocols/Telegram/tdlib/td/tdnet/td/net/HttpConnectionBase.cpp index 087ee5b790..c1f630ea1a 100644 --- a/protocols/Telegram/tdlib/td/tdnet/td/net/HttpConnectionBase.cpp +++ b/protocols/Telegram/tdlib/td/tdnet/td/net/HttpConnectionBase.cpp @@ -1,29 +1,40 @@ // -// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018 +// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022 // // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) // #include "td/net/HttpConnectionBase.h" -#include "td/actor/actor.h" - #include "td/net/HttpHeaderCreator.h" +#include "td/utils/common.h" #include "td/utils/logging.h" #include "td/utils/misc.h" +#include "td/utils/port/detail/PollableFd.h" namespace td { namespace detail { -HttpConnectionBase::HttpConnectionBase(State state, FdProxy fd, size_t max_post_size, size_t max_files, - int32 idle_timeout) +HttpConnectionBase::HttpConnectionBase(State state, BufferedFd<SocketFd> fd, SslStream ssl_stream, size_t max_post_size, + size_t max_files, int32 idle_timeout, int32 slow_scheduler_id) : state_(state) - , stream_connection_(std::move(fd)) + , fd_(std::move(fd)) + , ssl_stream_(std::move(ssl_stream)) , max_post_size_(max_post_size) , max_files_(max_files) - , idle_timeout_(idle_timeout) { + , idle_timeout_(idle_timeout) + , slow_scheduler_id_(slow_scheduler_id) { CHECK(state_ != State::Close); + + if (ssl_stream_) { + read_source_ >> ssl_stream_.read_byte_flow() >> read_sink_; + write_source_ >> ssl_stream_.write_byte_flow() >> write_sink_; + } else { + read_source_ >> read_sink_; + write_source_ >> write_sink_; + } + peer_address_.init_peer_address(fd_).ignore(); } void HttpConnectionBase::live_event() { @@ -33,9 +44,8 @@ void HttpConnectionBase::live_event() { } void HttpConnectionBase::start_up() { - stream_connection_.get_fd().set_observer(this); - subscribe(stream_connection_.get_fd()); - reader_.init(&stream_connection_.input_buffer(), max_post_size_, max_files_); + Scheduler::subscribe(fd_.get_poll_info().extract_pollable_fd(this)); + reader_.init(read_sink_.get_output(), max_post_size_, max_files_); if (state_ == State::Read) { current_query_ = make_unique<HttpQuery>(); } @@ -43,13 +53,16 @@ void HttpConnectionBase::start_up() { yield(); } void HttpConnectionBase::tear_down() { - unsubscribe_before_close(stream_connection_.get_fd()); - stream_connection_.close(); + Scheduler::unsubscribe_before_close(fd_.get_poll_info().get_pollable_fd_ref()); + fd_.close(); } -void HttpConnectionBase::write_next(BufferSlice buffer) { +void HttpConnectionBase::write_next_noflush(BufferSlice buffer) { CHECK(state_ == State::Write); - stream_connection_.output_buffer().append(std::move(buffer)); + write_buffer_.append(std::move(buffer)); +} +void HttpConnectionBase::write_next(BufferSlice buffer) { + write_next_noflush(std::move(buffer)); loop(); } @@ -63,7 +76,7 @@ void HttpConnectionBase::write_ok() { void HttpConnectionBase::write_error(Status error) { CHECK(state_ == State::Write); - LOG(WARNING) << "Close http connection: " << error; + LOG(WARNING) << "Close HTTP connection: " << error; state_ = State::Close; loop(); } @@ -71,7 +84,7 @@ void HttpConnectionBase::write_error(Status error) { void HttpConnectionBase::timeout_expired() { LOG(INFO) << "Idle timeout expired"; - if (stream_connection_.need_flush_write()) { + if (fd_.need_flush_write()) { on_error(Status::Error("Write timeout expired")); } else if (state_ == State::Read) { on_error(Status::Error("Read timeout expired")); @@ -80,74 +93,118 @@ void HttpConnectionBase::timeout_expired() { stop(); } void HttpConnectionBase::loop() { - if (can_read(stream_connection_)) { + if (ssl_stream_) { + //ssl_stream_.read_byte_flow().set_need_size(0); + ssl_stream_.write_byte_flow().reset_need_size(); + } + sync_with_poll(fd_); + if (can_read_local(fd_)) { LOG(DEBUG) << "Can read from the connection"; - auto r = stream_connection_.flush_read(); + auto r = fd_.flush_read(); if (r.is_error()) { if (!begins_with(r.error().message(), "SSL error {336134278")) { // if error is not yet outputed - LOG(INFO) << "flush_read error: " << r.error(); + LOG(INFO) << "Receive flush_read error: " << r.error(); } on_error(Status::Error(r.error().public_message())); return stop(); } } + read_source_.wakeup(); // TODO: read_next even when state_ == State::Write bool want_read = false; + bool can_be_slow = slow_scheduler_id_ == -1; if (state_ == State::Read) { - auto res = reader_.read_next(current_query_.get()); + auto res = reader_.read_next(current_query_.get(), can_be_slow); if (res.is_error()) { + if (res.error().message() == "SLOW") { + LOG(INFO) << "Slow HTTP connection: migrate to " << slow_scheduler_id_; + CHECK(!can_be_slow); + yield(); + migrate(slow_scheduler_id_); + slow_scheduler_id_ = -1; + return; + } live_event(); state_ = State::Write; - LOG(INFO) << res.error(); + if (res.error().code() == 500) { + LOG(WARNING) << "Failed to process an HTTP query: " << res.error(); + } else { + LOG(INFO) << res.error(); + } HttpHeaderCreator hc; hc.init_status_line(res.error().code()); hc.set_content_size(0); - stream_connection_.output_buffer().append(hc.finish().ok()); + write_buffer_.append(hc.finish().ok()); close_after_write_ = true; on_error(Status::Error(res.error().public_message())); } else if (res.ok() == 0) { state_ = State::Write; - LOG(INFO) << "Send query to handler"; + LOG(DEBUG) << "Send query to handler"; live_event(); + current_query_->peer_address_ = peer_address_; on_query(std::move(current_query_)); } else { want_read = true; } } - if (can_write(stream_connection_)) { + write_source_.wakeup(); + + if (can_write_local(fd_)) { LOG(DEBUG) << "Can write to the connection"; - auto r = stream_connection_.flush_write(); + auto r = fd_.flush_write(); if (r.is_error()) { - LOG(INFO) << "flush_write error: " << r.error(); + LOG(INFO) << "Receive flush_write error: " << r.error(); on_error(Status::Error(r.error().public_message())); } - if (close_after_write_ && !stream_connection_.need_flush_write()) { + if (close_after_write_ && !fd_.need_flush_write()) { return stop(); } } - if (stream_connection_.get_fd().has_pending_error()) { - auto pending_error = stream_connection_.get_pending_error(); + Status pending_error; + if (fd_.get_poll_info().get_flags_local().has_pending_error()) { + pending_error = fd_.get_pending_error(); + } + if (pending_error.is_ok() && write_sink_.status().is_error()) { + pending_error = std::move(write_sink_.status()); + } + if (pending_error.is_ok() && read_sink_.status().is_error()) { + pending_error = std::move(read_sink_.status()); + } + if (pending_error.is_error()) { LOG(INFO) << pending_error; if (!close_after_write_) { on_error(Status::Error(pending_error.public_message())); } state_ = State::Close; } - if (can_close(stream_connection_)) { - LOG(INFO) << "Can close the connection"; + + if (can_close_local(fd_)) { + LOG(DEBUG) << "Can close the connection"; state_ = State::Close; } if (state_ == State::Close) { - LOG_IF(INFO, stream_connection_.need_flush_write()) << "Close nonempty connection"; - LOG_IF(INFO, want_read && - (stream_connection_.input_buffer().size() > 0 || current_query_->type_ != HttpQuery::Type::EMPTY)) - << "Close connection while reading request/response"; + if (fd_.need_flush_write()) { + LOG(INFO) << "Close nonempty connection"; + } + if (want_read && (!fd_.input_buffer().empty() || current_query_->type_ != HttpQuery::Type::Empty)) { + LOG(INFO) << "Close connection while reading request/response"; + } return stop(); } } + +void HttpConnectionBase::on_start_migrate(int32 sched_id) { + Scheduler::unsubscribe(fd_.get_poll_info().get_pollable_fd_ref()); +} + +void HttpConnectionBase::on_finish_migrate() { + Scheduler::subscribe(fd_.get_poll_info().extract_pollable_fd(this)); + live_event(); +} + } // namespace detail } // namespace td diff --git a/protocols/Telegram/tdlib/td/tdnet/td/net/HttpConnectionBase.h b/protocols/Telegram/tdlib/td/tdnet/td/net/HttpConnectionBase.h index 1d420a3175..95373dc326 100644 --- a/protocols/Telegram/tdlib/td/tdnet/td/net/HttpConnectionBase.h +++ b/protocols/Telegram/tdlib/td/tdnet/td/net/HttpConnectionBase.h @@ -1,164 +1,77 @@ // -// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018 +// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022 // // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) // #pragma once -#include "td/actor/actor.h" - #include "td/net/HttpQuery.h" #include "td/net/HttpReader.h" +#include "td/net/SslStream.h" + +#include "td/actor/actor.h" #include "td/utils/buffer.h" #include "td/utils/BufferedFd.h" -#include "td/utils/port/Fd.h" -#include "td/utils/Slice.h" +#include "td/utils/ByteFlow.h" +#include "td/utils/port/IPAddress.h" +#include "td/utils/port/SocketFd.h" #include "td/utils/Status.h" namespace td { -class FdInterface { - public: - FdInterface() = default; - FdInterface(const FdInterface &) = delete; - FdInterface &operator=(const FdInterface &) = delete; - FdInterface(FdInterface &&) = default; - FdInterface &operator=(FdInterface &&) = default; - virtual ~FdInterface() = default; - virtual const Fd &get_fd() const = 0; - virtual Fd &get_fd() = 0; - virtual int32 get_flags() const = 0; - virtual Status get_pending_error() TD_WARN_UNUSED_RESULT = 0; - - virtual Result<size_t> write(Slice slice) TD_WARN_UNUSED_RESULT = 0; - virtual Result<size_t> read(MutableSlice slice) TD_WARN_UNUSED_RESULT = 0; - - virtual void close() = 0; - virtual bool empty() const = 0; -}; - -template <class FdT> -class FdToInterface : public FdInterface { - public: - FdToInterface() = default; - explicit FdToInterface(FdT fd) : fd_(std::move(fd)) { - } - - const Fd &get_fd() const final { - return fd_.get_fd(); - } - Fd &get_fd() final { - return fd_.get_fd(); - } - int32 get_flags() const final { - return fd_.get_flags(); - } - Status get_pending_error() final TD_WARN_UNUSED_RESULT { - return fd_.get_pending_error(); - } - - Result<size_t> write(Slice slice) final TD_WARN_UNUSED_RESULT { - return fd_.write(slice); - } - Result<size_t> read(MutableSlice slice) final TD_WARN_UNUSED_RESULT { - return fd_.read(slice); - } - - void close() final { - fd_.close(); - } - bool empty() const final { - return fd_.empty(); - } - - private: - FdT fd_; -}; - -template <class FdT> -std::unique_ptr<FdInterface> make_fd_interface(FdT fd) { - return make_unique<FdToInterface<FdT>>(std::move(fd)); -} - -class FdProxy { - public: - FdProxy() = default; - explicit FdProxy(std::unique_ptr<FdInterface> fd) : fd_(std::move(fd)) { - } - - const Fd &get_fd() const { - return fd_->get_fd(); - } - Fd &get_fd() { - return fd_->get_fd(); - } - int32 get_flags() const { - return fd_->get_flags(); - } - Status get_pending_error() TD_WARN_UNUSED_RESULT { - return fd_->get_pending_error(); - } - - Result<size_t> write(Slice slice) TD_WARN_UNUSED_RESULT { - return fd_->write(slice); - } - Result<size_t> read(MutableSlice slice) TD_WARN_UNUSED_RESULT { - return fd_->read(slice); - } - - void close() { - fd_->close(); - } - bool empty() const { - return fd_->empty(); - } - - private: - std::unique_ptr<FdInterface> fd_; -}; - -template <class FdT> -FdProxy make_fd_proxy(FdT fd) { - return FdProxy(make_fd_interface(std::move(fd))); -} - namespace detail { + class HttpConnectionBase : public Actor { public: + void write_next_noflush(BufferSlice buffer); void write_next(BufferSlice buffer); void write_ok(); void write_error(Status error); protected: enum class State { Read, Write, Close }; - template <class FdT> - HttpConnectionBase(State state, FdT fd, size_t max_post_size, size_t max_files, int32 idle_timeout) - : HttpConnectionBase(state, make_fd_proxy(std::move(fd)), max_post_size, max_files, idle_timeout) { - } - HttpConnectionBase(State state, FdProxy fd, size_t max_post_size, size_t max_files, int32 idle_timeout); + HttpConnectionBase(State state, BufferedFd<SocketFd> fd, SslStream ssl_stream, size_t max_post_size, size_t max_files, + int32 idle_timeout, int32 slow_scheduler_id); private: - using StreamConnection = BufferedFd<FdProxy>; State state_; - StreamConnection stream_connection_; + + BufferedFd<SocketFd> fd_; + IPAddress peer_address_; + SslStream ssl_stream_; + + ByteFlowSource read_source_{&fd_.input_buffer()}; + ByteFlowSink read_sink_; + + ChainBufferWriter write_buffer_; + ChainBufferReader write_buffer_reader_ = write_buffer_.extract_reader(); + ByteFlowSource write_source_{&write_buffer_reader_}; + ByteFlowMoveSink write_sink_{&fd_.output_buffer()}; + size_t max_post_size_; size_t max_files_; int32 idle_timeout_; HttpReader reader_; - HttpQueryPtr current_query_; + unique_ptr<HttpQuery> current_query_; bool close_after_write_ = false; + int32 slow_scheduler_id_{-1}; + void live_event(); - void start_up() override; - void tear_down() override; - void timeout_expired() override; - void loop() override; + void start_up() final; + void tear_down() final; + void timeout_expired() final; + void loop() final; + + void on_start_migrate(int32 sched_id) final; + void on_finish_migrate() final; - virtual void on_query(HttpQueryPtr) = 0; + virtual void on_query(unique_ptr<HttpQuery> query) = 0; virtual void on_error(Status error) = 0; }; + } // namespace detail } // namespace td diff --git a/protocols/Telegram/tdlib/td/tdnet/td/net/HttpContentLengthByteFlow.cpp b/protocols/Telegram/tdlib/td/tdnet/td/net/HttpContentLengthByteFlow.cpp index ea299b3993..8493bd0997 100644 --- a/protocols/Telegram/tdlib/td/tdnet/td/net/HttpContentLengthByteFlow.cpp +++ b/protocols/Telegram/tdlib/td/tdnet/td/net/HttpContentLengthByteFlow.cpp @@ -1,5 +1,5 @@ // -// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018 +// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022 // // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) @@ -10,7 +10,7 @@ namespace td { -void HttpContentLengthByteFlow::loop() { +bool HttpContentLengthByteFlow::loop() { auto ready_size = input_->size(); if (ready_size > len_) { ready_size = len_; @@ -18,17 +18,19 @@ void HttpContentLengthByteFlow::loop() { auto need_size = min(MIN_UPDATE_SIZE, len_); if (ready_size < need_size) { set_need_size(need_size); - return; + return false; } output_.append(input_->cut_head(ready_size)); len_ -= ready_size; if (len_ == 0) { - return finish(Status::OK()); + finish(Status::OK()); + return false; } if (!is_input_active_) { - return finish(Status::Error("Unexpected end of stream")); + finish(Status::Error("Unexpected end of stream")); + return false; } - on_output_updated(); + return true; } } // namespace td diff --git a/protocols/Telegram/tdlib/td/tdnet/td/net/HttpContentLengthByteFlow.h b/protocols/Telegram/tdlib/td/tdnet/td/net/HttpContentLengthByteFlow.h index 18f86abdb0..6ebb7d339f 100644 --- a/protocols/Telegram/tdlib/td/tdnet/td/net/HttpContentLengthByteFlow.h +++ b/protocols/Telegram/tdlib/td/tdnet/td/net/HttpContentLengthByteFlow.h @@ -1,5 +1,5 @@ // -// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018 +// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022 // // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) @@ -15,7 +15,7 @@ class HttpContentLengthByteFlow final : public ByteFlowBase { HttpContentLengthByteFlow() = default; explicit HttpContentLengthByteFlow(size_t len) : len_(len) { } - void loop() override; + bool loop() final; private: static constexpr size_t MIN_UPDATE_SIZE = 1 << 14; diff --git a/protocols/Telegram/tdlib/td/tdnet/td/net/HttpFile.cpp b/protocols/Telegram/tdlib/td/tdnet/td/net/HttpFile.cpp index b4f6e6d16b..6ca6975ac8 100644 --- a/protocols/Telegram/tdlib/td/tdnet/td/net/HttpFile.cpp +++ b/protocols/Telegram/tdlib/td/tdnet/td/net/HttpFile.cpp @@ -1,5 +1,5 @@ // -// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018 +// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022 // // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) diff --git a/protocols/Telegram/tdlib/td/tdnet/td/net/HttpFile.h b/protocols/Telegram/tdlib/td/tdnet/td/net/HttpFile.h index 6f35843060..7ba795eae2 100644 --- a/protocols/Telegram/tdlib/td/tdnet/td/net/HttpFile.h +++ b/protocols/Telegram/tdlib/td/tdnet/td/net/HttpFile.h @@ -1,5 +1,5 @@ // -// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018 +// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022 // // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) @@ -30,7 +30,7 @@ class HttpFile { HttpFile(const HttpFile &) = delete; HttpFile &operator=(const HttpFile &) = delete; - HttpFile(HttpFile &&other) + HttpFile(HttpFile &&other) noexcept : field_name(std::move(other.field_name)) , name(std::move(other.name)) , content_type(std::move(other.content_type)) diff --git a/protocols/Telegram/tdlib/td/tdnet/td/net/HttpHeaderCreator.h b/protocols/Telegram/tdlib/td/tdnet/td/net/HttpHeaderCreator.h index d3e84e5dbf..72cebaf2be 100644 --- a/protocols/Telegram/tdlib/td/tdnet/td/net/HttpHeaderCreator.h +++ b/protocols/Telegram/tdlib/td/tdnet/td/net/HttpHeaderCreator.h @@ -1,5 +1,5 @@ // -// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018 +// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022 // // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) @@ -8,6 +8,7 @@ #include "td/utils/logging.h" #include "td/utils/Slice.h" +#include "td/utils/SliceBuilder.h" #include "td/utils/Status.h" #include "td/utils/StringBuilder.h" @@ -16,27 +17,26 @@ namespace td { class HttpHeaderCreator { public: static constexpr size_t MAX_HEADER = 4096; - HttpHeaderCreator() : sb_({header_, MAX_HEADER}) { + HttpHeaderCreator() : sb_(MutableSlice{header_, MAX_HEADER}) { } void init_ok() { - sb_ = StringBuilder({header_, MAX_HEADER}); + sb_ = StringBuilder(MutableSlice{header_, MAX_HEADER}); sb_ << "HTTP/1.1 200 OK\r\n"; } void init_get(Slice url) { - sb_ = StringBuilder({header_, MAX_HEADER}); + sb_ = StringBuilder(MutableSlice{header_, MAX_HEADER}); sb_ << "GET " << url << " HTTP/1.1\r\n"; } void init_post(Slice url) { - sb_ = StringBuilder({header_, MAX_HEADER}); + sb_ = StringBuilder(MutableSlice{header_, MAX_HEADER}); sb_ << "POST " << url << " HTTP/1.1\r\n"; } void init_error(int code, Slice reason) { - sb_ = StringBuilder({header_, MAX_HEADER}); + sb_ = StringBuilder(MutableSlice{header_, MAX_HEADER}); sb_ << "HTTP/1.1 " << code << " " << reason << "\r\n"; } void init_status_line(int http_status_code) { - sb_ = StringBuilder({header_, MAX_HEADER}); - sb_ << "HTTP/1.1 " << http_status_code << " " << get_status_line(http_status_code) << "\r\n"; + init_error(http_status_code, get_status_line(http_status_code)); } void add_header(Slice key, Slice value) { sb_ << key << ": " << value << "\r\n"; @@ -45,7 +45,7 @@ class HttpHeaderCreator { add_header("Content-Type", type); } void set_content_size(size_t size) { - add_header("Content-Length", to_string(size)); + add_header("Content-Length", PSLICE() << size); } void set_keep_alive() { add_header("Connection", "keep-alive"); @@ -57,7 +57,7 @@ class HttpHeaderCreator { sb_ << content; } if (sb_.is_error()) { - return Status::Error("Too much headers"); + return Status::Error("Too many headers"); } return sb_.as_cslice(); } @@ -86,6 +86,8 @@ class HttpHeaderCreator { return CSlice("Not Modified"); case 307: return CSlice("Temporary Redirect"); + case 308: + return CSlice("Permanent Redirect"); case 400: return CSlice("Bad Request"); case 401: @@ -102,16 +104,28 @@ class HttpHeaderCreator { return CSlice("Request Timeout"); case 409: return CSlice("Conflict"); + case 410: + return CSlice("Gone"); case 411: return CSlice("Length Required"); + case 412: + return CSlice("Precondition Failed"); case 413: return CSlice("Request Entity Too Large"); case 414: return CSlice("Request-URI Too Long"); case 415: return CSlice("Unsupported Media Type"); + case 416: + return CSlice("Range Not Satisfiable"); + case 417: + return CSlice("Expectation Failed"); case 418: return CSlice("I'm a teapot"); + case 421: + return CSlice("Misdirected Request"); + case 426: + return CSlice("Upgrade Required"); case 429: return CSlice("Too Many Requests"); case 431: diff --git a/protocols/Telegram/tdlib/td/tdnet/td/net/HttpInboundConnection.cpp b/protocols/Telegram/tdlib/td/tdnet/td/net/HttpInboundConnection.cpp index 533cdd5407..88a5ca9350 100644 --- a/protocols/Telegram/tdlib/td/tdnet/td/net/HttpInboundConnection.cpp +++ b/protocols/Telegram/tdlib/td/tdnet/td/net/HttpInboundConnection.cpp @@ -1,22 +1,26 @@ // -// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018 +// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022 // // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) // #include "td/net/HttpInboundConnection.h" -#include "td/utils/logging.h" +#include "td/net/SslStream.h" + +#include "td/utils/common.h" namespace td { -// HttpInboundConnection implementation -HttpInboundConnection::HttpInboundConnection(SocketFd fd, size_t max_post_size, size_t max_files, int32 idle_timeout, - ActorShared<Callback> callback) - : HttpConnectionBase(State::Read, std::move(fd), max_post_size, max_files, idle_timeout) + +HttpInboundConnection::HttpInboundConnection(BufferedFd<SocketFd> fd, size_t max_post_size, size_t max_files, + int32 idle_timeout, ActorShared<Callback> callback, + int32 slow_scheduler_id) + : HttpConnectionBase(State::Read, std::move(fd), SslStream(), max_post_size, max_files, idle_timeout, + slow_scheduler_id) , callback_(std::move(callback)) { } -void HttpInboundConnection::on_query(HttpQueryPtr query) { +void HttpInboundConnection::on_query(unique_ptr<HttpQuery> query) { CHECK(!callback_.empty()); send_closure(callback_, &Callback::handle, std::move(query), ActorOwn<HttpInboundConnection>(actor_id(this))); } diff --git a/protocols/Telegram/tdlib/td/tdnet/td/net/HttpInboundConnection.h b/protocols/Telegram/tdlib/td/tdnet/td/net/HttpInboundConnection.h index 013b024592..b92e92c9b1 100644 --- a/protocols/Telegram/tdlib/td/tdnet/td/net/HttpInboundConnection.h +++ b/protocols/Telegram/tdlib/td/tdnet/td/net/HttpInboundConnection.h @@ -1,16 +1,17 @@ // -// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018 +// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022 // // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) // #pragma once -#include "td/actor/actor.h" - #include "td/net/HttpConnectionBase.h" #include "td/net/HttpQuery.h" +#include "td/actor/actor.h" + +#include "td/utils/BufferedFd.h" #include "td/utils/port/SocketFd.h" #include "td/utils/Status.h" @@ -20,20 +21,20 @@ class HttpInboundConnection final : public detail::HttpConnectionBase { public: class Callback : public Actor { public: - virtual void handle(HttpQueryPtr query, ActorOwn<HttpInboundConnection> connection) = 0; + virtual void handle(unique_ptr<HttpQuery> query, ActorOwn<HttpInboundConnection> connection) = 0; }; // Inherited interface // void write_next(BufferSlice buffer); // void write_ok(); // void write_error(Status error); - HttpInboundConnection(SocketFd fd, size_t max_post_size, size_t max_files, int32 idle_timeout, - ActorShared<Callback> callback); + HttpInboundConnection(BufferedFd<SocketFd> fd, size_t max_post_size, size_t max_files, int32 idle_timeout, + ActorShared<Callback> callback, int32 slow_scheduler_id = -1); private: - void on_query(HttpQueryPtr query) override; - void on_error(Status error) override; - void hangup() override { + void on_query(unique_ptr<HttpQuery> query) final; + void on_error(Status error) final; + void hangup() final { callback_.release(); stop(); } diff --git a/protocols/Telegram/tdlib/td/tdnet/td/net/HttpOutboundConnection.cpp b/protocols/Telegram/tdlib/td/tdnet/td/net/HttpOutboundConnection.cpp index f6efe7e07a..2bf8073809 100644 --- a/protocols/Telegram/tdlib/td/tdnet/td/net/HttpOutboundConnection.cpp +++ b/protocols/Telegram/tdlib/td/tdnet/td/net/HttpOutboundConnection.cpp @@ -1,16 +1,16 @@ // -// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018 +// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022 // // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) // #include "td/net/HttpOutboundConnection.h" -#include "td/utils/logging.h" +#include "td/utils/common.h" namespace td { -// HttpOutboundConnection implementation -void HttpOutboundConnection::on_query(HttpQueryPtr query) { + +void HttpOutboundConnection::on_query(unique_ptr<HttpQuery> query) { CHECK(!callback_.empty()); send_closure(callback_, &Callback::handle, std::move(query)); } diff --git a/protocols/Telegram/tdlib/td/tdnet/td/net/HttpOutboundConnection.h b/protocols/Telegram/tdlib/td/tdnet/td/net/HttpOutboundConnection.h index d7496c59c4..ca1f49f7d9 100644 --- a/protocols/Telegram/tdlib/td/tdnet/td/net/HttpOutboundConnection.h +++ b/protocols/Telegram/tdlib/td/tdnet/td/net/HttpOutboundConnection.h @@ -1,16 +1,19 @@ // -// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018 +// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022 // // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) // #pragma once -#include "td/actor/actor.h" - #include "td/net/HttpConnectionBase.h" #include "td/net/HttpQuery.h" +#include "td/net/SslStream.h" + +#include "td/actor/actor.h" +#include "td/utils/BufferedFd.h" +#include "td/utils/port/SocketFd.h" #include "td/utils/Status.h" namespace td { @@ -19,13 +22,13 @@ class HttpOutboundConnection final : public detail::HttpConnectionBase { public: class Callback : public Actor { public: - virtual void handle(HttpQueryPtr query) = 0; + virtual void handle(unique_ptr<HttpQuery> query) = 0; virtual void on_connection_error(Status error) = 0; // TODO rename to on_error }; - template <class FdT> - HttpOutboundConnection(FdT fd, size_t max_post_size, size_t max_files, int32 idle_timeout, - ActorShared<Callback> callback) - : HttpConnectionBase(HttpConnectionBase::State::Write, std::move(fd), max_post_size, max_files, idle_timeout) + HttpOutboundConnection(BufferedFd<SocketFd> fd, SslStream ssl_stream, size_t max_post_size, size_t max_files, + int32 idle_timeout, ActorShared<Callback> callback, int32 slow_scheduler_id = -1) + : HttpConnectionBase(HttpConnectionBase::State::Write, std::move(fd), std::move(ssl_stream), max_post_size, + max_files, idle_timeout, slow_scheduler_id) , callback_(std::move(callback)) { } // Inherited interface @@ -34,9 +37,9 @@ class HttpOutboundConnection final : public detail::HttpConnectionBase { // void write_error(Status error); private: - void on_query(HttpQueryPtr query) override; - void on_error(Status error) override; - void hangup() override { + void on_query(unique_ptr<HttpQuery> query) final; + void on_error(Status error) final; + void hangup() final { callback_.release(); HttpConnectionBase::hangup(); } diff --git a/protocols/Telegram/tdlib/td/tdnet/td/net/HttpProxy.cpp b/protocols/Telegram/tdlib/td/tdnet/td/net/HttpProxy.cpp new file mode 100644 index 0000000000..a676e691f6 --- /dev/null +++ b/protocols/Telegram/tdlib/td/tdnet/td/net/HttpProxy.cpp @@ -0,0 +1,110 @@ +// +// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022 +// +// Distributed under the Boost Software License, Version 1.0. (See accompanying +// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) +// +#include "td/net/HttpProxy.h" + +#include "td/utils/base64.h" +#include "td/utils/common.h" +#include "td/utils/format.h" +#include "td/utils/logging.h" +#include "td/utils/misc.h" +#include "td/utils/Slice.h" +#include "td/utils/SliceBuilder.h" + +namespace td { + +void HttpProxy::send_connect() { + VLOG(proxy) << "Send CONNECT to proxy"; + CHECK(state_ == State::SendConnect); + state_ = State::WaitConnectResponse; + + string host = PSTRING() << ip_address_.get_ip_host() << ':' << ip_address_.get_port(); + string proxy_authorization; + if (!username_.empty() || !password_.empty()) { + auto userinfo = PSTRING() << username_ << ':' << password_; + proxy_authorization = PSTRING() << "Proxy-Authorization: basic " << base64_encode(userinfo) << "\r\n"; + } + fd_.output_buffer().append(PSLICE() << "CONNECT " << host << " HTTP/1.1\r\n" + << "Host: " << host << "\r\n" + << proxy_authorization << "\r\n"); +} + +Status HttpProxy::wait_connect_response() { + CHECK(state_ == State::WaitConnectResponse); + auto it = fd_.input_buffer().clone(); + VLOG(proxy) << "Receive CONNECT response of size " << it.size(); + if (it.size() < 12 + 1 + 1) { + return Status::OK(); + } + char begin_buf[12]; + MutableSlice begin(begin_buf, 12); + it.advance(12, begin); + if ((begin.substr(0, 10) != "HTTP/1.1 2" && begin.substr(0, 10) != "HTTP/1.0 2") || !is_digit(begin[10]) || + !is_digit(begin[11])) { + char buf[1024]; + size_t len = min(sizeof(buf), it.size()); + it.advance(len, MutableSlice{buf, sizeof(buf)}); + VLOG(proxy) << "Failed to connect: " << format::escaped(Slice(buf, len)); + return Status::Error(PSLICE() << "Failed to connect to " << ip_address_.get_ip_host() << ':' + << ip_address_.get_port()); + } + + size_t total_size = 12; + char c; + MutableSlice c_slice(&c, 1); + while (!it.empty()) { + it.advance(1, c_slice); + total_size++; + if (c == '\n') { + break; + } + } + if (it.empty()) { + return Status::OK(); + } + + char prev = '\n'; + size_t pos = 0; + bool found = false; + while (!it.empty()) { + it.advance(1, c_slice); + total_size++; + if (c == '\n') { + if (pos == 0 || (pos == 1 && prev == '\r')) { + found = true; + break; + } + pos = 0; + } else { + pos++; + } + prev = c; + } + if (!found) { + CHECK(it.empty()); + return Status::OK(); + } + + fd_.input_buffer().advance(total_size); + stop(); + return Status::OK(); +} + +Status HttpProxy::loop_impl() { + switch (state_) { + case State::SendConnect: + send_connect(); + break; + case State::WaitConnectResponse: + TRY_STATUS(wait_connect_response()); + break; + default: + UNREACHABLE(); + } + return Status::OK(); +} + +} // namespace td diff --git a/protocols/Telegram/tdlib/td/tdnet/td/net/HttpProxy.h b/protocols/Telegram/tdlib/td/tdnet/td/net/HttpProxy.h new file mode 100644 index 0000000000..fd9f6233ec --- /dev/null +++ b/protocols/Telegram/tdlib/td/tdnet/td/net/HttpProxy.h @@ -0,0 +1,28 @@ +// +// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022 +// +// Distributed under the Boost Software License, Version 1.0. (See accompanying +// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) +// +#pragma once + +#include "td/net/TransparentProxy.h" + +#include "td/utils/Status.h" + +namespace td { + +class HttpProxy final : public TransparentProxy { + public: + using TransparentProxy::TransparentProxy; + + private: + enum class State { SendConnect, WaitConnectResponse } state_ = State::SendConnect; + + void send_connect(); + Status wait_connect_response(); + + Status loop_impl() final; +}; + +} // namespace td diff --git a/protocols/Telegram/tdlib/td/tdnet/td/net/HttpQuery.cpp b/protocols/Telegram/tdlib/td/tdnet/td/net/HttpQuery.cpp index b4af0eef3f..6cf0028f4d 100644 --- a/protocols/Telegram/tdlib/td/tdnet/td/net/HttpQuery.cpp +++ b/protocols/Telegram/tdlib/td/tdnet/td/net/HttpQuery.cpp @@ -1,51 +1,67 @@ // -// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018 +// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022 // // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) // #include "td/net/HttpQuery.h" +#include "td/utils/misc.h" + #include <algorithm> namespace td { -Slice HttpQuery::header(Slice key) const { +Slice HttpQuery::get_header(Slice key) const { auto it = std::find_if(headers_.begin(), headers_.end(), [&key](const std::pair<MutableSlice, MutableSlice> &s) { return s.first == key; }); return it == headers_.end() ? Slice() : it->second; } -MutableSlice HttpQuery::arg(Slice key) const { +MutableSlice HttpQuery::get_arg(Slice key) const { auto it = std::find_if(args_.begin(), args_.end(), [&key](const std::pair<MutableSlice, MutableSlice> &s) { return s.first == key; }); return it == args_.end() ? MutableSlice() : it->second; } -std::vector<std::pair<string, string>> HttpQuery::string_args() const { - std::vector<std::pair<string, string>> res; +vector<std::pair<string, string>> HttpQuery::get_args() const { + vector<std::pair<string, string>> res; + res.reserve(args_.size()); for (auto &it : args_) { - res.push_back(std::make_pair(it.first.str(), it.second.str())); + res.emplace_back(it.first.str(), it.second.str()); } return res; } +int HttpQuery::get_retry_after() const { + auto value = get_header("retry-after"); + if (value.empty()) { + return 0; + } + auto r_retry_after = to_integer_safe<int>(value); + if (r_retry_after.is_error()) { + return 0; + } + + return td::max(0, r_retry_after.ok()); +} + StringBuilder &operator<<(StringBuilder &sb, const HttpQuery &q) { switch (q.type_) { - case HttpQuery::Type::EMPTY: + case HttpQuery::Type::Empty: sb << "EMPTY"; return sb; - case HttpQuery::Type::GET: + case HttpQuery::Type::Get: sb << "GET"; break; - case HttpQuery::Type::POST: + case HttpQuery::Type::Post: sb << "POST"; break; - case HttpQuery::Type::RESPONSE: + case HttpQuery::Type::Response: sb << "RESPONSE"; break; } - if (q.type_ == HttpQuery::Type::RESPONSE) { + if (q.type_ == HttpQuery::Type::Response) { sb << ":" << q.code_ << ":" << q.reason_; } else { sb << ":" << q.url_path_; diff --git a/protocols/Telegram/tdlib/td/tdnet/td/net/HttpQuery.h b/protocols/Telegram/tdlib/td/tdnet/td/net/HttpQuery.h index acab74ac66..5abd4ae517 100644 --- a/protocols/Telegram/tdlib/td/tdnet/td/net/HttpQuery.h +++ b/protocols/Telegram/tdlib/td/tdnet/td/net/HttpQuery.h @@ -1,5 +1,5 @@ // -// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018 +// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022 // // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) @@ -10,6 +10,7 @@ #include "td/utils/buffer.h" #include "td/utils/common.h" +#include "td/utils/port/IPAddress.h" #include "td/utils/Slice.h" #include "td/utils/StringBuilder.h" @@ -19,28 +20,30 @@ namespace td { class HttpQuery { public: - enum class Type : int8 { EMPTY, GET, POST, RESPONSE }; + enum class Type : int8 { Empty, Get, Post, Response }; - std::vector<BufferSlice> container_; - Type type_; + vector<BufferSlice> container_; + Type type_ = Type::Empty; + int32 code_ = 0; MutableSlice url_path_; - std::vector<std::pair<MutableSlice, MutableSlice>> args_; - int code_; + vector<std::pair<MutableSlice, MutableSlice>> args_; MutableSlice reason_; - bool keep_alive_; - std::vector<std::pair<MutableSlice, MutableSlice>> headers_; - std::vector<HttpFile> files_; + bool keep_alive_ = true; + vector<std::pair<MutableSlice, MutableSlice>> headers_; + vector<HttpFile> files_; MutableSlice content_; - Slice header(Slice key) const; + IPAddress peer_address_; - MutableSlice arg(Slice key) const; + Slice get_header(Slice key) const; - std::vector<std::pair<string, string>> string_args() const; -}; + MutableSlice get_arg(Slice key) const; + + vector<std::pair<string, string>> get_args() const; -using HttpQueryPtr = std::unique_ptr<HttpQuery>; + int get_retry_after() const; +}; StringBuilder &operator<<(StringBuilder &sb, const HttpQuery &q); diff --git a/protocols/Telegram/tdlib/td/tdnet/td/net/HttpReader.cpp b/protocols/Telegram/tdlib/td/tdnet/td/net/HttpReader.cpp index 1cfa7666a7..b86039052c 100644 --- a/protocols/Telegram/tdlib/td/tdnet/td/net/HttpReader.cpp +++ b/protocols/Telegram/tdlib/td/tdnet/td/net/HttpReader.cpp @@ -1,5 +1,5 @@ // -// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018 +// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022 // // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) @@ -17,40 +17,18 @@ #include "td/utils/Parser.h" #include "td/utils/PathView.h" #include "td/utils/port/path.h" +#include "td/utils/SliceBuilder.h" +#include <cstddef> #include <cstring> namespace td { constexpr const char HttpReader::TEMP_DIRECTORY_PREFIX[]; -static size_t urldecode(Slice from, MutableSlice to, bool decode_plus_sign_as_space) { - size_t to_i = 0; - CHECK(to.size() >= from.size()); - for (size_t from_i = 0, n = from.size(); from_i < n; from_i++) { - if (from[from_i] == '%' && from_i + 2 < n) { - int high = hex_to_int(from[from_i + 1]); - int low = hex_to_int(from[from_i + 2]); - if (high < 16 && low < 16) { - to[to_i++] = static_cast<char>(high * 16 + low); - from_i += 2; - continue; - } - } - to[to_i++] = decode_plus_sign_as_space && from[from_i] == '+' ? ' ' : from[from_i]; - } - return to_i; -} - -static MutableSlice urldecode_inplace(MutableSlice str, bool decode_plus_sign_as_space) { - size_t result_size = urldecode(str, str, decode_plus_sign_as_space); - str.truncate(result_size); - return str; -} - void HttpReader::init(ChainBufferReader *input, size_t max_post_size, size_t max_files) { input_ = input; - state_ = ReadHeaders; + state_ = State::ReadHeaders; headers_read_length_ = 0; content_length_ = 0; query_ = nullptr; @@ -60,17 +38,21 @@ void HttpReader::init(ChainBufferReader *input, size_t max_post_size, size_t max total_headers_length_ = 0; } -Result<size_t> HttpReader::read_next(HttpQuery *query) { +Result<size_t> HttpReader::read_next(HttpQuery *query, bool can_be_slow) { if (query_ != query) { CHECK(query_ == nullptr); query_ = query; } size_t need_size = input_->size() + 1; while (true) { - if (state_ != ReadHeaders) { + if (state_ != State::ReadHeaders) { + gzip_flow_.wakeup(); flow_source_.wakeup(); if (flow_sink_.is_ready() && flow_sink_.status().is_error()) { - return Status::Error(400, "Bad Request: " + flow_sink_.status().message().str()); + if (!temp_file_.empty()) { + clean_temporary_file(); + } + return Status::Error(400, PSLICE() << "Bad Request: " << flow_sink_.status().message()); } need_size = flow_source_.get_need_size(); if (need_size == 0) { @@ -78,7 +60,7 @@ Result<size_t> HttpReader::read_next(HttpQuery *query) { } } switch (state_) { - case ReadHeaders: { + case State::ReadHeaders: { auto result = split_header(); if (result.is_error() || result.ok() != 0) { return result; @@ -104,12 +86,16 @@ Result<size_t> HttpReader::read_next(HttpQuery *query) { if (content_encoding_.empty()) { } else if (content_encoding_ == "gzip" || content_encoding_ == "deflate") { - gzip_flow_ = GzipByteFlow(Gzip::Decode); - gzip_flow_.set_max_output_size(MAX_FILE_SIZE); + gzip_flow_ = GzipByteFlow(Gzip::Mode::Decode); + GzipByteFlow::Options options; + options.write_watermark.low = 0; + options.write_watermark.high = max(max_post_size_, MAX_TOTAL_PARAMETERS_LENGTH + 1); + gzip_flow_.set_options(options); + gzip_flow_.set_max_output_size(MAX_CONTENT_SIZE); *source >> gzip_flow_; source = &gzip_flow_; } else { - LOG(ERROR) << "Unsupported " << tag("content-encoding", content_encoding_); + LOG(WARNING) << "Unsupported " << tag("content-encoding", content_encoding_); return Status::Error(415, "Unsupported Media Type: unsupported content-encoding"); } @@ -117,26 +103,26 @@ Result<size_t> HttpReader::read_next(HttpQuery *query) { *source >> flow_sink_; content_ = flow_sink_.get_output(); - if (content_length_ > MAX_CONTENT_SIZE) { + if (content_length_ >= MAX_CONTENT_SIZE) { return Status::Error(413, PSLICE() << "Request Entity Too Large: content length is " << content_length_); } - if (std::strstr(content_type_lowercased_.c_str(), "multipart/form-data")) { - state_ = ReadMultipartFormData; + if (content_type_lowercased_.find("multipart/form-data") != string::npos) { + state_ = State::ReadMultipartFormData; const char *p = std::strstr(content_type_lowercased_.c_str(), "boundary"); if (p == nullptr) { return Status::Error(400, "Bad Request: boundary not found"); } p += 8; - ptrdiff_t offset = p - content_type_lowercased_.c_str(); + std::ptrdiff_t offset = p - content_type_lowercased_.c_str(); p = static_cast<const char *>( std::memchr(content_type_.begin() + offset, '=', content_type_.size() - offset)); if (p == nullptr) { return Status::Error(400, "Bad Request: boundary value not found"); } p++; - const char *end_p = static_cast<const char *>(std::memchr(p, ';', content_type_.end() - p)); + auto end_p = static_cast<const char *>(std::memchr(p, ';', content_type_.end() - p)); if (end_p == nullptr) { end_p = content_type_.end(); } @@ -145,27 +131,32 @@ Result<size_t> HttpReader::read_next(HttpQuery *query) { end_p--; } + CHECK(p != nullptr); Slice boundary(p, static_cast<size_t>(end_p - p)); if (boundary.empty() || boundary.size() > MAX_BOUNDARY_LENGTH) { return Status::Error(400, "Bad Request: boundary too big or empty"); } boundary_ = "\r\n--" + boundary.str(); - form_data_parse_state_ = SkipPrologue; + form_data_parse_state_ = FormDataParseState::SkipPrologue; form_data_read_length_ = 0; form_data_skipped_length_ = 0; - } else if (std::strstr(content_type_lowercased_.c_str(), "application/x-www-form-urlencoded") || - std::strstr(content_type_lowercased_.c_str(), "application/json")) { - state_ = ReadArgs; + } else if (content_type_lowercased_.find("application/x-www-form-urlencoded") != string::npos || + content_type_lowercased_.find("application/json") != string::npos) { + state_ = State::ReadArgs; } else { form_data_skipped_length_ = 0; - state_ = ReadContent; + state_ = State::ReadContent; } continue; } - case ReadContent: { + case State::ReadContent: { if (content_->size() > max_post_size_) { - state_ = ReadContentToFile; + state_ = State::ReadContentToFile; + GzipByteFlow::Options options; + options.write_watermark.low = 4 << 20; + options.write_watermark.high = 8 << 20; + gzip_flow_.set_options(options); continue; } if (flow_sink_.is_ready()) { @@ -177,7 +168,10 @@ Result<size_t> HttpReader::read_next(HttpQuery *query) { return need_size; } - case ReadContentToFile: { + case State::ReadContentToFile: { + if (!can_be_slow) { + return Status::Error("SLOW"); + } // save content to a file if (temp_file_.empty()) { auto file = open_temp_file("file"); @@ -187,27 +181,32 @@ Result<size_t> HttpReader::read_next(HttpQuery *query) { } auto size = content_->size(); - if (size) { + bool restart = false; + if (size > (1 << 20) || flow_sink_.is_ready()) { TRY_STATUS(save_file_part(content_->cut_head(size).move_as_buffer_slice())); + restart = true; } if (flow_sink_.is_ready()) { query_->files_.emplace_back("file", "", content_type_.str(), file_size_, temp_file_name_); close_temp_file(); break; } + if (restart) { + continue; + } return need_size; } - case ReadArgs: { + case State::ReadArgs: { auto size = content_->size(); if (size > MAX_TOTAL_PARAMETERS_LENGTH - total_parameters_length_) { - return Status::Error(413, "Request Entity Too Large: too much parameters"); + return Status::Error(413, "Request Entity Too Large: too many parameters"); } if (flow_sink_.is_ready()) { query_->container_.emplace_back(content_->cut_head(size).move_as_buffer_slice()); Status result; - if (std::strstr(content_type_lowercased_.c_str(), "application/x-www-form-urlencoded")) { + if (content_type_lowercased_.find("application/x-www-form-urlencoded") != string::npos) { result = parse_parameters(query_->container_.back().as_slice()); } else { result = parse_json_parameters(query_->container_.back().as_slice()); @@ -224,10 +223,12 @@ Result<size_t> HttpReader::read_next(HttpQuery *query) { return need_size; } - case ReadMultipartFormData: { - TRY_RESULT(result, parse_multipart_form_data()); - if (result) { - break; + case State::ReadMultipartFormData: { + if (!content_->empty() || flow_sink_.is_ready()) { + TRY_RESULT(result, parse_multipart_form_data(can_be_slow)); + if (result) { + break; + } } return need_size; } @@ -244,18 +245,19 @@ Result<size_t> HttpReader::read_next(HttpQuery *query) { // returns Status on wrong request // returns true if parsing has finished // returns false if need more data -Result<bool> HttpReader::parse_multipart_form_data() { +Result<bool> HttpReader::parse_multipart_form_data(bool can_be_slow) { while (true) { - LOG(DEBUG) << "Parsing multipart form data in state " << form_data_parse_state_; + LOG(DEBUG) << "Parsing multipart form data in state " << static_cast<int32>(form_data_parse_state_) + << " with already read length " << form_data_read_length_; switch (form_data_parse_state_) { - case SkipPrologue: + case FormDataParseState::SkipPrologue: if (find_boundary(content_->clone(), {boundary_.c_str() + 2, boundary_.size() - 2}, form_data_read_length_)) { size_t to_skip = form_data_read_length_ + (boundary_.size() - 2); content_->advance(to_skip); form_data_skipped_length_ += to_skip; form_data_read_length_ = 0; - form_data_parse_state_ = ReadPartHeaders; + form_data_parse_state_ = FormDataParseState::ReadPartHeaders; continue; } @@ -263,14 +265,14 @@ Result<bool> HttpReader::parse_multipart_form_data() { form_data_skipped_length_ += form_data_read_length_; form_data_read_length_ = 0; return false; - case ReadPartHeaders: + case FormDataParseState::ReadPartHeaders: if (find_boundary(content_->clone(), "\r\n\r\n", form_data_read_length_)) { total_headers_length_ += form_data_read_length_; if (total_headers_length_ > MAX_TOTAL_HEADERS_LENGTH) { return Status::Error(431, "Request Header Fields Too Large: total headers size exceeded"); } if (form_data_read_length_ == 0) { - // there is no headers at all + // there are no headers at all return Status::Error(400, "Bad Request: headers in multipart/form-data are empty"); } @@ -287,6 +289,7 @@ Result<bool> HttpReader::parse_multipart_form_data() { file_field_name_.clear(); field_content_type_ = "application/octet-stream"; file_name_.clear(); + has_file_name_ = false; CHECK(temp_file_.empty()); temp_file_name_.clear(); @@ -317,35 +320,70 @@ Result<bool> HttpReader::parse_multipart_form_data() { header_value.remove_prefix(10); while (true) { header_value = trim(header_value); - const char *key_end = + const auto *key_end = static_cast<const char *>(std::memchr(header_value.data(), '=', header_value.size())); if (key_end == nullptr) { break; } size_t key_size = key_end - header_value.data(); - auto key = header_value.substr(0, key_size); - key = trim(key); + auto key = trim(header_value.substr(0, key_size)); header_value.remove_prefix(key_size + 1); - const char *value_end = - static_cast<const char *>(std::memchr(header_value.data(), ';', header_value.size())); - size_t value_size; - if (value_end == nullptr) { - value_size = header_value.size(); - } else { - value_size = value_end - header_value.data(); + + while (!header_value.empty() && is_space(header_value[0])) { + header_value.remove_prefix(1); } - auto value = header_value.substr(0, value_size); - value = trim(value); - if (value.size() > 1u && value[0] == '"' && value.back() == '"') { - value = {value.data() + 1, value.size() - 2}; + + MutableSlice value; + if (!header_value.empty() && header_value[0] == '"') { // quoted-string + char *value_end = header_value.data() + 1; + const char *pos = value_end; + while (true) { + if (pos == header_value.data() + header_value.size()) { + return Status::Error(400, "Bad Request: unclosed quoted string in Content-Disposition header"); + } + char c = *pos++; + if (c == '"') { + break; + } + if (c == '\\') { + if (pos == header_value.data() + header_value.size()) { + return Status::Error(400, "Bad Request: wrong escape sequence in Content-Disposition header"); + } + c = *pos++; + } + *value_end++ = c; + } + value = header_value.substr(1, value_end - header_value.data() - 1); + header_value.remove_prefix(pos - header_value.data()); + + while (!header_value.empty() && is_space(header_value[0])) { + header_value.remove_prefix(1); + } + if (!header_value.empty()) { + if (header_value[0] != ';') { + return Status::Error(400, "Bad Request: expected ';' in Content-Disposition header"); + } + header_value.remove_prefix(1); + } + } else { // token + auto value_end = + static_cast<const char *>(std::memchr(header_value.data(), ';', header_value.size())); + if (value_end != nullptr) { + auto value_size = static_cast<size_t>(value_end - header_value.data()); + value = trim(header_value.substr(0, value_size)); + header_value.remove_prefix(value_size + 1); + } else { + value = trim(header_value); + header_value = MutableSlice(); + } } - header_value.remove_prefix(value_size + (header_value.size() > value_size)); if (key == "name") { field_name_ = value; } else if (key == "filename") { file_name_ = value.str(); + has_file_name_ = true; } else { // ignore unknown parts of header } @@ -365,10 +403,10 @@ Result<bool> HttpReader::parse_multipart_form_data() { return Status::Error(400, "Bad Request: field name in multipart/form-data not found"); } - if (!file_name_.empty()) { + if (has_file_name_) { // file if (query_->files_.size() == max_files_) { - return Status::Error(413, "Request Entity Too Large: too much files attached"); + return Status::Error(413, "Request Entity Too Large: too many files attached"); } auto file = open_temp_file(file_name_); if (file.is_error()) { @@ -377,11 +415,11 @@ Result<bool> HttpReader::parse_multipart_form_data() { // don't need to save headers for files file_field_name_ = field_name_.str(); - form_data_parse_state_ = ReadFile; + form_data_parse_state_ = FormDataParseState::ReadFile; } else { // save headers for query parameters. They contain header names query_->container_.push_back(std::move(headers)); - form_data_parse_state_ = ReadPartValue; + form_data_parse_state_ = FormDataParseState::ReadPartValue; } continue; @@ -391,10 +429,10 @@ Result<bool> HttpReader::parse_multipart_form_data() { return Status::Error(431, "Request Header Fields Too Large: total headers size exceeded"); } return false; - case ReadPartValue: + case FormDataParseState::ReadPartValue: if (find_boundary(content_->clone(), boundary_, form_data_read_length_)) { if (total_parameters_length_ + form_data_read_length_ > MAX_TOTAL_PARAMETERS_LENGTH) { - return Status::Error(413, "Request Entity Too Large: too much parameters in form data"); + return Status::Error(413, "Request Entity Too Large: too many parameters in form data"); } query_->container_.emplace_back(content_->cut_head(form_data_read_length_).move_as_buffer_slice()); @@ -416,16 +454,19 @@ Result<bool> HttpReader::parse_multipart_form_data() { query_->args_.emplace_back(field_name_, value); } - form_data_parse_state_ = CheckForLastBoundary; + form_data_parse_state_ = FormDataParseState::CheckForLastBoundary; continue; } CHECK(content_->size() < form_data_read_length_ + boundary_.size()); if (total_parameters_length_ + form_data_read_length_ > MAX_TOTAL_PARAMETERS_LENGTH) { - return Status::Error(413, "Request Entity Too Large: too much parameters in form data"); + return Status::Error(413, "Request Entity Too Large: too many parameters in form data"); } return false; - case ReadFile: { + case FormDataParseState::ReadFile: { + if (!can_be_slow) { + return Status::Error("SLOW"); + } if (find_boundary(content_->clone(), boundary_, form_data_read_length_)) { auto file_part = content_->cut_head(form_data_read_length_).move_as_buffer_slice(); content_->advance(boundary_.size()); @@ -437,7 +478,7 @@ Result<bool> HttpReader::parse_multipart_form_data() { query_->files_.emplace_back(file_field_name_, file_name_, field_content_type_, file_size_, temp_file_name_); close_temp_file(); - form_data_parse_state_ = CheckForLastBoundary; + form_data_parse_state_ = FormDataParseState::CheckForLastBoundary; continue; } @@ -450,7 +491,7 @@ Result<bool> HttpReader::parse_multipart_form_data() { TRY_STATUS(save_file_part(std::move(file_part))); return false; } - case CheckForLastBoundary: { + case FormDataParseState::CheckForLastBoundary: { if (content_->size() < 2) { // need more data return false; @@ -462,13 +503,13 @@ Result<bool> HttpReader::parse_multipart_form_data() { if (x[0] == '-' && x[1] == '-') { content_->advance(2); form_data_skipped_length_ += 2; - form_data_parse_state_ = SkipEpilogue; + form_data_parse_state_ = FormDataParseState::SkipEpilogue; } else { - form_data_parse_state_ = ReadPartHeaders; + form_data_parse_state_ = FormDataParseState::ReadPartHeaders; } continue; } - case SkipEpilogue: { + case FormDataParseState::SkipEpilogue: { size_t size = content_->size(); LOG(DEBUG) << "Skipping epilogue. Have " << size << " bytes"; content_->advance(size); @@ -512,16 +553,20 @@ void HttpReader::process_header(MutableSlice header_name, MutableSlice header_va header_name = trim(header_name); header_value = trim(header_value); // TODO need to remove "\r\n" from value to_lower_inplace(header_name); - LOG(DEBUG) << "process_header [" << header_name << "=>" << header_value << "]"; + LOG(DEBUG) << "Process header [" << header_name << "=>" << header_value << "]"; query_->headers_.emplace_back(header_name, header_value); - // TODO: check if protocol is HTTP/1.1 - query_->keep_alive_ = true; if (header_name == "content-length") { - content_length_ = to_integer<size_t>(header_value); + auto content_length = to_integer<uint64>(header_value); + if (content_length > MAX_CONTENT_SIZE) { + content_length = MAX_CONTENT_SIZE; + } + content_length_ = static_cast<size_t>(content_length); } else if (header_name == "connection") { to_lower_inplace(header_value); if (header_value == "close") { query_->keep_alive_ = false; + } else { + query_->keep_alive_ = true; } } else if (header_name == "content-type") { content_type_ = header_value; @@ -542,7 +587,7 @@ Status HttpReader::parse_url(MutableSlice url) { url_path_size++; } - query_->url_path_ = urldecode_inplace({url.data(), url_path_size}, false); + query_->url_path_ = url_decode_inplace({url.data(), url_path_size}, false); if (url_path_size == url.size() || url[url_path_size] != '?') { return Status::OK(); @@ -553,7 +598,7 @@ Status HttpReader::parse_url(MutableSlice url) { Status HttpReader::parse_parameters(MutableSlice parameters) { total_parameters_length_ += parameters.size(); if (total_parameters_length_ > MAX_TOTAL_PARAMETERS_LENGTH) { - return Status::Error(413, "Request Entity Too Large: too much parameters"); + return Status::Error(413, "Request Entity Too Large: too many parameters"); } LOG(DEBUG) << "Parse parameters: \"" << parameters << "\""; @@ -562,9 +607,9 @@ Status HttpReader::parse_parameters(MutableSlice parameters) { auto key_value = parser.read_till_nofail('&'); parser.skip_nofail('&'); Parser kv_parser(key_value); - auto key = urldecode_inplace(kv_parser.read_till_nofail('='), true); + auto key = url_decode_inplace(kv_parser.read_till_nofail('='), true); kv_parser.skip_nofail('='); - auto value = urldecode_inplace(kv_parser.data(), true); + auto value = url_decode_inplace(kv_parser.data(), true); query_->args_.emplace_back(key, value); } @@ -579,15 +624,27 @@ Status HttpReader::parse_json_parameters(MutableSlice parameters) { total_parameters_length_ += parameters.size(); if (total_parameters_length_ > MAX_TOTAL_PARAMETERS_LENGTH) { - return Status::Error(413, "Request Entity Too Large: too much parameters"); + return Status::Error(413, "Request Entity Too Large: too many parameters"); } - LOG(DEBUG) << "Parse json parameters: \"" << parameters << "\""; + LOG(DEBUG) << "Parse JSON parameters: \"" << parameters << "\""; Parser parser(parameters); parser.skip_whitespaces(); + if (parser.peek_char() == '"') { + auto r_value = json_string_decode(parser); + if (r_value.is_error()) { + return Status::Error(400, PSLICE() << "Bad Request: can't parse string content: " << r_value.error().message()); + } + if (!parser.empty()) { + return Status::Error(400, "Bad Request: extra data after string"); + } + query_->container_.emplace_back("content"); + query_->args_.emplace_back(query_->container_.back().as_slice(), r_value.move_as_ok()); + return Status::OK(); + } parser.skip('{'); if (parser.status().is_error()) { - return Status::Error(400, "Bad Request: json object expected"); + return Status::Error(400, "Bad Request: JSON object expected"); } while (true) { parser.skip_whitespaces(); @@ -603,29 +660,29 @@ Status HttpReader::parse_json_parameters(MutableSlice parameters) { } auto r_key = json_string_decode(parser); if (r_key.is_error()) { - return Status::Error(400, string("Bad Request: can't parse parameter name: ") + r_key.error().message().c_str()); + return Status::Error(400, PSLICE() << "Bad Request: can't parse parameter name: " << r_key.error().message()); } parser.skip_whitespaces(); if (!parser.try_skip(':')) { return Status::Error(400, "Bad Request: can't parse object, ':' expected"); } parser.skip_whitespaces(); - Result<MutableSlice> r_value; - if (parser.peek_char() == '"') { - r_value = json_string_decode(parser); - } else { - const int32 DEFAULT_MAX_DEPTH = 100; - auto begin = parser.ptr(); - auto result = do_json_skip(parser, DEFAULT_MAX_DEPTH); - if (result.is_ok()) { - r_value = MutableSlice(begin, parser.ptr()); + auto r_value = [&]() -> Result<MutableSlice> { + if (parser.peek_char() == '"') { + return json_string_decode(parser); } else { - r_value = result.move_as_error(); + const int32 DEFAULT_MAX_DEPTH = 100; + auto begin = parser.ptr(); + auto result = do_json_skip(parser, DEFAULT_MAX_DEPTH); + if (result.is_ok()) { + return MutableSlice(begin, parser.ptr()); + } else { + return result.move_as_error(); + } } - } + }(); if (r_value.is_error()) { - return Status::Error(400, - string("Bad Request: can't parse parameter value: ") + r_value.error().message().c_str()); + return Status::Error(400, PSLICE() << "Bad Request: can't parse parameter value: " << r_value.error().message()); } query_->args_.emplace_back(r_key.move_as_ok(), r_value.move_as_ok()); @@ -645,12 +702,16 @@ Status HttpReader::parse_head(MutableSlice head) { parser.skip(' '); // GET POST HTTP/1.1 if (type == "GET") { - query_->type_ = HttpQuery::Type::GET; + query_->type_ = HttpQuery::Type::Get; } else if (type == "POST") { - query_->type_ = HttpQuery::Type::POST; + query_->type_ = HttpQuery::Type::Post; } else if (type.size() >= 4 && type.substr(0, 4) == "HTTP") { - if (type == "HTTP/1.1" || type == "HTTP/1.0") { - query_->type_ = HttpQuery::Type::RESPONSE; + if (type == "HTTP/1.1") { + query_->type_ = HttpQuery::Type::Response; + query_->keep_alive_ = true; + } else if (type == "HTTP/1.0") { + query_->type_ = HttpQuery::Type::Response; + query_->keep_alive_ = false; } else { LOG(INFO) << "Unsupported HTTP version: " << type; return Status::Error(505, "HTTP Version Not Supported"); @@ -662,10 +723,11 @@ Status HttpReader::parse_head(MutableSlice head) { query_->args_.clear(); - if (query_->type_ == HttpQuery::Type::RESPONSE) { + if (query_->type_ == HttpQuery::Type::Response) { query_->code_ = to_integer<int32>(parser.read_till(' ')); parser.skip(' '); query_->reason_ = parser.read_till('\r'); + LOG(DEBUG) << "Receive HTTP response " << query_->code_ << " " << query_->reason_; } else { auto url_version = parser.read_till('\r'); auto space_pos = url_version.rfind(' '); @@ -685,12 +747,11 @@ Status HttpReader::parse_head(MutableSlice head) { parser.skip('\n'); content_length_ = 0; - content_type_ = "application/octet-stream"; + content_type_ = Slice("application/octet-stream"); content_type_lowercased_ = content_type_.str(); - transfer_encoding_ = ""; - content_encoding_ = ""; + transfer_encoding_ = Slice(); + content_encoding_ = Slice(); - query_->keep_alive_ = false; query_->headers_.clear(); query_->files_.clear(); query_->content_ = MutableSlice(); @@ -771,24 +832,26 @@ Status HttpReader::try_open_temp_file(Slice directory_name, CSlice desired_file_ Status HttpReader::save_file_part(BufferSlice &&file_part) { file_size_ += narrow_cast<int64>(file_part.size()); if (file_size_ > MAX_FILE_SIZE) { - string file_name = temp_file_name_; - close_temp_file(); - delete_temp_file(file_name); + clean_temporary_file(); return Status::Error( - 413, PSLICE() << "Request Entity Too Large: file is too big to be uploaded " << tag("size", file_size_)); + 413, PSLICE() << "Request Entity Too Large: file of size " << file_size_ << " is too big to be uploaded"); } LOG(DEBUG) << "Save file part of size " << file_part.size() << " to file " << temp_file_name_; auto result_written = temp_file_.write(file_part.as_slice()); if (result_written.is_error() || result_written.ok() != file_part.size()) { - string file_name = temp_file_name_; - close_temp_file(); - delete_temp_file(file_name); - return Status::Error(500, "Internal server error: can't upload the file"); + clean_temporary_file(); + return Status::Error(500, "Internal Server Error: can't upload the file"); } return Status::OK(); } +void HttpReader::clean_temporary_file() { + string file_name = temp_file_name_; + close_temp_file(); + delete_temp_file(file_name); +} + void HttpReader::close_temp_file() { LOG(DEBUG) << "Close temporary file " << temp_file_name_; CHECK(!temp_file_.empty()); @@ -807,7 +870,7 @@ void HttpReader::delete_temp_file(CSlice file_name) { if (parent.size() >= prefix_length + 7 && parent.substr(parent.size() - prefix_length - 7, prefix_length) == TEMP_DIRECTORY_PREFIX) { LOG(DEBUG) << "Unlink temporary directory " << parent; - rmdir(Slice(parent.data(), parent.size() - 1).str()).ignore(); + rmdir(PSLICE() << Slice(parent.data(), parent.size() - 1)).ignore(); } } diff --git a/protocols/Telegram/tdlib/td/tdnet/td/net/HttpReader.h b/protocols/Telegram/tdlib/td/tdnet/td/net/HttpReader.h index 74067d1291..3851e74671 100644 --- a/protocols/Telegram/tdlib/td/tdnet/td/net/HttpReader.h +++ b/protocols/Telegram/tdlib/td/tdnet/td/net/HttpReader.h @@ -1,5 +1,5 @@ // -// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018 +// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022 // // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) @@ -27,7 +27,7 @@ class HttpReader { public: void init(ChainBufferReader *input, size_t max_post_size = std::numeric_limits<size_t>::max(), size_t max_files = 100); - Result<size_t> read_next(HttpQuery *query) TD_WARN_UNUSED_RESULT; // TODO move query to init + Result<size_t> read_next(HttpQuery *query, bool can_be_slow = true) TD_WARN_UNUSED_RESULT; // TODO move query to init HttpReader() = default; HttpReader(const HttpReader &other) = delete; @@ -36,57 +36,60 @@ class HttpReader { HttpReader &operator=(HttpReader &&other) = delete; ~HttpReader() { if (!temp_file_.empty()) { - temp_file_.close(); + clean_temporary_file(); } } static void delete_temp_file(CSlice file_name); private: - size_t max_post_size_; - size_t max_files_; + size_t max_post_size_ = 0; + size_t max_files_ = 0; - enum { ReadHeaders, ReadContent, ReadContentToFile, ReadArgs, ReadMultipartFormData } state_; - size_t headers_read_length_; - size_t content_length_; - ChainBufferReader *input_; + enum class State { ReadHeaders, ReadContent, ReadContentToFile, ReadArgs, ReadMultipartFormData }; + State state_ = State::ReadHeaders; + size_t headers_read_length_ = 0; + size_t content_length_ = 0; + ChainBufferReader *input_ = nullptr; ByteFlowSource flow_source_; HttpChunkedByteFlow chunked_flow_; GzipByteFlow gzip_flow_; HttpContentLengthByteFlow content_length_flow_; ByteFlowSink flow_sink_; - ChainBufferReader *content_; + ChainBufferReader *content_ = nullptr; - HttpQuery *query_; + HttpQuery *query_ = nullptr; Slice transfer_encoding_; Slice content_encoding_; Slice content_type_; string content_type_lowercased_; - size_t total_parameters_length_; - size_t total_headers_length_; + size_t total_parameters_length_ = 0; + size_t total_headers_length_ = 0; string boundary_; - size_t form_data_read_length_; - size_t form_data_skipped_length_; - enum { + size_t form_data_read_length_ = 0; + size_t form_data_skipped_length_ = 0; + enum class FormDataParseState : int32 { SkipPrologue, ReadPartHeaders, ReadPartValue, ReadFile, CheckForLastBoundary, SkipEpilogue - } form_data_parse_state_; + }; + FormDataParseState form_data_parse_state_ = FormDataParseState::SkipPrologue; MutableSlice field_name_; string file_field_name_; string field_content_type_; string file_name_; + bool has_file_name_ = false; FileFd temp_file_; string temp_file_name_; - int64 file_size_; + int64 file_size_ = 0; Result<size_t> split_header() TD_WARN_UNUSED_RESULT; void process_header(MutableSlice header_name, MutableSlice header_value); - Result<bool> parse_multipart_form_data() TD_WARN_UNUSED_RESULT; + Result<bool> parse_multipart_form_data(bool can_be_slow) TD_WARN_UNUSED_RESULT; Status parse_url(MutableSlice url) TD_WARN_UNUSED_RESULT; Status parse_parameters(MutableSlice parameters) TD_WARN_UNUSED_RESULT; Status parse_json_parameters(MutableSlice parameters) TD_WARN_UNUSED_RESULT; @@ -96,12 +99,13 @@ class HttpReader { Status try_open_temp_file(Slice directory_name, CSlice desired_file_name) TD_WARN_UNUSED_RESULT; Status save_file_part(BufferSlice &&file_part) TD_WARN_UNUSED_RESULT; void close_temp_file(); + void clean_temporary_file(); - static constexpr size_t MAX_CONTENT_SIZE = 150 << 20; // Some reasonable limit - static constexpr size_t MAX_TOTAL_PARAMETERS_LENGTH = 1 << 16; // Some reasonable limit - static constexpr size_t MAX_TOTAL_HEADERS_LENGTH = 1 << 18; // Some reasonable limit - static constexpr size_t MAX_BOUNDARY_LENGTH = 70; // As defined by RFC1341 - static constexpr int64 MAX_FILE_SIZE = 1500 << 20; // Telegram server file size limit + static constexpr size_t MAX_CONTENT_SIZE = std::numeric_limits<uint32>::max(); // Some reasonable limit + static constexpr size_t MAX_TOTAL_PARAMETERS_LENGTH = 1 << 20; // Some reasonable limit + static constexpr size_t MAX_TOTAL_HEADERS_LENGTH = 1 << 18; // Some reasonable limit + static constexpr size_t MAX_BOUNDARY_LENGTH = 70; // As defined by RFC1341 + static constexpr int64 MAX_FILE_SIZE = static_cast<int64>(4000) << 20; // Telegram server file size limit static constexpr const char TEMP_DIRECTORY_PREFIX[] = "tdlib-server-tmp"; }; diff --git a/protocols/Telegram/tdlib/td/tdnet/td/net/NetStats.h b/protocols/Telegram/tdlib/td/tdnet/td/net/NetStats.h index e67f9fbc93..325196b34c 100644 --- a/protocols/Telegram/tdlib/td/tdnet/td/net/NetStats.h +++ b/protocols/Telegram/tdlib/td/tdnet/td/net/NetStats.h @@ -1,5 +1,5 @@ // -// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018 +// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022 // // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) @@ -10,7 +10,6 @@ #include "td/utils/common.h" #include "td/utils/format.h" -#include "td/utils/logging.h" #include "td/utils/StringBuilder.h" #include "td/utils/Time.h" @@ -87,12 +86,12 @@ class NetStats { } // do it before get_callback - void set_callback(std::unique_ptr<Callback> callback) { + void set_callback(unique_ptr<Callback> callback) { impl_->set_callback(std::move(callback)); } private: - class Impl : public NetStatsCallback { + class Impl final : public NetStatsCallback { public: NetStatsData get_stats() const { NetStatsData res; @@ -102,7 +101,7 @@ class NetStats { }); return res; } - void set_callback(std::unique_ptr<Callback> callback) { + void set_callback(unique_ptr<Callback> callback) { callback_ = std::move(callback); } @@ -114,7 +113,7 @@ class NetStats { std::atomic<uint64> write_size{0}; }; SchedulerLocalStorage<LocalNetStats> local_net_stats_; - std::unique_ptr<Callback> callback_; + unique_ptr<Callback> callback_; void on_read(uint64 size) final { auto &stats = local_net_stats_.get(); @@ -131,8 +130,8 @@ class NetStats { void on_change(LocalNetStats &stats, uint64 size) { stats.unsync_size += size; - auto now = Time::now_cached(); - if (stats.unsync_size > 10000 || now - stats.last_update > 5 * 60) { + auto now = Time::now(); + if (stats.unsync_size > 10000 || now - stats.last_update > 300) { stats.unsync_size = 0; stats.last_update = now; callback_->on_stats_updated(); diff --git a/protocols/Telegram/tdlib/td/tdnet/td/net/Socks5.cpp b/protocols/Telegram/tdlib/td/tdnet/td/net/Socks5.cpp index 02e1e067ea..04c52278e8 100644 --- a/protocols/Telegram/tdlib/td/tdnet/td/net/Socks5.cpp +++ b/protocols/Telegram/tdlib/td/tdnet/td/net/Socks5.cpp @@ -1,67 +1,21 @@ // -// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018 +// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022 // // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) // #include "td/net/Socks5.h" -#include "td/utils/format.h" +#include "td/utils/common.h" #include "td/utils/logging.h" #include "td/utils/misc.h" -#include "td/utils/port/Fd.h" #include "td/utils/Slice.h" +#include "td/utils/SliceBuilder.h" namespace td { -static int VERBOSITY_NAME(socks5) = VERBOSITY_NAME(DEBUG); - -Socks5::Socks5(SocketFd socket_fd, IPAddress ip_address, string username, string password, - std::unique_ptr<Callback> callback, ActorShared<> parent) - : fd_(std::move(socket_fd)) - , ip_address_(std::move(ip_address)) - , username_(std::move(username)) - , password_(std::move(password)) - , callback_(std::move(callback)) - , parent_(std::move(parent)) { -} - -void Socks5::on_error(Status status) { - CHECK(status.is_error()); - VLOG(socks5) << "Receive " << status; - if (callback_) { - callback_->set_result(std::move(status)); - callback_.reset(); - } - stop(); -} - -void Socks5::tear_down() { - VLOG(socks5) << "Finish to connect to proxy"; - unsubscribe(fd_.get_fd()); - fd_.get_fd().set_observer(nullptr); - if (callback_) { - callback_->set_result(std::move(fd_)); - callback_.reset(); - } -} - -void Socks5::hangup() { - on_error(Status::Error("Cancelled")); -} - -void Socks5::start_up() { - VLOG(socks5) << "Begin to connect to proxy"; - fd_.get_fd().set_observer(this); - subscribe(fd_.get_fd()); - set_timeout_in(10); - if (can_write(fd_)) { - loop(); - } -} - void Socks5::send_greeting() { - VLOG(socks5) << "Send greeting to proxy"; + VLOG(proxy) << "Send greeting to proxy"; CHECK(state_ == State::SendGreeting); state_ = State::WaitGreetingResponse; @@ -80,18 +34,17 @@ void Socks5::send_greeting() { Status Socks5::wait_greeting_response() { auto &buf = fd_.input_buffer(); - VLOG(socks5) << "Receive greeting response of size " << buf.size(); + VLOG(proxy) << "Receive greeting response of size " << buf.size(); if (buf.size() < 2) { return Status::OK(); } auto buffer_slice = buf.read_as_buffer_slice(2); auto slice = buffer_slice.as_slice(); if (slice[0] != '\x05') { - return Status::Error(PSLICE() << "Unsupported socks protocol version " << int(slice[0])); + return Status::Error(PSLICE() << "Unsupported socks protocol version " << static_cast<int>(slice[0])); } auto authentication_method = slice[1]; if (authentication_method == '\0') { - state_ = State::SendIpAddress; send_ip_address(); return Status::OK(); } @@ -102,7 +55,7 @@ Status Socks5::wait_greeting_response() { } Status Socks5::send_username_password() { - VLOG(socks5) << "Send username and password"; + VLOG(proxy) << "Send username and password"; if (username_.size() >= 128) { return Status::Error("Username is too long"); } @@ -124,27 +77,26 @@ Status Socks5::send_username_password() { Status Socks5::wait_password_response() { auto &buf = fd_.input_buffer(); - VLOG(socks5) << "Receive password response of size " << buf.size(); + VLOG(proxy) << "Receive password response of size " << buf.size(); if (buf.size() < 2) { return Status::OK(); } auto buffer_slice = buf.read_as_buffer_slice(2); auto slice = buffer_slice.as_slice(); if (slice[0] != '\x01') { - return Status::Error(PSLICE() << "Unsupported socks subnegotiation protocol version " << int(slice[0])); + return Status::Error(PSLICE() << "Unsupported socks subnegotiation protocol version " + << static_cast<int>(slice[0])); } if (slice[1] != '\x00') { return Status::Error("Wrong username or password"); } - state_ = State::SendIpAddress; send_ip_address(); return Status::OK(); } void Socks5::send_ip_address() { - VLOG(socks5) << "Send IP address"; - CHECK(state_ == State::SendIpAddress); + VLOG(proxy) << "Send IP address"; callback_->on_connected(); string request; request += '\x05'; @@ -152,14 +104,14 @@ void Socks5::send_ip_address() { request += '\x00'; if (ip_address_.is_ipv4()) { request += '\x01'; - auto ipv4 = ip_address_.get_ipv4(); + auto ipv4 = ntohl(ip_address_.get_ipv4()); request += static_cast<char>(ipv4 & 255); request += static_cast<char>((ipv4 >> 8) & 255); request += static_cast<char>((ipv4 >> 16) & 255); request += static_cast<char>((ipv4 >> 24) & 255); } else { request += '\x04'; - request += ip_address_.get_ipv6().str(); + request += ip_address_.get_ipv6(); } auto port = ip_address_.get_port(); request += static_cast<char>((port >> 8) & 255); @@ -171,7 +123,7 @@ void Socks5::send_ip_address() { Status Socks5::wait_ip_address_response() { CHECK(state_ == State::WaitIpAddressResponse); auto it = fd_.input_buffer().clone(); - VLOG(socks5) << "Receive IP address response of size " << it.size(); + VLOG(proxy) << "Receive IP address response of size " << it.size(); if (it.size() < 4) { return Status::OK(); } @@ -183,23 +135,26 @@ Status Socks5::wait_ip_address_response() { } it.advance(1, c_slice); if (c != '\0') { - return Status::Error(PSLICE() << tag("code", c)); + return Status::Error(PSLICE() << "Receive error code " << static_cast<int32>(c) << " from server"); } it.advance(1, c_slice); if (c != '\0') { - return Status::Error("byte must be zero"); + return Status::Error("Byte must be zero"); } it.advance(1, c_slice); + size_t total_size = 6; if (c == '\x01') { if (it.size() < 4) { return Status::OK(); } it.advance(4); + total_size += 4; } else if (c == '\x04') { if (it.size() < 16) { return Status::OK(); } it.advance(16); + total_size += 16; } else { return Status::Error("Invalid response"); } @@ -207,43 +162,29 @@ Status Socks5::wait_ip_address_response() { return Status::OK(); } it.advance(2); + fd_.input_buffer().advance(total_size); stop(); return Status::OK(); } -void Socks5::loop() { - auto status = [&] { - TRY_STATUS(fd_.flush_read()); - switch (state_) { - case State::SendGreeting: - send_greeting(); - break; - case State::WaitGreetingResponse: - TRY_STATUS(wait_greeting_response()); - break; - case State::WaitPasswordResponse: - TRY_STATUS(wait_password_response()); - break; - case State::WaitIpAddressResponse: - TRY_STATUS(wait_ip_address_response()); - break; - case State::SendIpAddress: - case State::Stop: - UNREACHABLE(); - } - TRY_STATUS(fd_.flush_write()); - return Status::OK(); - }(); - if (status.is_error()) { - on_error(std::move(status)); - } - if (can_close(fd_)) { - on_error(Status::Error("Connection closed")); +Status Socks5::loop_impl() { + switch (state_) { + case State::SendGreeting: + send_greeting(); + break; + case State::WaitGreetingResponse: + TRY_STATUS(wait_greeting_response()); + break; + case State::WaitPasswordResponse: + TRY_STATUS(wait_password_response()); + break; + case State::WaitIpAddressResponse: + TRY_STATUS(wait_ip_address_response()); + break; + default: + UNREACHABLE(); } -} - -void Socks5::timeout_expired() { - on_error(Status::Error("Timeout expired")); + return Status::OK(); } } // namespace td diff --git a/protocols/Telegram/tdlib/td/tdnet/td/net/Socks5.h b/protocols/Telegram/tdlib/td/tdnet/td/net/Socks5.h index b67a33c282..e438fb8d90 100644 --- a/protocols/Telegram/tdlib/td/tdnet/td/net/Socks5.h +++ b/protocols/Telegram/tdlib/td/tdnet/td/net/Socks5.h @@ -1,58 +1,27 @@ // -// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018 +// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022 // // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) // #pragma once -#include "td/actor/actor.h" -#include "td/actor/PromiseFuture.h" +#include "td/net/TransparentProxy.h" -#include "td/utils/BufferedFd.h" -#include "td/utils/common.h" -#include "td/utils/port/IPAddress.h" -#include "td/utils/port/SocketFd.h" #include "td/utils/Status.h" namespace td { -class Socks5 : public Actor { +class Socks5 final : public TransparentProxy { public: - class Callback { - public: - Callback() = default; - Callback(const Callback &) = delete; - Callback &operator=(const Callback &) = delete; - virtual ~Callback() = default; - - virtual void set_result(Result<SocketFd>) = 0; - virtual void on_connected() = 0; - }; - - Socks5(SocketFd socket_fd, IPAddress ip_address, string username, string password, std::unique_ptr<Callback> callback, - ActorShared<> parent); + using TransparentProxy::TransparentProxy; private: - BufferedFd<SocketFd> fd_; - IPAddress ip_address_; - string username_; - string password_; - std::unique_ptr<Callback> callback_; - ActorShared<> parent_; - - void on_error(Status status); - void tear_down() override; - void start_up() override; - void hangup() override; - enum class State { SendGreeting, WaitGreetingResponse, WaitPasswordResponse, - SendIpAddress, - WaitIpAddressResponse, - Stop + WaitIpAddressResponse } state_ = State::SendGreeting; void send_greeting(); @@ -64,8 +33,7 @@ class Socks5 : public Actor { void send_ip_address(); Status wait_ip_address_response(); - void loop() override; - void timeout_expired() override; + Status loop_impl() final; }; } // namespace td diff --git a/protocols/Telegram/tdlib/td/tdnet/td/net/SslCtx.cpp b/protocols/Telegram/tdlib/td/tdnet/td/net/SslCtx.cpp new file mode 100644 index 0000000000..4ec0412e66 --- /dev/null +++ b/protocols/Telegram/tdlib/td/tdnet/td/net/SslCtx.cpp @@ -0,0 +1,312 @@ +// +// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022 +// +// Distributed under the Boost Software License, Version 1.0. (See accompanying +// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) +// +#include "td/net/SslCtx.h" + +#include "td/utils/common.h" +#include "td/utils/crypto.h" +#include "td/utils/FlatHashMap.h" +#include "td/utils/logging.h" +#include "td/utils/port/wstring_convert.h" +#include "td/utils/SliceBuilder.h" +#include "td/utils/Time.h" + +#if !TD_EMSCRIPTEN +#include <openssl/err.h> +#include <openssl/ssl.h> +#include <openssl/x509.h> + +#include <cstring> +#include <memory> +#include <mutex> + +#if TD_PORT_WINDOWS +#include <wincrypt.h> +#endif + +namespace td { + +namespace detail { +namespace { +int verify_callback(int preverify_ok, X509_STORE_CTX *ctx) { + if (!preverify_ok) { + char buf[256]; + X509_NAME_oneline(X509_get_subject_name(X509_STORE_CTX_get_current_cert(ctx)), buf, 256); + + int err = X509_STORE_CTX_get_error(ctx); + auto warning = PSTRING() << "verify error:num=" << err << ":" << X509_verify_cert_error_string(err) + << ":depth=" << X509_STORE_CTX_get_error_depth(ctx) << ":" << Slice(buf, std::strlen(buf)); + double now = Time::now(); + + static std::mutex warning_mutex; + { + std::lock_guard<std::mutex> lock(warning_mutex); + static FlatHashMap<string, double> next_warning_time; + double &next = next_warning_time[warning]; + if (next <= now) { + next = now + 300; // one warning per 5 minutes + LOG(WARNING) << warning; + } + } + } + + return preverify_ok; +} + +using SslCtxPtr = std::shared_ptr<SSL_CTX>; + +Result<SslCtxPtr> do_create_ssl_ctx(CSlice cert_file, SslCtx::VerifyPeer verify_peer) { + auto ssl_method = +#if OPENSSL_VERSION_NUMBER >= 0x10100000L + TLS_client_method(); +#else + SSLv23_client_method(); +#endif + if (ssl_method == nullptr) { + return create_openssl_error(-6, "Failed to create an SSL client method"); + } + auto ssl_ctx = SSL_CTX_new(ssl_method); + if (!ssl_ctx) { + return create_openssl_error(-7, "Failed to create an SSL context"); + } + auto ssl_ctx_ptr = SslCtxPtr(ssl_ctx, SSL_CTX_free); + long options = 0; +#ifdef SSL_OP_NO_SSLv2 + options |= SSL_OP_NO_SSLv2; +#endif +#ifdef SSL_OP_NO_SSLv3 + options |= SSL_OP_NO_SSLv3; +#endif + SSL_CTX_set_options(ssl_ctx, options); +#if OPENSSL_VERSION_NUMBER >= 0x10100000L + SSL_CTX_set_min_proto_version(ssl_ctx, TLS1_VERSION); +#endif + SSL_CTX_set_mode(ssl_ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER | SSL_MODE_ENABLE_PARTIAL_WRITE); + + if (cert_file.empty()) { +#if TD_PORT_WINDOWS + LOG(DEBUG) << "Begin to load system store"; + auto flags = CERT_STORE_OPEN_EXISTING_FLAG | CERT_STORE_READONLY_FLAG | CERT_SYSTEM_STORE_CURRENT_USER; + HCERTSTORE system_store = + CertOpenStore(CERT_STORE_PROV_SYSTEM_W, X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, HCRYPTPROV_LEGACY(), flags, + static_cast<const void *>(to_wstring("ROOT").ok().c_str())); + + if (system_store) { + X509_STORE *store = X509_STORE_new(); + + for (PCCERT_CONTEXT cert_context = CertEnumCertificatesInStore(system_store, nullptr); cert_context != nullptr; + cert_context = CertEnumCertificatesInStore(system_store, cert_context)) { + const unsigned char *in = cert_context->pbCertEncoded; + X509 *x509 = d2i_X509(nullptr, &in, static_cast<long>(cert_context->cbCertEncoded)); + if (x509 != nullptr) { + if (X509_STORE_add_cert(store, x509) != 1) { + auto error_code = ERR_peek_error(); + auto error = create_openssl_error(-20, "Failed to add certificate"); + if (ERR_GET_REASON(error_code) != X509_R_CERT_ALREADY_IN_HASH_TABLE) { + LOG(ERROR) << error; + } else { + LOG(INFO) << error; + } + } + + X509_free(x509); + } else { + LOG(ERROR) << create_openssl_error(-21, "Failed to load X509 certificate"); + } + } + + CertCloseStore(system_store, 0); + + SSL_CTX_set_cert_store(ssl_ctx, store); + LOG(DEBUG) << "End to load system store"; + } else { + LOG(ERROR) << create_openssl_error(-22, "Failed to open system certificate store"); + } +#else + if (SSL_CTX_set_default_verify_paths(ssl_ctx) == 0) { + auto error = create_openssl_error(-8, "Failed to load default verify paths"); + if (verify_peer == SslCtx::VerifyPeer::On) { + return std::move(error); + } else { + LOG(ERROR) << error; + } + } +#endif + } else { + if (SSL_CTX_load_verify_locations(ssl_ctx, cert_file.c_str(), nullptr) == 0) { + return create_openssl_error(-8, "Failed to set custom certificate file"); + } + } + + if (verify_peer == SslCtx::VerifyPeer::On) { + SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_PEER, verify_callback); + + constexpr int DEFAULT_VERIFY_DEPTH = 10; + SSL_CTX_set_verify_depth(ssl_ctx, DEFAULT_VERIFY_DEPTH); + } else { + SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_NONE, nullptr); + } + + string cipher_list; + if (SSL_CTX_set_cipher_list(ssl_ctx, cipher_list.empty() ? "DEFAULT" : cipher_list.c_str()) == 0) { + return create_openssl_error(-9, PSLICE() << "Failed to set cipher list \"" << cipher_list << '"'); + } + + return std::move(ssl_ctx_ptr); +} + +Result<SslCtxPtr> get_default_ssl_ctx() { + static auto ctx = do_create_ssl_ctx(CSlice(), SslCtx::VerifyPeer::On); + if (ctx.is_error()) { + return ctx.error().clone(); + } + + return ctx.ok(); +} + +Result<SslCtxPtr> get_default_unverified_ssl_ctx() { + static auto ctx = do_create_ssl_ctx(CSlice(), SslCtx::VerifyPeer::Off); + if (ctx.is_error()) { + return ctx.error().clone(); + } + + return ctx.ok(); +} + +} // namespace + +class SslCtxImpl { + public: + Status init(CSlice cert_file, SslCtx::VerifyPeer verify_peer) { + SslCtx::init_openssl(); + + clear_openssl_errors("Before SslCtx::init"); + + if (cert_file.empty()) { + if (verify_peer == SslCtx::VerifyPeer::On) { + TRY_RESULT_ASSIGN(ssl_ctx_ptr_, get_default_ssl_ctx()); + } else { + TRY_RESULT_ASSIGN(ssl_ctx_ptr_, get_default_unverified_ssl_ctx()); + } + return Status::OK(); + } + + auto start_time = Time::now(); + auto r_ssl_ctx_ptr = do_create_ssl_ctx(cert_file, verify_peer); + auto elapsed_time = Time::now() - start_time; + if (elapsed_time >= 0.1) { + LOG(ERROR) << "SSL context creation took " << elapsed_time << " seconds"; + } + if (r_ssl_ctx_ptr.is_error()) { + return r_ssl_ctx_ptr.move_as_error(); + } + ssl_ctx_ptr_ = r_ssl_ctx_ptr.move_as_ok(); + return Status::OK(); + } + + void *get_openssl_ctx() const { + return static_cast<void *>(ssl_ctx_ptr_.get()); + } + + private: + SslCtxPtr ssl_ctx_ptr_; +}; + +} // namespace detail + +SslCtx::SslCtx() = default; + +SslCtx::SslCtx(const SslCtx &other) { + if (other.impl_) { + impl_ = make_unique<detail::SslCtxImpl>(*other.impl_); + } +} + +SslCtx &SslCtx::operator=(const SslCtx &other) { + if (other.impl_) { + impl_ = make_unique<detail::SslCtxImpl>(*other.impl_); + } else { + impl_ = nullptr; + } + return *this; +} + +SslCtx::SslCtx(SslCtx &&) noexcept = default; + +SslCtx &SslCtx::operator=(SslCtx &&) noexcept = default; + +SslCtx::~SslCtx() = default; + +void SslCtx::init_openssl() { + static bool is_inited = [] { +#if OPENSSL_VERSION_NUMBER >= 0x10100000L + return OPENSSL_init_ssl(0, nullptr) != 0; +#else + OpenSSL_add_all_algorithms(); + SSL_load_error_strings(); + return OpenSSL_add_ssl_algorithms() != 0; +#endif + }(); + CHECK(is_inited); +} + +Result<SslCtx> SslCtx::create(CSlice cert_file, VerifyPeer verify_peer) { + auto impl = make_unique<detail::SslCtxImpl>(); + TRY_STATUS(impl->init(cert_file, verify_peer)); + return SslCtx(std::move(impl)); +} + +void *SslCtx::get_openssl_ctx() const { + return impl_ == nullptr ? nullptr : impl_->get_openssl_ctx(); +} + +SslCtx::SslCtx(unique_ptr<detail::SslCtxImpl> impl) : impl_(std::move(impl)) { +} + +} // namespace td + +#else + +namespace td { + +namespace detail { +class SslCtxImpl {}; +} // namespace detail + +SslCtx::SslCtx() = default; + +SslCtx::SslCtx(const SslCtx &other) { + UNREACHABLE(); +} + +SslCtx &SslCtx::operator=(const SslCtx &other) { + UNREACHABLE(); + return *this; +} + +SslCtx::SslCtx(SslCtx &&) noexcept = default; + +SslCtx &SslCtx::operator=(SslCtx &&) noexcept = default; + +SslCtx::~SslCtx() = default; + +void SslCtx::init_openssl() { +} + +Result<SslCtx> SslCtx::create(CSlice cert_file, VerifyPeer verify_peer) { + return Status::Error("Not supported in Emscripten"); +} + +void *SslCtx::get_openssl_ctx() const { + return nullptr; +} + +SslCtx::SslCtx(unique_ptr<detail::SslCtxImpl> impl) : impl_(std::move(impl)) { +} + +} // namespace td + +#endif diff --git a/protocols/Telegram/tdlib/td/tdnet/td/net/SslCtx.h b/protocols/Telegram/tdlib/td/tdnet/td/net/SslCtx.h new file mode 100644 index 0000000000..5cd55925e5 --- /dev/null +++ b/protocols/Telegram/tdlib/td/tdnet/td/net/SslCtx.h @@ -0,0 +1,45 @@ +// +// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022 +// +// Distributed under the Boost Software License, Version 1.0. (See accompanying +// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) +// +#pragma once + +#include "td/utils/Slice.h" +#include "td/utils/Status.h" + +namespace td { + +namespace detail { +class SslCtxImpl; +} // namespace detail + +class SslCtx { + public: + SslCtx(); + SslCtx(const SslCtx &other); + SslCtx &operator=(const SslCtx &other); + SslCtx(SslCtx &&) noexcept; + SslCtx &operator=(SslCtx &&) noexcept; + ~SslCtx(); + + static void init_openssl(); + + enum class VerifyPeer { On, Off }; + + static Result<SslCtx> create(CSlice cert_file, VerifyPeer verify_peer); + + void *get_openssl_ctx() const; + + explicit operator bool() const noexcept { + return static_cast<bool>(impl_); + } + + private: + unique_ptr<detail::SslCtxImpl> impl_; + + explicit SslCtx(unique_ptr<detail::SslCtxImpl> impl); +}; + +} // namespace td diff --git a/protocols/Telegram/tdlib/td/tdnet/td/net/SslFd.cpp b/protocols/Telegram/tdlib/td/tdnet/td/net/SslFd.cpp deleted file mode 100644 index f6f7557235..0000000000 --- a/protocols/Telegram/tdlib/td/tdnet/td/net/SslFd.cpp +++ /dev/null @@ -1,280 +0,0 @@ -// -// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018 -// -// Distributed under the Boost Software License, Version 1.0. (See accompanying -// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) -// -#include "td/net/SslFd.h" - -#include "td/utils/logging.h" -#include "td/utils/StackAllocator.h" -#include "td/utils/StringBuilder.h" -#include "td/utils/Time.h" - -#include <openssl/err.h> -#include <openssl/evp.h> -#include <openssl/ssl.h> -#include <openssl/x509v3.h> - -#include <map> -#include <mutex> - -namespace td { - -#if !TD_WINDOWS -static int verify_callback(int preverify_ok, X509_STORE_CTX *ctx) { - if (!preverify_ok) { - char buf[256]; - X509_NAME_oneline(X509_get_subject_name(X509_STORE_CTX_get_current_cert(ctx)), buf, 256); - - int err = X509_STORE_CTX_get_error(ctx); - auto warning = PSTRING() << "verify error:num=" << err << ":" << X509_verify_cert_error_string(err) - << ":depth=" << X509_STORE_CTX_get_error_depth(ctx) << ":" << buf; - double now = Time::now(); - - static std::mutex warning_mutex; - { - std::lock_guard<std::mutex> lock(warning_mutex); - static std::map<std::string, double> next_warning_time; - double &next = next_warning_time[warning]; - if (next <= now) { - next = now + 300; // one warning per 5 minutes - LOG(WARNING) << warning; - } - } - } - - return preverify_ok; -} -#endif - -namespace { - -Status create_openssl_error(int code, Slice message) { - const int buf_size = 1 << 12; - auto buf = StackAllocator::alloc(buf_size); - StringBuilder sb(buf.as_slice()); - - sb << message; - while (unsigned long error_code = ERR_get_error()) { - sb << "{" << error_code << ", " << ERR_error_string(error_code, nullptr) << "}"; - } - LOG_IF(ERROR, sb.is_error()) << "OPENSSL error buffer overflow"; - return Status::Error(code, sb.as_cslice()); -} - -void openssl_clear_errors(Slice from) { - if (ERR_peek_error() != 0) { - LOG(ERROR) << from << ": " << create_openssl_error(0, "Unprocessed OPENSSL_ERROR"); - } - errno = 0; -} - -void do_ssl_shutdown(SSL *ssl_handle) { - if (!SSL_is_init_finished(ssl_handle)) { - return; - } - openssl_clear_errors("Before SSL_shutdown"); - SSL_set_quiet_shutdown(ssl_handle, 1); - SSL_shutdown(ssl_handle); - openssl_clear_errors("After SSL_shutdown"); -} - -} // namespace - -SslFd::SslFd(SocketFd &&fd, SSL *ssl_handle_, SSL_CTX *ssl_ctx_) - : fd_(std::move(fd)), ssl_handle_(ssl_handle_), ssl_ctx_(ssl_ctx_) { -} - -Result<SslFd> SslFd::init(SocketFd fd, CSlice host, CSlice cert_file, VerifyPeer verify_peer) { -#if TD_WINDOWS - return Status::Error("TODO"); -#else - static bool init_openssl = [] { -#if OPENSSL_VERSION_NUMBER >= 0x10100000L - return OPENSSL_init_ssl(0, nullptr) != 0; -#else - OpenSSL_add_all_algorithms(); - SSL_load_error_strings(); - return OpenSSL_add_ssl_algorithms() != 0; -#endif - }(); - CHECK(init_openssl); - - openssl_clear_errors("Before SslFd::init"); - CHECK(!fd.empty()); - - auto ssl_method = -#if OPENSSL_VERSION_NUMBER >= 0x10100000L - TLS_client_method(); -#else - SSLv23_client_method(); -#endif - if (ssl_method == nullptr) { - return create_openssl_error(-6, "Failed to create an SSL client method"); - } - - auto ssl_ctx = SSL_CTX_new(ssl_method); - if (ssl_ctx == nullptr) { - return create_openssl_error(-7, "Failed to create an SSL context"); - } - auto ssl_ctx_guard = ScopeExit() + [&]() { SSL_CTX_free(ssl_ctx); }; - long options = 0; -#ifdef SSL_OP_NO_SSLv2 - options |= SSL_OP_NO_SSLv2; -#endif -#ifdef SSL_OP_NO_SSLv3 - options |= SSL_OP_NO_SSLv3; -#endif - SSL_CTX_set_options(ssl_ctx, options); - SSL_CTX_set_mode(ssl_ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER | SSL_MODE_ENABLE_PARTIAL_WRITE); - - if (cert_file.empty()) { - SSL_CTX_set_default_verify_paths(ssl_ctx); - } else { - if (SSL_CTX_load_verify_locations(ssl_ctx, cert_file.c_str(), nullptr) == 0) { - return create_openssl_error(-8, "Failed to set custom cert file"); - } - } - if (VERIFY_PEER && verify_peer == VerifyPeer::On) { - SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_PEER, verify_callback); - - if (VERIFY_DEPTH != -1) { - SSL_CTX_set_verify_depth(ssl_ctx, VERIFY_DEPTH); - } - } else { - SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_NONE, nullptr); - } - - // TODO(now): cipher list - string cipher_list; - if (SSL_CTX_set_cipher_list(ssl_ctx, cipher_list.empty() ? "DEFAULT" : cipher_list.c_str()) == 0) { - return create_openssl_error(-9, PSLICE() << "Failed to set cipher list \"" << cipher_list << '"'); - } - - auto ssl_handle = SSL_new(ssl_ctx); - if (ssl_handle == nullptr) { - return create_openssl_error(-13, "Failed to create an SSL handle"); - } - auto ssl_handle_guard = ScopeExit() + [&]() { - do_ssl_shutdown(ssl_handle); - SSL_free(ssl_handle); - }; - -#if OPENSSL_VERSION_NUMBER >= 0x10002000L - X509_VERIFY_PARAM *param = SSL_get0_param(ssl_handle); - /* Enable automatic hostname checks */ - // TODO: X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS - X509_VERIFY_PARAM_set_hostflags(param, 0); - X509_VERIFY_PARAM_set1_host(param, host.c_str(), 0); -#else -#warning DANGEROUS! HTTPS HOST WILL NOT BE CHECKED. INSTALL OPENSSL >= 1.0.2 OR IMPLEMENT HTTPS HOST CHECK MANUALLY -#endif - - if (!SSL_set_fd(ssl_handle, fd.get_fd().get_native_fd())) { - return create_openssl_error(-14, "Failed to set fd"); - } - -#if OPENSSL_VERSION_NUMBER >= 0x0090806fL && !defined(OPENSSL_NO_TLSEXT) - auto host_str = host.str(); - SSL_set_tlsext_host_name(ssl_handle, MutableCSlice(host_str).begin()); -#endif - SSL_set_connect_state(ssl_handle); - - ssl_ctx_guard.dismiss(); - ssl_handle_guard.dismiss(); - return SslFd(std::move(fd), ssl_handle, ssl_ctx); -#endif -} - -Result<size_t> SslFd::process_ssl_error(int ret, int *mask) { -#if TD_WINDOWS - return Status::Error("TODO"); -#else - auto openssl_errno = errno; - int error = SSL_get_error(ssl_handle_, ret); - LOG(INFO) << "SSL ERROR: " << ret << " " << error; - switch (error) { - case SSL_ERROR_NONE: - LOG(ERROR) << "SSL_get_error returned no error"; - return 0; - case SSL_ERROR_ZERO_RETURN: - LOG(DEBUG) << "SSL_ERROR_ZERO_RETURN"; - fd_.get_fd().update_flags(Fd::Close); - write_mask_ |= Fd::Error; - *mask |= Fd::Error; - return 0; - case SSL_ERROR_WANT_READ: - LOG(DEBUG) << "SSL_ERROR_WANT_READ"; - fd_.get_fd().clear_flags(Fd::Read); - *mask |= Fd::Read; - return 0; - case SSL_ERROR_WANT_WRITE: - LOG(DEBUG) << "SSL_ERROR_WANT_WRITE"; - fd_.get_fd().clear_flags(Fd::Write); - *mask |= Fd::Write; - return 0; - case SSL_ERROR_WANT_CONNECT: - case SSL_ERROR_WANT_ACCEPT: - case SSL_ERROR_WANT_X509_LOOKUP: - LOG(DEBUG) << "SSL_ERROR: CONNECT ACCEPT LOOKUP"; - fd_.get_fd().clear_flags(Fd::Write); - *mask |= Fd::Write; - return 0; - case SSL_ERROR_SYSCALL: - LOG(DEBUG) << "SSL_ERROR_SYSCALL"; - if (ERR_peek_error() == 0) { - if (openssl_errno != 0) { - CHECK(openssl_errno != EAGAIN); - return Status::PosixError(openssl_errno, "SSL_ERROR_SYSCALL"); - } else { - // Socket was closed from the other side, probably. Not an error - fd_.get_fd().update_flags(Fd::Close); - write_mask_ |= Fd::Error; - *mask |= Fd::Error; - return 0; - } - } - /* fall through */ - default: - LOG(DEBUG) << "SSL_ERROR Default"; - fd_.get_fd().update_flags(Fd::Close); - write_mask_ |= Fd::Error; - read_mask_ |= Fd::Error; - return create_openssl_error(1, "SSL error "); - } -#endif -} - -Result<size_t> SslFd::write(Slice slice) { - openssl_clear_errors("Before SslFd::write"); - auto size = SSL_write(ssl_handle_, slice.data(), static_cast<int>(slice.size())); - if (size <= 0) { - return process_ssl_error(size, &write_mask_); - } - return size; -} -Result<size_t> SslFd::read(MutableSlice slice) { - openssl_clear_errors("Before SslFd::read"); - auto size = SSL_read(ssl_handle_, slice.data(), static_cast<int>(slice.size())); - if (size <= 0) { - return process_ssl_error(size, &read_mask_); - } - return size; -} - -void SslFd::close() { - if (fd_.empty()) { - CHECK(!ssl_handle_ && !ssl_ctx_); - return; - } - CHECK(ssl_handle_ && ssl_ctx_); - do_ssl_shutdown(ssl_handle_); - SSL_free(ssl_handle_); - ssl_handle_ = nullptr; - SSL_CTX_free(ssl_ctx_); - ssl_ctx_ = nullptr; - fd_.close(); -} - -} // namespace td diff --git a/protocols/Telegram/tdlib/td/tdnet/td/net/SslFd.h b/protocols/Telegram/tdlib/td/tdnet/td/net/SslFd.h deleted file mode 100644 index c197b9c318..0000000000 --- a/protocols/Telegram/tdlib/td/tdnet/td/net/SslFd.h +++ /dev/null @@ -1,109 +0,0 @@ -// -// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018 -// -// Distributed under the Boost Software License, Version 1.0. (See accompanying -// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) -// -#pragma once - -#include "td/utils/port/Fd.h" -#include "td/utils/port/SocketFd.h" -#include "td/utils/Slice.h" -#include "td/utils/Status.h" - -#include <openssl/ssl.h> // TODO can we remove it from header and make target_link_libraries dependence PRIVATE? - -namespace td { - -class SslFd { - public: - enum class VerifyPeer { On, Off }; - static Result<SslFd> init(SocketFd fd, CSlice host, CSlice cert_file = CSlice(), - VerifyPeer verify_peer = VerifyPeer::On) TD_WARN_UNUSED_RESULT; - - SslFd(const SslFd &other) = delete; - SslFd &operator=(const SslFd &other) = delete; - SslFd(SslFd &&other) - : fd_(std::move(other.fd_)) - , write_mask_(other.write_mask_) - , read_mask_(other.read_mask_) - , ssl_handle_(other.ssl_handle_) - , ssl_ctx_(other.ssl_ctx_) { - other.ssl_handle_ = nullptr; - other.ssl_ctx_ = nullptr; - } - SslFd &operator=(SslFd &&other) { - close(); - - fd_ = std::move(other.fd_); - write_mask_ = other.write_mask_; - read_mask_ = other.read_mask_; - ssl_handle_ = other.ssl_handle_; - ssl_ctx_ = other.ssl_ctx_; - - other.ssl_handle_ = nullptr; - other.ssl_ctx_ = nullptr; - return *this; - } - - const Fd &get_fd() const { - return fd_.get_fd(); - } - - Fd &get_fd() { - return fd_.get_fd(); - } - - Status get_pending_error() TD_WARN_UNUSED_RESULT { - return fd_.get_pending_error(); - } - - Result<size_t> write(Slice slice) TD_WARN_UNUSED_RESULT; - Result<size_t> read(MutableSlice slice) TD_WARN_UNUSED_RESULT; - - void close(); - - int32 get_flags() const { - int32 res = 0; - int32 fd_flags = fd_.get_flags(); - fd_flags &= ~Fd::Error; - if (fd_flags & Fd::Close) { - res |= Fd::Close; - } - write_mask_ &= ~fd_flags; - read_mask_ &= ~fd_flags; - if (write_mask_ == 0) { - res |= Fd::Write; - } - if (read_mask_ == 0) { - res |= Fd::Read; - } - return res; - } - - bool empty() const { - return fd_.empty(); - } - - ~SslFd() { - close(); - } - - private: - static constexpr bool VERIFY_PEER = true; - static constexpr int VERIFY_DEPTH = 10; - - SocketFd fd_; - mutable int write_mask_ = 0; - mutable int read_mask_ = 0; - - // TODO unique_ptr - SSL *ssl_handle_ = nullptr; - SSL_CTX *ssl_ctx_ = nullptr; - - SslFd(SocketFd &&fd, SSL *ssl_handle_, SSL_CTX *ssl_ctx_); - - Result<size_t> process_ssl_error(int ret, int *mask) TD_WARN_UNUSED_RESULT; -}; - -} // namespace td diff --git a/protocols/Telegram/tdlib/td/tdnet/td/net/SslStream.cpp b/protocols/Telegram/tdlib/td/tdnet/td/net/SslStream.cpp new file mode 100644 index 0000000000..bede94b0f1 --- /dev/null +++ b/protocols/Telegram/tdlib/td/tdnet/td/net/SslStream.cpp @@ -0,0 +1,420 @@ +// +// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022 +// +// Distributed under the Boost Software License, Version 1.0. (See accompanying +// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) +// +#include "td/net/SslStream.h" + +#if !TD_EMSCRIPTEN +#include "td/utils/common.h" +#include "td/utils/crypto.h" +#include "td/utils/logging.h" +#include "td/utils/misc.h" +#include "td/utils/port/IPAddress.h" +#include "td/utils/Status.h" +#include "td/utils/Time.h" + +#include <openssl/bio.h> +#include <openssl/err.h> +#include <openssl/ssl.h> +#include <openssl/x509.h> + +#include <cstring> +#include <memory> + +namespace td { + +namespace detail { +namespace { +#if OPENSSL_VERSION_NUMBER < 0x10100000L +void *BIO_get_data(BIO *b) { + return b->ptr; +} +void BIO_set_data(BIO *b, void *ptr) { + b->ptr = ptr; +} +void BIO_set_init(BIO *b, int init) { + b->init = init; +} + +int BIO_get_new_index() { + return 0; +} +BIO_METHOD *BIO_meth_new(int type, const char *name) { + auto res = new BIO_METHOD(); + std::memset(res, 0, sizeof(*res)); + return res; +} + +int BIO_meth_set_write(BIO_METHOD *biom, int (*bwrite)(BIO *, const char *, int)) { + biom->bwrite = bwrite; + return 1; +} +int BIO_meth_set_read(BIO_METHOD *biom, int (*bread)(BIO *, char *, int)) { + biom->bread = bread; + return 1; +} +int BIO_meth_set_ctrl(BIO_METHOD *biom, long (*ctrl)(BIO *, int, long, void *)) { + biom->ctrl = ctrl; + return 1; +} +int BIO_meth_set_create(BIO_METHOD *biom, int (*create)(BIO *)) { + biom->create = create; + return 1; +} +int BIO_meth_set_destroy(BIO_METHOD *biom, int (*destroy)(BIO *)) { + biom->destroy = destroy; + return 1; +} +#endif + +int strm_create(BIO *b) { + BIO_set_init(b, 1); + return 1; +} + +int strm_destroy(BIO *b) { + return 1; +} + +int strm_read(BIO *b, char *buf, int len); + +int strm_write(BIO *b, const char *buf, int len); + +long strm_ctrl(BIO *b, int cmd, long num, void *ptr) { + switch (cmd) { + case BIO_CTRL_FLUSH: + return 1; + case BIO_CTRL_PUSH: + case BIO_CTRL_POP: + return 0; +#if OPENSSL_VERSION_NUMBER >= 0x30000000L + case BIO_CTRL_GET_KTLS_SEND: + case BIO_CTRL_GET_KTLS_RECV: + return 0; +#endif + default: + LOG(FATAL) << b << " " << cmd << " " << num << " " << ptr; + } + return 1; +} + +BIO_METHOD *BIO_s_sslstream() { + static BIO_METHOD *result = [] { + BIO_METHOD *res = BIO_meth_new(BIO_get_new_index(), "td::SslStream helper bio"); + BIO_meth_set_write(res, strm_write); + BIO_meth_set_read(res, strm_read); + BIO_meth_set_create(res, strm_create); + BIO_meth_set_destroy(res, strm_destroy); + BIO_meth_set_ctrl(res, strm_ctrl); + return res; + }(); + return result; +} + +struct SslHandleDeleter { + void operator()(SSL *ssl_handle) { + auto start_time = Time::now(); + if (SSL_is_init_finished(ssl_handle)) { + clear_openssl_errors("Before SSL_shutdown"); + SSL_set_quiet_shutdown(ssl_handle, 1); + SSL_shutdown(ssl_handle); + clear_openssl_errors("After SSL_shutdown"); + } + SSL_free(ssl_handle); + auto elapsed_time = Time::now() - start_time; + if (elapsed_time >= 0.1) { + LOG(ERROR) << "SSL_free took " << elapsed_time << " seconds"; + } + } +}; + +using SslHandle = std::unique_ptr<SSL, SslHandleDeleter>; + +} // namespace + +class SslStreamImpl { + public: + Status init(CSlice host, SslCtx ssl_ctx, bool check_ip_address_as_host) { + if (!ssl_ctx) { + return Status::Error("Invalid SSL context provided"); + } + + clear_openssl_errors("Before SslFd::init"); + + auto ssl_handle = SslHandle(SSL_new(static_cast<SSL_CTX *>(ssl_ctx.get_openssl_ctx()))); + if (!ssl_handle) { + return create_openssl_error(-13, "Failed to create an SSL handle"); + } + + auto r_ip_address = IPAddress::get_ip_address(host); + +#if OPENSSL_VERSION_NUMBER >= 0x10002000L + X509_VERIFY_PARAM *param = SSL_get0_param(ssl_handle.get()); + X509_VERIFY_PARAM_set_hostflags(param, 0); + if (r_ip_address.is_ok() && !check_ip_address_as_host) { + LOG(DEBUG) << "Set verification IP address to " << r_ip_address.ok().get_ip_str(); + X509_VERIFY_PARAM_set1_ip_asc(param, r_ip_address.ok().get_ip_str().c_str()); + } else { + LOG(DEBUG) << "Set verification host to " << host; + X509_VERIFY_PARAM_set1_host(param, host.c_str(), 0); + } +#else +#warning DANGEROUS! HTTPS HOST WILL NOT BE CHECKED. INSTALL OPENSSL >= 1.0.2 OR IMPLEMENT HTTPS HOST CHECK MANUALLY +#endif + + auto *bio = BIO_new(BIO_s_sslstream()); + BIO_set_data(bio, static_cast<void *>(this)); + SSL_set_bio(ssl_handle.get(), bio, bio); + +#if OPENSSL_VERSION_NUMBER >= 0x0090806fL && !defined(OPENSSL_NO_TLSEXT) + if (r_ip_address.is_error()) { // IP address must not be send as SNI + LOG(DEBUG) << "Set SNI host name to " << host; + auto host_str = host.str(); + SSL_set_tlsext_host_name(ssl_handle.get(), MutableCSlice(host_str).begin()); + } +#endif + SSL_set_connect_state(ssl_handle.get()); + + ssl_handle_ = std::move(ssl_handle); + + return Status::OK(); + } + + ByteFlowInterface &read_byte_flow() { + return read_flow_; + } + ByteFlowInterface &write_byte_flow() { + return write_flow_; + } + size_t flow_read(MutableSlice slice) { + return read_flow_.read(slice); + } + size_t flow_write(Slice slice) { + return write_flow_.write(slice); + } + + private: + SslHandle ssl_handle_; + + friend class SslReadByteFlow; + friend class SslWriteByteFlow; + + Result<size_t> write(Slice slice) { + clear_openssl_errors("Before SslFd::write"); + auto start_time = Time::now(); + auto size = SSL_write(ssl_handle_.get(), slice.data(), static_cast<int>(slice.size())); + auto elapsed_time = Time::now() - start_time; + if (elapsed_time >= 0.1) { + LOG(ERROR) << "SSL_write of size " << slice.size() << " took " << elapsed_time << " seconds and returned " << size + << ' ' << SSL_get_error(ssl_handle_.get(), size); + } + if (size <= 0) { + return process_ssl_error(size); + } + return size; + } + + Result<size_t> read(MutableSlice slice) { + clear_openssl_errors("Before SslFd::read"); + auto start_time = Time::now(); + auto size = SSL_read(ssl_handle_.get(), slice.data(), static_cast<int>(slice.size())); + auto elapsed_time = Time::now() - start_time; + if (elapsed_time >= 0.1) { + LOG(ERROR) << "SSL_read took " << elapsed_time << " seconds and returned " << size << ' ' + << SSL_get_error(ssl_handle_.get(), size); + } + if (size <= 0) { + return process_ssl_error(size); + } + return size; + } + + class SslReadByteFlow final : public ByteFlowBase { + public: + explicit SslReadByteFlow(SslStreamImpl *stream) : stream_(stream) { + } + bool loop() final { + auto to_read = output_.prepare_append(); + auto r_size = stream_->read(to_read); + if (r_size.is_error()) { + finish(r_size.move_as_error()); + return false; + } + auto size = r_size.move_as_ok(); + if (size == 0) { + return false; + } + output_.confirm_append(size); + return true; + } + + size_t read(MutableSlice data) { + return input_->advance(min(data.size(), input_->size()), data); + } + + private: + SslStreamImpl *stream_; + }; + + class SslWriteByteFlow final : public ByteFlowBase { + public: + explicit SslWriteByteFlow(SslStreamImpl *stream) : stream_(stream) { + } + bool loop() final { + auto to_write = input_->prepare_read(); + auto r_size = stream_->write(to_write); + if (r_size.is_error()) { + finish(r_size.move_as_error()); + return false; + } + auto size = r_size.move_as_ok(); + if (size == 0) { + return false; + } + input_->confirm_read(size); + return true; + } + + size_t write(Slice data) { + output_.append(data); + return data.size(); + } + + private: + SslStreamImpl *stream_; + }; + + SslReadByteFlow read_flow_{this}; + SslWriteByteFlow write_flow_{this}; + + Result<size_t> process_ssl_error(int ret) { + auto os_error = OS_ERROR("SSL_ERROR_SYSCALL"); + int error = SSL_get_error(ssl_handle_.get(), ret); + switch (error) { + case SSL_ERROR_NONE: + LOG(ERROR) << "SSL_get_error returned no error"; + return 0; + case SSL_ERROR_ZERO_RETURN: + LOG(DEBUG) << "SSL_ZERO_RETURN"; + return 0; + case SSL_ERROR_WANT_READ: + LOG(DEBUG) << "SSL_WANT_READ"; + return 0; + case SSL_ERROR_WANT_WRITE: + LOG(DEBUG) << "SSL_WANT_WRITE"; + return 0; + case SSL_ERROR_WANT_CONNECT: + case SSL_ERROR_WANT_ACCEPT: + case SSL_ERROR_WANT_X509_LOOKUP: + LOG(DEBUG) << "SSL: CONNECT ACCEPT LOOKUP"; + return 0; + case SSL_ERROR_SYSCALL: + if (ERR_peek_error() == 0) { + if (os_error.code() != 0) { + LOG(DEBUG) << "SSL_ERROR_SYSCALL"; + return std::move(os_error); + } else { + LOG(DEBUG) << "SSL_SYSCALL"; + return 0; + } + } + /* fallthrough */ + default: + LOG(DEBUG) << "SSL_ERROR Default"; + return create_openssl_error(1, "SSL error "); + } + } +}; + +namespace { +int strm_read(BIO *b, char *buf, int len) { + auto *stream = static_cast<SslStreamImpl *>(BIO_get_data(b)); + CHECK(stream != nullptr); + BIO_clear_retry_flags(b); + CHECK(buf != nullptr); + auto res = narrow_cast<int>(stream->flow_read(MutableSlice(buf, len))); + if (res == 0) { + BIO_set_retry_read(b); + return -1; + } + return res; +} +int strm_write(BIO *b, const char *buf, int len) { + auto *stream = static_cast<SslStreamImpl *>(BIO_get_data(b)); + CHECK(stream != nullptr); + BIO_clear_retry_flags(b); + CHECK(buf != nullptr); + return narrow_cast<int>(stream->flow_write(Slice(buf, len))); +} +} // namespace + +} // namespace detail + +SslStream::SslStream() = default; +SslStream::SslStream(SslStream &&) noexcept = default; +SslStream &SslStream::operator=(SslStream &&) noexcept = default; +SslStream::~SslStream() = default; + +Result<SslStream> SslStream::create(CSlice host, SslCtx ssl_ctx, bool use_ip_address_as_host) { + auto impl = make_unique<detail::SslStreamImpl>(); + TRY_STATUS(impl->init(host, ssl_ctx, use_ip_address_as_host)); + return SslStream(std::move(impl)); +} +SslStream::SslStream(unique_ptr<detail::SslStreamImpl> impl) : impl_(std::move(impl)) { +} +ByteFlowInterface &SslStream::read_byte_flow() { + return impl_->read_byte_flow(); +} +ByteFlowInterface &SslStream::write_byte_flow() { + return impl_->write_byte_flow(); +} +size_t SslStream::flow_read(MutableSlice slice) { + return impl_->flow_read(slice); +} +size_t SslStream::flow_write(Slice slice) { + return impl_->flow_write(slice); +} + +} // namespace td + +#else + +namespace td { + +namespace detail { +class SslStreamImpl {}; +} // namespace detail + +SslStream::SslStream() = default; +SslStream::SslStream(SslStream &&) noexcept = default; +SslStream &SslStream::operator=(SslStream &&) noexcept = default; +SslStream::~SslStream() = default; + +Result<SslStream> SslStream::create(CSlice host, SslCtx ssl_ctx, bool check_ip_address_as_host) { + return Status::Error("Not supported in Emscripten"); +} + +SslStream::SslStream(unique_ptr<detail::SslStreamImpl> impl) : impl_(std::move(impl)) { +} + +ByteFlowInterface &SslStream::read_byte_flow() { + UNREACHABLE(); +} + +ByteFlowInterface &SslStream::write_byte_flow() { + UNREACHABLE(); +} + +size_t SslStream::flow_read(MutableSlice slice) { + UNREACHABLE(); +} + +size_t SslStream::flow_write(Slice slice) { + UNREACHABLE(); +} + +} // namespace td + +#endif diff --git a/protocols/Telegram/tdlib/td/tdnet/td/net/SslStream.h b/protocols/Telegram/tdlib/td/tdnet/td/net/SslStream.h new file mode 100644 index 0000000000..286eb80be3 --- /dev/null +++ b/protocols/Telegram/tdlib/td/tdnet/td/net/SslStream.h @@ -0,0 +1,46 @@ +// +// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022 +// +// Distributed under the Boost Software License, Version 1.0. (See accompanying +// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) +// +#pragma once + +#include "td/net/SslCtx.h" + +#include "td/utils/ByteFlow.h" +#include "td/utils/Slice.h" +#include "td/utils/Status.h" + +namespace td { + +namespace detail { +class SslStreamImpl; +} // namespace detail + +class SslStream { + public: + SslStream(); + SslStream(SslStream &&) noexcept; + SslStream &operator=(SslStream &&) noexcept; + ~SslStream(); + + static Result<SslStream> create(CSlice host, SslCtx ssl_ctx, bool use_ip_address_as_host = false); + + ByteFlowInterface &read_byte_flow(); + ByteFlowInterface &write_byte_flow(); + + size_t flow_read(MutableSlice slice); + size_t flow_write(Slice slice); + + explicit operator bool() const noexcept { + return static_cast<bool>(impl_); + } + + private: + unique_ptr<detail::SslStreamImpl> impl_; + + explicit SslStream(unique_ptr<detail::SslStreamImpl> impl); +}; + +} // namespace td diff --git a/protocols/Telegram/tdlib/td/tdnet/td/net/TcpListener.cpp b/protocols/Telegram/tdlib/td/tdnet/td/net/TcpListener.cpp index 54531f9b60..7a8d280624 100644 --- a/protocols/Telegram/tdlib/td/tdnet/td/net/TcpListener.cpp +++ b/protocols/Telegram/tdlib/td/tdnet/td/net/TcpListener.cpp @@ -1,5 +1,5 @@ // -// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018 +// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022 // // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) @@ -7,11 +7,12 @@ #include "td/net/TcpListener.h" #include "td/utils/logging.h" -#include "td/utils/port/Fd.h" +#include "td/utils/port/detail/PollableFd.h" namespace td { -// TcpListener implementation -TcpListener::TcpListener(int port, ActorShared<Callback> callback) : port_(port), callback_(std::move(callback)) { + +TcpListener::TcpListener(int port, ActorShared<Callback> callback, Slice server_address) + : port_(port), callback_(std::move(callback)), server_address_(server_address.str()) { } void TcpListener::hangup() { @@ -19,21 +20,19 @@ void TcpListener::hangup() { } void TcpListener::start_up() { - auto r_socket = ServerSocketFd::open(port_); + auto r_socket = ServerSocketFd::open(port_, server_address_); if (r_socket.is_error()) { LOG(ERROR) << "Can't open server socket: " << r_socket.error(); set_timeout_in(5); return; } server_fd_ = r_socket.move_as_ok(); - server_fd_.get_fd().set_observer(this); - subscribe(server_fd_.get_fd()); + Scheduler::subscribe(server_fd_.get_poll_info().extract_pollable_fd(this)); } void TcpListener::tear_down() { - LOG(ERROR) << "TcpListener closed"; if (!server_fd_.empty()) { - unsubscribe_before_close(server_fd_.get_fd()); + Scheduler::unsubscribe_before_close(server_fd_.get_poll_info().get_pollable_fd_ref()); server_fd_.close(); } } @@ -41,8 +40,12 @@ void TcpListener::tear_down() { void TcpListener::loop() { if (server_fd_.empty()) { start_up(); + if (server_fd_.empty()) { + return; + } } - while (can_read(server_fd_)) { + sync_with_poll(server_fd_); + while (can_read_local(server_fd_)) { auto r_socket_fd = server_fd_.accept(); if (r_socket_fd.is_error()) { if (r_socket_fd.error().code() != -1) { @@ -53,8 +56,7 @@ void TcpListener::loop() { send_closure(callback_, &Callback::accept, r_socket_fd.move_as_ok()); } - if (can_close(server_fd_)) { - LOG(ERROR) << "HELLO!"; + if (can_close_local(server_fd_)) { stop(); } } diff --git a/protocols/Telegram/tdlib/td/tdnet/td/net/TcpListener.h b/protocols/Telegram/tdlib/td/tdnet/td/net/TcpListener.h index f2e61a2387..84cf3c6874 100644 --- a/protocols/Telegram/tdlib/td/tdnet/td/net/TcpListener.h +++ b/protocols/Telegram/tdlib/td/tdnet/td/net/TcpListener.h @@ -1,5 +1,5 @@ // -// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018 +// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022 // // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) @@ -10,6 +10,7 @@ #include "td/utils/port/ServerSocketFd.h" #include "td/utils/port/SocketFd.h" +#include "td/utils/Slice.h" namespace td { @@ -20,16 +21,17 @@ class TcpListener final : public Actor { virtual void accept(SocketFd fd) = 0; }; - TcpListener(int port, ActorShared<Callback> callback); - void hangup() override; + TcpListener(int port, ActorShared<Callback> callback, Slice server_address = Slice("0.0.0.0")); + void hangup() final; private: int port_; ServerSocketFd server_fd_; ActorShared<Callback> callback_; - void start_up() override; - void tear_down() override; - void loop() override; + const string server_address_; + void start_up() final; + void tear_down() final; + void loop() final; }; } // namespace td diff --git a/protocols/Telegram/tdlib/td/tdnet/td/net/TransparentProxy.cpp b/protocols/Telegram/tdlib/td/tdnet/td/net/TransparentProxy.cpp new file mode 100644 index 0000000000..b5102a37b4 --- /dev/null +++ b/protocols/Telegram/tdlib/td/tdnet/td/net/TransparentProxy.cpp @@ -0,0 +1,84 @@ +// +// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022 +// +// Distributed under the Boost Software License, Version 1.0. (See accompanying +// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) +// +#include "td/net/TransparentProxy.h" + +#include "td/utils/logging.h" +#include "td/utils/port/detail/PollableFd.h" + +namespace td { + +int VERBOSITY_NAME(proxy) = VERBOSITY_NAME(DEBUG); + +TransparentProxy::TransparentProxy(SocketFd socket_fd, IPAddress ip_address, string username, string password, + unique_ptr<Callback> callback, ActorShared<> parent) + : fd_(std::move(socket_fd)) + , ip_address_(std::move(ip_address)) + , username_(std::move(username)) + , password_(std::move(password)) + , callback_(std::move(callback)) + , parent_(std::move(parent)) { +} + +void TransparentProxy::on_error(Status status) { + CHECK(status.is_error()); + VLOG(proxy) << "Receive " << status; + if (callback_) { + callback_->set_result(std::move(status)); + callback_.reset(); + } + stop(); +} + +void TransparentProxy::tear_down() { + VLOG(proxy) << "Finish to connect to proxy"; + Scheduler::unsubscribe(fd_.get_poll_info().get_pollable_fd_ref()); + if (callback_) { + if (!fd_.input_buffer().empty()) { + LOG(ERROR) << "Have " << fd_.input_buffer().size() << " unread bytes"; + callback_->set_result(Status::Error("Proxy has sent too many data")); + } else { + callback_->set_result(std::move(fd_)); + } + callback_.reset(); + } +} + +void TransparentProxy::hangup() { + on_error(Status::Error("Canceled")); +} + +void TransparentProxy::start_up() { + VLOG(proxy) << "Begin to connect to proxy"; + Scheduler::subscribe(fd_.get_poll_info().extract_pollable_fd(this)); + set_timeout_in(10); + sync_with_poll(fd_); + if (can_write_local(fd_)) { + loop(); + } +} + +void TransparentProxy::loop() { + sync_with_poll(fd_); + auto status = [&] { + TRY_STATUS(fd_.flush_read()); + TRY_STATUS(loop_impl()); + TRY_STATUS(fd_.flush_write()); + return Status::OK(); + }(); + if (status.is_error()) { + on_error(std::move(status)); + } + if (can_close_local(fd_)) { + on_error(Status::Error("Connection closed")); + } +} + +void TransparentProxy::timeout_expired() { + on_error(Status::Error("Connection timeout expired")); +} + +} // namespace td diff --git a/protocols/Telegram/tdlib/td/tdnet/td/net/TransparentProxy.h b/protocols/Telegram/tdlib/td/tdnet/td/net/TransparentProxy.h new file mode 100644 index 0000000000..66a3830589 --- /dev/null +++ b/protocols/Telegram/tdlib/td/tdnet/td/net/TransparentProxy.h @@ -0,0 +1,57 @@ +// +// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022 +// +// Distributed under the Boost Software License, Version 1.0. (See accompanying +// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) +// +#pragma once + +#include "td/actor/actor.h" + +#include "td/utils/BufferedFd.h" +#include "td/utils/common.h" +#include "td/utils/logging.h" +#include "td/utils/port/IPAddress.h" +#include "td/utils/port/SocketFd.h" +#include "td/utils/Status.h" + +namespace td { + +extern int VERBOSITY_NAME(proxy); + +class TransparentProxy : public Actor { + public: + class Callback { + public: + Callback() = default; + Callback(const Callback &) = delete; + Callback &operator=(const Callback &) = delete; + virtual ~Callback() = default; + + virtual void set_result(Result<BufferedFd<SocketFd>> r_buffered_socket_fd) = 0; + virtual void on_connected() = 0; + }; + + TransparentProxy(SocketFd socket_fd, IPAddress ip_address, string username, string password, + unique_ptr<Callback> callback, ActorShared<> parent); + + protected: + BufferedFd<SocketFd> fd_; + IPAddress ip_address_; + string username_; + string password_; + unique_ptr<Callback> callback_; + ActorShared<> parent_; + + void on_error(Status status); + void tear_down() override; + void start_up() override; + void hangup() override; + + void loop() override; + void timeout_expired() override; + + virtual Status loop_impl() = 0; +}; + +} // namespace td diff --git a/protocols/Telegram/tdlib/td/tdnet/td/net/Wget.cpp b/protocols/Telegram/tdlib/td/tdnet/td/net/Wget.cpp index b30128be32..f6a6c72eb8 100644 --- a/protocols/Telegram/tdlib/td/tdnet/td/net/Wget.cpp +++ b/protocols/Telegram/tdlib/td/tdnet/td/net/Wget.cpp @@ -1,5 +1,5 @@ // -// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018 +// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022 // // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) @@ -8,62 +8,90 @@ #include "td/net/HttpHeaderCreator.h" #include "td/net/HttpOutboundConnection.h" -#include "td/net/SslFd.h" +#include "td/net/SslStream.h" #include "td/utils/buffer.h" +#include "td/utils/BufferedFd.h" #include "td/utils/HttpUrl.h" #include "td/utils/logging.h" +#include "td/utils/misc.h" #include "td/utils/port/IPAddress.h" #include "td/utils/port/SocketFd.h" #include "td/utils/Slice.h" +#include "td/utils/SliceBuilder.h" #include <limits> namespace td { -Wget::Wget(Promise<HttpQueryPtr> promise, string url, std::vector<std::pair<string, string>> headers, int32 timeout_in, - int32 ttl, SslFd::VerifyPeer verify_peer) + +Wget::Wget(Promise<unique_ptr<HttpQuery>> promise, string url, std::vector<std::pair<string, string>> headers, + int32 timeout_in, int32 ttl, bool prefer_ipv6, SslCtx::VerifyPeer verify_peer, string content, + string content_type) : promise_(std::move(promise)) , input_url_(std::move(url)) , headers_(std::move(headers)) , timeout_in_(timeout_in) , ttl_(ttl) - , verify_peer_(verify_peer) { + , prefer_ipv6_(prefer_ipv6) + , verify_peer_(verify_peer) + , content_(std::move(content)) + , content_type_(std::move(content_type)) { } Status Wget::try_init() { - string input_url = input_url_; - TRY_RESULT(url, parse_url(MutableSlice(input_url))); - - IPAddress addr; - TRY_STATUS(addr.init_host_port(url.host_, url.port_)); + TRY_RESULT(url, parse_url(input_url_)); + TRY_RESULT_ASSIGN(url.host_, idn_to_ascii(url.host_)); - TRY_RESULT(fd, SocketFd::open(addr)); - if (url.protocol_ == HttpUrl::Protocol::HTTP) { - connection_ = - create_actor<HttpOutboundConnection>("Connect", std::move(fd), std::numeric_limits<std::size_t>::max(), 0, 0, - ActorOwn<HttpOutboundConnection::Callback>(actor_id(this))); + HttpHeaderCreator hc; + if (content_.empty()) { + hc.init_get(url.query_); } else { - TRY_RESULT(ssl_fd, SslFd::init(std::move(fd), url.host_, CSlice() /* certificate */, verify_peer_)); - connection_ = - create_actor<HttpOutboundConnection>("Connect", std::move(ssl_fd), std::numeric_limits<std::size_t>::max(), 0, - 0, ActorOwn<HttpOutboundConnection::Callback>(actor_id(this))); + hc.init_post(url.query_); + hc.set_content_size(content_.size()); + if (!content_type_.empty()) { + hc.set_content_type(content_type_); + } } - - HttpHeaderCreator hc; - hc.init_get(url.query_); bool was_host = false; + bool was_accept_encoding = false; for (auto &header : headers_) { - if (header.first == "Host") { // TODO: lowercase + auto header_lower = to_lower(header.first); + if (header_lower == "host") { was_host = true; } + if (header_lower == "accept-encoding") { + was_accept_encoding = true; + } hc.add_header(header.first, header.second); } if (!was_host) { hc.add_header("Host", url.host_); } - hc.add_header("Accept-Encoding", "gzip, deflate"); + if (!was_accept_encoding) { + hc.add_header("Accept-Encoding", "gzip, deflate"); + } + TRY_RESULT(header, hc.finish(content_)); + + IPAddress addr; + TRY_STATUS(addr.init_host_port(url.host_, url.port_, prefer_ipv6_)); - send_closure(connection_, &HttpOutboundConnection::write_next, BufferSlice(hc.finish().ok())); + TRY_RESULT(fd, SocketFd::open(addr)); + if (fd.empty()) { + return Status::Error("Sockets are not supported"); + } + if (url.protocol_ == HttpUrl::Protocol::Http) { + connection_ = create_actor<HttpOutboundConnection>("Connect", BufferedFd<SocketFd>(std::move(fd)), SslStream{}, + std::numeric_limits<std::size_t>::max(), 0, 0, + ActorOwn<HttpOutboundConnection::Callback>(actor_id(this))); + } else { + TRY_RESULT(ssl_ctx, SslCtx::create(CSlice() /* certificate */, verify_peer_)); + TRY_RESULT(ssl_stream, SslStream::create(url.host_, std::move(ssl_ctx))); + connection_ = create_actor<HttpOutboundConnection>( + "Connect", BufferedFd<SocketFd>(std::move(fd)), std::move(ssl_stream), std::numeric_limits<std::size_t>::max(), + 0, 0, ActorOwn<HttpOutboundConnection::Callback>(actor_id(this))); + } + + send_closure(connection_, &HttpOutboundConnection::write_next, BufferSlice(header)); send_closure(connection_, &HttpOutboundConnection::write_ok); return Status::OK(); } @@ -77,7 +105,7 @@ void Wget::loop() { } } -void Wget::handle(HttpQueryPtr result) { +void Wget::handle(unique_ptr<HttpQuery> result) { on_ok(std::move(result)); } @@ -85,11 +113,14 @@ void Wget::on_connection_error(Status error) { on_error(std::move(error)); } -void Wget::on_ok(HttpQueryPtr http_query_ptr) { +void Wget::on_ok(unique_ptr<HttpQuery> http_query_ptr) { CHECK(promise_); - if (http_query_ptr->code_ == 302 && ttl_ > 0) { + CHECK(http_query_ptr); + if ((http_query_ptr->code_ == 301 || http_query_ptr->code_ == 302 || http_query_ptr->code_ == 307 || + http_query_ptr->code_ == 308) && + ttl_ > 0) { LOG(DEBUG) << *http_query_ptr; - input_url_ = http_query_ptr->header("location").str(); + input_url_ = http_query_ptr->get_header("location").str(); LOG(DEBUG) << input_url_; ttl_--; connection_.reset(); @@ -98,7 +129,7 @@ void Wget::on_ok(HttpQueryPtr http_query_ptr) { promise_.set_value(std::move(http_query_ptr)); stop(); } else { - on_error(Status::Error(PSLICE() << "http error: " << http_query_ptr->code_)); + on_error(Status::Error(PSLICE() << "HTTP error: " << http_query_ptr->code_)); } } @@ -115,12 +146,13 @@ void Wget::start_up() { } void Wget::timeout_expired() { - on_error(Status::Error("Timeout expired")); + on_error(Status::Error("Response timeout expired")); } void Wget::tear_down() { if (promise_) { - on_error(Status::Error("Cancelled")); + on_error(Status::Error("Canceled")); } } + } // namespace td diff --git a/protocols/Telegram/tdlib/td/tdnet/td/net/Wget.h b/protocols/Telegram/tdlib/td/tdnet/td/net/Wget.h index cecb113c94..f3d225982b 100644 --- a/protocols/Telegram/tdlib/td/tdnet/td/net/Wget.h +++ b/protocols/Telegram/tdlib/td/tdnet/td/net/Wget.h @@ -1,5 +1,5 @@ // -// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018 +// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022 // // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) @@ -8,41 +8,46 @@ #include "td/net/HttpOutboundConnection.h" #include "td/net/HttpQuery.h" -#include "td/net/SslFd.h" +#include "td/net/SslCtx.h" -#include "td/actor/PromiseFuture.h" +#include "td/actor/actor.h" #include "td/utils/common.h" +#include "td/utils/Promise.h" #include "td/utils/Status.h" #include <utility> namespace td { -class Wget : public HttpOutboundConnection::Callback { +class Wget final : public HttpOutboundConnection::Callback { public: - explicit Wget(Promise<HttpQueryPtr> promise, string url, std::vector<std::pair<string, string>> headers = {}, - int32 timeout_in = 10, int32 ttl = 3, SslFd::VerifyPeer verify_peer = SslFd::VerifyPeer::On); + explicit Wget(Promise<unique_ptr<HttpQuery>> promise, string url, std::vector<std::pair<string, string>> headers = {}, + int32 timeout_in = 10, int32 ttl = 3, bool prefer_ipv6 = false, + SslCtx::VerifyPeer verify_peer = SslCtx::VerifyPeer::On, string content = {}, string content_type = {}); private: Status try_init(); - void loop() override; - void handle(HttpQueryPtr result) override; - void on_connection_error(Status error) override; - void on_ok(HttpQueryPtr http_query_ptr); + void loop() final; + void handle(unique_ptr<HttpQuery> result) final; + void on_connection_error(Status error) final; + void on_ok(unique_ptr<HttpQuery> http_query_ptr); void on_error(Status error); - void tear_down() override; - void start_up() override; - void timeout_expired() override; + void tear_down() final; + void start_up() final; + void timeout_expired() final; - Promise<HttpQueryPtr> promise_; + Promise<unique_ptr<HttpQuery>> promise_; ActorOwn<HttpOutboundConnection> connection_; string input_url_; std::vector<std::pair<string, string>> headers_; int32 timeout_in_; int32 ttl_; - SslFd::VerifyPeer verify_peer_; + bool prefer_ipv6_ = false; + SslCtx::VerifyPeer verify_peer_; + string content_; + string content_type_; }; } // namespace td |