00001 #include "factory.h"
00002 #include "integer.h"
00003 #include "filters.h"
00004 #include "hex.h"
00005 #include "randpool.h"
00006 #include "files.h"
00007 #include "trunhash.h"
00008 #include "queue.h"
00009 #include "validate.h"
00010 #include <iostream>
00011 #include <memory>
00012
00013 USING_NAMESPACE(CryptoPP)
00014 USING_NAMESPACE(std)
00015
00016 typedef std::map<std::string, std::string> TestData;
00017
00018 class TestFailure : public Exception
00019 {
00020 public:
00021 TestFailure() : Exception(OTHER_ERROR, "Validation test failed") {}
00022 };
00023
00024 static const TestData *s_currentTestData = NULL;
00025
00026 static void OutputTestData(const TestData &v)
00027 {
00028 for (TestData::const_iterator i = v.begin(); i != v.end(); ++i)
00029 {
00030 cerr << i->first << ": " << i->second << endl;
00031 }
00032 }
00033
00034 static void SignalTestFailure()
00035 {
00036 OutputTestData(*s_currentTestData);
00037 throw TestFailure();
00038 }
00039
00040 static void SignalTestError()
00041 {
00042 OutputTestData(*s_currentTestData);
00043 throw Exception(Exception::OTHER_ERROR, "Unexpected error during validation test");
00044 }
00045
00046 const std::string & GetRequiredDatum(const TestData &data, const char *name)
00047 {
00048 TestData::const_iterator i = data.find(name);
00049 if (i == data.end())
00050 SignalTestError();
00051 return i->second;
00052 }
00053
00054 void PutDecodedDatumInto(const TestData &data, const char *name, BufferedTransformation &target)
00055 {
00056 std::string s1 = GetRequiredDatum(data, name), s2;
00057
00058 while (!s1.empty())
00059 {
00060 while (s1[0] == ' ')
00061 s1 = s1.substr(1);
00062
00063 int repeat = 1;
00064 if (s1[0] == 'r')
00065 {
00066 repeat = atoi(s1.c_str()+1);
00067 s1 = s1.substr(s1.find(' ')+1);
00068 }
00069
00070 s2 = "";
00071
00072 if (s1[0] == '\"')
00073 {
00074 s2 = s1.substr(1, s1.find('\"', 1)-1);
00075 s1 = s1.substr(s2.length() + 2);
00076 }
00077 else if (s1.substr(0, 2) == "0x")
00078 {
00079 StringSource(s1.substr(2, s1.find(' ')), true, new HexDecoder(new StringSink(s2)));
00080 s1 = s1.substr(STDMIN(s1.find(' '), s1.length()));
00081 }
00082 else
00083 {
00084 StringSource(s1.substr(0, s1.find(' ')), true, new HexDecoder(new StringSink(s2)));
00085 s1 = s1.substr(STDMIN(s1.find(' '), s1.length()));
00086 }
00087
00088 ByteQueue q;
00089 while (repeat--)
00090 {
00091 q.Put((const byte *)s2.data(), s2.size());
00092 if (q.MaxRetrievable() > 4*1024 || repeat == 0)
00093 q.TransferTo(target);
00094 }
00095 }
00096 }
00097
00098 std::string GetDecodedDatum(const TestData &data, const char *name)
00099 {
00100 std::string s;
00101 PutDecodedDatumInto(data, name, StringSink(s).Ref());
00102 return s;
00103 }
00104
00105 class TestDataNameValuePairs : public NameValuePairs
00106 {
00107 public:
00108 TestDataNameValuePairs(const TestData &data) : m_data(data) {}
00109
00110 virtual bool GetVoidValue(const char *name, const std::type_info &valueType, void *pValue) const
00111 {
00112 TestData::const_iterator i = m_data.find(name);
00113 if (i == m_data.end())
00114 return false;
00115
00116 const std::string &value = i->second;
00117
00118 if (valueType == typeid(int))
00119 *reinterpret_cast<int *>(pValue) = atoi(value.c_str());
00120 else if (valueType == typeid(Integer))
00121 *reinterpret_cast<Integer *>(pValue) = Integer((std::string(value) + "h").c_str());
00122 else if (valueType == typeid(ConstByteArrayParameter))
00123 {
00124 m_temp.resize(0);
00125 PutDecodedDatumInto(m_data, name, StringSink(m_temp).Ref());
00126 reinterpret_cast<ConstByteArrayParameter *>(pValue)->Assign((const byte *)m_temp.data(), m_temp.size(), true);
00127 }
00128 else if (valueType == typeid(const byte *))
00129 {
00130 m_temp.resize(0);
00131 PutDecodedDatumInto(m_data, name, StringSink(m_temp).Ref());
00132 *reinterpret_cast<const byte * *>(pValue) = (const byte *)m_temp.data();
00133 }
00134 else
00135 throw ValueTypeMismatch(name, typeid(std::string), valueType);
00136
00137 return true;
00138 }
00139
00140 private:
00141 const TestData &m_data;
00142 mutable std::string m_temp;
00143 };
00144
00145 void TestKeyPairValidAndConsistent(CryptoMaterial &pub, const CryptoMaterial &priv)
00146 {
00147 if (!pub.Validate(GlobalRNG(), 3))
00148 SignalTestFailure();
00149 if (!priv.Validate(GlobalRNG(), 3))
00150 SignalTestFailure();
00151
00152
00153
00154
00155
00156
00157
00158
00159 }
00160
00161 void TestSignatureScheme(TestData &v)
00162 {
00163 std::string name = GetRequiredDatum(v, "Name");
00164 std::string test = GetRequiredDatum(v, "Test");
00165
00166 std::auto_ptr<PK_Signer> signer(ObjectFactoryRegistry<PK_Signer>::Registry().CreateObject(name.c_str()));
00167 std::auto_ptr<PK_Verifier> verifier(ObjectFactoryRegistry<PK_Verifier>::Registry().CreateObject(name.c_str()));
00168
00169 TestDataNameValuePairs pairs(v);
00170 std::string keyFormat = GetRequiredDatum(v, "KeyFormat");
00171
00172 if (keyFormat == "DER")
00173 verifier->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PublicKey")).Ref());
00174 else if (keyFormat == "Component")
00175 verifier->AccessMaterial().AssignFrom(pairs);
00176
00177 if (test == "Verify" || test == "NotVerify")
00178 {
00179 VerifierFilter verifierFilter(*verifier, NULL, VerifierFilter::SIGNATURE_AT_BEGIN);
00180 PutDecodedDatumInto(v, "Signature", verifierFilter);
00181 PutDecodedDatumInto(v, "Message", verifierFilter);
00182 verifierFilter.MessageEnd();
00183 if (verifierFilter.GetLastResult() == (test == "NotVerify"))
00184 SignalTestFailure();
00185 }
00186 else if (test == "PublicKeyValid")
00187 {
00188 if (!verifier->GetMaterial().Validate(GlobalRNG(), 3))
00189 SignalTestFailure();
00190 }
00191 else
00192 goto privateKeyTests;
00193
00194 return;
00195
00196 privateKeyTests:
00197 if (keyFormat == "DER")
00198 signer->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PrivateKey")).Ref());
00199 else if (keyFormat == "Component")
00200 signer->AccessMaterial().AssignFrom(pairs);
00201
00202 if (test == "KeyPairValidAndConsistent")
00203 {
00204 TestKeyPairValidAndConsistent(verifier->AccessMaterial(), signer->GetMaterial());
00205 }
00206 else if (test == "Sign")
00207 {
00208 SignerFilter f(GlobalRNG(), *signer, new HexEncoder(new FileSink(cout)));
00209 StringSource ss(GetDecodedDatum(v, "Message"), true, new Redirector(f));
00210 SignalTestFailure();
00211 }
00212 else if (test == "DeterministicSign")
00213 {
00214 SignalTestError();
00215 assert(false);
00216 }
00217 else if (test == "RandomSign")
00218 {
00219 SignalTestError();
00220 assert(false);
00221 }
00222 else if (test == "GenerateKey")
00223 {
00224 SignalTestError();
00225 assert(false);
00226 }
00227 else
00228 {
00229 SignalTestError();
00230 assert(false);
00231 }
00232 }
00233
00234 void TestAsymmetricCipher(TestData &v)
00235 {
00236 std::string name = GetRequiredDatum(v, "Name");
00237 std::string test = GetRequiredDatum(v, "Test");
00238
00239 std::auto_ptr<PK_Encryptor> encryptor(ObjectFactoryRegistry<PK_Encryptor>::Registry().CreateObject(name.c_str()));
00240 std::auto_ptr<PK_Decryptor> decryptor(ObjectFactoryRegistry<PK_Decryptor>::Registry().CreateObject(name.c_str()));
00241
00242 std::string keyFormat = GetRequiredDatum(v, "KeyFormat");
00243
00244 if (keyFormat == "DER")
00245 {
00246 decryptor->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PrivateKey")).Ref());
00247 encryptor->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PublicKey")).Ref());
00248 }
00249 else if (keyFormat == "Component")
00250 {
00251 TestDataNameValuePairs pairs(v);
00252 decryptor->AccessMaterial().AssignFrom(pairs);
00253 encryptor->AccessMaterial().AssignFrom(pairs);
00254 }
00255
00256 if (test == "DecryptMatch")
00257 {
00258 std::string decrypted, expected = GetDecodedDatum(v, "Plaintext");
00259 StringSource ss(GetDecodedDatum(v, "Ciphertext"), true, new PK_DecryptorFilter(GlobalRNG(), *decryptor, new StringSink(decrypted)));
00260 if (decrypted != expected)
00261 SignalTestFailure();
00262 }
00263 else if (test == "KeyPairValidAndConsistent")
00264 {
00265 TestKeyPairValidAndConsistent(encryptor->AccessMaterial(), decryptor->GetMaterial());
00266 }
00267 else
00268 {
00269 SignalTestError();
00270 assert(false);
00271 }
00272 }
00273
00274 void TestSymmetricCipher(TestData &v)
00275 {
00276 std::string name = GetRequiredDatum(v, "Name");
00277 std::string test = GetRequiredDatum(v, "Test");
00278
00279 std::string key = GetDecodedDatum(v, "Key");
00280 std::string plaintext = GetDecodedDatum(v, "Plaintext");
00281
00282 TestDataNameValuePairs pairs(v);
00283
00284 if (test == "Encrypt" || test == "EncryptXorDigest")
00285 {
00286 std::auto_ptr<SymmetricCipher> encryptor(ObjectFactoryRegistry<SymmetricCipher, ENCRYPTION>::Registry().CreateObject(name.c_str()));
00287 std::auto_ptr<SymmetricCipher> decryptor(ObjectFactoryRegistry<SymmetricCipher, DECRYPTION>::Registry().CreateObject(name.c_str()));
00288 ConstByteArrayParameter iv;
00289 if (pairs.GetValue(Name::IV(), iv) && iv.size() != encryptor->IVSize())
00290 SignalTestFailure();
00291 encryptor->SetKey((const byte *)key.data(), key.size(), pairs);
00292 decryptor->SetKey((const byte *)key.data(), key.size(), pairs);
00293 int seek = pairs.GetIntValueWithDefault("Seek", 0);
00294 if (seek)
00295 {
00296 encryptor->Seek(seek);
00297 decryptor->Seek(seek);
00298 }
00299 std::string encrypted, xorDigest, ciphertext, ciphertextXorDigest;
00300 StringSource ss(plaintext, false, new StreamTransformationFilter(*encryptor, new StringSink(encrypted), StreamTransformationFilter::NO_PADDING));
00301 ss.Pump(plaintext.size()/2 + 1);
00302 ss.PumpAll();
00303
00304
00305
00306
00307
00308
00309
00310
00311
00312 if (test == "Encrypt")
00313 ciphertext = GetDecodedDatum(v, "Ciphertext");
00314 else
00315 {
00316 ciphertextXorDigest = GetDecodedDatum(v, "CiphertextXorDigest");
00317 xorDigest.append(encrypted, 0, 64);
00318 for (size_t i=64; i<encrypted.size(); i++)
00319 xorDigest[i%64] ^= encrypted[i];
00320 }
00321 if (test == "Encrypt" ? encrypted != ciphertext : xorDigest != ciphertextXorDigest)
00322 {
00323 std::cout << "incorrectly encrypted: ";
00324 StringSource xx(encrypted, false, new HexEncoder(new FileSink(std::cout)));
00325 xx.Pump(256); xx.Flush(false);
00326 std::cout << "\n";
00327 SignalTestFailure();
00328 }
00329 std::string decrypted;
00330 StringSource dd(encrypted, false, new StreamTransformationFilter(*decryptor, new StringSink(decrypted), StreamTransformationFilter::NO_PADDING));
00331 dd.Pump(plaintext.size()/2 + 1);
00332 dd.PumpAll();
00333 if (decrypted != plaintext)
00334 {
00335 std::cout << "incorrectly decrypted: ";
00336 StringSource xx(decrypted, false, new HexEncoder(new FileSink(std::cout)));
00337 xx.Pump(256); xx.Flush(false);
00338 std::cout << "\n";
00339 SignalTestFailure();
00340 }
00341 }
00342 else if (test == "Decrypt")
00343 {
00344 }
00345 else
00346 {
00347 SignalTestError();
00348 assert(false);
00349 }
00350 }
00351
00352 void TestDigestOrMAC(TestData &v, bool testDigest)
00353 {
00354 std::string name = GetRequiredDatum(v, "Name");
00355 std::string test = GetRequiredDatum(v, "Test");
00356
00357 member_ptr<MessageAuthenticationCode> mac;
00358 member_ptr<HashTransformation> hash;
00359 HashTransformation *pHash = NULL;
00360
00361 TestDataNameValuePairs pairs(v);
00362
00363 if (testDigest)
00364 {
00365 hash.reset(ObjectFactoryRegistry<HashTransformation>::Registry().CreateObject(name.c_str()));
00366 pHash = hash.get();
00367 }
00368 else
00369 {
00370 mac.reset(ObjectFactoryRegistry<MessageAuthenticationCode>::Registry().CreateObject(name.c_str()));
00371 pHash = mac.get();
00372 ConstByteArrayParameter iv;
00373 if (pairs.GetValue(Name::IV(), iv) && iv.size() != mac->IVSize())
00374 SignalTestFailure();
00375 std::string key = GetDecodedDatum(v, "Key");
00376 mac->SetKey((const byte *)key.c_str(), key.size(), pairs);
00377 }
00378
00379 if (test == "Verify" || test == "VerifyTruncated" || test == "NotVerify")
00380 {
00381 int digestSize = pHash->DigestSize();
00382 if (test == "VerifyTruncated")
00383 digestSize = atoi(GetRequiredDatum(v, "TruncatedSize").c_str());
00384 TruncatedHashModule thash(*pHash, digestSize);
00385 HashVerificationFilter verifierFilter(thash, NULL, HashVerificationFilter::HASH_AT_BEGIN);
00386 PutDecodedDatumInto(v, "Digest", verifierFilter);
00387 PutDecodedDatumInto(v, "Message", verifierFilter);
00388 verifierFilter.MessageEnd();
00389 if (verifierFilter.GetLastResult() == (test == "NotVerify"))
00390 SignalTestFailure();
00391 }
00392 else
00393 {
00394 SignalTestError();
00395 assert(false);
00396 }
00397 }
00398
00399 bool GetField(std::istream &is, std::string &name, std::string &value)
00400 {
00401 name.resize(0);
00402 is >> name;
00403 if (name.empty())
00404 return false;
00405
00406 if (name[name.size()-1] != ':')
00407 SignalTestError();
00408 name.erase(name.size()-1);
00409
00410 while (is.peek() == ' ')
00411 is.ignore(1);
00412
00413
00414 char buffer[128];
00415 value.resize(0);
00416 bool continueLine;
00417
00418 do
00419 {
00420 do
00421 {
00422 is.get(buffer, sizeof(buffer));
00423 value += buffer;
00424 }
00425 while (buffer[0] != 0);
00426 is.clear();
00427 is.ignore();
00428
00429 if (!value.empty() && value[value.size()-1] == '\r')
00430 value.resize(value.size()-1);
00431
00432 if (!value.empty() && value[value.size()-1] == '\\')
00433 {
00434 value.resize(value.size()-1);
00435 continueLine = true;
00436 }
00437 else
00438 continueLine = false;
00439
00440 std::string::size_type i = value.find('#');
00441 if (i != std::string::npos)
00442 value.erase(i);
00443 }
00444 while (continueLine);
00445
00446 return true;
00447 }
00448
00449 void OutputPair(const NameValuePairs &v, const char *name)
00450 {
00451 Integer x;
00452 bool b = v.GetValue(name, x);
00453 assert(b);
00454 cout << name << ": \\\n ";
00455 x.Encode(HexEncoder(new FileSink(cout), false, 64, "\\\n ").Ref(), x.MinEncodedSize());
00456 cout << endl;
00457 }
00458
00459 void OutputNameValuePairs(const NameValuePairs &v)
00460 {
00461 std::string names = v.GetValueNames();
00462 string::size_type i = 0;
00463 while (i < names.size())
00464 {
00465 string::size_type j = names.find_first_of (';', i);
00466
00467 if (j == string::npos)
00468 return;
00469 else
00470 {
00471 std::string name = names.substr(i, j-i);
00472 if (name.find(':') == string::npos)
00473 OutputPair(v, name.c_str());
00474 }
00475
00476 i = j + 1;
00477 }
00478 }
00479
00480 void TestDataFile(const std::string &filename, unsigned int &totalTests, unsigned int &failedTests)
00481 {
00482 std::ifstream file(filename.c_str());
00483 if (!file.good())
00484 throw Exception(Exception::OTHER_ERROR, "Can not open file " + filename + " for reading");
00485 TestData v;
00486 s_currentTestData = &v;
00487 std::string name, value, lastAlgName;
00488
00489 while (file)
00490 {
00491 while (file.peek() == '#')
00492 file.ignore(INT_MAX, '\n');
00493
00494 if (file.peek() == '\n')
00495 v.clear();
00496
00497 if (!GetField(file, name, value))
00498 break;
00499 v[name] = value;
00500
00501 if (name == "Test")
00502 {
00503 bool failed = true;
00504 std::string algType = GetRequiredDatum(v, "AlgorithmType");
00505
00506 if (lastAlgName != GetRequiredDatum(v, "Name"))
00507 {
00508 lastAlgName = GetRequiredDatum(v, "Name");
00509 cout << "\nTesting " << algType.c_str() << " algorithm " << lastAlgName.c_str() << ".\n";
00510 }
00511
00512 try
00513 {
00514 if (algType == "Signature")
00515 TestSignatureScheme(v);
00516 else if (algType == "SymmetricCipher")
00517 TestSymmetricCipher(v);
00518 else if (algType == "AsymmetricCipher")
00519 TestAsymmetricCipher(v);
00520 else if (algType == "MessageDigest")
00521 TestDigestOrMAC(v, true);
00522 else if (algType == "MAC")
00523 TestDigestOrMAC(v, false);
00524 else if (algType == "FileList")
00525 TestDataFile(GetRequiredDatum(v, "Test"), totalTests, failedTests);
00526 else
00527 SignalTestError();
00528 failed = false;
00529 }
00530 catch (TestFailure &)
00531 {
00532 cout << "\nTest failed.\n";
00533 }
00534 catch (CryptoPP::Exception &e)
00535 {
00536 cout << "\nCryptoPP::Exception caught: " << e.what() << endl;
00537 }
00538 catch (std::exception &e)
00539 {
00540 cout << "\nstd::exception caught: " << e.what() << endl;
00541 }
00542
00543 if (failed)
00544 {
00545 cout << "Skipping to next test.\n";
00546 failedTests++;
00547 }
00548 else
00549 cout << "." << flush;
00550
00551 totalTests++;
00552 }
00553 }
00554 }
00555
00556 bool RunTestDataFile(const char *filename)
00557 {
00558 unsigned int totalTests = 0, failedTests = 0;
00559 TestDataFile(filename, totalTests, failedTests);
00560 cout << "\nTests complete. Total tests = " << totalTests << ". Failed tests = " << failedTests << ".\n";
00561 if (failedTests != 0)
00562 cout << "SOME TESTS FAILED!\n";
00563 return failedTests == 0;
00564 }