00001 #include "factory.h"
00002 #include "integer.h"
00003 #include "filters.h"
00004 #include "hex.h"
00005 #include "randpool.h"
00006 #include "files.h"
00007 #include "trunhash.h"
00008 #include <iostream>
00009 #include <memory>
00010
00011 USING_NAMESPACE(CryptoPP)
00012 USING_NAMESPACE(std)
00013
00014 RandomPool & GlobalRNG();
00015 void RegisterFactories();
00016
00017 typedef std::map<std::string, std::string> TestData;
00018
00019 class TestFailure : public Exception
00020 {
00021 public:
00022 TestFailure() : Exception(OTHER_ERROR, "Validation test failed") {}
00023 };
00024
00025 static const TestData *s_currentTestData = NULL;
00026
00027 static void OutputTestData(const TestData &v)
00028 {
00029 for (TestData::const_iterator i = v.begin(); i != v.end(); ++i)
00030 {
00031 cerr << i->first << ": " << i->second << endl;
00032 }
00033 }
00034
00035 static void SignalTestFailure()
00036 {
00037 OutputTestData(*s_currentTestData);
00038 throw TestFailure();
00039 }
00040
00041 static void SignalTestError()
00042 {
00043 OutputTestData(*s_currentTestData);
00044 throw Exception(Exception::OTHER_ERROR, "Unexpected error during validation test");
00045 }
00046
00047 class TestDataNameValuePairs : public NameValuePairs
00048 {
00049 public:
00050 TestDataNameValuePairs(const TestData &data) : m_data(data) {}
00051
00052 virtual bool GetVoidValue(const char *name, const std::type_info &valueType, void *pValue) const
00053 {
00054 TestData::const_iterator i = m_data.find(name);
00055 if (i == m_data.end())
00056 return false;
00057
00058 const std::string &value = i->second;
00059
00060 if (valueType == typeid(int))
00061 *reinterpret_cast<int *>(pValue) = atoi(value.c_str());
00062 else if (valueType == typeid(Integer))
00063 *reinterpret_cast<Integer *>(pValue) = Integer((std::string(value) + "h").c_str());
00064 else if (valueType == typeid(ConstByteArrayParameter))
00065 {
00066 m_temp.resize(0);
00067 StringSource(value, true, new HexDecoder(new StringSink(m_temp)));
00068 reinterpret_cast<ConstByteArrayParameter *>(pValue)->Assign((const byte *)m_temp.data(), m_temp.size(), true);
00069 }
00070 else if (valueType == typeid(const byte *))
00071 {
00072 m_temp.resize(0);
00073 StringSource(value, true, new HexDecoder(new StringSink(m_temp)));
00074 *reinterpret_cast<const byte * *>(pValue) = (const byte *)m_temp.data();
00075 }
00076 else
00077 throw ValueTypeMismatch(name, typeid(std::string), valueType);
00078
00079 return true;
00080 }
00081
00082 private:
00083 const TestData &m_data;
00084 mutable std::string m_temp;
00085 };
00086
00087 const std::string & GetRequiredDatum(const TestData &data, const char *name)
00088 {
00089 TestData::const_iterator i = data.find(name);
00090 if (i == data.end())
00091 SignalTestError();
00092 return i->second;
00093 }
00094
00095 void PutDecodedDatumInto(const TestData &data, const char *name, BufferedTransformation &target)
00096 {
00097 std::string s1 = GetRequiredDatum(data, name), s2;
00098
00099 int repeat = 1;
00100 if (s1[0] == 'r')
00101 {
00102 repeat = atoi(s1.c_str()+1);
00103 s1 = s1.substr(s1.find(' ')+1);
00104 }
00105
00106 if (s1[0] == '\"')
00107 s2 = s1.substr(1, s1.find('\"', 1)-1);
00108 else if (s1.substr(0, 2) == "0x")
00109 StringSource(s1.substr(2), true, new HexDecoder(new StringSink(s2)));
00110 else
00111 StringSource(s1, true, new HexDecoder(new StringSink(s2)));
00112
00113 while (repeat--)
00114 target.Put((const byte *)s2.data(), s2.size());
00115 }
00116
00117 std::string GetDecodedDatum(const TestData &data, const char *name)
00118 {
00119 std::string s;
00120 PutDecodedDatumInto(data, name, StringSink(s).Ref());
00121 return s;
00122 }
00123
00124 void TestKeyPairValidAndConsistent(CryptoMaterial &pub, const CryptoMaterial &priv)
00125 {
00126 if (!pub.Validate(GlobalRNG(), 3))
00127 SignalTestFailure();
00128 if (!priv.Validate(GlobalRNG(), 3))
00129 SignalTestFailure();
00130
00131
00132
00133
00134
00135
00136
00137
00138 }
00139
00140 void TestSignatureScheme(TestData &v)
00141 {
00142 std::string name = GetRequiredDatum(v, "Name");
00143 std::string test = GetRequiredDatum(v, "Test");
00144
00145 std::auto_ptr<PK_Signer> signer(ObjectFactoryRegistry<PK_Signer>::Registry().CreateObject(name.c_str()));
00146 std::auto_ptr<PK_Verifier> verifier(ObjectFactoryRegistry<PK_Verifier>::Registry().CreateObject(name.c_str()));
00147
00148 TestDataNameValuePairs pairs(v);
00149 std::string keyFormat = GetRequiredDatum(v, "KeyFormat");
00150
00151 if (keyFormat == "DER")
00152 verifier->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PublicKey")).Ref());
00153 else if (keyFormat == "Component")
00154 verifier->AccessMaterial().AssignFrom(pairs);
00155
00156 if (test == "Verify" || test == "NotVerify")
00157 {
00158 VerifierFilter verifierFilter(*verifier, NULL, VerifierFilter::SIGNATURE_AT_BEGIN);
00159 PutDecodedDatumInto(v, "Signature", verifierFilter);
00160 PutDecodedDatumInto(v, "Message", verifierFilter);
00161 verifierFilter.MessageEnd();
00162 if (verifierFilter.GetLastResult() == (test == "NotVerify"))
00163 SignalTestFailure();
00164 }
00165 else if (test == "PublicKeyValid")
00166 {
00167 if (!verifier->GetMaterial().Validate(GlobalRNG(), 3))
00168 SignalTestFailure();
00169 }
00170 else
00171 goto privateKeyTests;
00172
00173 return;
00174
00175 privateKeyTests:
00176 if (keyFormat == "DER")
00177 signer->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PrivateKey")).Ref());
00178 else if (keyFormat == "Component")
00179 signer->AccessMaterial().AssignFrom(pairs);
00180
00181 if (test == "KeyPairValidAndConsistent")
00182 {
00183 TestKeyPairValidAndConsistent(verifier->AccessMaterial(), signer->GetMaterial());
00184 }
00185 else if (test == "Sign")
00186 {
00187 SignerFilter f(GlobalRNG(), *signer, new HexEncoder(new FileSink(cout)));
00188 StringSource ss(GetDecodedDatum(v, "Message"), true, new Redirector(f));
00189 SignalTestFailure();
00190 }
00191 else if (test == "DeterministicSign")
00192 {
00193 SignalTestError();
00194 assert(false);
00195 }
00196 else if (test == "RandomSign")
00197 {
00198 SignalTestError();
00199 assert(false);
00200 }
00201 else if (test == "GenerateKey")
00202 {
00203 SignalTestError();
00204 assert(false);
00205 }
00206 else
00207 {
00208 SignalTestError();
00209 assert(false);
00210 }
00211 }
00212
00213 void TestAsymmetricCipher(TestData &v)
00214 {
00215 std::string name = GetRequiredDatum(v, "Name");
00216 std::string test = GetRequiredDatum(v, "Test");
00217
00218 std::auto_ptr<PK_Encryptor> encryptor(ObjectFactoryRegistry<PK_Encryptor>::Registry().CreateObject(name.c_str()));
00219 std::auto_ptr<PK_Decryptor> decryptor(ObjectFactoryRegistry<PK_Decryptor>::Registry().CreateObject(name.c_str()));
00220
00221 std::string keyFormat = GetRequiredDatum(v, "KeyFormat");
00222
00223 if (keyFormat == "DER")
00224 {
00225 decryptor->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PrivateKey")).Ref());
00226 encryptor->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PublicKey")).Ref());
00227 }
00228 else if (keyFormat == "Component")
00229 {
00230 TestDataNameValuePairs pairs(v);
00231 decryptor->AccessMaterial().AssignFrom(pairs);
00232 encryptor->AccessMaterial().AssignFrom(pairs);
00233 }
00234
00235 if (test == "DecryptMatch")
00236 {
00237 std::string decrypted, expected = GetDecodedDatum(v, "Plaintext");
00238 StringSource ss(GetDecodedDatum(v, "Ciphertext"), true, new PK_DecryptorFilter(GlobalRNG(), *decryptor, new StringSink(decrypted)));
00239 if (decrypted != expected)
00240 SignalTestFailure();
00241 }
00242 else if (test == "KeyPairValidAndConsistent")
00243 {
00244 TestKeyPairValidAndConsistent(encryptor->AccessMaterial(), decryptor->GetMaterial());
00245 }
00246 else
00247 {
00248 SignalTestError();
00249 assert(false);
00250 }
00251 }
00252
00253 void TestSymmetricCipher(TestData &v)
00254 {
00255 std::string name = GetRequiredDatum(v, "Name");
00256 std::string test = GetRequiredDatum(v, "Test");
00257
00258 std::string key = GetDecodedDatum(v, "Key");
00259 std::string ciphertext = GetDecodedDatum(v, "Ciphertext");
00260 std::string plaintext = GetDecodedDatum(v, "Plaintext");
00261
00262 TestDataNameValuePairs pairs(v);
00263
00264 if (test == "Encrypt")
00265 {
00266 std::auto_ptr<SymmetricCipher> encryptor(ObjectFactoryRegistry<SymmetricCipher, ENCRYPTION>::Registry().CreateObject(name.c_str()));
00267 ConstByteArrayParameter iv;
00268 if (pairs.GetValue(Name::IV(), iv) && iv.size() != encryptor->IVSize())
00269 SignalTestFailure();
00270 encryptor->SetKey((const byte *)key.data(), key.size(), pairs);
00271 int seek = pairs.GetIntValueWithDefault("Seek", 0);
00272 if (seek)
00273 encryptor->Seek(seek);
00274 std::string encrypted;
00275 StringSource ss(plaintext, true, new StreamTransformationFilter(*encryptor, new StringSink(encrypted), StreamTransformationFilter::NO_PADDING));
00276 if (encrypted != ciphertext)
00277 SignalTestFailure();
00278 }
00279 else if (test == "Decrypt")
00280 {
00281 std::auto_ptr<SymmetricCipher> decryptor(ObjectFactoryRegistry<SymmetricCipher, DECRYPTION>::Registry().CreateObject(name.c_str()));
00282 ConstByteArrayParameter iv;
00283 if (pairs.GetValue(Name::IV(), iv) && iv.size() != decryptor->IVSize())
00284 SignalTestFailure();
00285 decryptor->SetKey((const byte *)key.data(), key.size(), pairs);
00286 int seek = pairs.GetIntValueWithDefault("Seek", 0);
00287 if (seek)
00288 decryptor->Seek(seek);
00289 std::string decrypted;
00290 StringSource ss(ciphertext, true, new StreamTransformationFilter(*decryptor, new StringSink(decrypted), StreamTransformationFilter::NO_PADDING));
00291 if (decrypted != plaintext)
00292 SignalTestFailure();
00293 }
00294 else
00295 {
00296 SignalTestError();
00297 assert(false);
00298 }
00299 }
00300
00301 void TestDigestOrMAC(TestData &v, bool testDigest)
00302 {
00303 std::string name = GetRequiredDatum(v, "Name");
00304 std::string test = GetRequiredDatum(v, "Test");
00305
00306 member_ptr<MessageAuthenticationCode> mac;
00307 member_ptr<HashTransformation> hash;
00308 HashTransformation *pHash = NULL;
00309
00310 if (testDigest)
00311 {
00312 hash.reset(ObjectFactoryRegistry<HashTransformation>::Registry().CreateObject(name.c_str()));
00313 pHash = hash.get();
00314 }
00315 else
00316 {
00317 mac.reset(ObjectFactoryRegistry<MessageAuthenticationCode>::Registry().CreateObject(name.c_str()));
00318 pHash = mac.get();
00319 std::string key = GetDecodedDatum(v, "Key");
00320 mac->SetKey((const byte *)key.c_str(), key.size());
00321 }
00322
00323 if (test == "Verify" || test == "VerifyTruncated" || test == "NotVerify")
00324 {
00325 int digestSize = pHash->DigestSize();
00326 if (test == "VerifyTruncated")
00327 digestSize = atoi(GetRequiredDatum(v, "TruncatedSize").c_str());
00328 TruncatedHashModule thash(*pHash, digestSize);
00329 HashVerificationFilter verifierFilter(thash, NULL, HashVerificationFilter::HASH_AT_BEGIN);
00330 PutDecodedDatumInto(v, "Digest", verifierFilter);
00331 PutDecodedDatumInto(v, "Message", verifierFilter);
00332 verifierFilter.MessageEnd();
00333 if (verifierFilter.GetLastResult() == (test == "NotVerify"))
00334 SignalTestFailure();
00335 }
00336 else
00337 {
00338 SignalTestError();
00339 assert(false);
00340 }
00341 }
00342
00343 bool GetField(std::istream &is, std::string &name, std::string &value)
00344 {
00345 name.resize(0);
00346 is >> name;
00347 if (name.empty())
00348 return false;
00349
00350 if (name[name.size()-1] != ':')
00351 SignalTestError();
00352 name.erase(name.size()-1);
00353
00354 while (is.peek() == ' ')
00355 is.ignore(1);
00356
00357
00358 char buffer[128];
00359 value.resize(0);
00360 bool continueLine;
00361
00362 do
00363 {
00364 do
00365 {
00366 is.get(buffer, sizeof(buffer));
00367 value += buffer;
00368 }
00369 while (buffer[0] != 0);
00370 is.clear();
00371 is.ignore();
00372
00373 if (!value.empty() && value[value.size()-1] == '\r')
00374 value.resize(value.size()-1);
00375
00376 if (!value.empty() && value[value.size()-1] == '\\')
00377 {
00378 value.resize(value.size()-1);
00379 continueLine = true;
00380 }
00381 else
00382 continueLine = false;
00383
00384 std::string::size_type i = value.find('#');
00385 if (i != std::string::npos)
00386 value.erase(i);
00387 }
00388 while (continueLine);
00389
00390 return true;
00391 }
00392
00393 void OutputPair(const NameValuePairs &v, const char *name)
00394 {
00395 Integer x;
00396 bool b = v.GetValue(name, x);
00397 assert(b);
00398 cout << name << ": \\\n ";
00399 x.Encode(HexEncoder(new FileSink(cout), false, 64, "\\\n ").Ref(), x.MinEncodedSize());
00400 cout << endl;
00401 }
00402
00403 void OutputNameValuePairs(const NameValuePairs &v)
00404 {
00405 std::string names = v.GetValueNames();
00406 string::size_type i = 0;
00407 while (i < names.size())
00408 {
00409 string::size_type j = names.find_first_of (';', i);
00410
00411 if (j == string::npos)
00412 return;
00413 else
00414 {
00415 std::string name = names.substr(i, j-i);
00416 if (name.find(':') == string::npos)
00417 OutputPair(v, name.c_str());
00418 }
00419
00420 i = j + 1;
00421 }
00422 }
00423
00424 void TestDataFile(const std::string &filename, unsigned int &totalTests, unsigned int &failedTests)
00425 {
00426 std::ifstream file(filename.c_str());
00427 if (!file.good())
00428 throw Exception(Exception::OTHER_ERROR, "Can not open file " + filename + " for reading");
00429 TestData v;
00430 s_currentTestData = &v;
00431 std::string name, value, lastAlgName;
00432
00433 while (file)
00434 {
00435 while (file.peek() == '#')
00436 file.ignore(INT_MAX, '\n');
00437
00438 if (file.peek() == '\n')
00439 v.clear();
00440
00441 if (!GetField(file, name, value))
00442 break;
00443 v[name] = value;
00444
00445 if (name == "Test")
00446 {
00447 bool failed = true;
00448 std::string algType = GetRequiredDatum(v, "AlgorithmType");
00449
00450 if (lastAlgName != GetRequiredDatum(v, "Name"))
00451 {
00452 lastAlgName = GetRequiredDatum(v, "Name");
00453 cout << "\nTesting " << algType.c_str() << " algorithm " << lastAlgName.c_str() << ".\n";
00454 }
00455
00456 try
00457 {
00458 if (algType == "Signature")
00459 TestSignatureScheme(v);
00460 else if (algType == "SymmetricCipher")
00461 TestSymmetricCipher(v);
00462 else if (algType == "AsymmetricCipher")
00463 TestAsymmetricCipher(v);
00464 else if (algType == "MessageDigest")
00465 TestDigestOrMAC(v, true);
00466 else if (algType == "MAC")
00467 TestDigestOrMAC(v, false);
00468 else if (algType == "FileList")
00469 TestDataFile(GetRequiredDatum(v, "Test"), totalTests, failedTests);
00470 else
00471 SignalTestError();
00472 failed = false;
00473 }
00474 catch (TestFailure &)
00475 {
00476 cout << "\nTest failed.\n";
00477 }
00478 catch (CryptoPP::Exception &e)
00479 {
00480 cout << "\nCryptoPP::Exception caught: " << e.what() << endl;
00481 }
00482 catch (std::exception &e)
00483 {
00484 cout << "\nstd::exception caught: " << e.what() << endl;
00485 }
00486
00487 if (failed)
00488 {
00489 cout << "Skipping to next test.\n";
00490 failedTests++;
00491 }
00492 else
00493 cout << "." << flush;
00494
00495 totalTests++;
00496 }
00497 }
00498 }
00499
00500 bool RunTestDataFile(const char *filename)
00501 {
00502 RegisterFactories();
00503 unsigned int totalTests = 0, failedTests = 0;
00504 TestDataFile(filename, totalTests, failedTests);
00505 cout << "\nTests complete. Total tests = " << totalTests << ". Failed tests = " << failedTests << ".\n";
00506 if (failedTests != 0)
00507 cout << "SOME TESTS FAILED!\n";
00508 return failedTests == 0;
00509 }