diff --git a/a8/a8.h b/a8/a8.h index 9e5d599..373673d 100644 --- a/a8/a8.h +++ b/a8/a8.h @@ -12,6 +12,7 @@ #include #include +#include #include #include #include diff --git a/a8/asiotcpclient.cc b/a8/asiotcpclient.cc index ed5e253..8610fbd 100644 --- a/a8/asiotcpclient.cc +++ b/a8/asiotcpclient.cc @@ -8,24 +8,25 @@ #include #include -#ifdef USE_ASIO +#ifdef USE_BOOST const int MAX_RECV_BUFFERSIZE = 1024 * 64; namespace a8 { - AsioTcpClient::AsioTcpClient(asio::io_context& io_context, const std::string& remote_ip, int remote_port) + AsioTcpClient::AsioTcpClient(std::shared_ptr io_context, const std::string& remote_ip, int remote_port) { + io_context_ = io_context; remote_address_ = remote_ip; remote_port_ = remote_port; endpoint_ = std::make_shared ( - asio::ip::make_address(remote_address_), + asio::ip::address::from_string(remote_address_), remote_port_ ); send_buffer_mutex_ = std::make_shared(); - socket_ = std::make_shared(io_context); + socket_ = std::make_shared(*io_context); } AsioTcpClient::~AsioTcpClient() diff --git a/a8/asiotcpclient.h b/a8/asiotcpclient.h index a5b6b11..c86ab10 100644 --- a/a8/asiotcpclient.h +++ b/a8/asiotcpclient.h @@ -1,6 +1,6 @@ #pragma once -#ifdef USE_ASIO +#ifdef USE_BOOST #include @@ -16,7 +16,9 @@ namespace a8 std::function on_connect; std::function on_disconnect; std::function on_socketread; - AsioTcpClient(asio::io_context& io_context, const std::string& remote_ip, int remote_port); + AsioTcpClient(std::shared_ptr io_context, + const std::string& remote_ip, + int remote_port); virtual ~AsioTcpClient(); const std::string& GetRemoteAddress() { return remote_address_; } int GetRemotePort() { return remote_port_; } @@ -33,6 +35,7 @@ namespace a8 void DoSend(); private: + std::shared_ptr io_context_; std::string remote_address_; int remote_port_ = 0; diff --git a/a8/awaiter.cc b/a8/awaiter.cc new file mode 100644 index 0000000..46a3013 --- /dev/null +++ b/a8/awaiter.cc @@ -0,0 +1,29 @@ +#include + +#include + +namespace a8 +{ + + void Awaiter::Await(std::shared_ptr notifyer) + { + notifyers_.push_back(notifyer); + DoAwait(); + } + + void Awaiter::DoDone() + { + done_ = true; + for (auto notifyer : notifyers_) { + if (!notifyer.expired()) { + notifyer.lock()->DoResume(); + } + } + } + + void Awaiter::SetResult(std::vector results) + { + results_ = std::make_shared(results); + } + +} diff --git a/a8/awaiter.h b/a8/awaiter.h new file mode 100644 index 0000000..47f2176 --- /dev/null +++ b/a8/awaiter.h @@ -0,0 +1,38 @@ +#pragma once + +#include + +namespace f8 +{ + class Coroutine; +} + +namespace a8 +{ + + class Awaiter : public std::enable_shared_from_this + { + public: + virtual ~Awaiter() {}; + + std::shared_ptr GetResult() { return results_; } + bool Done() const { return done_; } + virtual void DoResume() {}; + + protected: + bool done_ = false; + + std::list> notifyers_; + void Await(std::shared_ptr notifyer); + virtual void DoAwait() = 0; + void DoDone(); + void SetResult(std::vector results); + + private: + std::shared_ptr results_; + std::function cb_; + + friend class f8::Coroutine; + }; + +} diff --git a/a8/promise.cc b/a8/promise.cc new file mode 100644 index 0000000..d2a83e6 --- /dev/null +++ b/a8/promise.cc @@ -0,0 +1,7 @@ +#include +#include + +namespace a8 +{ + +} diff --git a/a8/promise.h b/a8/promise.h new file mode 100644 index 0000000..9c6f6b8 --- /dev/null +++ b/a8/promise.h @@ -0,0 +1,12 @@ +#pragma once + +#include + +namespace a8 +{ + + class Promise : public Awaiter + { + }; + +} diff --git a/a8/result.h b/a8/result.h new file mode 100644 index 0000000..950f41a --- /dev/null +++ b/a8/result.h @@ -0,0 +1,17 @@ +#pragma once + +namespace a8 +{ + class Results + { + public: + + Results(std::vector results):results_(std::move(results)) {}; + + template + T Get(size_t index) const { return std::any_cast(results_.at(index));}; + + private: + std::vector results_; + }; +} diff --git a/a8/websocketclient.cc b/a8/websocketclient.cc new file mode 100644 index 0000000..4819003 --- /dev/null +++ b/a8/websocketclient.cc @@ -0,0 +1,297 @@ +#include + +#include + +#ifdef USE_BOOST + +static const unsigned char FIN = 0x80; +static const unsigned char RSV1 = 0x40; +static const unsigned char RSV2 = 0x20; +static const unsigned char RSV3 = 0x10; +static const unsigned char RSV_MASK = RSV1 | RSV2 | RSV3; +static const unsigned char OPCODE_MASK = 0x0F; + +static const unsigned char TEXT_MODE = 0x01; +static const unsigned char BINARY_MODE = 0x02; + +static const unsigned char WEBSOCKET_OPCODE = 0x0F; +static const unsigned char WEBSOCKET_FRAME_CONTINUE = 0x0; +static const unsigned char WEBSOCKET_FRAME_TEXT = 0x1; +static const unsigned char WEBSOCKET_FRAME_BINARY = 0x2; +static const unsigned char WEBSOCKET_FRAME_CLOSE = 0x8; +static const unsigned char WEBSOCKET_FRAME_PING = 0x9; +static const unsigned char WEBSOCKET_FRAME_PONG = 0xA; + +static const unsigned char WEBSOCKET_MASK = 0x80; +static const unsigned char WEBSOCKET_PAYLOAD_LEN = 0x7F; +static const unsigned char WEBSOCKET_PAYLOAD_LEN_UINT16 = 126; +static const unsigned char WEBSOCKET_PAYLOAD_LEN_UINT64 = 127; + +static const char* WEB_SOCKET_KEY = "Sec-WebSocket-Key: "; +static const char* WEB_SOCKET_KEY2 = "Sec-Websocket-Key: "; + +static const int DEFAULT_MAX_PACKET_LEN = 1024 * 10; +static const int DEFAULT_MAX_RECV_BUFFERSIZE = 1024 * 64; + +namespace a8 +{ + + WebSocketClient::WebSocketClient(std::shared_ptr io_context, const std::string& remote_ip, int remote_port) + { + max_packet_len_ = DEFAULT_MAX_PACKET_LEN; + recv_buff_ = (char *)malloc(max_packet_len_ + 1); + recv_bufflen_ = 0; + + tcp_client_ = std::make_shared(io_context, remote_ip, remote_port); + decoded_buff_ = (char *)malloc(1024 * 64 + 1); + decoded_bufflen_ = 0; + tcp_client_->on_error = + [this] (a8::AsioTcpClient* socket, int err) + { + if (on_error) { + on_error(this, err); + } + }; + tcp_client_->on_connect = + [this] (a8::AsioTcpClient* socket) + { + std::string data = a8::Format("GET ws://%s:%d/\r\n", + {socket->GetRemoteAddress(), + socket->GetRemotePort()}); + data += "Upgrade: websocket\r\n"; + data += "Connection: Upgrade\r\n"; + data += "Sec-WebSocket-Version: 13\r\n"; + data += "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==*\r\n"; + data += "\r\n"; + socket->SendBuff(data.data(), data.size()); + }; + tcp_client_->on_disconnect = + [this] (a8::AsioTcpClient* socket) + { + if (on_disconnect) { + on_disconnect(this); + } + }; + tcp_client_->on_socketread = + [this] (a8::AsioTcpClient* socket, char* buf, unsigned int buflen) + { + unsigned int already_read_bytes = 0; + do { + if (already_read_bytes < buflen) { + int read_bytes = std::min(buflen - already_read_bytes, + (unsigned int)max_packet_len_ - recv_bufflen_); + if (read_bytes > 0) { + memmove(&recv_buff_[recv_bufflen_], buf + already_read_bytes, read_bytes); + recv_bufflen_ += read_bytes; + already_read_bytes += read_bytes; + } + } + + int offset = 0; + int prev_offset = 0; + do { + prev_offset = offset; + DecodePacket(recv_buff_, offset, recv_bufflen_); + } while (prev_offset < offset && offset < recv_bufflen_); + + if (offset > 0 && offset < recv_bufflen_){ + memmove(recv_buff_, recv_buff_ + offset, recv_bufflen_ - offset); + } + recv_bufflen_ -= offset; + if (recv_bufflen_ >= max_packet_len_) { + //收到超长包 + Close(); + return; + } + } while (already_read_bytes < buflen); + }; + } + + WebSocketClient::~WebSocketClient() + { + recv_bufflen_ = 0; + free(recv_buff_); + recv_buff_ = nullptr; + + free(decoded_buff_); + decoded_buff_ = nullptr; + decoded_bufflen_ = 0; + } + + void WebSocketClient::Open() + { + tcp_client_->Open(); + } + + void WebSocketClient::Close() + { + tcp_client_->Close(); + } + + bool WebSocketClient::IsActive() + { + return tcp_client_->IsActive(); + } + + bool WebSocketClient::Connected() + { + return tcp_client_->Connected(); + } + + void WebSocketClient::SendBuff(const char* buff, unsigned int bufflen) + { + if (!handshook_) { + abort(); + } + unsigned char szbuff [1024 * 65]; + szbuff[0] = FIN | BINARY_MODE | 0; + int mask_offset = 2; + int payloadlen = bufflen; + if (payloadlen < 126) { + szbuff[1] = payloadlen | 0x80; + mask_offset = 2; + } else if (payloadlen <= 0xFFFF) { + szbuff[1] = 126 | 0x80; + szbuff[2] = (payloadlen >> 8) & 0xFF; + szbuff[3] = payloadlen & 0xFF; + mask_offset = 4; + } else { + abort(); + } + *((int*)(szbuff + mask_offset)) = rand(); + for (unsigned i = 0; i < bufflen; ++i) { + szbuff[mask_offset + 4 + i] = + ((unsigned char)buff[i]) ^ szbuff[mask_offset + (i % 4)] ; + } + tcp_client_->SendBuff((char*)szbuff, bufflen + mask_offset + 4); + } + + void WebSocketClient::ProcessHandShake(char* buf, int& offset, unsigned int buflen) + { + char* pend = strstr(buf + offset, "\r\n\r\n"); + if (!pend) { + return; + } + ProcessWsHandShake(buf, offset, buflen); + } + + void WebSocketClient::ProcessWsHandShake(char* buf, int& offset, unsigned int buflen) + { + char* pend = strstr(buf + offset, "\r\n\r\n"); + if (!pend) { + return; + } + handshook_ = true; + offset += pend - buf - offset + strlen("\r\n\r\n"); + if (on_connect) { + on_connect(this); + } + } + + void WebSocketClient::ProcessUserPacket() + { + int offset = 0; + int prev_offset = 0; + do { + prev_offset = offset; + on_decode_userpacket(decoded_buff_, offset, decoded_bufflen_); + } while (prev_offset < offset && offset < decoded_bufflen_); + + if (offset > 0 && offset < decoded_bufflen_){ + memmove(decoded_buff_, decoded_buff_ + offset, decoded_bufflen_ - offset); + } + decoded_bufflen_ -= offset; + if (decoded_bufflen_ >= max_packet_len_) { + //收到超长包 + Close(); + return; + } + } + + void WebSocketClient::DecodeFrame(char* buf, int& offset, unsigned int buflen) + { + if (offset + 2 > (int)buflen) { + return; + } + char* real_buf = buf + offset; + unsigned int ava_len = buflen - offset; + unsigned char header = real_buf[0]; + unsigned char mask_payloadlen = real_buf[1]; + + bool is_final_frame = (header & FIN) == FIN; + #if 0 + bool reserved_bits = (header & FIN) == RSV_MASK; + #endif + unsigned char opcode = header & OPCODE_MASK; + #if 0 + bool opcode_is_control = opcode & 0x8; + #endif + + if (opcode == WEBSOCKET_FRAME_CLOSE) { + Close(); + return; + } + if (opcode != BINARY_MODE) { + if (opcode != WEBSOCKET_FRAME_PING) { + Close(); + return; + } + } + if (!is_final_frame) { + Close(); + return; + } + + bool is_masked = (mask_payloadlen & 0x80) == 0x80; + + unsigned char payloadlen = mask_payloadlen & 0x7F; + unsigned int framelen = 0; + int mask_offset = 0; + + if (payloadlen < 126) { + framelen = payloadlen; + mask_offset = 2; + } else if (payloadlen == 126 && ava_len >= 4) { + framelen = ntohs( *(u_short*) (real_buf + 2) ); + mask_offset = 4; + } else if (payloadlen == 127 && ava_len >= 8) { + //int32 or int64? + framelen = ntohl( *(u_long*) (real_buf + 2) ); + mask_offset = 8; + } else { + return; + } + unsigned int real_pkg_len = mask_offset + framelen + (is_masked ? 4 : 0); + if (ava_len < real_pkg_len) { + return; + } + + if (is_masked) { + unsigned char *frame_mask = (unsigned char*)(real_buf + mask_offset); + memcpy(&decoded_buff_[decoded_bufflen_], real_buf + mask_offset + 4, framelen); + for (unsigned int i = 0; i < framelen; i++) { + decoded_buff_[decoded_bufflen_ + i] = + (decoded_buff_[ decoded_bufflen_ + i] ^ frame_mask[i%4]); + } + } else { + memcpy(&decoded_buff_[decoded_bufflen_], real_buf + mask_offset, framelen); + } + decoded_bufflen_ += framelen; + + ProcessUserPacket(); + offset += real_pkg_len; + } + + void WebSocketClient::DecodePacket(char* buf, int& offset, unsigned int buflen) + { + if (!handshook_) { + buf[buflen] = '\0'; + ProcessHandShake(buf, offset, buflen); + } else { + DecodeFrame(buf, offset, buflen); + } + } + +} + +#endif diff --git a/a8/websocketclient.h b/a8/websocketclient.h new file mode 100644 index 0000000..6f7d007 --- /dev/null +++ b/a8/websocketclient.h @@ -0,0 +1,51 @@ +#pragma once + +#include + +#ifdef USE_BOOST + +#include + +using asio::ip::tcp; + +namespace a8 +{ + + class WebSocketClient + { + public: + WebSocketClient(std::shared_ptr io_context, const std::string& remote_ip, int remote_port); + virtual ~WebSocketClient(); + + std::function on_error; + std::function on_connect; + std::function on_disconnect; + std::function on_decode_userpacket; + + void Open(); + void Close(); + bool IsActive(); + bool Connected(); + void SendBuff(const char* buff, unsigned int bufflen); + + private: + void DecodePacket(char* buf, int& offset, unsigned int buflen); + + void ProcessHandShake(char* buf, int& offset, unsigned int buflen); + void ProcessWsHandShake(char* buf, int& offset, unsigned int buflen); + void ProcessUserPacket(); + void DecodeFrame(char* buf, int& offset, unsigned int buflen); + + private: + std::shared_ptr tcp_client_; + char *decoded_buff_ = nullptr; + int decoded_bufflen_ = 0; + bool handshook_ = false; + char *recv_buff_ = nullptr; + int recv_bufflen_ = 0; + int max_packet_len_ = 0; + }; + +} + +#endif diff --git a/a8/websocketsession.cc b/a8/websocketsession.cc index 44303c7..385356b 100644 --- a/a8/websocketsession.cc +++ b/a8/websocketsession.cc @@ -9,31 +9,31 @@ #include #include -const unsigned char FIN = 0x80; -const unsigned char RSV1 = 0x40; -const unsigned char RSV2 = 0x20; -const unsigned char RSV3 = 0x10; -const unsigned char RSV_MASK = RSV1 | RSV2 | RSV3; -const unsigned char OPCODE_MASK = 0x0F; +static const unsigned char FIN = 0x80; +static const unsigned char RSV1 = 0x40; +static const unsigned char RSV2 = 0x20; +static const unsigned char RSV3 = 0x10; +static const unsigned char RSV_MASK = RSV1 | RSV2 | RSV3; +static const unsigned char OPCODE_MASK = 0x0F; -const unsigned char TEXT_MODE = 0x01; -const unsigned char BINARY_MODE = 0x02; +static const unsigned char TEXT_MODE = 0x01; +static const unsigned char BINARY_MODE = 0x02; -const unsigned char WEBSOCKET_OPCODE = 0x0F; -const unsigned char WEBSOCKET_FRAME_CONTINUE = 0x0; -const unsigned char WEBSOCKET_FRAME_TEXT = 0x1; -const unsigned char WEBSOCKET_FRAME_BINARY = 0x2; -const unsigned char WEBSOCKET_FRAME_CLOSE = 0x8; -const unsigned char WEBSOCKET_FRAME_PING = 0x9; -const unsigned char WEBSOCKET_FRAME_PONG = 0xA; +static const unsigned char WEBSOCKET_OPCODE = 0x0F; +static const unsigned char WEBSOCKET_FRAME_CONTINUE = 0x0; +static const unsigned char WEBSOCKET_FRAME_TEXT = 0x1; +static const unsigned char WEBSOCKET_FRAME_BINARY = 0x2; +static const unsigned char WEBSOCKET_FRAME_CLOSE = 0x8; +static const unsigned char WEBSOCKET_FRAME_PING = 0x9; +static const unsigned char WEBSOCKET_FRAME_PONG = 0xA; -const unsigned char WEBSOCKET_MASK = 0x80; -const unsigned char WEBSOCKET_PAYLOAD_LEN = 0x7F; -const unsigned char WEBSOCKET_PAYLOAD_LEN_UINT16 = 126; -const unsigned char WEBSOCKET_PAYLOAD_LEN_UINT64 = 127; +static const unsigned char WEBSOCKET_MASK = 0x80; +static const unsigned char WEBSOCKET_PAYLOAD_LEN = 0x7F; +static const unsigned char WEBSOCKET_PAYLOAD_LEN_UINT16 = 126; +static const unsigned char WEBSOCKET_PAYLOAD_LEN_UINT64 = 127; -const char* WEB_SOCKET_KEY = "Sec-WebSocket-Key: "; -const char* WEB_SOCKET_KEY2 = "Sec-Websocket-Key: "; +static const char* WEB_SOCKET_KEY = "Sec-WebSocket-Key: "; +static const char* WEB_SOCKET_KEY2 = "Sec-Websocket-Key: "; namespace a8 { @@ -188,6 +188,11 @@ namespace a8 } } + /* + finbit|opcode|falgs + unsigned char header FIN|BINARY_MODE + unsigned char mask_payloadlen + */ void WebSocketSession::DecodeFrame(char* buf, int& offset, unsigned int buflen) { if (offset + 2 > (int)buflen) {