/* * libjingle * Copyright 2004--2005, Google Inc. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * 3. The name of the author may not be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO * EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ #include #include #include "talk/base/natsocketfactory.h" namespace talk_base { class NATSocket : public AsyncSocket { public: NATSocket(Socket* socket, const SocketAddress& server_addr) : async_(false), connected_(false), server_addr_(server_addr), socket_(socket), buf_(0), size_(0) { } NATSocket(AsyncSocket* socket, const SocketAddress& server_addr) : async_(true), connected_(false), server_addr_(server_addr), socket_(socket), buf_(0), size_(0) { socket->SignalReadEvent.connect(this, &NATSocket::OnReadEvent); socket->SignalWriteEvent.connect(this, &NATSocket::OnWriteEvent); } virtual ~NATSocket() { delete socket_; delete buf_; } SocketAddress GetLocalAddress() const { return socket_->GetLocalAddress(); } SocketAddress GetRemoteAddress() const { return remote_addr_; // will be ANY if not connected } int Bind(const SocketAddress& addr) { return socket_->Bind(addr); } int Connect(const SocketAddress& addr) { connected_ = true; remote_addr_ = addr; return 0; } int Send(const void *pv, size_t cb) { assert(connected_); return SendInternal(pv, cb, remote_addr_); } int SendTo(const void *pv, size_t cb, const SocketAddress& addr) { assert(!connected_); return SendInternal(pv, cb, addr); } int SendInternal(const void *pv, size_t cb, const SocketAddress& addr) { size_t size = cb + addr.Size_(); char* buf = new char[size]; Encode(static_cast(pv), cb, buf, size, addr); int result = socket_->SendTo(buf, size, server_addr_); delete buf; if (result < 0) { return result; } else { assert(result == static_cast(size)); // TODO: This isn't fair. return (int)((size_t)result - addr.Size_()); } } int Recv(void *pv, size_t cb) { SocketAddress addr; return RecvFrom(pv, cb, &addr); } int RecvFrom(void *pv, size_t cb, SocketAddress *paddr) { // Make sure we have enough room to read the requested amount plus the // header address. SocketAddress remote_addr; Grow(cb + remote_addr.Size_()); // Read the packet from the socket. int result = socket_->RecvFrom(buf_, size_, &remote_addr); if (result < 0) return result; assert(remote_addr == server_addr_); // TODO: we need better framing so that we know how many bytes we can // return before we need to read the next address. For UDP, this will be // fine as long as the reader always reads everything in the packet. assert((size_t)result < size_); // Decode the wire packet into the actual results. SocketAddress real_remote_addr; size_t real_size = cb; Decode(buf_, result, pv, &real_size, &real_remote_addr); // Make sure this packet should be delivered before returning it. if (!connected_ || (real_remote_addr == remote_addr_)) { if (paddr) *paddr = real_remote_addr; return (int)real_size; } else { std::cerr << "Dropping packet from unknown remote address: " << real_remote_addr.ToString() << std::endl; return 0; // Tell the caller we didn't read anything } } int Close() { connected_ = false; remote_addr_ = SocketAddress(); return socket_->Close(); } int Listen(int backlog) { assert(false); // not yet implemented return 0; } Socket* Accept(SocketAddress *paddr) { assert(false); // not yet implemented return 0; } AsyncSocket* asyncsocket() { assert(async_); return static_cast(socket_); } int GetError() const { return socket_->GetError(); } void SetError(int error) { socket_->SetError(error); } ConnState GetState() const { return connected_ ? CS_CONNECTED : CS_CLOSED; } virtual int EstimateMTU(uint16* mtu) { return socket_->EstimateMTU(mtu); } virtual int SetOption(Option opt, int value) { return socket_->SetOption(opt, value); } void OnReadEvent(AsyncSocket* socket) { assert(socket == socket_); SignalReadEvent(this); } void OnWriteEvent(AsyncSocket* socket) { assert(socket == socket_); SignalWriteEvent(this); } private: // Makes sure the buffer is at least the given size. void Grow(size_t new_size) { if (size_ < new_size) { delete buf_; size_ = new_size; buf_ = new char[size_]; } } // Encodes the given data and intended remote address into a packet to send // to the NAT server. void Encode(const char* data, size_t data_size, char* buf, size_t buf_size, const SocketAddress& remote_addr) { assert(buf_size == data_size + remote_addr.Size_()); remote_addr.Write_(buf, (int)buf_size); std::memcpy(buf + remote_addr.Size_(), data, data_size); } // Decodes the given packet from the NAT server into the actual remote // address and data. void Decode(const char* data, size_t data_size, void* buf, size_t* buf_size, SocketAddress* remote_addr) { assert(data_size >= remote_addr->Size_()); assert(data_size <= *buf_size + remote_addr->Size_()); remote_addr->Read_(data, (int)data_size); *buf_size = data_size - remote_addr->Size_(); std::memcpy(buf, data + remote_addr->Size_(), *buf_size); } bool async_; bool connected_; SocketAddress remote_addr_; SocketAddress server_addr_; // address of the NAT server Socket* socket_; char* buf_; size_t size_; }; NATSocketFactory::NATSocketFactory( SocketFactory* factory, const SocketAddress& nat_addr) : factory_(factory), nat_addr_(nat_addr) { } Socket* NATSocketFactory::CreateSocket(int type) { assert(type == SOCK_DGRAM); // TCP is not yet suported return new NATSocket(factory_->CreateSocket(type), nat_addr_); } AsyncSocket* NATSocketFactory::CreateAsyncSocket(int type) { assert(type == SOCK_DGRAM); // TCP is not yet suported return new NATSocket(factory_->CreateAsyncSocket(type), nat_addr_); } } // namespace talk_base