00001
00002
00003 #include "pch.h"
00004 #include "ida.h"
00005
00006 #include "algebra.h"
00007 #include "gf2_32.h"
00008 #include "polynomi.h"
00009 #include <functional>
00010
00011 #include "polynomi.cpp"
00012
00013 ANONYMOUS_NAMESPACE_BEGIN
00014 static const CryptoPP::GF2_32 field;
00015 NAMESPACE_END
00016
00017 using namespace std;
00018
00019 NAMESPACE_BEGIN(CryptoPP)
00020
00021 void RawIDA::IsolatedInitialize(const NameValuePairs ¶meters)
00022 {
00023 if (!parameters.GetIntValue("RecoveryThreshold", m_threshold))
00024 throw InvalidArgument("RawIDA: missing RecoveryThreshold argument");
00025
00026 if (m_threshold <= 0)
00027 throw InvalidArgument("RawIDA: RecoveryThreshold must be greater than 0");
00028
00029 m_lastMapPosition = m_inputChannelMap.end();
00030 m_channelsReady = 0;
00031 m_channelsFinished = 0;
00032 m_w.New(m_threshold);
00033 m_y.New(m_threshold);
00034 m_inputQueues.reserve(m_threshold);
00035
00036 m_outputChannelIds.clear();
00037 m_outputChannelIdStrings.clear();
00038 m_outputQueues.clear();
00039
00040 word32 outputChannelID;
00041 if (parameters.GetValue("OutputChannelID", outputChannelID))
00042 AddOutputChannel(outputChannelID);
00043 else
00044 {
00045 int nShares = parameters.GetIntValueWithDefault("NumberOfShares", m_threshold);
00046 for (int i=0; i<nShares; i++)
00047 AddOutputChannel(i);
00048 }
00049 }
00050
00051 unsigned int RawIDA::InsertInputChannel(word32 channelId)
00052 {
00053 if (m_lastMapPosition != m_inputChannelMap.end())
00054 {
00055 if (m_lastMapPosition->first == channelId)
00056 goto skipFind;
00057 ++m_lastMapPosition;
00058 if (m_lastMapPosition != m_inputChannelMap.end() && m_lastMapPosition->first == channelId)
00059 goto skipFind;
00060 }
00061 m_lastMapPosition = m_inputChannelMap.find(channelId);
00062
00063 skipFind:
00064 if (m_lastMapPosition == m_inputChannelMap.end())
00065 {
00066 if (m_inputChannelIds.size() == m_threshold)
00067 return m_threshold;
00068
00069 m_lastMapPosition = m_inputChannelMap.insert(InputChannelMap::value_type(channelId, (unsigned int)m_inputChannelIds.size())).first;
00070 m_inputQueues.push_back(MessageQueue());
00071 m_inputChannelIds.push_back(channelId);
00072
00073 if (m_inputChannelIds.size() == m_threshold)
00074 PrepareInterpolation();
00075 }
00076 return m_lastMapPosition->second;
00077 }
00078
00079 unsigned int RawIDA::LookupInputChannel(word32 channelId) const
00080 {
00081 map<word32, unsigned int>::const_iterator it = m_inputChannelMap.find(channelId);
00082 if (it == m_inputChannelMap.end())
00083 return m_threshold;
00084 else
00085 return it->second;
00086 }
00087
00088 void RawIDA::ChannelData(word32 channelId, const byte *inString, size_t length, bool messageEnd)
00089 {
00090 int i = InsertInputChannel(channelId);
00091 if (i < m_threshold)
00092 {
00093 lword size = m_inputQueues[i].MaxRetrievable();
00094 m_inputQueues[i].Put(inString, length);
00095 if (size < 4 && size + length >= 4)
00096 {
00097 m_channelsReady++;
00098 if (m_channelsReady == m_threshold)
00099 ProcessInputQueues();
00100 }
00101
00102 if (messageEnd)
00103 {
00104 m_inputQueues[i].MessageEnd();
00105 if (m_inputQueues[i].NumberOfMessages() == 1)
00106 {
00107 m_channelsFinished++;
00108 if (m_channelsFinished == m_threshold)
00109 {
00110 m_channelsReady = 0;
00111 for (i=0; i<m_threshold; i++)
00112 m_channelsReady += m_inputQueues[i].AnyRetrievable();
00113 ProcessInputQueues();
00114 }
00115 }
00116 }
00117 }
00118 }
00119
00120 lword RawIDA::InputBuffered(word32 channelId) const
00121 {
00122 int i = LookupInputChannel(channelId);
00123 return i < m_threshold ? m_inputQueues[i].MaxRetrievable() : 0;
00124 }
00125
00126 void RawIDA::ComputeV(unsigned int i)
00127 {
00128 if (i >= m_v.size())
00129 {
00130 m_v.resize(i+1);
00131 m_outputToInput.resize(i+1);
00132 }
00133
00134 m_outputToInput[i] = LookupInputChannel(m_outputChannelIds[i]);
00135 if (m_outputToInput[i] == m_threshold && i * m_threshold <= 1000*1000)
00136 {
00137 m_v[i].resize(m_threshold);
00138 PrepareBulkPolynomialInterpolationAt(field, m_v[i].begin(), m_outputChannelIds[i], &(m_inputChannelIds[0]), m_w.begin(), m_threshold);
00139 }
00140 }
00141
00142 void RawIDA::AddOutputChannel(word32 channelId)
00143 {
00144 m_outputChannelIds.push_back(channelId);
00145 m_outputChannelIdStrings.push_back(WordToString(channelId));
00146 m_outputQueues.push_back(ByteQueue());
00147 if (m_inputChannelIds.size() == m_threshold)
00148 ComputeV((unsigned int)m_outputChannelIds.size() - 1);
00149 }
00150
00151 void RawIDA::PrepareInterpolation()
00152 {
00153 assert(m_inputChannelIds.size() == m_threshold);
00154 PrepareBulkPolynomialInterpolation(field, m_w.begin(), &(m_inputChannelIds[0]), m_threshold);
00155 for (unsigned int i=0; i<m_outputChannelIds.size(); i++)
00156 ComputeV(i);
00157 }
00158
00159 void RawIDA::ProcessInputQueues()
00160 {
00161 bool finished = (m_channelsFinished == m_threshold);
00162 int i;
00163
00164 while (finished ? m_channelsReady > 0 : m_channelsReady == m_threshold)
00165 {
00166 m_channelsReady = 0;
00167 for (i=0; i<m_threshold; i++)
00168 {
00169 MessageQueue &queue = m_inputQueues[i];
00170 queue.GetWord32(m_y[i]);
00171
00172 if (finished)
00173 m_channelsReady += queue.AnyRetrievable();
00174 else
00175 m_channelsReady += queue.NumberOfMessages() > 0 || queue.MaxRetrievable() >= 4;
00176 }
00177
00178 for (i=0; (unsigned int)i<m_outputChannelIds.size(); i++)
00179 {
00180 if (m_outputToInput[i] != m_threshold)
00181 m_outputQueues[i].PutWord32(m_y[m_outputToInput[i]]);
00182 else if (m_v[i].size() == m_threshold)
00183 m_outputQueues[i].PutWord32(BulkPolynomialInterpolateAt(field, m_y.begin(), m_v[i].begin(), m_threshold));
00184 else
00185 {
00186 m_u.resize(m_threshold);
00187 PrepareBulkPolynomialInterpolationAt(field, m_u.begin(), m_outputChannelIds[i], &(m_inputChannelIds[0]), m_w.begin(), m_threshold);
00188 m_outputQueues[i].PutWord32(BulkPolynomialInterpolateAt(field, m_y.begin(), m_u.begin(), m_threshold));
00189 }
00190 }
00191 }
00192
00193 if (m_outputChannelIds.size() > 0 && m_outputQueues[0].AnyRetrievable())
00194 FlushOutputQueues();
00195
00196 if (finished)
00197 {
00198 OutputMessageEnds();
00199
00200 m_channelsReady = 0;
00201 m_channelsFinished = 0;
00202 m_v.clear();
00203
00204 vector<MessageQueue> inputQueues;
00205 vector<word32> inputChannelIds;
00206
00207 inputQueues.swap(m_inputQueues);
00208 inputChannelIds.swap(m_inputChannelIds);
00209 m_inputChannelMap.clear();
00210 m_lastMapPosition = m_inputChannelMap.end();
00211
00212 for (i=0; i<m_threshold; i++)
00213 {
00214 inputQueues[i].GetNextMessage();
00215 inputQueues[i].TransferAllTo(*AttachedTransformation(), WordToString(inputChannelIds[i]));
00216 }
00217 }
00218 }
00219
00220 void RawIDA::FlushOutputQueues()
00221 {
00222 for (unsigned int i=0; i<m_outputChannelIds.size(); i++)
00223 m_outputQueues[i].TransferAllTo(*AttachedTransformation(), m_outputChannelIdStrings[i]);
00224 }
00225
00226 void RawIDA::OutputMessageEnds()
00227 {
00228 if (GetAutoSignalPropagation() != 0)
00229 {
00230 for (unsigned int i=0; i<m_outputChannelIds.size(); i++)
00231 AttachedTransformation()->ChannelMessageEnd(m_outputChannelIdStrings[i], GetAutoSignalPropagation()-1);
00232 }
00233 }
00234
00235
00236
00237 void SecretSharing::IsolatedInitialize(const NameValuePairs ¶meters)
00238 {
00239 m_pad = parameters.GetValueWithDefault("AddPadding", true);
00240 m_ida.IsolatedInitialize(parameters);
00241 }
00242
00243 size_t SecretSharing::Put2(const byte *begin, size_t length, int messageEnd, bool blocking)
00244 {
00245 if (!blocking)
00246 throw BlockingInputOnly("SecretSharing");
00247
00248 SecByteBlock buf(UnsignedMin(256, length));
00249 unsigned int threshold = m_ida.GetThreshold();
00250 while (length > 0)
00251 {
00252 size_t len = STDMIN(length, buf.size());
00253 m_ida.ChannelData(0xffffffff, begin, len, false);
00254 for (unsigned int i=0; i<threshold-1; i++)
00255 {
00256 m_rng.GenerateBlock(buf, len);
00257 m_ida.ChannelData(i, buf, len, false);
00258 }
00259 length -= len;
00260 begin += len;
00261 }
00262
00263 if (messageEnd)
00264 {
00265 m_ida.SetAutoSignalPropagation(messageEnd-1);
00266 if (m_pad)
00267 {
00268 SecretSharing::Put(1);
00269 while (m_ida.InputBuffered(0xffffffff) > 0)
00270 SecretSharing::Put(0);
00271 }
00272 m_ida.ChannelData(0xffffffff, NULL, 0, true);
00273 for (unsigned int i=0; i<m_ida.GetThreshold()-1; i++)
00274 m_ida.ChannelData(i, NULL, 0, true);
00275 }
00276
00277 return 0;
00278 }
00279
00280 void SecretRecovery::IsolatedInitialize(const NameValuePairs ¶meters)
00281 {
00282 m_pad = parameters.GetValueWithDefault("RemovePadding", true);
00283 RawIDA::IsolatedInitialize(CombinedNameValuePairs(parameters, MakeParameters("OutputChannelID", (word32)0xffffffff)));
00284 }
00285
00286 void SecretRecovery::FlushOutputQueues()
00287 {
00288 if (m_pad)
00289 m_outputQueues[0].TransferTo(*AttachedTransformation(), m_outputQueues[0].MaxRetrievable()-4);
00290 else
00291 m_outputQueues[0].TransferTo(*AttachedTransformation());
00292 }
00293
00294 void SecretRecovery::OutputMessageEnds()
00295 {
00296 if (m_pad)
00297 {
00298 PaddingRemover paddingRemover(new Redirector(*AttachedTransformation()));
00299 m_outputQueues[0].TransferAllTo(paddingRemover);
00300 }
00301
00302 if (GetAutoSignalPropagation() != 0)
00303 AttachedTransformation()->MessageEnd(GetAutoSignalPropagation()-1);
00304 }
00305
00306
00307
00308 void InformationDispersal::IsolatedInitialize(const NameValuePairs ¶meters)
00309 {
00310 m_nextChannel = 0;
00311 m_pad = parameters.GetValueWithDefault("AddPadding", true);
00312 m_ida.IsolatedInitialize(parameters);
00313 }
00314
00315 size_t InformationDispersal::Put2(const byte *begin, size_t length, int messageEnd, bool blocking)
00316 {
00317 if (!blocking)
00318 throw BlockingInputOnly("InformationDispersal");
00319
00320 while (length--)
00321 {
00322 m_ida.ChannelData(m_nextChannel, begin, 1, false);
00323 begin++;
00324 m_nextChannel++;
00325 if (m_nextChannel == m_ida.GetThreshold())
00326 m_nextChannel = 0;
00327 }
00328
00329 if (messageEnd)
00330 {
00331 m_ida.SetAutoSignalPropagation(messageEnd-1);
00332 if (m_pad)
00333 InformationDispersal::Put(1);
00334 for (word32 i=0; i<m_ida.GetThreshold(); i++)
00335 m_ida.ChannelData(i, NULL, 0, true);
00336 }
00337
00338 return 0;
00339 }
00340
00341 void InformationRecovery::IsolatedInitialize(const NameValuePairs ¶meters)
00342 {
00343 m_pad = parameters.GetValueWithDefault("RemovePadding", true);
00344 RawIDA::IsolatedInitialize(parameters);
00345 }
00346
00347 void InformationRecovery::FlushOutputQueues()
00348 {
00349 while (m_outputQueues[0].AnyRetrievable())
00350 {
00351 for (unsigned int i=0; i<m_outputChannelIds.size(); i++)
00352 m_outputQueues[i].TransferTo(m_queue, 1);
00353 }
00354
00355 if (m_pad)
00356 m_queue.TransferTo(*AttachedTransformation(), m_queue.MaxRetrievable()-4*m_threshold);
00357 else
00358 m_queue.TransferTo(*AttachedTransformation());
00359 }
00360
00361 void InformationRecovery::OutputMessageEnds()
00362 {
00363 if (m_pad)
00364 {
00365 PaddingRemover paddingRemover(new Redirector(*AttachedTransformation()));
00366 m_queue.TransferAllTo(paddingRemover);
00367 }
00368
00369 if (GetAutoSignalPropagation() != 0)
00370 AttachedTransformation()->MessageEnd(GetAutoSignalPropagation()-1);
00371 }
00372
00373 size_t PaddingRemover::Put2(const byte *begin, size_t length, int messageEnd, bool blocking)
00374 {
00375 if (!blocking)
00376 throw BlockingInputOnly("PaddingRemover");
00377
00378 const byte *const end = begin + length;
00379
00380 if (m_possiblePadding)
00381 {
00382 size_t len = find_if(begin, end, bind2nd(not_equal_to<byte>(), 0)) - begin;
00383 m_zeroCount += len;
00384 begin += len;
00385 if (begin == end)
00386 return 0;
00387
00388 AttachedTransformation()->Put(1);
00389 while (m_zeroCount--)
00390 AttachedTransformation()->Put(0);
00391 AttachedTransformation()->Put(*begin++);
00392 m_possiblePadding = false;
00393 }
00394
00395 #if defined(_MSC_VER) && !defined(__MWERKS__) && (_MSC_VER <= 1300)
00396
00397 typedef reverse_bidirectional_iterator<const byte *, const byte> RevIt;
00398 #elif defined(_RWSTD_NO_CLASS_PARTIAL_SPEC)
00399 typedef reverse_iterator<const byte *, random_access_iterator_tag, const byte> RevIt;
00400 #else
00401 typedef reverse_iterator<const byte *> RevIt;
00402 #endif
00403 const byte *x = find_if(RevIt(end), RevIt(begin), bind2nd(not_equal_to<byte>(), 0)).base();
00404 if (x != begin && *(x-1) == 1)
00405 {
00406 AttachedTransformation()->Put(begin, x-begin-1);
00407 m_possiblePadding = true;
00408 m_zeroCount = end - x;
00409 }
00410 else
00411 AttachedTransformation()->Put(begin, end-begin);
00412
00413 if (messageEnd)
00414 {
00415 m_possiblePadding = false;
00416 Output(0, begin, length, messageEnd, blocking);
00417 }
00418 return 0;
00419 }
00420
00421 NAMESPACE_END