datatest.cpp

00001 #include "factory.h"
00002 #include "integer.h"
00003 #include "filters.h"
00004 #include "hex.h"
00005 #include "randpool.h"
00006 #include "files.h"
00007 #include "trunhash.h"
00008 #include <iostream>
00009 #include <memory>
00010 
00011 USING_NAMESPACE(CryptoPP)
00012 USING_NAMESPACE(std)
00013 
00014 RandomPool & GlobalRNG();
00015 void RegisterFactories();
00016 
00017 typedef std::map<std::string, std::string> TestData;
00018 
00019 class TestFailure : public Exception
00020 {
00021 public:
00022         TestFailure() : Exception(OTHER_ERROR, "Validation test failed") {}
00023 };
00024 
00025 static const TestData *s_currentTestData = NULL;
00026 
00027 static void OutputTestData(const TestData &v)
00028 {
00029         for (TestData::const_iterator i = v.begin(); i != v.end(); ++i)
00030         {
00031                 cerr << i->first << ": " << i->second << endl;
00032         }
00033 }
00034 
00035 static void SignalTestFailure()
00036 {
00037         OutputTestData(*s_currentTestData);
00038         throw TestFailure();
00039 }
00040 
00041 static void SignalTestError()
00042 {
00043         OutputTestData(*s_currentTestData);
00044         throw Exception(Exception::OTHER_ERROR, "Unexpected error during validation test");
00045 }
00046 
00047 class TestDataNameValuePairs : public NameValuePairs
00048 {
00049 public:
00050         TestDataNameValuePairs(const TestData &data) : m_data(data) {}
00051 
00052         virtual bool GetVoidValue(const char *name, const std::type_info &valueType, void *pValue) const
00053         {
00054                 TestData::const_iterator i = m_data.find(name);
00055                 if (i == m_data.end())
00056                         return false;
00057                 
00058                 const std::string &value = i->second;
00059                 
00060                 if (valueType == typeid(int))
00061                         *reinterpret_cast<int *>(pValue) = atoi(value.c_str());
00062                 else if (valueType == typeid(Integer))
00063                         *reinterpret_cast<Integer *>(pValue) = Integer((std::string(value) + "h").c_str());
00064                 else if (valueType == typeid(ConstByteArrayParameter))
00065                 {
00066                         m_temp.resize(0);
00067                         StringSource(value, true, new HexDecoder(new StringSink(m_temp)));
00068                         reinterpret_cast<ConstByteArrayParameter *>(pValue)->Assign((const byte *)m_temp.data(), m_temp.size(), true);
00069                 }
00070                 else if (valueType == typeid(const byte *))
00071                 {
00072                         m_temp.resize(0);
00073                         StringSource(value, true, new HexDecoder(new StringSink(m_temp)));
00074                         *reinterpret_cast<const byte * *>(pValue) = (const byte *)m_temp.data();
00075                 }
00076                 else
00077                         throw ValueTypeMismatch(name, typeid(std::string), valueType);
00078 
00079                 return true;
00080         }
00081 
00082 private:
00083         const TestData &m_data;
00084         mutable std::string m_temp;
00085 };
00086 
00087 const std::string & GetRequiredDatum(const TestData &data, const char *name)
00088 {
00089         TestData::const_iterator i = data.find(name);
00090         if (i == data.end())
00091                 SignalTestError();
00092         return i->second;
00093 }
00094 
00095 void PutDecodedDatumInto(const TestData &data, const char *name, BufferedTransformation &target)
00096 {
00097         std::string s1 = GetRequiredDatum(data, name), s2;
00098 
00099         int repeat = 1;
00100         if (s1[0] == 'r')
00101         {
00102                 repeat = atoi(s1.c_str()+1);
00103                 s1 = s1.substr(s1.find(' ')+1);
00104         }
00105         
00106         if (s1[0] == '\"')
00107                 s2 = s1.substr(1, s1.find('\"', 1)-1);
00108         else if (s1.substr(0, 2) == "0x")
00109                 StringSource(s1.substr(2), true, new HexDecoder(new StringSink(s2)));
00110         else
00111                 StringSource(s1, true, new HexDecoder(new StringSink(s2)));
00112 
00113         while (repeat--)
00114                 target.Put((const byte *)s2.data(), s2.size());
00115 }
00116 
00117 std::string GetDecodedDatum(const TestData &data, const char *name)
00118 {
00119         std::string s;
00120         PutDecodedDatumInto(data, name, StringSink(s).Ref());
00121         return s;
00122 }
00123 
00124 void TestKeyPairValidAndConsistent(CryptoMaterial &pub, const CryptoMaterial &priv)
00125 {
00126         if (!pub.Validate(GlobalRNG(), 3))
00127                 SignalTestFailure();
00128         if (!priv.Validate(GlobalRNG(), 3))
00129                 SignalTestFailure();
00130 
00131 /*      EqualityComparisonFilter comparison;
00132         pub.Save(ChannelSwitch(comparison, "0"));
00133         pub.AssignFrom(priv);
00134         pub.Save(ChannelSwitch(comparison, "1"));
00135         comparison.ChannelMessageSeriesEnd("0");
00136         comparison.ChannelMessageSeriesEnd("1");
00137 */
00138 }
00139 
00140 void TestSignatureScheme(TestData &v)
00141 {
00142         std::string name = GetRequiredDatum(v, "Name");
00143         std::string test = GetRequiredDatum(v, "Test");
00144 
00145         std::auto_ptr<PK_Signer> signer(ObjectFactoryRegistry<PK_Signer>::Registry().CreateObject(name.c_str()));
00146         std::auto_ptr<PK_Verifier> verifier(ObjectFactoryRegistry<PK_Verifier>::Registry().CreateObject(name.c_str()));
00147 
00148         TestDataNameValuePairs pairs(v);
00149         std::string keyFormat = GetRequiredDatum(v, "KeyFormat");
00150 
00151         if (keyFormat == "DER")
00152                 verifier->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PublicKey")).Ref());
00153         else if (keyFormat == "Component")
00154                 verifier->AccessMaterial().AssignFrom(pairs);
00155 
00156         if (test == "Verify" || test == "NotVerify")
00157         {
00158                 VerifierFilter verifierFilter(*verifier, NULL, VerifierFilter::SIGNATURE_AT_BEGIN);
00159                 PutDecodedDatumInto(v, "Signature", verifierFilter);
00160                 PutDecodedDatumInto(v, "Message", verifierFilter);
00161                 verifierFilter.MessageEnd();
00162                 if (verifierFilter.GetLastResult() == (test == "NotVerify"))
00163                         SignalTestFailure();
00164         }
00165         else if (test == "PublicKeyValid")
00166         {
00167                 if (!verifier->GetMaterial().Validate(GlobalRNG(), 3))
00168                         SignalTestFailure();
00169         }
00170         else
00171                 goto privateKeyTests;
00172 
00173         return;
00174 
00175 privateKeyTests:
00176         if (keyFormat == "DER")
00177                 signer->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PrivateKey")).Ref());
00178         else if (keyFormat == "Component")
00179                 signer->AccessMaterial().AssignFrom(pairs);
00180         
00181         if (test == "KeyPairValidAndConsistent")
00182         {
00183                 TestKeyPairValidAndConsistent(verifier->AccessMaterial(), signer->GetMaterial());
00184         }
00185         else if (test == "Sign")
00186         {
00187                 SignerFilter f(GlobalRNG(), *signer, new HexEncoder(new FileSink(cout)));
00188                 StringSource ss(GetDecodedDatum(v, "Message"), true, new Redirector(f));
00189                 SignalTestFailure();
00190         }
00191         else if (test == "DeterministicSign")
00192         {
00193                 SignalTestError();
00194                 assert(false);  // TODO: implement
00195         }
00196         else if (test == "RandomSign")
00197         {
00198                 SignalTestError();
00199                 assert(false);  // TODO: implement
00200         }
00201         else if (test == "GenerateKey")
00202         {
00203                 SignalTestError();
00204                 assert(false);
00205         }
00206         else
00207         {
00208                 SignalTestError();
00209                 assert(false);
00210         }
00211 }
00212 
00213 void TestAsymmetricCipher(TestData &v)
00214 {
00215         std::string name = GetRequiredDatum(v, "Name");
00216         std::string test = GetRequiredDatum(v, "Test");
00217 
00218         std::auto_ptr<PK_Encryptor> encryptor(ObjectFactoryRegistry<PK_Encryptor>::Registry().CreateObject(name.c_str()));
00219         std::auto_ptr<PK_Decryptor> decryptor(ObjectFactoryRegistry<PK_Decryptor>::Registry().CreateObject(name.c_str()));
00220 
00221         std::string keyFormat = GetRequiredDatum(v, "KeyFormat");
00222 
00223         if (keyFormat == "DER")
00224         {
00225                 decryptor->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PrivateKey")).Ref());
00226                 encryptor->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PublicKey")).Ref());
00227         }
00228         else if (keyFormat == "Component")
00229         {
00230                 TestDataNameValuePairs pairs(v);
00231                 decryptor->AccessMaterial().AssignFrom(pairs);
00232                 encryptor->AccessMaterial().AssignFrom(pairs);
00233         }
00234 
00235         if (test == "DecryptMatch")
00236         {
00237                 std::string decrypted, expected = GetDecodedDatum(v, "Plaintext");
00238                 StringSource ss(GetDecodedDatum(v, "Ciphertext"), true, new PK_DecryptorFilter(GlobalRNG(), *decryptor, new StringSink(decrypted)));
00239                 if (decrypted != expected)
00240                         SignalTestFailure();
00241         }
00242         else if (test == "KeyPairValidAndConsistent")
00243         {
00244                 TestKeyPairValidAndConsistent(encryptor->AccessMaterial(), decryptor->GetMaterial());
00245         }
00246         else
00247         {
00248                 SignalTestError();
00249                 assert(false);
00250         }
00251 }
00252 
00253 void TestSymmetricCipher(TestData &v)
00254 {
00255         std::string name = GetRequiredDatum(v, "Name");
00256         std::string test = GetRequiredDatum(v, "Test");
00257 
00258         std::string key = GetDecodedDatum(v, "Key");
00259         std::string ciphertext = GetDecodedDatum(v, "Ciphertext");
00260         std::string plaintext = GetDecodedDatum(v, "Plaintext");
00261 
00262         TestDataNameValuePairs pairs(v);
00263 
00264         if (test == "Encrypt")
00265         {
00266                 std::auto_ptr<SymmetricCipher> encryptor(ObjectFactoryRegistry<SymmetricCipher, ENCRYPTION>::Registry().CreateObject(name.c_str()));
00267                 ConstByteArrayParameter iv;
00268                 if (pairs.GetValue(Name::IV(), iv) && iv.size() != encryptor->IVSize())
00269                         SignalTestFailure();
00270                 encryptor->SetKey((const byte *)key.data(), key.size(), pairs);
00271                 int seek = pairs.GetIntValueWithDefault("Seek", 0);
00272                 if (seek)
00273                         encryptor->Seek(seek);
00274                 std::string encrypted;
00275                 StringSource ss(plaintext, true, new StreamTransformationFilter(*encryptor, new StringSink(encrypted), StreamTransformationFilter::NO_PADDING));
00276                 if (encrypted != ciphertext)
00277                         SignalTestFailure();
00278         }
00279         else if (test == "Decrypt")
00280         {
00281                 std::auto_ptr<SymmetricCipher> decryptor(ObjectFactoryRegistry<SymmetricCipher, DECRYPTION>::Registry().CreateObject(name.c_str()));
00282                 ConstByteArrayParameter iv;
00283                 if (pairs.GetValue(Name::IV(), iv) && iv.size() != decryptor->IVSize())
00284                         SignalTestFailure();
00285                 decryptor->SetKey((const byte *)key.data(), key.size(), pairs);
00286                 int seek = pairs.GetIntValueWithDefault("Seek", 0);
00287                 if (seek)
00288                         decryptor->Seek(seek);
00289                 std::string decrypted;
00290                 StringSource ss(ciphertext, true, new StreamTransformationFilter(*decryptor, new StringSink(decrypted), StreamTransformationFilter::NO_PADDING));
00291                 if (decrypted != plaintext)
00292                         SignalTestFailure();
00293         }
00294         else
00295         {
00296                 SignalTestError();
00297                 assert(false);
00298         }
00299 }
00300 
00301 void TestDigestOrMAC(TestData &v, bool testDigest)
00302 {
00303         std::string name = GetRequiredDatum(v, "Name");
00304         std::string test = GetRequiredDatum(v, "Test");
00305 
00306         member_ptr<MessageAuthenticationCode> mac;
00307         member_ptr<HashTransformation> hash;
00308         HashTransformation *pHash = NULL;
00309 
00310         if (testDigest)
00311         {
00312                 hash.reset(ObjectFactoryRegistry<HashTransformation>::Registry().CreateObject(name.c_str()));
00313                 pHash = hash.get();
00314         }
00315         else
00316         {
00317                 mac.reset(ObjectFactoryRegistry<MessageAuthenticationCode>::Registry().CreateObject(name.c_str()));
00318                 pHash = mac.get();
00319                 std::string key = GetDecodedDatum(v, "Key");
00320                 mac->SetKey((const byte *)key.c_str(), key.size());
00321         }
00322 
00323         if (test == "Verify" || test == "VerifyTruncated" || test == "NotVerify")
00324         {
00325                 int digestSize = pHash->DigestSize();
00326                 if (test == "VerifyTruncated")
00327                         digestSize = atoi(GetRequiredDatum(v, "TruncatedSize").c_str());
00328                 TruncatedHashModule thash(*pHash, digestSize);
00329                 HashVerificationFilter verifierFilter(thash, NULL, HashVerificationFilter::HASH_AT_BEGIN);
00330                 PutDecodedDatumInto(v, "Digest", verifierFilter);
00331                 PutDecodedDatumInto(v, "Message", verifierFilter);
00332                 verifierFilter.MessageEnd();
00333                 if (verifierFilter.GetLastResult() == (test == "NotVerify"))
00334                         SignalTestFailure();
00335         }
00336         else
00337         {
00338                 SignalTestError();
00339                 assert(false);
00340         }
00341 }
00342 
00343 bool GetField(std::istream &is, std::string &name, std::string &value)
00344 {
00345         name.resize(0);         // GCC workaround: 2.95.3 doesn't have clear()
00346         is >> name;
00347         if (name.empty())
00348                 return false;
00349 
00350         if (name[name.size()-1] != ':')
00351                 SignalTestError();
00352         name.erase(name.size()-1);
00353 
00354         while (is.peek() == ' ')
00355                 is.ignore(1);
00356 
00357         // VC60 workaround: getline bug
00358         char buffer[128];
00359         value.resize(0);        // GCC workaround: 2.95.3 doesn't have clear()
00360         bool continueLine;
00361 
00362         do
00363         {
00364                 do
00365                 {
00366                         is.get(buffer, sizeof(buffer));
00367                         value += buffer;
00368                 }
00369                 while (buffer[0] != 0);
00370                 is.clear();
00371                 is.ignore();
00372 
00373                 if (!value.empty() && value[value.size()-1] == '\r')
00374                         value.resize(value.size()-1);
00375 
00376                 if (!value.empty() && value[value.size()-1] == '\\')
00377                 {
00378                         value.resize(value.size()-1);
00379                         continueLine = true;
00380                 }
00381                 else
00382                         continueLine = false;
00383 
00384                 std::string::size_type i = value.find('#');
00385                 if (i != std::string::npos)
00386                         value.erase(i);
00387         }
00388         while (continueLine);
00389 
00390         return true;
00391 }
00392 
00393 void OutputPair(const NameValuePairs &v, const char *name)
00394 {
00395         Integer x;
00396         bool b = v.GetValue(name, x);
00397         assert(b);
00398         cout << name << ": \\\n    ";
00399         x.Encode(HexEncoder(new FileSink(cout), false, 64, "\\\n    ").Ref(), x.MinEncodedSize());
00400         cout << endl;
00401 }
00402 
00403 void OutputNameValuePairs(const NameValuePairs &v)
00404 {
00405         std::string names = v.GetValueNames();
00406         string::size_type i = 0;
00407         while (i < names.size())
00408         {
00409                 string::size_type j = names.find_first_of (';', i);
00410 
00411                 if (j == string::npos)
00412                         return;
00413                 else
00414                 {
00415                         std::string name = names.substr(i, j-i);
00416                         if (name.find(':') == string::npos)
00417                                 OutputPair(v, name.c_str());
00418                 }
00419 
00420                 i = j + 1;
00421         }
00422 }
00423 
00424 void TestDataFile(const std::string &filename, unsigned int &totalTests, unsigned int &failedTests)
00425 {
00426         std::ifstream file(filename.c_str());
00427         if (!file.good())
00428                 throw Exception(Exception::OTHER_ERROR, "Can not open file " + filename + " for reading");
00429         TestData v;
00430         s_currentTestData = &v;
00431         std::string name, value, lastAlgName;
00432 
00433         while (file)
00434         {
00435                 while (file.peek() == '#')
00436                         file.ignore(INT_MAX, '\n');
00437 
00438                 if (file.peek() == '\n')
00439                         v.clear();
00440 
00441                 if (!GetField(file, name, value))
00442                         break;
00443                 v[name] = value;
00444 
00445                 if (name == "Test")
00446                 {
00447                         bool failed = true;
00448                         std::string algType = GetRequiredDatum(v, "AlgorithmType");
00449 
00450                         if (lastAlgName != GetRequiredDatum(v, "Name"))
00451                         {
00452                                 lastAlgName = GetRequiredDatum(v, "Name");
00453                                 cout << "\nTesting " << algType.c_str() << " algorithm " << lastAlgName.c_str() << ".\n";
00454                         }
00455 
00456                         try
00457                         {
00458                                 if (algType == "Signature")
00459                                         TestSignatureScheme(v);
00460                                 else if (algType == "SymmetricCipher")
00461                                         TestSymmetricCipher(v);
00462                                 else if (algType == "AsymmetricCipher")
00463                                         TestAsymmetricCipher(v);
00464                                 else if (algType == "MessageDigest")
00465                                         TestDigestOrMAC(v, true);
00466                                 else if (algType == "MAC")
00467                                         TestDigestOrMAC(v, false);
00468                                 else if (algType == "FileList")
00469                                         TestDataFile(GetRequiredDatum(v, "Test"), totalTests, failedTests);
00470                                 else
00471                                         SignalTestError();
00472                                 failed = false;
00473                         }
00474                         catch (TestFailure &)
00475                         {
00476                                 cout << "\nTest failed.\n";
00477                         }
00478                         catch (CryptoPP::Exception &e)
00479                         {
00480                                 cout << "\nCryptoPP::Exception caught: " << e.what() << endl;
00481                         }
00482                         catch (std::exception &e)
00483                         {
00484                                 cout << "\nstd::exception caught: " << e.what() << endl;
00485                         }
00486 
00487                         if (failed)
00488                         {
00489                                 cout << "Skipping to next test.\n";
00490                                 failedTests++;
00491                         }
00492                         else
00493                                 cout << "." << flush;
00494 
00495                         totalTests++;
00496                 }
00497         }
00498 }
00499 
00500 bool RunTestDataFile(const char *filename)
00501 {
00502         RegisterFactories();
00503         unsigned int totalTests = 0, failedTests = 0;
00504         TestDataFile(filename, totalTests, failedTests);
00505         cout << "\nTests complete. Total tests = " << totalTests << ". Failed tests = " << failedTests << ".\n";
00506         if (failedTests != 0)
00507                 cout << "SOME TESTS FAILED!\n";
00508         return failedTests == 0;
00509 }

Generated on Sat Dec 23 02:07:06 2006 for Crypto++ by  doxygen 1.5.1-p1