Crypto++  5.6.3
Free C++ class library of cryptographic schemes
socketft.cpp
1 // socketft.cpp - written and placed in the public domain by Wei Dai
2 
3 #include "pch.h"
4 
5 // TODO: http://github.com/weidai11/cryptopp/issues/19
6 #define _WINSOCK_DEPRECATED_NO_WARNINGS
7 #include "socketft.h"
8 
9 #ifdef SOCKETS_AVAILABLE
10 
11 #include "wait.h"
12 
13 #ifdef USE_BERKELEY_STYLE_SOCKETS
14 #include <errno.h>
15 #include <netdb.h>
16 #include <unistd.h>
17 #include <arpa/inet.h>
18 #include <netinet/in.h>
19 #include <sys/ioctl.h>
20 #endif
21 
22 #if defined(CRYPTOPP_MSAN)
23 # include <sanitizer/msan_interface.h>
24 #endif
25 
26 #ifdef PREFER_WINDOWS_STYLE_SOCKETS
27 # pragma comment(lib, "ws2_32.lib")
28 #endif
29 
30 NAMESPACE_BEGIN(CryptoPP)
31 
32 #ifdef USE_WINDOWS_STYLE_SOCKETS
33 const int SOCKET_EINVAL = WSAEINVAL;
34 const int SOCKET_EWOULDBLOCK = WSAEWOULDBLOCK;
35 typedef int socklen_t;
36 #else
37 const int SOCKET_EINVAL = EINVAL;
38 const int SOCKET_EWOULDBLOCK = EWOULDBLOCK;
39 #endif
40 
41 // Solaris doesn't have INADDR_NONE
42 #ifndef INADDR_NONE
43 # define INADDR_NONE 0xffffffff
44 #endif /* INADDR_NONE */
45 
46 Socket::Err::Err(socket_t s, const std::string& operation, int error)
47  : OS_Error(IO_ERROR, "Socket: " + operation + " operation failed with error " + IntToString(error), operation, error)
48  , m_s(s)
49 {
50 }
51 
52 Socket::~Socket()
53 {
54  if (m_own)
55  {
56  try
57  {
58  CloseSocket();
59  }
60  catch (const Exception&)
61  {
62  assert(0);
63  }
64  }
65 }
66 
67 void Socket::AttachSocket(socket_t s, bool own)
68 {
69  if (m_own)
70  CloseSocket();
71 
72  m_s = s;
73  m_own = own;
74  SocketChanged();
75 }
76 
77 socket_t Socket::DetachSocket()
78 {
79  socket_t s = m_s;
80  m_s = INVALID_SOCKET;
81  SocketChanged();
82  return s;
83 }
84 
85 void Socket::Create(int nType)
86 {
87  assert(m_s == INVALID_SOCKET);
88  m_s = socket(AF_INET, nType, 0);
89  CheckAndHandleError("socket", m_s);
90  m_own = true;
91  SocketChanged();
92 }
93 
94 void Socket::CloseSocket()
95 {
96  if (m_s != INVALID_SOCKET)
97  {
98 #ifdef USE_WINDOWS_STYLE_SOCKETS
99  CancelIo((HANDLE) m_s);
100  CheckAndHandleError_int("closesocket", closesocket(m_s));
101 #else
102  CheckAndHandleError_int("close", close(m_s));
103 #endif
104  m_s = INVALID_SOCKET;
105  SocketChanged();
106  }
107 }
108 
109 void Socket::Bind(unsigned int port, const char *addr)
110 {
111  sockaddr_in sa;
112  memset(&sa, 0, sizeof(sa));
113  sa.sin_family = AF_INET;
114 
115  if (addr == NULL)
116  sa.sin_addr.s_addr = htonl(INADDR_ANY);
117  else
118  {
119  unsigned long result = inet_addr(addr);
120  if (result == INADDR_NONE)
121  {
122  SetLastError(SOCKET_EINVAL);
123  CheckAndHandleError_int("inet_addr", SOCKET_ERROR);
124  }
125  sa.sin_addr.s_addr = result;
126  }
127 
128  sa.sin_port = htons((u_short)port);
129 
130  Bind((sockaddr *)&sa, sizeof(sa));
131 }
132 
133 void Socket::Bind(const sockaddr *psa, socklen_t saLen)
134 {
135  assert(m_s != INVALID_SOCKET);
136  // cygwin workaround: needs const_cast
137  CheckAndHandleError_int("bind", bind(m_s, const_cast<sockaddr *>(psa), saLen));
138 }
139 
140 void Socket::Listen(int backlog)
141 {
142  assert(m_s != INVALID_SOCKET);
143  CheckAndHandleError_int("listen", listen(m_s, backlog));
144 }
145 
146 bool Socket::Connect(const char *addr, unsigned int port)
147 {
148  assert(addr != NULL);
149 
150  sockaddr_in sa;
151  memset(&sa, 0, sizeof(sa));
152  sa.sin_family = AF_INET;
153  sa.sin_addr.s_addr = inet_addr(addr);
154 
155  if (sa.sin_addr.s_addr == INADDR_NONE)
156  {
157  hostent *lphost = gethostbyname(addr);
158  if (lphost == NULL)
159  {
160  SetLastError(SOCKET_EINVAL);
161  CheckAndHandleError_int("gethostbyname", SOCKET_ERROR);
162  }
163  else
164  {
165  assert(IsAlignedOn(lphost->h_addr,GetAlignmentOf<in_addr>()));
166  sa.sin_addr.s_addr = ((in_addr *)(void *)lphost->h_addr)->s_addr;
167  }
168  }
169 
170  sa.sin_port = htons((u_short)port);
171 
172  return Connect((const sockaddr *)&sa, sizeof(sa));
173 }
174 
175 bool Socket::Connect(const sockaddr* psa, socklen_t saLen)
176 {
177  assert(m_s != INVALID_SOCKET);
178  int result = connect(m_s, const_cast<sockaddr*>(psa), saLen);
179  if (result == SOCKET_ERROR && GetLastError() == SOCKET_EWOULDBLOCK)
180  return false;
181  CheckAndHandleError_int("connect", result);
182  return true;
183 }
184 
185 bool Socket::Accept(Socket& target, sockaddr *psa, socklen_t *psaLen)
186 {
187  assert(m_s != INVALID_SOCKET);
188  socket_t s = accept(m_s, psa, psaLen);
189  if (s == INVALID_SOCKET && GetLastError() == SOCKET_EWOULDBLOCK)
190  return false;
191  CheckAndHandleError("accept", s);
192  target.AttachSocket(s, true);
193  return true;
194 }
195 
196 void Socket::GetSockName(sockaddr *psa, socklen_t *psaLen)
197 {
198  assert(m_s != INVALID_SOCKET);
199  CheckAndHandleError_int("getsockname", getsockname(m_s, psa, psaLen));
200 }
201 
202 void Socket::GetPeerName(sockaddr *psa, socklen_t *psaLen)
203 {
204  assert(m_s != INVALID_SOCKET);
205  CheckAndHandleError_int("getpeername", getpeername(m_s, psa, psaLen));
206 }
207 
208 unsigned int Socket::Send(const byte* buf, size_t bufLen, int flags)
209 {
210  assert(m_s != INVALID_SOCKET);
211  int result = send(m_s, (const char *)buf, UnsignedMin(INT_MAX, bufLen), flags);
212  CheckAndHandleError_int("send", result);
213  return result;
214 }
215 
216 unsigned int Socket::Receive(byte* buf, size_t bufLen, int flags)
217 {
218  assert(m_s != INVALID_SOCKET);
219  int result = recv(m_s, (char *)buf, UnsignedMin(INT_MAX, bufLen), flags);
220  CheckAndHandleError_int("recv", result);
221  return result;
222 }
223 
224 void Socket::ShutDown(int how)
225 {
226  assert(m_s != INVALID_SOCKET);
227  int result = shutdown(m_s, how);
228  CheckAndHandleError_int("shutdown", result);
229 }
230 
231 void Socket::IOCtl(long cmd, unsigned long *argp)
232 {
233  assert(m_s != INVALID_SOCKET);
234 #ifdef USE_WINDOWS_STYLE_SOCKETS
235  CheckAndHandleError_int("ioctlsocket", ioctlsocket(m_s, cmd, argp));
236 #else
237  CheckAndHandleError_int("ioctl", ioctl(m_s, cmd, argp));
238 #endif
239 }
240 
241 bool Socket::SendReady(const timeval *timeout)
242 {
243  fd_set fds;
244  FD_ZERO(&fds);
245  FD_SET(m_s, &fds);
246 #ifdef CRYPTOPP_MSAN
247  __msan_unpoison(&fds, sizeof(fds));
248 #endif
249 
250  int ready;
251  if (timeout == NULL)
252  ready = select((int)m_s+1, NULL, &fds, NULL, NULL);
253  else
254  {
255  timeval timeoutCopy = *timeout; // select() modified timeout on Linux
256  ready = select((int)m_s+1, NULL, &fds, NULL, &timeoutCopy);
257  }
258  CheckAndHandleError_int("select", ready);
259  return ready > 0;
260 }
261 
262 bool Socket::ReceiveReady(const timeval *timeout)
263 {
264  fd_set fds;
265  FD_ZERO(&fds);
266  FD_SET(m_s, &fds);
267 #ifdef CRYPTOPP_MSAN
268  __msan_unpoison(&fds, sizeof(fds));
269 #endif
270 
271  int ready;
272  if (timeout == NULL)
273  ready = select((int)m_s+1, &fds, NULL, NULL, NULL);
274  else
275  {
276  timeval timeoutCopy = *timeout; // select() modified timeout on Linux
277  ready = select((int)m_s+1, &fds, NULL, NULL, &timeoutCopy);
278  }
279  CheckAndHandleError_int("select", ready);
280  return ready > 0;
281 }
282 
283 unsigned int Socket::PortNameToNumber(const char *name, const char *protocol)
284 {
285  int port = atoi(name);
286  if (IntToString(port) == name)
287  return port;
288 
289  servent *se = getservbyname(name, protocol);
290  if (!se)
291  throw Err(INVALID_SOCKET, "getservbyname", SOCKET_EINVAL);
292  return ntohs(se->s_port);
293 }
294 
295 void Socket::StartSockets()
296 {
297 #ifdef USE_WINDOWS_STYLE_SOCKETS
298  WSADATA wsd;
299  int result = WSAStartup(0x0202, &wsd);
300  if (result != 0)
301  throw Err(INVALID_SOCKET, "WSAStartup", result);
302 #endif
303 }
304 
305 void Socket::ShutdownSockets()
306 {
307 #ifdef USE_WINDOWS_STYLE_SOCKETS
308  int result = WSACleanup();
309  if (result != 0)
310  throw Err(INVALID_SOCKET, "WSACleanup", result);
311 #endif
312 }
313 
314 int Socket::GetLastError()
315 {
316 #ifdef USE_WINDOWS_STYLE_SOCKETS
317  return WSAGetLastError();
318 #else
319  return errno;
320 #endif
321 }
322 
323 void Socket::SetLastError(int errorCode)
324 {
325 #ifdef USE_WINDOWS_STYLE_SOCKETS
326  WSASetLastError(errorCode);
327 #else
328  errno = errorCode;
329 #endif
330 }
331 
332 void Socket::HandleError(const char *operation) const
333 {
334  int err = GetLastError();
335  throw Err(m_s, operation, err);
336 }
337 
338 #ifdef USE_WINDOWS_STYLE_SOCKETS
339 
340 SocketReceiver::SocketReceiver(Socket &s)
341  : m_s(s), m_eofReceived(false), m_resultPending(false)
342 {
343  m_event.AttachHandle(CreateEvent(NULL, true, false, NULL), true);
344  m_s.CheckAndHandleError("CreateEvent", m_event.HandleValid());
345  memset(&m_overlapped, 0, sizeof(m_overlapped));
346  m_overlapped.hEvent = m_event;
347 }
348 
349 SocketReceiver::~SocketReceiver()
350 {
351 #ifdef USE_WINDOWS_STYLE_SOCKETS
352  CancelIo((HANDLE) m_s.GetSocket());
353 #endif
354 }
355 
356 bool SocketReceiver::Receive(byte* buf, size_t bufLen)
357 {
358  assert(!m_resultPending && !m_eofReceived);
359 
360  DWORD flags = 0;
361  // don't queue too much at once, or we might use up non-paged memory
362  WSABUF wsabuf = {UnsignedMin((u_long)128*1024, bufLen), (char *)buf};
363  if (WSARecv(m_s, &wsabuf, 1, &m_lastResult, &flags, &m_overlapped, NULL) == 0)
364  {
365  if (m_lastResult == 0)
366  m_eofReceived = true;
367  }
368  else
369  {
370  switch (WSAGetLastError())
371  {
372  default:
373  m_s.CheckAndHandleError_int("WSARecv", SOCKET_ERROR);
374  case WSAEDISCON:
375  m_lastResult = 0;
376  m_eofReceived = true;
377  break;
378  case WSA_IO_PENDING:
379  m_resultPending = true;
380  }
381  }
382  return !m_resultPending;
383 }
384 
385 void SocketReceiver::GetWaitObjects(WaitObjectContainer &container, CallStack const& callStack)
386 {
387  if (m_resultPending)
388  container.AddHandle(m_event, CallStack("SocketReceiver::GetWaitObjects() - result pending", &callStack));
389  else if (!m_eofReceived)
390  container.SetNoWait(CallStack("SocketReceiver::GetWaitObjects() - result ready", &callStack));
391 }
392 
393 unsigned int SocketReceiver::GetReceiveResult()
394 {
395  if (m_resultPending)
396  {
397  DWORD flags = 0;
398  if (WSAGetOverlappedResult(m_s, &m_overlapped, &m_lastResult, false, &flags))
399  {
400  if (m_lastResult == 0)
401  m_eofReceived = true;
402  }
403  else
404  {
405  switch (WSAGetLastError())
406  {
407  default:
408  m_s.CheckAndHandleError("WSAGetOverlappedResult", FALSE);
409  case WSAEDISCON:
410  m_lastResult = 0;
411  m_eofReceived = true;
412  }
413  }
414  m_resultPending = false;
415  }
416  return m_lastResult;
417 }
418 
419 // *************************************************************
420 
421 SocketSender::SocketSender(Socket &s)
422  : m_s(s), m_resultPending(false), m_lastResult(0)
423 {
424  m_event.AttachHandle(CreateEvent(NULL, true, false, NULL), true);
425  m_s.CheckAndHandleError("CreateEvent", m_event.HandleValid());
426  memset(&m_overlapped, 0, sizeof(m_overlapped));
427  m_overlapped.hEvent = m_event;
428 }
429 
430 
431 SocketSender::~SocketSender()
432 {
433 #ifdef USE_WINDOWS_STYLE_SOCKETS
434  CancelIo((HANDLE) m_s.GetSocket());
435 #endif
436 }
437 
438 void SocketSender::Send(const byte* buf, size_t bufLen)
439 {
440  assert(!m_resultPending);
441  DWORD written = 0;
442  // don't queue too much at once, or we might use up non-paged memory
443  WSABUF wsabuf = {UnsignedMin((u_long)128*1024, bufLen), (char *)buf};
444  if (WSASend(m_s, &wsabuf, 1, &written, 0, &m_overlapped, NULL) == 0)
445  {
446  m_resultPending = false;
447  m_lastResult = written;
448  }
449  else
450  {
451  if (WSAGetLastError() != WSA_IO_PENDING)
452  m_s.CheckAndHandleError_int("WSASend", SOCKET_ERROR);
453 
454  m_resultPending = true;
455  }
456 }
457 
458 void SocketSender::SendEof()
459 {
460  assert(!m_resultPending);
461  m_s.ShutDown(SD_SEND);
462  m_s.CheckAndHandleError("ResetEvent", ResetEvent(m_event));
463  m_s.CheckAndHandleError_int("WSAEventSelect", WSAEventSelect(m_s, m_event, FD_CLOSE));
464  m_resultPending = true;
465 }
466 
467 bool SocketSender::EofSent()
468 {
469  if (m_resultPending)
470  {
471  WSANETWORKEVENTS events;
472  m_s.CheckAndHandleError_int("WSAEnumNetworkEvents", WSAEnumNetworkEvents(m_s, m_event, &events));
473  if ((events.lNetworkEvents & FD_CLOSE) != FD_CLOSE)
474  throw Socket::Err(m_s, "WSAEnumNetworkEvents (FD_CLOSE not present)", E_FAIL);
475  if (events.iErrorCode[FD_CLOSE_BIT] != 0)
476  throw Socket::Err(m_s, "FD_CLOSE (via WSAEnumNetworkEvents)", events.iErrorCode[FD_CLOSE_BIT]);
477  m_resultPending = false;
478  }
479  return m_lastResult != 0;
480 }
481 
482 void SocketSender::GetWaitObjects(WaitObjectContainer &container, CallStack const& callStack)
483 {
484  if (m_resultPending)
485  container.AddHandle(m_event, CallStack("SocketSender::GetWaitObjects() - result pending", &callStack));
486  else
487  container.SetNoWait(CallStack("SocketSender::GetWaitObjects() - result ready", &callStack));
488 }
489 
490 unsigned int SocketSender::GetSendResult()
491 {
492  if (m_resultPending)
493  {
494  DWORD flags = 0;
495  BOOL result = WSAGetOverlappedResult(m_s, &m_overlapped, &m_lastResult, false, &flags);
496  m_s.CheckAndHandleError("WSAGetOverlappedResult", result);
497  m_resultPending = false;
498  }
499  return m_lastResult;
500 }
501 
502 #endif
503 
504 #ifdef USE_BERKELEY_STYLE_SOCKETS
505 
506 SocketReceiver::SocketReceiver(Socket &s)
507  : m_s(s), m_eofReceived(false), m_lastResult(0)
508 {
509 }
510 
511 void SocketReceiver::GetWaitObjects(WaitObjectContainer &container, CallStack const& callStack)
512 {
513  if (!m_eofReceived)
514  container.AddReadFd(m_s, CallStack("SocketReceiver::GetWaitObjects()", &callStack));
515 }
516 
517 bool SocketReceiver::Receive(byte* buf, size_t bufLen)
518 {
519  m_lastResult = m_s.Receive(buf, bufLen);
520  if (bufLen > 0 && m_lastResult == 0)
521  m_eofReceived = true;
522  return true;
523 }
524 
525 unsigned int SocketReceiver::GetReceiveResult()
526 {
527  return m_lastResult;
528 }
529 
530 SocketSender::SocketSender(Socket &s)
531  : m_s(s), m_lastResult(0)
532 {
533 }
534 
535 void SocketSender::Send(const byte* buf, size_t bufLen)
536 {
537  m_lastResult = m_s.Send(buf, bufLen);
538 }
539 
540 void SocketSender::SendEof()
541 {
542  m_s.ShutDown(SD_SEND);
543 }
544 
545 unsigned int SocketSender::GetSendResult()
546 {
547  return m_lastResult;
548 }
549 
550 void SocketSender::GetWaitObjects(WaitObjectContainer &container, CallStack const& callStack)
551 {
552  container.AddWriteFd(m_s, CallStack("SocketSender::GetWaitObjects()", &callStack));
553 }
554 
555 #endif
556 
557 NAMESPACE_END
558 
559 #endif // #ifdef SOCKETS_AVAILABLE
Base class for all exceptions thrown by the library.
Definition: cryptlib.h:139
container of wait objects
Definition: wait.h:154
The operating system reported an error.
Definition: cryptlib.h:217
bool IsAlignedOn(const void *ptr, unsigned int alignment)
Determines whether ptr is aligned to a minimum value.
Definition: misc.h:811
const T1 UnsignedMin(const T1 &a, const T2 &b)
Safe comparison of values that could be neagtive and incorrectly promoted.
Definition: misc.h:433
std::string IntToString(T value, unsigned int base=10)
Converts a value to a string.
Definition: misc.h:460
Crypto++ library namespace.