Crypto++  5.6.3
Free C++ class library of cryptographic schemes
socketft.cpp
1 // socketft.cpp - written and placed in the public domain by Wei Dai
2 
3 #include "pch.h"
4 
5 // TODO: http://github.com/weidai11/cryptopp/issues/19
6 #define _WINSOCK_DEPRECATED_NO_WARNINGS
7 #include "socketft.h"
8 
9 #ifdef SOCKETS_AVAILABLE
10 
11 #include "wait.h"
12 
13 #ifdef USE_BERKELEY_STYLE_SOCKETS
14 #include <errno.h>
15 #include <netdb.h>
16 #include <unistd.h>
17 #include <arpa/inet.h>
18 #include <netinet/in.h>
19 #include <sys/ioctl.h>
20 #endif
21 
22 #ifdef PREFER_WINDOWS_STYLE_SOCKETS
23 # pragma comment(lib, "ws2_32.lib")
24 #endif
25 
26 NAMESPACE_BEGIN(CryptoPP)
27 
28 #ifdef USE_WINDOWS_STYLE_SOCKETS
29 const int SOCKET_EINVAL = WSAEINVAL;
30 const int SOCKET_EWOULDBLOCK = WSAEWOULDBLOCK;
31 typedef int socklen_t;
32 #else
33 const int SOCKET_EINVAL = EINVAL;
34 const int SOCKET_EWOULDBLOCK = EWOULDBLOCK;
35 #endif
36 
37 // Solaris doesn't have INADDR_NONE
38 #ifndef INADDR_NONE
39 # define INADDR_NONE 0xffffffff
40 #endif /* INADDR_NONE */
41 
42 Socket::Err::Err(socket_t s, const std::string& operation, int error)
43  : OS_Error(IO_ERROR, "Socket: " + operation + " operation failed with error " + IntToString(error), operation, error)
44  , m_s(s)
45 {
46 }
47 
48 Socket::~Socket()
49 {
50  if (m_own)
51  {
52  try
53  {
54  CloseSocket();
55  }
56  catch (const Exception&)
57  {
58  assert(0);
59  }
60  }
61 }
62 
63 void Socket::AttachSocket(socket_t s, bool own)
64 {
65  if (m_own)
66  CloseSocket();
67 
68  m_s = s;
69  m_own = own;
70  SocketChanged();
71 }
72 
73 socket_t Socket::DetachSocket()
74 {
75  socket_t s = m_s;
76  m_s = INVALID_SOCKET;
77  SocketChanged();
78  return s;
79 }
80 
81 void Socket::Create(int nType)
82 {
83  assert(m_s == INVALID_SOCKET);
84  m_s = socket(AF_INET, nType, 0);
85  CheckAndHandleError("socket", m_s);
86  m_own = true;
87  SocketChanged();
88 }
89 
90 void Socket::CloseSocket()
91 {
92  if (m_s != INVALID_SOCKET)
93  {
94 #ifdef USE_WINDOWS_STYLE_SOCKETS
95  CancelIo((HANDLE) m_s);
96  CheckAndHandleError_int("closesocket", closesocket(m_s));
97 #else
98  CheckAndHandleError_int("close", close(m_s));
99 #endif
100  m_s = INVALID_SOCKET;
101  SocketChanged();
102  }
103 }
104 
105 void Socket::Bind(unsigned int port, const char *addr)
106 {
107  sockaddr_in sa;
108  memset(&sa, 0, sizeof(sa));
109  sa.sin_family = AF_INET;
110 
111  if (addr == NULL)
112  sa.sin_addr.s_addr = htonl(INADDR_ANY);
113  else
114  {
115  unsigned long result = inet_addr(addr);
116  if (result == INADDR_NONE)
117  {
118  SetLastError(SOCKET_EINVAL);
119  CheckAndHandleError_int("inet_addr", SOCKET_ERROR);
120  }
121  sa.sin_addr.s_addr = result;
122  }
123 
124  sa.sin_port = htons((u_short)port);
125 
126  Bind((sockaddr *)&sa, sizeof(sa));
127 }
128 
129 void Socket::Bind(const sockaddr *psa, socklen_t saLen)
130 {
131  assert(m_s != INVALID_SOCKET);
132  // cygwin workaround: needs const_cast
133  CheckAndHandleError_int("bind", bind(m_s, const_cast<sockaddr *>(psa), saLen));
134 }
135 
136 void Socket::Listen(int backlog)
137 {
138  assert(m_s != INVALID_SOCKET);
139  CheckAndHandleError_int("listen", listen(m_s, backlog));
140 }
141 
142 bool Socket::Connect(const char *addr, unsigned int port)
143 {
144  assert(addr != NULL);
145 
146  sockaddr_in sa;
147  memset(&sa, 0, sizeof(sa));
148  sa.sin_family = AF_INET;
149  sa.sin_addr.s_addr = inet_addr(addr);
150 
151  if (sa.sin_addr.s_addr == INADDR_NONE)
152  {
153  hostent *lphost = gethostbyname(addr);
154  if (lphost == NULL)
155  {
156  SetLastError(SOCKET_EINVAL);
157  CheckAndHandleError_int("gethostbyname", SOCKET_ERROR);
158  }
159  else
160  {
161  assert(IsAlignedOn(lphost->h_addr,GetAlignmentOf<in_addr>()));
162  sa.sin_addr.s_addr = ((in_addr *)(void *)lphost->h_addr)->s_addr;
163  }
164  }
165 
166  sa.sin_port = htons((u_short)port);
167 
168  return Connect((const sockaddr *)&sa, sizeof(sa));
169 }
170 
171 bool Socket::Connect(const sockaddr* psa, socklen_t saLen)
172 {
173  assert(m_s != INVALID_SOCKET);
174  int result = connect(m_s, const_cast<sockaddr*>(psa), saLen);
175  if (result == SOCKET_ERROR && GetLastError() == SOCKET_EWOULDBLOCK)
176  return false;
177  CheckAndHandleError_int("connect", result);
178  return true;
179 }
180 
181 bool Socket::Accept(Socket& target, sockaddr *psa, socklen_t *psaLen)
182 {
183  assert(m_s != INVALID_SOCKET);
184  socket_t s = accept(m_s, psa, psaLen);
185  if (s == INVALID_SOCKET && GetLastError() == SOCKET_EWOULDBLOCK)
186  return false;
187  CheckAndHandleError("accept", s);
188  target.AttachSocket(s, true);
189  return true;
190 }
191 
192 void Socket::GetSockName(sockaddr *psa, socklen_t *psaLen)
193 {
194  assert(m_s != INVALID_SOCKET);
195  CheckAndHandleError_int("getsockname", getsockname(m_s, psa, psaLen));
196 }
197 
198 void Socket::GetPeerName(sockaddr *psa, socklen_t *psaLen)
199 {
200  assert(m_s != INVALID_SOCKET);
201  CheckAndHandleError_int("getpeername", getpeername(m_s, psa, psaLen));
202 }
203 
204 unsigned int Socket::Send(const byte* buf, size_t bufLen, int flags)
205 {
206  assert(m_s != INVALID_SOCKET);
207  int result = send(m_s, (const char *)buf, UnsignedMin(INT_MAX, bufLen), flags);
208  CheckAndHandleError_int("send", result);
209  return result;
210 }
211 
212 unsigned int Socket::Receive(byte* buf, size_t bufLen, int flags)
213 {
214  assert(m_s != INVALID_SOCKET);
215  int result = recv(m_s, (char *)buf, UnsignedMin(INT_MAX, bufLen), flags);
216  CheckAndHandleError_int("recv", result);
217  return result;
218 }
219 
220 void Socket::ShutDown(int how)
221 {
222  assert(m_s != INVALID_SOCKET);
223  int result = shutdown(m_s, how);
224  CheckAndHandleError_int("shutdown", result);
225 }
226 
227 void Socket::IOCtl(long cmd, unsigned long *argp)
228 {
229  assert(m_s != INVALID_SOCKET);
230 #ifdef USE_WINDOWS_STYLE_SOCKETS
231  CheckAndHandleError_int("ioctlsocket", ioctlsocket(m_s, cmd, argp));
232 #else
233  CheckAndHandleError_int("ioctl", ioctl(m_s, cmd, argp));
234 #endif
235 }
236 
237 bool Socket::SendReady(const timeval *timeout)
238 {
239  fd_set fds;
240  FD_ZERO(&fds);
241  FD_SET(m_s, &fds);
242  int ready;
243  if (timeout == NULL)
244  ready = select((int)m_s+1, NULL, &fds, NULL, NULL);
245  else
246  {
247  timeval timeoutCopy = *timeout; // select() modified timeout on Linux
248  ready = select((int)m_s+1, NULL, &fds, NULL, &timeoutCopy);
249  }
250  CheckAndHandleError_int("select", ready);
251  return ready > 0;
252 }
253 
254 bool Socket::ReceiveReady(const timeval *timeout)
255 {
256  fd_set fds;
257  FD_ZERO(&fds);
258  FD_SET(m_s, &fds);
259  int ready;
260  if (timeout == NULL)
261  ready = select((int)m_s+1, &fds, NULL, NULL, NULL);
262  else
263  {
264  timeval timeoutCopy = *timeout; // select() modified timeout on Linux
265  ready = select((int)m_s+1, &fds, NULL, NULL, &timeoutCopy);
266  }
267  CheckAndHandleError_int("select", ready);
268  return ready > 0;
269 }
270 
271 unsigned int Socket::PortNameToNumber(const char *name, const char *protocol)
272 {
273  int port = atoi(name);
274  if (IntToString(port) == name)
275  return port;
276 
277  servent *se = getservbyname(name, protocol);
278  if (!se)
279  throw Err(INVALID_SOCKET, "getservbyname", SOCKET_EINVAL);
280  return ntohs(se->s_port);
281 }
282 
283 void Socket::StartSockets()
284 {
285 #ifdef USE_WINDOWS_STYLE_SOCKETS
286  WSADATA wsd;
287  int result = WSAStartup(0x0202, &wsd);
288  if (result != 0)
289  throw Err(INVALID_SOCKET, "WSAStartup", result);
290 #endif
291 }
292 
293 void Socket::ShutdownSockets()
294 {
295 #ifdef USE_WINDOWS_STYLE_SOCKETS
296  int result = WSACleanup();
297  if (result != 0)
298  throw Err(INVALID_SOCKET, "WSACleanup", result);
299 #endif
300 }
301 
302 int Socket::GetLastError()
303 {
304 #ifdef USE_WINDOWS_STYLE_SOCKETS
305  return WSAGetLastError();
306 #else
307  return errno;
308 #endif
309 }
310 
311 void Socket::SetLastError(int errorCode)
312 {
313 #ifdef USE_WINDOWS_STYLE_SOCKETS
314  WSASetLastError(errorCode);
315 #else
316  errno = errorCode;
317 #endif
318 }
319 
320 void Socket::HandleError(const char *operation) const
321 {
322  int err = GetLastError();
323  throw Err(m_s, operation, err);
324 }
325 
326 #ifdef USE_WINDOWS_STYLE_SOCKETS
327 
328 SocketReceiver::SocketReceiver(Socket &s)
329  : m_s(s), m_eofReceived(false), m_resultPending(false)
330 {
331  m_event.AttachHandle(CreateEvent(NULL, true, false, NULL), true);
332  m_s.CheckAndHandleError("CreateEvent", m_event.HandleValid());
333  memset(&m_overlapped, 0, sizeof(m_overlapped));
334  m_overlapped.hEvent = m_event;
335 }
336 
337 SocketReceiver::~SocketReceiver()
338 {
339 #ifdef USE_WINDOWS_STYLE_SOCKETS
340  CancelIo((HANDLE) m_s.GetSocket());
341 #endif
342 }
343 
344 bool SocketReceiver::Receive(byte* buf, size_t bufLen)
345 {
346  assert(!m_resultPending && !m_eofReceived);
347 
348  DWORD flags = 0;
349  // don't queue too much at once, or we might use up non-paged memory
350  WSABUF wsabuf = {UnsignedMin((u_long)128*1024, bufLen), (char *)buf};
351  if (WSARecv(m_s, &wsabuf, 1, &m_lastResult, &flags, &m_overlapped, NULL) == 0)
352  {
353  if (m_lastResult == 0)
354  m_eofReceived = true;
355  }
356  else
357  {
358  switch (WSAGetLastError())
359  {
360  default:
361  m_s.CheckAndHandleError_int("WSARecv", SOCKET_ERROR);
362  case WSAEDISCON:
363  m_lastResult = 0;
364  m_eofReceived = true;
365  break;
366  case WSA_IO_PENDING:
367  m_resultPending = true;
368  }
369  }
370  return !m_resultPending;
371 }
372 
373 void SocketReceiver::GetWaitObjects(WaitObjectContainer &container, CallStack const& callStack)
374 {
375  if (m_resultPending)
376  container.AddHandle(m_event, CallStack("SocketReceiver::GetWaitObjects() - result pending", &callStack));
377  else if (!m_eofReceived)
378  container.SetNoWait(CallStack("SocketReceiver::GetWaitObjects() - result ready", &callStack));
379 }
380 
381 unsigned int SocketReceiver::GetReceiveResult()
382 {
383  if (m_resultPending)
384  {
385  DWORD flags = 0;
386  if (WSAGetOverlappedResult(m_s, &m_overlapped, &m_lastResult, false, &flags))
387  {
388  if (m_lastResult == 0)
389  m_eofReceived = true;
390  }
391  else
392  {
393  switch (WSAGetLastError())
394  {
395  default:
396  m_s.CheckAndHandleError("WSAGetOverlappedResult", FALSE);
397  case WSAEDISCON:
398  m_lastResult = 0;
399  m_eofReceived = true;
400  }
401  }
402  m_resultPending = false;
403  }
404  return m_lastResult;
405 }
406 
407 // *************************************************************
408 
409 SocketSender::SocketSender(Socket &s)
410  : m_s(s), m_resultPending(false), m_lastResult(0)
411 {
412  m_event.AttachHandle(CreateEvent(NULL, true, false, NULL), true);
413  m_s.CheckAndHandleError("CreateEvent", m_event.HandleValid());
414  memset(&m_overlapped, 0, sizeof(m_overlapped));
415  m_overlapped.hEvent = m_event;
416 }
417 
418 
419 SocketSender::~SocketSender()
420 {
421 #ifdef USE_WINDOWS_STYLE_SOCKETS
422  CancelIo((HANDLE) m_s.GetSocket());
423 #endif
424 }
425 
426 void SocketSender::Send(const byte* buf, size_t bufLen)
427 {
428  assert(!m_resultPending);
429  DWORD written = 0;
430  // don't queue too much at once, or we might use up non-paged memory
431  WSABUF wsabuf = {UnsignedMin((u_long)128*1024, bufLen), (char *)buf};
432  if (WSASend(m_s, &wsabuf, 1, &written, 0, &m_overlapped, NULL) == 0)
433  {
434  m_resultPending = false;
435  m_lastResult = written;
436  }
437  else
438  {
439  if (WSAGetLastError() != WSA_IO_PENDING)
440  m_s.CheckAndHandleError_int("WSASend", SOCKET_ERROR);
441 
442  m_resultPending = true;
443  }
444 }
445 
446 void SocketSender::SendEof()
447 {
448  assert(!m_resultPending);
449  m_s.ShutDown(SD_SEND);
450  m_s.CheckAndHandleError("ResetEvent", ResetEvent(m_event));
451  m_s.CheckAndHandleError_int("WSAEventSelect", WSAEventSelect(m_s, m_event, FD_CLOSE));
452  m_resultPending = true;
453 }
454 
455 bool SocketSender::EofSent()
456 {
457  if (m_resultPending)
458  {
459  WSANETWORKEVENTS events;
460  m_s.CheckAndHandleError_int("WSAEnumNetworkEvents", WSAEnumNetworkEvents(m_s, m_event, &events));
461  if ((events.lNetworkEvents & FD_CLOSE) != FD_CLOSE)
462  throw Socket::Err(m_s, "WSAEnumNetworkEvents (FD_CLOSE not present)", E_FAIL);
463  if (events.iErrorCode[FD_CLOSE_BIT] != 0)
464  throw Socket::Err(m_s, "FD_CLOSE (via WSAEnumNetworkEvents)", events.iErrorCode[FD_CLOSE_BIT]);
465  m_resultPending = false;
466  }
467  return m_lastResult != 0;
468 }
469 
470 void SocketSender::GetWaitObjects(WaitObjectContainer &container, CallStack const& callStack)
471 {
472  if (m_resultPending)
473  container.AddHandle(m_event, CallStack("SocketSender::GetWaitObjects() - result pending", &callStack));
474  else
475  container.SetNoWait(CallStack("SocketSender::GetWaitObjects() - result ready", &callStack));
476 }
477 
478 unsigned int SocketSender::GetSendResult()
479 {
480  if (m_resultPending)
481  {
482  DWORD flags = 0;
483  BOOL result = WSAGetOverlappedResult(m_s, &m_overlapped, &m_lastResult, false, &flags);
484  m_s.CheckAndHandleError("WSAGetOverlappedResult", result);
485  m_resultPending = false;
486  }
487  return m_lastResult;
488 }
489 
490 #endif
491 
492 #ifdef USE_BERKELEY_STYLE_SOCKETS
493 
494 SocketReceiver::SocketReceiver(Socket &s)
495  : m_s(s), m_eofReceived(false), m_lastResult(0)
496 {
497 }
498 
499 void SocketReceiver::GetWaitObjects(WaitObjectContainer &container, CallStack const& callStack)
500 {
501  if (!m_eofReceived)
502  container.AddReadFd(m_s, CallStack("SocketReceiver::GetWaitObjects()", &callStack));
503 }
504 
505 bool SocketReceiver::Receive(byte* buf, size_t bufLen)
506 {
507  m_lastResult = m_s.Receive(buf, bufLen);
508  if (bufLen > 0 && m_lastResult == 0)
509  m_eofReceived = true;
510  return true;
511 }
512 
513 unsigned int SocketReceiver::GetReceiveResult()
514 {
515  return m_lastResult;
516 }
517 
518 SocketSender::SocketSender(Socket &s)
519  : m_s(s), m_lastResult(0)
520 {
521 }
522 
523 void SocketSender::Send(const byte* buf, size_t bufLen)
524 {
525  m_lastResult = m_s.Send(buf, bufLen);
526 }
527 
528 void SocketSender::SendEof()
529 {
530  m_s.ShutDown(SD_SEND);
531 }
532 
533 unsigned int SocketSender::GetSendResult()
534 {
535  return m_lastResult;
536 }
537 
538 void SocketSender::GetWaitObjects(WaitObjectContainer &container, CallStack const& callStack)
539 {
540  container.AddWriteFd(m_s, CallStack("SocketSender::GetWaitObjects()", &callStack));
541 }
542 
543 #endif
544 
545 NAMESPACE_END
546 
547 #endif // #ifdef SOCKETS_AVAILABLE
Base class for all exceptions thrown by the library.
Definition: cryptlib.h:139
container of wait objects
Definition: wait.h:151
The operating system reported an error.
Definition: cryptlib.h:217
bool IsAlignedOn(const void *ptr, unsigned int alignment)
Determines whether ptr is aligned to a minimum value.
Definition: misc.h:811
const T1 UnsignedMin(const T1 &a, const T2 &b)
Safe comparison of values that could be neagtive and incorrectly promoted.
Definition: misc.h:433
std::string IntToString(T value, unsigned int base=10)
Converts a value to a string.
Definition: misc.h:460
Crypto++ library namespace.