datatest.cpp

00001 #include "factory.h"
00002 #include "integer.h"
00003 #include "filters.h"
00004 #include "hex.h"
00005 #include "randpool.h"
00006 #include "files.h"
00007 #include "trunhash.h"
00008 #include "queue.h"
00009 #include "validate.h"
00010 #include <iostream>
00011 #include <memory>
00012 
00013 USING_NAMESPACE(CryptoPP)
00014 USING_NAMESPACE(std)
00015 
00016 typedef std::map<std::string, std::string> TestData;
00017 
00018 class TestFailure : public Exception
00019 {
00020 public:
00021         TestFailure() : Exception(OTHER_ERROR, "Validation test failed") {}
00022 };
00023 
00024 static const TestData *s_currentTestData = NULL;
00025 
00026 static void OutputTestData(const TestData &v)
00027 {
00028         for (TestData::const_iterator i = v.begin(); i != v.end(); ++i)
00029         {
00030                 cerr << i->first << ": " << i->second << endl;
00031         }
00032 }
00033 
00034 static void SignalTestFailure()
00035 {
00036         OutputTestData(*s_currentTestData);
00037         throw TestFailure();
00038 }
00039 
00040 static void SignalTestError()
00041 {
00042         OutputTestData(*s_currentTestData);
00043         throw Exception(Exception::OTHER_ERROR, "Unexpected error during validation test");
00044 }
00045 
00046 const std::string & GetRequiredDatum(const TestData &data, const char *name)
00047 {
00048         TestData::const_iterator i = data.find(name);
00049         if (i == data.end())
00050                 SignalTestError();
00051         return i->second;
00052 }
00053 
00054 void PutDecodedDatumInto(const TestData &data, const char *name, BufferedTransformation &target)
00055 {
00056         std::string s1 = GetRequiredDatum(data, name), s2;
00057 
00058         while (!s1.empty())
00059         {
00060                 while (s1[0] == ' ')
00061                         s1 = s1.substr(1);
00062 
00063                 int repeat = 1;
00064                 if (s1[0] == 'r')
00065                 {
00066                         repeat = atoi(s1.c_str()+1);
00067                         s1 = s1.substr(s1.find(' ')+1);
00068                 }
00069                 
00070                 s2 = ""; // MSVC 6 doesn't have clear();
00071 
00072                 if (s1[0] == '\"')
00073                 {
00074                         s2 = s1.substr(1, s1.find('\"', 1)-1);
00075                         s1 = s1.substr(s2.length() + 2);
00076                 }
00077                 else if (s1.substr(0, 2) == "0x")
00078                 {
00079                         StringSource(s1.substr(2, s1.find(' ')), true, new HexDecoder(new StringSink(s2)));
00080                         s1 = s1.substr(STDMIN(s1.find(' '), s1.length()));
00081                 }
00082                 else
00083                 {
00084                         StringSource(s1.substr(0, s1.find(' ')), true, new HexDecoder(new StringSink(s2)));
00085                         s1 = s1.substr(STDMIN(s1.find(' '), s1.length()));
00086                 }
00087 
00088                 ByteQueue q;
00089                 while (repeat--)
00090                 {
00091                         q.Put((const byte *)s2.data(), s2.size());
00092                         if (q.MaxRetrievable() > 4*1024 || repeat == 0)
00093                                 q.TransferTo(target);
00094                 }
00095         }
00096 }
00097 
00098 std::string GetDecodedDatum(const TestData &data, const char *name)
00099 {
00100         std::string s;
00101         PutDecodedDatumInto(data, name, StringSink(s).Ref());
00102         return s;
00103 }
00104 
00105 class TestDataNameValuePairs : public NameValuePairs
00106 {
00107 public:
00108         TestDataNameValuePairs(const TestData &data) : m_data(data) {}
00109 
00110         virtual bool GetVoidValue(const char *name, const std::type_info &valueType, void *pValue) const
00111         {
00112                 TestData::const_iterator i = m_data.find(name);
00113                 if (i == m_data.end())
00114                         return false;
00115                 
00116                 const std::string &value = i->second;
00117                 
00118                 if (valueType == typeid(int))
00119                         *reinterpret_cast<int *>(pValue) = atoi(value.c_str());
00120                 else if (valueType == typeid(Integer))
00121                         *reinterpret_cast<Integer *>(pValue) = Integer((std::string(value) + "h").c_str());
00122                 else if (valueType == typeid(ConstByteArrayParameter))
00123                 {
00124                         m_temp.resize(0);
00125                         PutDecodedDatumInto(m_data, name, StringSink(m_temp).Ref());
00126                         reinterpret_cast<ConstByteArrayParameter *>(pValue)->Assign((const byte *)m_temp.data(), m_temp.size(), true);
00127                 }
00128                 else if (valueType == typeid(const byte *))
00129                 {
00130                         m_temp.resize(0);
00131                         PutDecodedDatumInto(m_data, name, StringSink(m_temp).Ref());
00132                         *reinterpret_cast<const byte * *>(pValue) = (const byte *)m_temp.data();
00133                 }
00134                 else
00135                         throw ValueTypeMismatch(name, typeid(std::string), valueType);
00136 
00137                 return true;
00138         }
00139 
00140 private:
00141         const TestData &m_data;
00142         mutable std::string m_temp;
00143 };
00144 
00145 void TestKeyPairValidAndConsistent(CryptoMaterial &pub, const CryptoMaterial &priv)
00146 {
00147         if (!pub.Validate(GlobalRNG(), 3))
00148                 SignalTestFailure();
00149         if (!priv.Validate(GlobalRNG(), 3))
00150                 SignalTestFailure();
00151 
00152 /*      EqualityComparisonFilter comparison;
00153         pub.Save(ChannelSwitch(comparison, "0"));
00154         pub.AssignFrom(priv);
00155         pub.Save(ChannelSwitch(comparison, "1"));
00156         comparison.ChannelMessageSeriesEnd("0");
00157         comparison.ChannelMessageSeriesEnd("1");
00158 */
00159 }
00160 
00161 void TestSignatureScheme(TestData &v)
00162 {
00163         std::string name = GetRequiredDatum(v, "Name");
00164         std::string test = GetRequiredDatum(v, "Test");
00165 
00166         std::auto_ptr<PK_Signer> signer(ObjectFactoryRegistry<PK_Signer>::Registry().CreateObject(name.c_str()));
00167         std::auto_ptr<PK_Verifier> verifier(ObjectFactoryRegistry<PK_Verifier>::Registry().CreateObject(name.c_str()));
00168 
00169         TestDataNameValuePairs pairs(v);
00170         std::string keyFormat = GetRequiredDatum(v, "KeyFormat");
00171 
00172         if (keyFormat == "DER")
00173                 verifier->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PublicKey")).Ref());
00174         else if (keyFormat == "Component")
00175                 verifier->AccessMaterial().AssignFrom(pairs);
00176 
00177         if (test == "Verify" || test == "NotVerify")
00178         {
00179                 VerifierFilter verifierFilter(*verifier, NULL, VerifierFilter::SIGNATURE_AT_BEGIN);
00180                 PutDecodedDatumInto(v, "Signature", verifierFilter);
00181                 PutDecodedDatumInto(v, "Message", verifierFilter);
00182                 verifierFilter.MessageEnd();
00183                 if (verifierFilter.GetLastResult() == (test == "NotVerify"))
00184                         SignalTestFailure();
00185         }
00186         else if (test == "PublicKeyValid")
00187         {
00188                 if (!verifier->GetMaterial().Validate(GlobalRNG(), 3))
00189                         SignalTestFailure();
00190         }
00191         else
00192                 goto privateKeyTests;
00193 
00194         return;
00195 
00196 privateKeyTests:
00197         if (keyFormat == "DER")
00198                 signer->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PrivateKey")).Ref());
00199         else if (keyFormat == "Component")
00200                 signer->AccessMaterial().AssignFrom(pairs);
00201         
00202         if (test == "KeyPairValidAndConsistent")
00203         {
00204                 TestKeyPairValidAndConsistent(verifier->AccessMaterial(), signer->GetMaterial());
00205         }
00206         else if (test == "Sign")
00207         {
00208                 SignerFilter f(GlobalRNG(), *signer, new HexEncoder(new FileSink(cout)));
00209                 StringSource ss(GetDecodedDatum(v, "Message"), true, new Redirector(f));
00210                 SignalTestFailure();
00211         }
00212         else if (test == "DeterministicSign")
00213         {
00214                 SignalTestError();
00215                 assert(false);  // TODO: implement
00216         }
00217         else if (test == "RandomSign")
00218         {
00219                 SignalTestError();
00220                 assert(false);  // TODO: implement
00221         }
00222         else if (test == "GenerateKey")
00223         {
00224                 SignalTestError();
00225                 assert(false);
00226         }
00227         else
00228         {
00229                 SignalTestError();
00230                 assert(false);
00231         }
00232 }
00233 
00234 void TestAsymmetricCipher(TestData &v)
00235 {
00236         std::string name = GetRequiredDatum(v, "Name");
00237         std::string test = GetRequiredDatum(v, "Test");
00238 
00239         std::auto_ptr<PK_Encryptor> encryptor(ObjectFactoryRegistry<PK_Encryptor>::Registry().CreateObject(name.c_str()));
00240         std::auto_ptr<PK_Decryptor> decryptor(ObjectFactoryRegistry<PK_Decryptor>::Registry().CreateObject(name.c_str()));
00241 
00242         std::string keyFormat = GetRequiredDatum(v, "KeyFormat");
00243 
00244         if (keyFormat == "DER")
00245         {
00246                 decryptor->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PrivateKey")).Ref());
00247                 encryptor->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PublicKey")).Ref());
00248         }
00249         else if (keyFormat == "Component")
00250         {
00251                 TestDataNameValuePairs pairs(v);
00252                 decryptor->AccessMaterial().AssignFrom(pairs);
00253                 encryptor->AccessMaterial().AssignFrom(pairs);
00254         }
00255 
00256         if (test == "DecryptMatch")
00257         {
00258                 std::string decrypted, expected = GetDecodedDatum(v, "Plaintext");
00259                 StringSource ss(GetDecodedDatum(v, "Ciphertext"), true, new PK_DecryptorFilter(GlobalRNG(), *decryptor, new StringSink(decrypted)));
00260                 if (decrypted != expected)
00261                         SignalTestFailure();
00262         }
00263         else if (test == "KeyPairValidAndConsistent")
00264         {
00265                 TestKeyPairValidAndConsistent(encryptor->AccessMaterial(), decryptor->GetMaterial());
00266         }
00267         else
00268         {
00269                 SignalTestError();
00270                 assert(false);
00271         }
00272 }
00273 
00274 void TestSymmetricCipher(TestData &v)
00275 {
00276         std::string name = GetRequiredDatum(v, "Name");
00277         std::string test = GetRequiredDatum(v, "Test");
00278 
00279         std::string key = GetDecodedDatum(v, "Key");
00280         std::string plaintext = GetDecodedDatum(v, "Plaintext");
00281 
00282         TestDataNameValuePairs pairs(v);
00283 
00284         if (test == "Encrypt" || test == "EncryptXorDigest")
00285         {
00286                 std::auto_ptr<SymmetricCipher> encryptor(ObjectFactoryRegistry<SymmetricCipher, ENCRYPTION>::Registry().CreateObject(name.c_str()));
00287                 std::auto_ptr<SymmetricCipher> decryptor(ObjectFactoryRegistry<SymmetricCipher, DECRYPTION>::Registry().CreateObject(name.c_str()));
00288                 ConstByteArrayParameter iv;
00289                 if (pairs.GetValue(Name::IV(), iv) && iv.size() != encryptor->IVSize())
00290                         SignalTestFailure();
00291                 encryptor->SetKey((const byte *)key.data(), key.size(), pairs);
00292                 decryptor->SetKey((const byte *)key.data(), key.size(), pairs);
00293                 int seek = pairs.GetIntValueWithDefault("Seek", 0);
00294                 if (seek)
00295                 {
00296                         encryptor->Seek(seek);
00297                         decryptor->Seek(seek);
00298                 }
00299                 std::string encrypted, xorDigest, ciphertext, ciphertextXorDigest;
00300                 StringSource ss(plaintext, false, new StreamTransformationFilter(*encryptor, new StringSink(encrypted), StreamTransformationFilter::NO_PADDING));
00301                 ss.Pump(plaintext.size()/2 + 1);
00302                 ss.PumpAll();
00303                 /*{
00304                         std::string z;
00305                         encryptor->Seek(seek);
00306                         StringSource ss(plaintext, false, new StreamTransformationFilter(*encryptor, new StringSink(z), StreamTransformationFilter::NO_PADDING));
00307                         while (ss.Pump(64)) {}
00308                         ss.PumpAll();
00309                         for (int i=0; i<z.length(); i++)
00310                                 assert(encrypted[i] == z[i]);
00311                 }*/
00312                 if (test == "Encrypt")
00313                         ciphertext = GetDecodedDatum(v, "Ciphertext");
00314                 else
00315                 {
00316                         ciphertextXorDigest = GetDecodedDatum(v, "CiphertextXorDigest");
00317                         xorDigest.append(encrypted, 0, 64);
00318                         for (size_t i=64; i<encrypted.size(); i++)
00319                                 xorDigest[i%64] ^= encrypted[i];
00320                 }
00321                 if (test == "Encrypt" ? encrypted != ciphertext : xorDigest != ciphertextXorDigest)
00322                 {
00323                         std::cout << "incorrectly encrypted: ";
00324                         StringSource xx(encrypted, false, new HexEncoder(new FileSink(std::cout)));
00325                         xx.Pump(256); xx.Flush(false);
00326                         std::cout << "\n";
00327                         SignalTestFailure();
00328                 }
00329                 std::string decrypted;
00330                 StringSource dd(encrypted, false, new StreamTransformationFilter(*decryptor, new StringSink(decrypted), StreamTransformationFilter::NO_PADDING));
00331                 dd.Pump(plaintext.size()/2 + 1);
00332                 dd.PumpAll();
00333                 if (decrypted != plaintext)
00334                 {
00335                         std::cout << "incorrectly decrypted: ";
00336                         StringSource xx(decrypted, false, new HexEncoder(new FileSink(std::cout)));
00337                         xx.Pump(256); xx.Flush(false);
00338                         std::cout << "\n";
00339                         SignalTestFailure();
00340                 }
00341         }
00342         else if (test == "Decrypt")
00343         {
00344         }
00345         else
00346         {
00347                 SignalTestError();
00348                 assert(false);
00349         }
00350 }
00351 
00352 void TestDigestOrMAC(TestData &v, bool testDigest)
00353 {
00354         std::string name = GetRequiredDatum(v, "Name");
00355         std::string test = GetRequiredDatum(v, "Test");
00356 
00357         member_ptr<MessageAuthenticationCode> mac;
00358         member_ptr<HashTransformation> hash;
00359         HashTransformation *pHash = NULL;
00360 
00361         TestDataNameValuePairs pairs(v);
00362 
00363         if (testDigest)
00364         {
00365                 hash.reset(ObjectFactoryRegistry<HashTransformation>::Registry().CreateObject(name.c_str()));
00366                 pHash = hash.get();
00367         }
00368         else
00369         {
00370                 mac.reset(ObjectFactoryRegistry<MessageAuthenticationCode>::Registry().CreateObject(name.c_str()));
00371                 pHash = mac.get();
00372                 ConstByteArrayParameter iv;
00373                 if (pairs.GetValue(Name::IV(), iv) && iv.size() != mac->IVSize())
00374                         SignalTestFailure();
00375                 std::string key = GetDecodedDatum(v, "Key");
00376                 mac->SetKey((const byte *)key.c_str(), key.size(), pairs);
00377         }
00378 
00379         if (test == "Verify" || test == "VerifyTruncated" || test == "NotVerify")
00380         {
00381                 int digestSize = pHash->DigestSize();
00382                 if (test == "VerifyTruncated")
00383                         digestSize = atoi(GetRequiredDatum(v, "TruncatedSize").c_str());
00384                 TruncatedHashModule thash(*pHash, digestSize);
00385                 HashVerificationFilter verifierFilter(thash, NULL, HashVerificationFilter::HASH_AT_BEGIN);
00386                 PutDecodedDatumInto(v, "Digest", verifierFilter);
00387                 PutDecodedDatumInto(v, "Message", verifierFilter);
00388                 verifierFilter.MessageEnd();
00389                 if (verifierFilter.GetLastResult() == (test == "NotVerify"))
00390                         SignalTestFailure();
00391         }
00392         else
00393         {
00394                 SignalTestError();
00395                 assert(false);
00396         }
00397 }
00398 
00399 bool GetField(std::istream &is, std::string &name, std::string &value)
00400 {
00401         name.resize(0);         // GCC workaround: 2.95.3 doesn't have clear()
00402         is >> name;
00403         if (name.empty())
00404                 return false;
00405 
00406         if (name[name.size()-1] != ':')
00407                 SignalTestError();
00408         name.erase(name.size()-1);
00409 
00410         while (is.peek() == ' ')
00411                 is.ignore(1);
00412 
00413         // VC60 workaround: getline bug
00414         char buffer[128];
00415         value.resize(0);        // GCC workaround: 2.95.3 doesn't have clear()
00416         bool continueLine;
00417 
00418         do
00419         {
00420                 do
00421                 {
00422                         is.get(buffer, sizeof(buffer));
00423                         value += buffer;
00424                 }
00425                 while (buffer[0] != 0);
00426                 is.clear();
00427                 is.ignore();
00428 
00429                 if (!value.empty() && value[value.size()-1] == '\r')
00430                         value.resize(value.size()-1);
00431 
00432                 if (!value.empty() && value[value.size()-1] == '\\')
00433                 {
00434                         value.resize(value.size()-1);
00435                         continueLine = true;
00436                 }
00437                 else
00438                         continueLine = false;
00439 
00440                 std::string::size_type i = value.find('#');
00441                 if (i != std::string::npos)
00442                         value.erase(i);
00443         }
00444         while (continueLine);
00445 
00446         return true;
00447 }
00448 
00449 void OutputPair(const NameValuePairs &v, const char *name)
00450 {
00451         Integer x;
00452         bool b = v.GetValue(name, x);
00453         assert(b);
00454         cout << name << ": \\\n    ";
00455         x.Encode(HexEncoder(new FileSink(cout), false, 64, "\\\n    ").Ref(), x.MinEncodedSize());
00456         cout << endl;
00457 }
00458 
00459 void OutputNameValuePairs(const NameValuePairs &v)
00460 {
00461         std::string names = v.GetValueNames();
00462         string::size_type i = 0;
00463         while (i < names.size())
00464         {
00465                 string::size_type j = names.find_first_of (';', i);
00466 
00467                 if (j == string::npos)
00468                         return;
00469                 else
00470                 {
00471                         std::string name = names.substr(i, j-i);
00472                         if (name.find(':') == string::npos)
00473                                 OutputPair(v, name.c_str());
00474                 }
00475 
00476                 i = j + 1;
00477         }
00478 }
00479 
00480 void TestDataFile(const std::string &filename, unsigned int &totalTests, unsigned int &failedTests)
00481 {
00482         std::ifstream file(filename.c_str());
00483         if (!file.good())
00484                 throw Exception(Exception::OTHER_ERROR, "Can not open file " + filename + " for reading");
00485         TestData v;
00486         s_currentTestData = &v;
00487         std::string name, value, lastAlgName;
00488 
00489         while (file)
00490         {
00491                 while (file.peek() == '#')
00492                         file.ignore(INT_MAX, '\n');
00493 
00494                 if (file.peek() == '\n')
00495                         v.clear();
00496 
00497                 if (!GetField(file, name, value))
00498                         break;
00499                 v[name] = value;
00500 
00501                 if (name == "Test")
00502                 {
00503                         bool failed = true;
00504                         std::string algType = GetRequiredDatum(v, "AlgorithmType");
00505 
00506                         if (lastAlgName != GetRequiredDatum(v, "Name"))
00507                         {
00508                                 lastAlgName = GetRequiredDatum(v, "Name");
00509                                 cout << "\nTesting " << algType.c_str() << " algorithm " << lastAlgName.c_str() << ".\n";
00510                         }
00511 
00512                         try
00513                         {
00514                                 if (algType == "Signature")
00515                                         TestSignatureScheme(v);
00516                                 else if (algType == "SymmetricCipher")
00517                                         TestSymmetricCipher(v);
00518                                 else if (algType == "AsymmetricCipher")
00519                                         TestAsymmetricCipher(v);
00520                                 else if (algType == "MessageDigest")
00521                                         TestDigestOrMAC(v, true);
00522                                 else if (algType == "MAC")
00523                                         TestDigestOrMAC(v, false);
00524                                 else if (algType == "FileList")
00525                                         TestDataFile(GetRequiredDatum(v, "Test"), totalTests, failedTests);
00526                                 else
00527                                         SignalTestError();
00528                                 failed = false;
00529                         }
00530                         catch (TestFailure &)
00531                         {
00532                                 cout << "\nTest failed.\n";
00533                         }
00534                         catch (CryptoPP::Exception &e)
00535                         {
00536                                 cout << "\nCryptoPP::Exception caught: " << e.what() << endl;
00537                         }
00538                         catch (std::exception &e)
00539                         {
00540                                 cout << "\nstd::exception caught: " << e.what() << endl;
00541                         }
00542 
00543                         if (failed)
00544                         {
00545                                 cout << "Skipping to next test.\n";
00546                                 failedTests++;
00547                         }
00548                         else
00549                                 cout << "." << flush;
00550 
00551                         totalTests++;
00552                 }
00553         }
00554 }
00555 
00556 bool RunTestDataFile(const char *filename)
00557 {
00558         unsigned int totalTests = 0, failedTests = 0;
00559         TestDataFile(filename, totalTests, failedTests);
00560         cout << "\nTests complete. Total tests = " << totalTests << ". Failed tests = " << failedTests << ".\n";
00561         if (failedTests != 0)
00562                 cout << "SOME TESTS FAILED!\n";
00563         return failedTests == 0;
00564 }

Generated on Fri Jun 1 11:11:20 2007 for Crypto++ by  doxygen 1.5.2