Crypto++  8.8
Free C++ class library of cryptographic schemes
oaep.cpp
1 // oaep.cpp - originally written and placed in the public domain by Wei Dai
2 
3 #include "pch.h"
4 
5 #ifndef CRYPTOPP_IMPORTS
6 
7 #include "oaep.h"
8 #include "stdcpp.h"
9 #include "smartptr.h"
10 
11 NAMESPACE_BEGIN(CryptoPP)
12 
13 // ********************************************************
14 
15 size_t OAEP_Base::MaxUnpaddedLength(size_t paddedLength) const
16 {
17  return SaturatingSubtract(paddedLength/8, 1+2*DigestSize());
18 }
19 
20 void OAEP_Base::Pad(RandomNumberGenerator &rng, const byte *input, size_t inputLength, byte *oaepBlock, size_t oaepBlockLen, const NameValuePairs &parameters) const
21 {
22  CRYPTOPP_ASSERT (inputLength <= MaxUnpaddedLength(oaepBlockLen));
23 
24  // convert from bit length to byte length
25  if (oaepBlockLen % 8 != 0)
26  {
27  oaepBlock[0] = 0;
28  oaepBlock++;
29  }
30  oaepBlockLen /= 8;
31 
32  member_ptr<HashTransformation> pHash(NewHash());
33  const size_t hLen = pHash->DigestSize();
34  const size_t seedLen = hLen, dbLen = oaepBlockLen-seedLen;
35  byte *const maskedSeed = oaepBlock;
36  byte *const maskedDB = oaepBlock+seedLen;
37 
38  ConstByteArrayParameter encodingParameters;
39  parameters.GetValue(Name::EncodingParameters(), encodingParameters);
40 
41  // DB = pHash || 00 ... || 01 || M
42  pHash->CalculateDigest(maskedDB, encodingParameters.begin(), encodingParameters.size());
43  std::memset(maskedDB+hLen, 0, dbLen-hLen-inputLength-1);
44  maskedDB[dbLen-inputLength-1] = 0x01;
45  std::memcpy(maskedDB+dbLen-inputLength, input, inputLength);
46 
47  rng.GenerateBlock(maskedSeed, seedLen);
49  pMGF->GenerateAndMask(*pHash, maskedDB, dbLen, maskedSeed, seedLen);
50  pMGF->GenerateAndMask(*pHash, maskedSeed, seedLen, maskedDB, dbLen);
51 }
52 
53 DecodingResult OAEP_Base::Unpad(const byte *oaepBlock, size_t oaepBlockLen, byte *output, const NameValuePairs &parameters) const
54 {
55  bool invalid = false;
56 
57  // convert from bit length to byte length
58  if (oaepBlockLen % 8 != 0)
59  {
60  invalid = (oaepBlock[0] != 0) || invalid;
61  oaepBlock++;
62  }
63  oaepBlockLen /= 8;
64 
65  member_ptr<HashTransformation> pHash(NewHash());
66  const size_t hLen = pHash->DigestSize();
67  const size_t seedLen = hLen, dbLen = oaepBlockLen-seedLen;
68 
69  invalid = (oaepBlockLen < 2*hLen+1) || invalid;
70 
71  SecByteBlock t(oaepBlock, oaepBlockLen);
72  byte *const maskedSeed = t;
73  byte *const maskedDB = t+seedLen;
74 
76  pMGF->GenerateAndMask(*pHash, maskedSeed, seedLen, maskedDB, dbLen);
77  pMGF->GenerateAndMask(*pHash, maskedDB, dbLen, maskedSeed, seedLen);
78 
79  ConstByteArrayParameter encodingParameters;
80  parameters.GetValue(Name::EncodingParameters(), encodingParameters);
81 
82  // DB = pHash' || 00 ... || 01 || M
83  byte *M = std::find(maskedDB+hLen, maskedDB+dbLen, 0x01);
84  invalid = (M == maskedDB+dbLen) || invalid;
85  invalid = (FindIfNot(maskedDB+hLen, M, byte(0)) != M) || invalid;
86  invalid = !pHash->VerifyDigest(maskedDB, encodingParameters.begin(), encodingParameters.size()) || invalid;
87 
88  if (invalid)
89  return DecodingResult();
90 
91  M++;
92  std::memcpy(output, M, maskedDB+dbLen-M);
93  return DecodingResult(maskedDB+dbLen-M);
94 }
95 
96 NAMESPACE_END
97 
98 #endif
Used to pass byte array input as part of a NameValuePairs object.
Definition: algparam.h:25
size_t size() const
Length of the memory block.
Definition: algparam.h:88
const byte * begin() const
Pointer to the first byte in the memory block.
Definition: algparam.h:84
Interface for retrieving values given their names.
Definition: cryptlib.h:327
bool GetValue(const char *name, T &value) const
Get a named value.
Definition: cryptlib.h:384
size_t MaxUnpaddedLength(size_t paddedLength) const
max size of unpadded message in bytes, given max size of padded message in bits (1 less than size of ...
Interface for random number generators.
Definition: cryptlib.h:1440
virtual void GenerateBlock(byte *output, size_t size)
Generate random array of bytes.
SecBlock<byte> typedef.
Definition: secblock.h:1226
Pointer that overloads operator ->
Definition: smartptr.h:38
T1 SaturatingSubtract(const T1 &a, const T2 &b)
Performs a saturating subtract clamped at 0.
Definition: misc.h:1302
InputIt FindIfNot(InputIt first, InputIt last, const T &value)
Finds first element not in a range.
Definition: misc.h:3190
Crypto++ library namespace.
const char * EncodingParameters()
ConstByteArrayParameter.
Definition: argnames.h:66
Classes for optimal asymmetric encryption padding.
Precompiled header file.
Classes for automatic resource management.
Common C++ header files.
Returns a decoding results.
Definition: cryptlib.h:283
#define CRYPTOPP_ASSERT(exp)
Debugging and diagnostic assertion.
Definition: trap.h:68