Crypto++  5.6.3
Free C++ class library of cryptographic schemes
socketft.cpp
1 // socketft.cpp - written and placed in the public domain by Wei Dai
2 
3 #include "pch.h"
4 #include "config.h"
5 
6 #if !defined(NO_OS_DEPENDENCE) && defined(SOCKETS_AVAILABLE)
7 
8 // TODO: http://github.com/weidai11/cryptopp/issues/19
9 #define _WINSOCK_DEPRECATED_NO_WARNINGS
10 #include "socketft.h"
11 #include "wait.h"
12 
13 // Windows 8, Windows Server 2012, and Windows Phone 8.1 need <synchapi.h> and <ioapiset.h>
14 #if defined(CRYPTOPP_WIN32_AVAILABLE)
15 # if ((WINVER >= 0x0602 /*_WIN32_WINNT_WIN8*/) || (_WIN32_WINNT >= 0x0602 /*_WIN32_WINNT_WIN8*/))
16 # include <synchapi.h>
17 # include <ioapiset.h>
18 # define USE_WINDOWS8_API
19 # endif
20 #endif
21 
22 #ifdef USE_BERKELEY_STYLE_SOCKETS
23 #include <errno.h>
24 #include <netdb.h>
25 #include <unistd.h>
26 #include <arpa/inet.h>
27 #include <netinet/in.h>
28 #include <sys/ioctl.h>
29 #endif
30 
31 #if defined(CRYPTOPP_MSAN)
32 # include <sanitizer/msan_interface.h>
33 #endif
34 
35 #ifdef PREFER_WINDOWS_STYLE_SOCKETS
36 # pragma comment(lib, "ws2_32.lib")
37 #endif
38 
39 NAMESPACE_BEGIN(CryptoPP)
40 
41 #ifdef USE_WINDOWS_STYLE_SOCKETS
42 const int SOCKET_EINVAL = WSAEINVAL;
43 const int SOCKET_EWOULDBLOCK = WSAEWOULDBLOCK;
44 typedef int socklen_t;
45 #else
46 const int SOCKET_EINVAL = EINVAL;
47 const int SOCKET_EWOULDBLOCK = EWOULDBLOCK;
48 #endif
49 
50 // Solaris doesn't have INADDR_NONE
51 #ifndef INADDR_NONE
52 # define INADDR_NONE 0xffffffff
53 #endif /* INADDR_NONE */
54 
55 Socket::Err::Err(socket_t s, const std::string& operation, int error)
56  : OS_Error(IO_ERROR, "Socket: " + operation + " operation failed with error " + IntToString(error), operation, error)
57  , m_s(s)
58 {
59 }
60 
61 Socket::~Socket()
62 {
63  if (m_own)
64  {
65  try
66  {
67  CloseSocket();
68  }
69  catch (const Exception&)
70  {
71  assert(0);
72  }
73  }
74 }
75 
76 void Socket::AttachSocket(socket_t s, bool own)
77 {
78  if (m_own)
79  CloseSocket();
80 
81  m_s = s;
82  m_own = own;
83  SocketChanged();
84 }
85 
86 socket_t Socket::DetachSocket()
87 {
88  socket_t s = m_s;
89  m_s = INVALID_SOCKET;
90  SocketChanged();
91  return s;
92 }
93 
94 void Socket::Create(int nType)
95 {
96  assert(m_s == INVALID_SOCKET);
97  m_s = socket(AF_INET, nType, 0);
98  CheckAndHandleError("socket", m_s);
99  m_own = true;
100  SocketChanged();
101 }
102 
103 void Socket::CloseSocket()
104 {
105  if (m_s != INVALID_SOCKET)
106  {
107 #ifdef USE_WINDOWS_STYLE_SOCKETS
108 # if defined(USE_WINDOWS8_API)
109  BOOL result = CancelIoEx((HANDLE) m_s, NULL);
110  assert(result || (!result && GetLastError() == ERROR_NOT_FOUND));
111  CheckAndHandleError_int("closesocket", closesocket(m_s));
112  CRYPTOPP_UNUSED(result); // Used by assert in debug builds
113 # else
114  BOOL result = CancelIo((HANDLE) m_s);
115  assert(result || (!result && GetLastError() == ERROR_NOT_FOUND));
116  CheckAndHandleError_int("closesocket", closesocket(m_s));
117  CRYPTOPP_UNUSED(result);
118 # endif
119 #else
120  CheckAndHandleError_int("close", close(m_s));
121 #endif
122  m_s = INVALID_SOCKET;
123  SocketChanged();
124  }
125 }
126 
127 void Socket::Bind(unsigned int port, const char *addr)
128 {
129  sockaddr_in sa;
130  memset(&sa, 0, sizeof(sa));
131  sa.sin_family = AF_INET;
132 
133  if (addr == NULL)
134  sa.sin_addr.s_addr = htonl(INADDR_ANY);
135  else
136  {
137  unsigned long result = inet_addr(addr);
138  if (result == INADDR_NONE)
139  {
140  SetLastError(SOCKET_EINVAL);
141  CheckAndHandleError_int("inet_addr", SOCKET_ERROR);
142  }
143  sa.sin_addr.s_addr = result;
144  }
145 
146  sa.sin_port = htons((unsigned short)port);
147 
148  Bind((sockaddr *)&sa, sizeof(sa));
149 }
150 
151 void Socket::Bind(const sockaddr *psa, socklen_t saLen)
152 {
153  assert(m_s != INVALID_SOCKET);
154  // cygwin workaround: needs const_cast
155  CheckAndHandleError_int("bind", bind(m_s, const_cast<sockaddr *>(psa), saLen));
156 }
157 
158 void Socket::Listen(int backlog)
159 {
160  assert(m_s != INVALID_SOCKET);
161  CheckAndHandleError_int("listen", listen(m_s, backlog));
162 }
163 
164 bool Socket::Connect(const char *addr, unsigned int port)
165 {
166  assert(addr != NULL);
167 
168  sockaddr_in sa;
169  memset(&sa, 0, sizeof(sa));
170  sa.sin_family = AF_INET;
171  sa.sin_addr.s_addr = inet_addr(addr);
172 
173  if (sa.sin_addr.s_addr == INADDR_NONE)
174  {
175  hostent *lphost = gethostbyname(addr);
176  if (lphost == NULL)
177  {
178  SetLastError(SOCKET_EINVAL);
179  CheckAndHandleError_int("gethostbyname", SOCKET_ERROR);
180  }
181  else
182  {
183  assert(IsAlignedOn(lphost->h_addr,GetAlignmentOf<in_addr>()));
184  sa.sin_addr.s_addr = ((in_addr *)(void *)lphost->h_addr)->s_addr;
185  }
186  }
187 
188  sa.sin_port = htons((unsigned short)port);
189 
190  return Connect((const sockaddr *)&sa, sizeof(sa));
191 }
192 
193 bool Socket::Connect(const sockaddr* psa, socklen_t saLen)
194 {
195  assert(m_s != INVALID_SOCKET);
196  int result = connect(m_s, const_cast<sockaddr*>(psa), saLen);
197  if (result == SOCKET_ERROR && GetLastError() == SOCKET_EWOULDBLOCK)
198  return false;
199  CheckAndHandleError_int("connect", result);
200  return true;
201 }
202 
203 bool Socket::Accept(Socket& target, sockaddr *psa, socklen_t *psaLen)
204 {
205  assert(m_s != INVALID_SOCKET);
206  socket_t s = accept(m_s, psa, psaLen);
207  if (s == INVALID_SOCKET && GetLastError() == SOCKET_EWOULDBLOCK)
208  return false;
209  CheckAndHandleError("accept", s);
210  target.AttachSocket(s, true);
211  return true;
212 }
213 
214 void Socket::GetSockName(sockaddr *psa, socklen_t *psaLen)
215 {
216  assert(m_s != INVALID_SOCKET);
217  CheckAndHandleError_int("getsockname", getsockname(m_s, psa, psaLen));
218 }
219 
220 void Socket::GetPeerName(sockaddr *psa, socklen_t *psaLen)
221 {
222  assert(m_s != INVALID_SOCKET);
223  CheckAndHandleError_int("getpeername", getpeername(m_s, psa, psaLen));
224 }
225 
226 unsigned int Socket::Send(const byte* buf, size_t bufLen, int flags)
227 {
228  assert(m_s != INVALID_SOCKET);
229  int result = send(m_s, (const char *)buf, UnsignedMin(INT_MAX, bufLen), flags);
230  CheckAndHandleError_int("send", result);
231  return result;
232 }
233 
234 unsigned int Socket::Receive(byte* buf, size_t bufLen, int flags)
235 {
236  assert(m_s != INVALID_SOCKET);
237  int result = recv(m_s, (char *)buf, UnsignedMin(INT_MAX, bufLen), flags);
238  CheckAndHandleError_int("recv", result);
239  return result;
240 }
241 
242 void Socket::ShutDown(int how)
243 {
244  assert(m_s != INVALID_SOCKET);
245  int result = shutdown(m_s, how);
246  CheckAndHandleError_int("shutdown", result);
247 }
248 
249 void Socket::IOCtl(long cmd, unsigned long *argp)
250 {
251  assert(m_s != INVALID_SOCKET);
252 #ifdef USE_WINDOWS_STYLE_SOCKETS
253  CheckAndHandleError_int("ioctlsocket", ioctlsocket(m_s, cmd, argp));
254 #else
255  CheckAndHandleError_int("ioctl", ioctl(m_s, cmd, argp));
256 #endif
257 }
258 
259 bool Socket::SendReady(const timeval *timeout)
260 {
261  fd_set fds;
262  FD_ZERO(&fds);
263  FD_SET(m_s, &fds);
264 #ifdef CRYPTOPP_MSAN
265  __msan_unpoison(&fds, sizeof(fds));
266 #endif
267 
268  int ready;
269  if (timeout == NULL)
270  ready = select((int)m_s+1, NULL, &fds, NULL, NULL);
271  else
272  {
273  timeval timeoutCopy = *timeout; // select() modified timeout on Linux
274  ready = select((int)m_s+1, NULL, &fds, NULL, &timeoutCopy);
275  }
276  CheckAndHandleError_int("select", ready);
277  return ready > 0;
278 }
279 
280 bool Socket::ReceiveReady(const timeval *timeout)
281 {
282  fd_set fds;
283  FD_ZERO(&fds);
284  FD_SET(m_s, &fds);
285 #ifdef CRYPTOPP_MSAN
286  __msan_unpoison(&fds, sizeof(fds));
287 #endif
288 
289  int ready;
290  if (timeout == NULL)
291  ready = select((int)m_s+1, &fds, NULL, NULL, NULL);
292  else
293  {
294  timeval timeoutCopy = *timeout; // select() modified timeout on Linux
295  ready = select((int)m_s+1, &fds, NULL, NULL, &timeoutCopy);
296  }
297  CheckAndHandleError_int("select", ready);
298  return ready > 0;
299 }
300 
301 unsigned int Socket::PortNameToNumber(const char *name, const char *protocol)
302 {
303  int port = atoi(name);
304  if (IntToString(port) == name)
305  return port;
306 
307  servent *se = getservbyname(name, protocol);
308  if (!se)
309  throw Err(INVALID_SOCKET, "getservbyname", SOCKET_EINVAL);
310  return ntohs(se->s_port);
311 }
312 
314 {
315 #ifdef USE_WINDOWS_STYLE_SOCKETS
316  WSADATA wsd;
317  int result = WSAStartup(0x0202, &wsd);
318  if (result != 0)
319  throw Err(INVALID_SOCKET, "WSAStartup", result);
320 #endif
321 }
322 
324 {
325 #ifdef USE_WINDOWS_STYLE_SOCKETS
326  int result = WSACleanup();
327  if (result != 0)
328  throw Err(INVALID_SOCKET, "WSACleanup", result);
329 #endif
330 }
331 
333 {
334 #ifdef USE_WINDOWS_STYLE_SOCKETS
335  return WSAGetLastError();
336 #else
337  return errno;
338 #endif
339 }
340 
341 void Socket::SetLastError(int errorCode)
342 {
343 #ifdef USE_WINDOWS_STYLE_SOCKETS
344  WSASetLastError(errorCode);
345 #else
346  errno = errorCode;
347 #endif
348 }
349 
350 void Socket::HandleError(const char *operation) const
351 {
352  int err = GetLastError();
353  throw Err(m_s, operation, err);
354 }
355 
356 #ifdef USE_WINDOWS_STYLE_SOCKETS
357 
358 SocketReceiver::SocketReceiver(Socket &s)
359  : m_s(s), m_eofReceived(false), m_resultPending(false)
360 {
361  m_event.AttachHandle(CreateEvent(NULL, true, false, NULL), true);
362  m_s.CheckAndHandleError("CreateEvent", m_event.HandleValid());
363  memset(&m_overlapped, 0, sizeof(m_overlapped));
364  m_overlapped.hEvent = m_event;
365 }
366 
367 SocketReceiver::~SocketReceiver()
368 {
369 #ifdef USE_WINDOWS_STYLE_SOCKETS
370 # if defined(USE_WINDOWS8_API)
371  BOOL result = CancelIoEx((HANDLE) m_s.GetSocket(), NULL);
372  assert(result || (!result && GetLastError() == ERROR_NOT_FOUND));
373  CRYPTOPP_UNUSED(result); // Used by assert in debug builds
374 # else
375  BOOL result = CancelIo((HANDLE) m_s.GetSocket());
376  assert(result || (!result && GetLastError() == ERROR_NOT_FOUND));
377  CRYPTOPP_UNUSED(result);
378 # endif
379 #endif
380 }
381 
382 bool SocketReceiver::Receive(byte* buf, size_t bufLen)
383 {
384  assert(!m_resultPending && !m_eofReceived);
385 
386  DWORD flags = 0;
387  // don't queue too much at once, or we might use up non-paged memory
388  WSABUF wsabuf = {UnsignedMin((u_long)128*1024, bufLen), (char *)buf};
389  if (WSARecv(m_s, &wsabuf, 1, &m_lastResult, &flags, &m_overlapped, NULL) == 0)
390  {
391  if (m_lastResult == 0)
392  m_eofReceived = true;
393  }
394  else
395  {
396  switch (WSAGetLastError())
397  {
398  default:
399  m_s.CheckAndHandleError_int("WSARecv", SOCKET_ERROR);
400  case WSAEDISCON:
401  m_lastResult = 0;
402  m_eofReceived = true;
403  break;
404  case WSA_IO_PENDING:
405  m_resultPending = true;
406  }
407  }
408  return !m_resultPending;
409 }
410 
412 {
413  if (m_resultPending)
414  container.AddHandle(m_event, CallStack("SocketReceiver::GetWaitObjects() - result pending", &callStack));
415  else if (!m_eofReceived)
416  container.SetNoWait(CallStack("SocketReceiver::GetWaitObjects() - result ready", &callStack));
417 }
418 
419 unsigned int SocketReceiver::GetReceiveResult()
420 {
421  if (m_resultPending)
422  {
423  DWORD flags = 0;
424  if (WSAGetOverlappedResult(m_s, &m_overlapped, &m_lastResult, false, &flags))
425  {
426  if (m_lastResult == 0)
427  m_eofReceived = true;
428  }
429  else
430  {
431  switch (WSAGetLastError())
432  {
433  default:
434  m_s.CheckAndHandleError("WSAGetOverlappedResult", FALSE);
435  case WSAEDISCON:
436  m_lastResult = 0;
437  m_eofReceived = true;
438  }
439  }
440  m_resultPending = false;
441  }
442  return m_lastResult;
443 }
444 
445 // *************************************************************
446 
447 SocketSender::SocketSender(Socket &s)
448  : m_s(s), m_resultPending(false), m_lastResult(0)
449 {
450  m_event.AttachHandle(CreateEvent(NULL, true, false, NULL), true);
451  m_s.CheckAndHandleError("CreateEvent", m_event.HandleValid());
452  memset(&m_overlapped, 0, sizeof(m_overlapped));
453  m_overlapped.hEvent = m_event;
454 }
455 
456 
457 SocketSender::~SocketSender()
458 {
459 #ifdef USE_WINDOWS_STYLE_SOCKETS
460 # if defined(USE_WINDOWS8_API)
461  BOOL result = CancelIoEx((HANDLE) m_s.GetSocket(), NULL);
462  assert(result || (!result && GetLastError() == ERROR_NOT_FOUND));
463  CRYPTOPP_UNUSED(result); // Used by assert in debug builds
464 # else
465  BOOL result = CancelIo((HANDLE) m_s.GetSocket());
466  assert(result || (!result && GetLastError() == ERROR_NOT_FOUND));
467  CRYPTOPP_UNUSED(result);
468 # endif
469 #endif
470 }
471 
472 void SocketSender::Send(const byte* buf, size_t bufLen)
473 {
474  assert(!m_resultPending);
475  DWORD written = 0;
476  // don't queue too much at once, or we might use up non-paged memory
477  WSABUF wsabuf = {UnsignedMin((u_long)128*1024, bufLen), (char *)buf};
478  if (WSASend(m_s, &wsabuf, 1, &written, 0, &m_overlapped, NULL) == 0)
479  {
480  m_resultPending = false;
481  m_lastResult = written;
482  }
483  else
484  {
485  if (WSAGetLastError() != WSA_IO_PENDING)
486  m_s.CheckAndHandleError_int("WSASend", SOCKET_ERROR);
487 
488  m_resultPending = true;
489  }
490 }
491 
492 void SocketSender::SendEof()
493 {
494  assert(!m_resultPending);
495  m_s.ShutDown(SD_SEND);
496  m_s.CheckAndHandleError("ResetEvent", ResetEvent(m_event));
497  m_s.CheckAndHandleError_int("WSAEventSelect", WSAEventSelect(m_s, m_event, FD_CLOSE));
498  m_resultPending = true;
499 }
500 
501 bool SocketSender::EofSent()
502 {
503  if (m_resultPending)
504  {
505  WSANETWORKEVENTS events;
506  m_s.CheckAndHandleError_int("WSAEnumNetworkEvents", WSAEnumNetworkEvents(m_s, m_event, &events));
507  if ((events.lNetworkEvents & FD_CLOSE) != FD_CLOSE)
508  throw Socket::Err(m_s, "WSAEnumNetworkEvents (FD_CLOSE not present)", E_FAIL);
509  if (events.iErrorCode[FD_CLOSE_BIT] != 0)
510  throw Socket::Err(m_s, "FD_CLOSE (via WSAEnumNetworkEvents)", events.iErrorCode[FD_CLOSE_BIT]);
511  m_resultPending = false;
512  }
513  return m_lastResult != 0;
514 }
515 
517 {
518  if (m_resultPending)
519  container.AddHandle(m_event, CallStack("SocketSender::GetWaitObjects() - result pending", &callStack));
520  else
521  container.SetNoWait(CallStack("SocketSender::GetWaitObjects() - result ready", &callStack));
522 }
523 
524 unsigned int SocketSender::GetSendResult()
525 {
526  if (m_resultPending)
527  {
528  DWORD flags = 0;
529  BOOL result = WSAGetOverlappedResult(m_s, &m_overlapped, &m_lastResult, false, &flags);
530  m_s.CheckAndHandleError("WSAGetOverlappedResult", result);
531  m_resultPending = false;
532  }
533  return m_lastResult;
534 }
535 
536 #endif
537 
538 #ifdef USE_BERKELEY_STYLE_SOCKETS
539 
540 SocketReceiver::SocketReceiver(Socket &s)
541  : m_s(s), m_eofReceived(false), m_lastResult(0)
542 {
543 }
544 
545 void SocketReceiver::GetWaitObjects(WaitObjectContainer &container, CallStack const& callStack)
546 {
547  if (!m_eofReceived)
548  container.AddReadFd(m_s, CallStack("SocketReceiver::GetWaitObjects()", &callStack));
549 }
550 
551 bool SocketReceiver::Receive(byte* buf, size_t bufLen)
552 {
553  m_lastResult = m_s.Receive(buf, bufLen);
554  if (bufLen > 0 && m_lastResult == 0)
555  m_eofReceived = true;
556  return true;
557 }
558 
559 unsigned int SocketReceiver::GetReceiveResult()
560 {
561  return m_lastResult;
562 }
563 
564 SocketSender::SocketSender(Socket &s)
565  : m_s(s), m_lastResult(0)
566 {
567 }
568 
569 void SocketSender::Send(const byte* buf, size_t bufLen)
570 {
571  m_lastResult = m_s.Send(buf, bufLen);
572 }
573 
574 void SocketSender::SendEof()
575 {
576  m_s.ShutDown(SD_SEND);
577 }
578 
579 unsigned int SocketSender::GetSendResult()
580 {
581  return m_lastResult;
582 }
583 
584 void SocketSender::GetWaitObjects(WaitObjectContainer &container, CallStack const& callStack)
585 {
586  container.AddWriteFd(m_s, CallStack("SocketSender::GetWaitObjects()", &callStack));
587 }
588 
589 #endif // USE_BERKELEY_STYLE_SOCKETS
590 
591 NAMESPACE_END
592 
593 #endif // SOCKETS_AVAILABLE
Base class for all exceptions thrown by the library.
Definition: cryptlib.h:139
static unsigned int PortNameToNumber(const char *name, const char *protocol="tcp")
look up the port number given its name, returns 0 if not found
Definition: socketft.cpp:301
container of wait objects
Definition: wait.h:169
The operating system reported an error.
Definition: cryptlib.h:217
static void ShutdownSockets()
calls WSACleanup for Windows Sockets
Definition: socketft.cpp:323
Library configuration file.
exception thrown by Socket class
Definition: socketft.h:48
bool IsAlignedOn(const void *ptr, unsigned int alignment)
Determines whether ptr is aligned to a minimum value.
Definition: misc.h:907
static void StartSockets()
start Windows Sockets 2
Definition: socketft.cpp:313
const T1 UnsignedMin(const T1 &a, const T2 &b)
Safe comparison of values that could be neagtive and incorrectly promoted.
Definition: misc.h:503
static void SetLastError(int errorCode)
sets errno or calls WSASetLastError
Definition: socketft.cpp:341
wrapper for Windows or Berkeley Sockets
Definition: socketft.h:44
bool Receive(byte *buf, size_t bufLen)
receive data from network source, returns whether result is immediately available ...
Definition: socketft.cpp:382
std::string IntToString(T value, unsigned int base=10)
Converts a value to a string.
Definition: misc.h:530
Crypto++ library namespace.
void GetWaitObjects(WaitObjectContainer &container, CallStack const &callStack)
Retrieves waitable objects.
Definition: socketft.cpp:411
static int GetLastError()
returns errno or WSAGetLastError
Definition: socketft.cpp:332
void GetWaitObjects(WaitObjectContainer &container, CallStack const &callStack)
Retrieves waitable objects.
Definition: socketft.cpp:516