Crypto++  5.6.3
Free C++ class library of cryptographic schemes
socketft.cpp
1 // socketft.cpp - written and placed in the public domain by Wei Dai
2 
3 #include "pch.h"
4 #include "config.h"
5 
6 #if !defined(NO_OS_DEPENDENCE) && defined(SOCKETS_AVAILABLE)
7 
8 // TODO: http://github.com/weidai11/cryptopp/issues/19
9 #define _WINSOCK_DEPRECATED_NO_WARNINGS
10 #include "socketft.h"
11 #include "wait.h"
12 
13 // Windows 8, Windows Server 2012, and Windows Phone 8.1 need <synchapi.h> and <ioapiset.h>
14 #if defined(CRYPTOPP_WIN32_AVAILABLE)
15 # if ((WINVER >= 0x0602 /*_WIN32_WINNT_WIN8*/) || (_WIN32_WINNT >= 0x0602 /*_WIN32_WINNT_WIN8*/))
16 # include <synchapi.h>
17 # include <ioapiset.h>
18 # define USE_WINDOWS8_API
19 # endif
20 #endif
21 
22 #ifdef USE_BERKELEY_STYLE_SOCKETS
23 #include <errno.h>
24 #include <netdb.h>
25 #include <unistd.h>
26 #include <arpa/inet.h>
27 #include <netinet/in.h>
28 #include <sys/ioctl.h>
29 #endif
30 
31 #if defined(CRYPTOPP_MSAN)
32 # include <sanitizer/msan_interface.h>
33 #endif
34 
35 #ifdef PREFER_WINDOWS_STYLE_SOCKETS
36 # pragma comment(lib, "ws2_32.lib")
37 #endif
38 
39 NAMESPACE_BEGIN(CryptoPP)
40 
41 #ifdef USE_WINDOWS_STYLE_SOCKETS
42 const int SOCKET_EINVAL = WSAEINVAL;
43 const int SOCKET_EWOULDBLOCK = WSAEWOULDBLOCK;
44 typedef int socklen_t;
45 #else
46 const int SOCKET_EINVAL = EINVAL;
47 const int SOCKET_EWOULDBLOCK = EWOULDBLOCK;
48 #endif
49 
50 // Solaris doesn't have INADDR_NONE
51 #ifndef INADDR_NONE
52 # define INADDR_NONE 0xffffffff
53 #endif /* INADDR_NONE */
54 
55 Socket::Err::Err(socket_t s, const std::string& operation, int error)
56  : OS_Error(IO_ERROR, "Socket: " + operation + " operation failed with error " + IntToString(error), operation, error)
57  , m_s(s)
58 {
59 }
60 
61 Socket::~Socket()
62 {
63  if (m_own)
64  {
65  try
66  {
67  CloseSocket();
68  }
69  catch (const Exception&)
70  {
71  assert(0);
72  }
73  }
74 }
75 
76 void Socket::AttachSocket(socket_t s, bool own)
77 {
78  if (m_own)
79  CloseSocket();
80 
81  m_s = s;
82  m_own = own;
83  SocketChanged();
84 }
85 
86 socket_t Socket::DetachSocket()
87 {
88  socket_t s = m_s;
89  m_s = INVALID_SOCKET;
90  SocketChanged();
91  return s;
92 }
93 
94 void Socket::Create(int nType)
95 {
96  assert(m_s == INVALID_SOCKET);
97  m_s = socket(AF_INET, nType, 0);
98  CheckAndHandleError("socket", m_s);
99  m_own = true;
100  SocketChanged();
101 }
102 
103 void Socket::CloseSocket()
104 {
105  if (m_s != INVALID_SOCKET)
106  {
107 #ifdef USE_WINDOWS_STYLE_SOCKETS
108 # if defined(USE_WINDOWS8_API)
109  BOOL result = CancelIoEx((HANDLE) m_s, NULL);
110  assert(result || (!result && GetLastError() == ERROR_NOT_FOUND));
111  CheckAndHandleError_int("closesocket", closesocket(m_s));
112 # else
113  BOOL result = CancelIo((HANDLE) m_s);
114  assert(result || (!result && GetLastError() == ERROR_NOT_FOUND));
115  CheckAndHandleError_int("closesocket", closesocket(m_s));
116 # endif
117 #else
118  CheckAndHandleError_int("close", close(m_s));
119 #endif
120  m_s = INVALID_SOCKET;
121  SocketChanged();
122  }
123 }
124 
125 void Socket::Bind(unsigned int port, const char *addr)
126 {
127  sockaddr_in sa;
128  memset(&sa, 0, sizeof(sa));
129  sa.sin_family = AF_INET;
130 
131  if (addr == NULL)
132  sa.sin_addr.s_addr = htonl(INADDR_ANY);
133  else
134  {
135  unsigned long result = inet_addr(addr);
136  if (result == INADDR_NONE)
137  {
138  SetLastError(SOCKET_EINVAL);
139  CheckAndHandleError_int("inet_addr", SOCKET_ERROR);
140  }
141  sa.sin_addr.s_addr = result;
142  }
143 
144  sa.sin_port = htons((u_short)port);
145 
146  Bind((sockaddr *)&sa, sizeof(sa));
147 }
148 
149 void Socket::Bind(const sockaddr *psa, socklen_t saLen)
150 {
151  assert(m_s != INVALID_SOCKET);
152  // cygwin workaround: needs const_cast
153  CheckAndHandleError_int("bind", bind(m_s, const_cast<sockaddr *>(psa), saLen));
154 }
155 
156 void Socket::Listen(int backlog)
157 {
158  assert(m_s != INVALID_SOCKET);
159  CheckAndHandleError_int("listen", listen(m_s, backlog));
160 }
161 
162 bool Socket::Connect(const char *addr, unsigned int port)
163 {
164  assert(addr != NULL);
165 
166  sockaddr_in sa;
167  memset(&sa, 0, sizeof(sa));
168  sa.sin_family = AF_INET;
169  sa.sin_addr.s_addr = inet_addr(addr);
170 
171  if (sa.sin_addr.s_addr == INADDR_NONE)
172  {
173  hostent *lphost = gethostbyname(addr);
174  if (lphost == NULL)
175  {
176  SetLastError(SOCKET_EINVAL);
177  CheckAndHandleError_int("gethostbyname", SOCKET_ERROR);
178  }
179  else
180  {
181  assert(IsAlignedOn(lphost->h_addr,GetAlignmentOf<in_addr>()));
182  sa.sin_addr.s_addr = ((in_addr *)(void *)lphost->h_addr)->s_addr;
183  }
184  }
185 
186  sa.sin_port = htons((u_short)port);
187 
188  return Connect((const sockaddr *)&sa, sizeof(sa));
189 }
190 
191 bool Socket::Connect(const sockaddr* psa, socklen_t saLen)
192 {
193  assert(m_s != INVALID_SOCKET);
194  int result = connect(m_s, const_cast<sockaddr*>(psa), saLen);
195  if (result == SOCKET_ERROR && GetLastError() == SOCKET_EWOULDBLOCK)
196  return false;
197  CheckAndHandleError_int("connect", result);
198  return true;
199 }
200 
201 bool Socket::Accept(Socket& target, sockaddr *psa, socklen_t *psaLen)
202 {
203  assert(m_s != INVALID_SOCKET);
204  socket_t s = accept(m_s, psa, psaLen);
205  if (s == INVALID_SOCKET && GetLastError() == SOCKET_EWOULDBLOCK)
206  return false;
207  CheckAndHandleError("accept", s);
208  target.AttachSocket(s, true);
209  return true;
210 }
211 
212 void Socket::GetSockName(sockaddr *psa, socklen_t *psaLen)
213 {
214  assert(m_s != INVALID_SOCKET);
215  CheckAndHandleError_int("getsockname", getsockname(m_s, psa, psaLen));
216 }
217 
218 void Socket::GetPeerName(sockaddr *psa, socklen_t *psaLen)
219 {
220  assert(m_s != INVALID_SOCKET);
221  CheckAndHandleError_int("getpeername", getpeername(m_s, psa, psaLen));
222 }
223 
224 unsigned int Socket::Send(const byte* buf, size_t bufLen, int flags)
225 {
226  assert(m_s != INVALID_SOCKET);
227  int result = send(m_s, (const char *)buf, UnsignedMin(INT_MAX, bufLen), flags);
228  CheckAndHandleError_int("send", result);
229  return result;
230 }
231 
232 unsigned int Socket::Receive(byte* buf, size_t bufLen, int flags)
233 {
234  assert(m_s != INVALID_SOCKET);
235  int result = recv(m_s, (char *)buf, UnsignedMin(INT_MAX, bufLen), flags);
236  CheckAndHandleError_int("recv", result);
237  return result;
238 }
239 
240 void Socket::ShutDown(int how)
241 {
242  assert(m_s != INVALID_SOCKET);
243  int result = shutdown(m_s, how);
244  CheckAndHandleError_int("shutdown", result);
245 }
246 
247 void Socket::IOCtl(long cmd, unsigned long *argp)
248 {
249  assert(m_s != INVALID_SOCKET);
250 #ifdef USE_WINDOWS_STYLE_SOCKETS
251  CheckAndHandleError_int("ioctlsocket", ioctlsocket(m_s, cmd, argp));
252 #else
253  CheckAndHandleError_int("ioctl", ioctl(m_s, cmd, argp));
254 #endif
255 }
256 
257 bool Socket::SendReady(const timeval *timeout)
258 {
259  fd_set fds;
260  FD_ZERO(&fds);
261  FD_SET(m_s, &fds);
262 #ifdef CRYPTOPP_MSAN
263  __msan_unpoison(&fds, sizeof(fds));
264 #endif
265 
266  int ready;
267  if (timeout == NULL)
268  ready = select((int)m_s+1, NULL, &fds, NULL, NULL);
269  else
270  {
271  timeval timeoutCopy = *timeout; // select() modified timeout on Linux
272  ready = select((int)m_s+1, NULL, &fds, NULL, &timeoutCopy);
273  }
274  CheckAndHandleError_int("select", ready);
275  return ready > 0;
276 }
277 
278 bool Socket::ReceiveReady(const timeval *timeout)
279 {
280  fd_set fds;
281  FD_ZERO(&fds);
282  FD_SET(m_s, &fds);
283 #ifdef CRYPTOPP_MSAN
284  __msan_unpoison(&fds, sizeof(fds));
285 #endif
286 
287  int ready;
288  if (timeout == NULL)
289  ready = select((int)m_s+1, &fds, NULL, NULL, NULL);
290  else
291  {
292  timeval timeoutCopy = *timeout; // select() modified timeout on Linux
293  ready = select((int)m_s+1, &fds, NULL, NULL, &timeoutCopy);
294  }
295  CheckAndHandleError_int("select", ready);
296  return ready > 0;
297 }
298 
299 unsigned int Socket::PortNameToNumber(const char *name, const char *protocol)
300 {
301  int port = atoi(name);
302  if (IntToString(port) == name)
303  return port;
304 
305  servent *se = getservbyname(name, protocol);
306  if (!se)
307  throw Err(INVALID_SOCKET, "getservbyname", SOCKET_EINVAL);
308  return ntohs(se->s_port);
309 }
310 
312 {
313 #ifdef USE_WINDOWS_STYLE_SOCKETS
314  WSADATA wsd;
315  int result = WSAStartup(0x0202, &wsd);
316  if (result != 0)
317  throw Err(INVALID_SOCKET, "WSAStartup", result);
318 #endif
319 }
320 
322 {
323 #ifdef USE_WINDOWS_STYLE_SOCKETS
324  int result = WSACleanup();
325  if (result != 0)
326  throw Err(INVALID_SOCKET, "WSACleanup", result);
327 #endif
328 }
329 
331 {
332 #ifdef USE_WINDOWS_STYLE_SOCKETS
333  return WSAGetLastError();
334 #else
335  return errno;
336 #endif
337 }
338 
339 void Socket::SetLastError(int errorCode)
340 {
341 #ifdef USE_WINDOWS_STYLE_SOCKETS
342  WSASetLastError(errorCode);
343 #else
344  errno = errorCode;
345 #endif
346 }
347 
348 void Socket::HandleError(const char *operation) const
349 {
350  int err = GetLastError();
351  throw Err(m_s, operation, err);
352 }
353 
354 #ifdef USE_WINDOWS_STYLE_SOCKETS
355 
356 SocketReceiver::SocketReceiver(Socket &s)
357  : m_s(s), m_eofReceived(false), m_resultPending(false)
358 {
359  m_event.AttachHandle(CreateEvent(NULL, true, false, NULL), true);
360  m_s.CheckAndHandleError("CreateEvent", m_event.HandleValid());
361  memset(&m_overlapped, 0, sizeof(m_overlapped));
362  m_overlapped.hEvent = m_event;
363 }
364 
365 SocketReceiver::~SocketReceiver()
366 {
367 #ifdef USE_WINDOWS_STYLE_SOCKETS
368 # if defined(USE_WINDOWS8_API)
369  BOOL result = CancelIoEx((HANDLE) m_s.GetSocket(), NULL);
370  assert(result || (!result && GetLastError() == ERROR_NOT_FOUND));
371 # else
372  BOOL result = CancelIo((HANDLE) m_s.GetSocket());
373  assert(result || (!result && GetLastError() == ERROR_NOT_FOUND));
374 # endif
375 #endif
376 }
377 
378 bool SocketReceiver::Receive(byte* buf, size_t bufLen)
379 {
380  assert(!m_resultPending && !m_eofReceived);
381 
382  DWORD flags = 0;
383  // don't queue too much at once, or we might use up non-paged memory
384  WSABUF wsabuf = {UnsignedMin((u_long)128*1024, bufLen), (char *)buf};
385  if (WSARecv(m_s, &wsabuf, 1, &m_lastResult, &flags, &m_overlapped, NULL) == 0)
386  {
387  if (m_lastResult == 0)
388  m_eofReceived = true;
389  }
390  else
391  {
392  switch (WSAGetLastError())
393  {
394  default:
395  m_s.CheckAndHandleError_int("WSARecv", SOCKET_ERROR);
396  case WSAEDISCON:
397  m_lastResult = 0;
398  m_eofReceived = true;
399  break;
400  case WSA_IO_PENDING:
401  m_resultPending = true;
402  }
403  }
404  return !m_resultPending;
405 }
406 
408 {
409  if (m_resultPending)
410  container.AddHandle(m_event, CallStack("SocketReceiver::GetWaitObjects() - result pending", &callStack));
411  else if (!m_eofReceived)
412  container.SetNoWait(CallStack("SocketReceiver::GetWaitObjects() - result ready", &callStack));
413 }
414 
415 unsigned int SocketReceiver::GetReceiveResult()
416 {
417  if (m_resultPending)
418  {
419  DWORD flags = 0;
420  if (WSAGetOverlappedResult(m_s, &m_overlapped, &m_lastResult, false, &flags))
421  {
422  if (m_lastResult == 0)
423  m_eofReceived = true;
424  }
425  else
426  {
427  switch (WSAGetLastError())
428  {
429  default:
430  m_s.CheckAndHandleError("WSAGetOverlappedResult", FALSE);
431  case WSAEDISCON:
432  m_lastResult = 0;
433  m_eofReceived = true;
434  }
435  }
436  m_resultPending = false;
437  }
438  return m_lastResult;
439 }
440 
441 // *************************************************************
442 
443 SocketSender::SocketSender(Socket &s)
444  : m_s(s), m_resultPending(false), m_lastResult(0)
445 {
446  m_event.AttachHandle(CreateEvent(NULL, true, false, NULL), true);
447  m_s.CheckAndHandleError("CreateEvent", m_event.HandleValid());
448  memset(&m_overlapped, 0, sizeof(m_overlapped));
449  m_overlapped.hEvent = m_event;
450 }
451 
452 
453 SocketSender::~SocketSender()
454 {
455 #ifdef USE_WINDOWS_STYLE_SOCKETS
456 # if defined(USE_WINDOWS8_API)
457  BOOL result = CancelIoEx((HANDLE) m_s.GetSocket(), NULL);
458  assert(result || (!result && GetLastError() == ERROR_NOT_FOUND));
459 # else
460  BOOL result = CancelIo((HANDLE) m_s.GetSocket());
461  assert(result || (!result && GetLastError() == ERROR_NOT_FOUND));
462 # endif
463 #endif
464 }
465 
466 void SocketSender::Send(const byte* buf, size_t bufLen)
467 {
468  assert(!m_resultPending);
469  DWORD written = 0;
470  // don't queue too much at once, or we might use up non-paged memory
471  WSABUF wsabuf = {UnsignedMin((u_long)128*1024, bufLen), (char *)buf};
472  if (WSASend(m_s, &wsabuf, 1, &written, 0, &m_overlapped, NULL) == 0)
473  {
474  m_resultPending = false;
475  m_lastResult = written;
476  }
477  else
478  {
479  if (WSAGetLastError() != WSA_IO_PENDING)
480  m_s.CheckAndHandleError_int("WSASend", SOCKET_ERROR);
481 
482  m_resultPending = true;
483  }
484 }
485 
486 void SocketSender::SendEof()
487 {
488  assert(!m_resultPending);
489  m_s.ShutDown(SD_SEND);
490  m_s.CheckAndHandleError("ResetEvent", ResetEvent(m_event));
491  m_s.CheckAndHandleError_int("WSAEventSelect", WSAEventSelect(m_s, m_event, FD_CLOSE));
492  m_resultPending = true;
493 }
494 
495 bool SocketSender::EofSent()
496 {
497  if (m_resultPending)
498  {
499  WSANETWORKEVENTS events;
500  m_s.CheckAndHandleError_int("WSAEnumNetworkEvents", WSAEnumNetworkEvents(m_s, m_event, &events));
501  if ((events.lNetworkEvents & FD_CLOSE) != FD_CLOSE)
502  throw Socket::Err(m_s, "WSAEnumNetworkEvents (FD_CLOSE not present)", E_FAIL);
503  if (events.iErrorCode[FD_CLOSE_BIT] != 0)
504  throw Socket::Err(m_s, "FD_CLOSE (via WSAEnumNetworkEvents)", events.iErrorCode[FD_CLOSE_BIT]);
505  m_resultPending = false;
506  }
507  return m_lastResult != 0;
508 }
509 
511 {
512  if (m_resultPending)
513  container.AddHandle(m_event, CallStack("SocketSender::GetWaitObjects() - result pending", &callStack));
514  else
515  container.SetNoWait(CallStack("SocketSender::GetWaitObjects() - result ready", &callStack));
516 }
517 
518 unsigned int SocketSender::GetSendResult()
519 {
520  if (m_resultPending)
521  {
522  DWORD flags = 0;
523  BOOL result = WSAGetOverlappedResult(m_s, &m_overlapped, &m_lastResult, false, &flags);
524  m_s.CheckAndHandleError("WSAGetOverlappedResult", result);
525  m_resultPending = false;
526  }
527  return m_lastResult;
528 }
529 
530 #endif
531 
532 #ifdef USE_BERKELEY_STYLE_SOCKETS
533 
534 SocketReceiver::SocketReceiver(Socket &s)
535  : m_s(s), m_eofReceived(false), m_lastResult(0)
536 {
537 }
538 
539 void SocketReceiver::GetWaitObjects(WaitObjectContainer &container, CallStack const& callStack)
540 {
541  if (!m_eofReceived)
542  container.AddReadFd(m_s, CallStack("SocketReceiver::GetWaitObjects()", &callStack));
543 }
544 
545 bool SocketReceiver::Receive(byte* buf, size_t bufLen)
546 {
547  m_lastResult = m_s.Receive(buf, bufLen);
548  if (bufLen > 0 && m_lastResult == 0)
549  m_eofReceived = true;
550  return true;
551 }
552 
553 unsigned int SocketReceiver::GetReceiveResult()
554 {
555  return m_lastResult;
556 }
557 
558 SocketSender::SocketSender(Socket &s)
559  : m_s(s), m_lastResult(0)
560 {
561 }
562 
563 void SocketSender::Send(const byte* buf, size_t bufLen)
564 {
565  m_lastResult = m_s.Send(buf, bufLen);
566 }
567 
568 void SocketSender::SendEof()
569 {
570  m_s.ShutDown(SD_SEND);
571 }
572 
573 unsigned int SocketSender::GetSendResult()
574 {
575  return m_lastResult;
576 }
577 
578 void SocketSender::GetWaitObjects(WaitObjectContainer &container, CallStack const& callStack)
579 {
580  container.AddWriteFd(m_s, CallStack("SocketSender::GetWaitObjects()", &callStack));
581 }
582 
583 #endif // USE_BERKELEY_STYLE_SOCKETS
584 
585 NAMESPACE_END
586 
587 #endif // SOCKETS_AVAILABLE
Base class for all exceptions thrown by the library.
Definition: cryptlib.h:139
static unsigned int PortNameToNumber(const char *name, const char *protocol="tcp")
look up the port number given its name, returns 0 if not found
Definition: socketft.cpp:299
container of wait objects
Definition: wait.h:162
The operating system reported an error.
Definition: cryptlib.h:217
static void ShutdownSockets()
calls WSACleanup for Windows Sockets
Definition: socketft.cpp:321
Library configuration file.
exception thrown by Socket class
Definition: socketft.h:48
bool IsAlignedOn(const void *ptr, unsigned int alignment)
Determines whether ptr is aligned to a minimum value.
Definition: misc.h:869
static void StartSockets()
start Windows Sockets 2
Definition: socketft.cpp:311
const T1 UnsignedMin(const T1 &a, const T2 &b)
Safe comparison of values that could be neagtive and incorrectly promoted.
Definition: misc.h:467
static void SetLastError(int errorCode)
sets errno or calls WSASetLastError
Definition: socketft.cpp:339
wrapper for Windows or Berkeley Sockets
Definition: socketft.h:44
bool Receive(byte *buf, size_t bufLen)
receive data from network source, returns whether result is immediately available ...
Definition: socketft.cpp:378
std::string IntToString(T value, unsigned int base=10)
Converts a value to a string.
Definition: misc.h:494
Crypto++ library namespace.
void GetWaitObjects(WaitObjectContainer &container, CallStack const &callStack)
Retrieves waitable objects.
Definition: socketft.cpp:407
static int GetLastError()
returns errno or WSAGetLastError
Definition: socketft.cpp:330
void GetWaitObjects(WaitObjectContainer &container, CallStack const &callStack)
Retrieves waitable objects.
Definition: socketft.cpp:510