Crypto++  5.6.3
Free C++ class library of cryptographic schemes
ecp.cpp
1 // ecp.cpp - written and placed in the public domain by Wei Dai
2 
3 #include "pch.h"
4 
5 #ifndef CRYPTOPP_IMPORTS
6 
7 #include "ecp.h"
8 #include "asn.h"
9 #include "integer.h"
10 #include "nbtheory.h"
11 #include "modarith.h"
12 #include "filters.h"
13 #include "algebra.cpp"
14 
15 NAMESPACE_BEGIN(CryptoPP)
16 
17 ANONYMOUS_NAMESPACE_BEGIN
18 static inline ECP::Point ToMontgomery(const ModularArithmetic &mr, const ECP::Point &P)
19 {
20  return P.identity ? P : ECP::Point(mr.ConvertIn(P.x), mr.ConvertIn(P.y));
21 }
22 
23 static inline ECP::Point FromMontgomery(const ModularArithmetic &mr, const ECP::Point &P)
24 {
25  return P.identity ? P : ECP::Point(mr.ConvertOut(P.x), mr.ConvertOut(P.y));
26 }
27 NAMESPACE_END
28 
29 ECP::ECP(const ECP &ecp, bool convertToMontgomeryRepresentation)
30 {
31  if (convertToMontgomeryRepresentation && !ecp.GetField().IsMontgomeryRepresentation())
32  {
33  m_fieldPtr.reset(new MontgomeryRepresentation(ecp.GetField().GetModulus()));
34  m_a = GetField().ConvertIn(ecp.m_a);
35  m_b = GetField().ConvertIn(ecp.m_b);
36  }
37  else
38  operator=(ecp);
39 }
40 
41 ECP::ECP(BufferedTransformation &bt)
42  : m_fieldPtr(new Field(bt))
43 {
44  BERSequenceDecoder seq(bt);
45  GetField().BERDecodeElement(seq, m_a);
46  GetField().BERDecodeElement(seq, m_b);
47  // skip optional seed
48  if (!seq.EndReached())
49  {
50  SecByteBlock seed;
51  unsigned int unused;
52  BERDecodeBitString(seq, seed, unused);
53  }
54  seq.MessageEnd();
55 }
56 
57 void ECP::DEREncode(BufferedTransformation &bt) const
58 {
59  GetField().DEREncode(bt);
60  DERSequenceEncoder seq(bt);
61  GetField().DEREncodeElement(seq, m_a);
62  GetField().DEREncodeElement(seq, m_b);
63  seq.MessageEnd();
64 }
65 
66 bool ECP::DecodePoint(ECP::Point &P, const byte *encodedPoint, size_t encodedPointLen) const
67 {
68  StringStore store(encodedPoint, encodedPointLen);
69  return DecodePoint(P, store, encodedPointLen);
70 }
71 
72 bool ECP::DecodePoint(ECP::Point &P, BufferedTransformation &bt, size_t encodedPointLen) const
73 {
74  byte type;
75  if (encodedPointLen < 1 || !bt.Get(type))
76  return false;
77 
78  switch (type)
79  {
80  case 0:
81  P.identity = true;
82  return true;
83  case 2:
84  case 3:
85  {
86  if (encodedPointLen != EncodedPointSize(true))
87  return false;
88 
89  Integer p = FieldSize();
90 
91  P.identity = false;
92  P.x.Decode(bt, GetField().MaxElementByteLength());
93  P.y = ((P.x*P.x+m_a)*P.x+m_b) % p;
94 
95  if (Jacobi(P.y, p) !=1)
96  return false;
97 
98  P.y = ModularSquareRoot(P.y, p);
99 
100  if ((type & 1) != P.y.GetBit(0))
101  P.y = p-P.y;
102 
103  return true;
104  }
105  case 4:
106  {
107  if (encodedPointLen != EncodedPointSize(false))
108  return false;
109 
110  unsigned int len = GetField().MaxElementByteLength();
111  P.identity = false;
112  P.x.Decode(bt, len);
113  P.y.Decode(bt, len);
114  return true;
115  }
116  default:
117  return false;
118  }
119 }
120 
121 void ECP::EncodePoint(BufferedTransformation &bt, const Point &P, bool compressed) const
122 {
123  if (P.identity)
124  NullStore().TransferTo(bt, EncodedPointSize(compressed));
125  else if (compressed)
126  {
127  bt.Put(2 + P.y.GetBit(0));
128  P.x.Encode(bt, GetField().MaxElementByteLength());
129  }
130  else
131  {
132  unsigned int len = GetField().MaxElementByteLength();
133  bt.Put(4); // uncompressed
134  P.x.Encode(bt, len);
135  P.y.Encode(bt, len);
136  }
137 }
138 
139 void ECP::EncodePoint(byte *encodedPoint, const Point &P, bool compressed) const
140 {
141  ArraySink sink(encodedPoint, EncodedPointSize(compressed));
142  EncodePoint(sink, P, compressed);
143  assert(sink.TotalPutLength() == EncodedPointSize(compressed));
144 }
145 
146 ECP::Point ECP::BERDecodePoint(BufferedTransformation &bt) const
147 {
148  SecByteBlock str;
149  BERDecodeOctetString(bt, str);
150  Point P;
151  if (!DecodePoint(P, str, str.size()))
152  BERDecodeError();
153  return P;
154 }
155 
156 void ECP::DEREncodePoint(BufferedTransformation &bt, const Point &P, bool compressed) const
157 {
158  SecByteBlock str(EncodedPointSize(compressed));
159  EncodePoint(str, P, compressed);
160  DEREncodeOctetString(bt, str);
161 }
162 
163 bool ECP::ValidateParameters(RandomNumberGenerator &rng, unsigned int level) const
164 {
165  Integer p = FieldSize();
166 
167  bool pass = p.IsOdd();
168  pass = pass && !m_a.IsNegative() && m_a<p && !m_b.IsNegative() && m_b<p;
169 
170  if (level >= 1)
171  pass = pass && ((4*m_a*m_a*m_a+27*m_b*m_b)%p).IsPositive();
172 
173  if (level >= 2)
174  pass = pass && VerifyPrime(rng, p);
175 
176  return pass;
177 }
178 
179 bool ECP::VerifyPoint(const Point &P) const
180 {
181  const FieldElement &x = P.x, &y = P.y;
182  Integer p = FieldSize();
183  return P.identity ||
184  (!x.IsNegative() && x<p && !y.IsNegative() && y<p
185  && !(((x*x+m_a)*x+m_b-y*y)%p));
186 }
187 
188 bool ECP::Equal(const Point &P, const Point &Q) const
189 {
190  if (P.identity && Q.identity)
191  return true;
192 
193  if (P.identity && !Q.identity)
194  return false;
195 
196  if (!P.identity && Q.identity)
197  return false;
198 
199  return (GetField().Equal(P.x,Q.x) && GetField().Equal(P.y,Q.y));
200 }
201 
202 const ECP::Point& ECP::Identity() const
203 {
204  return Singleton<Point>().Ref();
205 }
206 
207 const ECP::Point& ECP::Inverse(const Point &P) const
208 {
209  if (P.identity)
210  return P;
211  else
212  {
213  m_R.identity = false;
214  m_R.x = P.x;
215  m_R.y = GetField().Inverse(P.y);
216  return m_R;
217  }
218 }
219 
220 const ECP::Point& ECP::Add(const Point &P, const Point &Q) const
221 {
222  if (P.identity) return Q;
223  if (Q.identity) return P;
224  if (GetField().Equal(P.x, Q.x))
225  return GetField().Equal(P.y, Q.y) ? Double(P) : Identity();
226 
227  FieldElement t = GetField().Subtract(Q.y, P.y);
228  t = GetField().Divide(t, GetField().Subtract(Q.x, P.x));
229  FieldElement x = GetField().Subtract(GetField().Subtract(GetField().Square(t), P.x), Q.x);
230  m_R.y = GetField().Subtract(GetField().Multiply(t, GetField().Subtract(P.x, x)), P.y);
231 
232  m_R.x.swap(x);
233  m_R.identity = false;
234  return m_R;
235 }
236 
237 const ECP::Point& ECP::Double(const Point &P) const
238 {
239  if (P.identity || P.y==GetField().Identity()) return Identity();
240 
241  FieldElement t = GetField().Square(P.x);
242  t = GetField().Add(GetField().Add(GetField().Double(t), t), m_a);
243  t = GetField().Divide(t, GetField().Double(P.y));
244  FieldElement x = GetField().Subtract(GetField().Subtract(GetField().Square(t), P.x), P.x);
245  m_R.y = GetField().Subtract(GetField().Multiply(t, GetField().Subtract(P.x, x)), P.y);
246 
247  m_R.x.swap(x);
248  m_R.identity = false;
249  return m_R;
250 }
251 
252 template <class T, class Iterator> void ParallelInvert(const AbstractRing<T> &ring, Iterator begin, Iterator end)
253 {
254  size_t n = end-begin;
255  if (n == 1)
256  *begin = ring.MultiplicativeInverse(*begin);
257  else if (n > 1)
258  {
259  std::vector<T> vec((n+1)/2);
260  unsigned int i;
261  Iterator it;
262 
263  for (i=0, it=begin; i<n/2; i++, it+=2)
264  vec[i] = ring.Multiply(*it, *(it+1));
265  if (n%2 == 1)
266  vec[n/2] = *it;
267 
268  ParallelInvert(ring, vec.begin(), vec.end());
269 
270  for (i=0, it=begin; i<n/2; i++, it+=2)
271  {
272  if (!vec[i])
273  {
274  *it = ring.MultiplicativeInverse(*it);
275  *(it+1) = ring.MultiplicativeInverse(*(it+1));
276  }
277  else
278  {
279  std::swap(*it, *(it+1));
280  *it = ring.Multiply(*it, vec[i]);
281  *(it+1) = ring.Multiply(*(it+1), vec[i]);
282  }
283  }
284  if (n%2 == 1)
285  *it = vec[n/2];
286  }
287 }
288 
289 struct ProjectivePoint
290 {
291  ProjectivePoint() {}
292  ProjectivePoint(const Integer &x, const Integer &y, const Integer &z)
293  : x(x), y(y), z(z) {}
294 
295  Integer x,y,z;
296 };
297 
298 class ProjectiveDoubling
299 {
300 public:
301  ProjectiveDoubling(const ModularArithmetic &m_mr, const Integer &m_a, const Integer &m_b, const ECPPoint &Q)
302  : mr(m_mr), firstDoubling(true), negated(false)
303  {
304  CRYPTOPP_UNUSED(m_b);
305  if (Q.identity)
306  {
307  sixteenY4 = P.x = P.y = mr.MultiplicativeIdentity();
308  aZ4 = P.z = mr.Identity();
309  }
310  else
311  {
312  P.x = Q.x;
313  P.y = Q.y;
314  sixteenY4 = P.z = mr.MultiplicativeIdentity();
315  aZ4 = m_a;
316  }
317  }
318 
319  void Double()
320  {
321  twoY = mr.Double(P.y);
322  P.z = mr.Multiply(P.z, twoY);
323  fourY2 = mr.Square(twoY);
324  S = mr.Multiply(fourY2, P.x);
325  aZ4 = mr.Multiply(aZ4, sixteenY4);
326  M = mr.Square(P.x);
327  M = mr.Add(mr.Add(mr.Double(M), M), aZ4);
328  P.x = mr.Square(M);
329  mr.Reduce(P.x, S);
330  mr.Reduce(P.x, S);
331  mr.Reduce(S, P.x);
332  P.y = mr.Multiply(M, S);
333  sixteenY4 = mr.Square(fourY2);
334  mr.Reduce(P.y, mr.Half(sixteenY4));
335  }
336 
337  const ModularArithmetic &mr;
338  ProjectivePoint P;
339  bool firstDoubling, negated;
340  Integer sixteenY4, aZ4, twoY, fourY2, S, M;
341 };
342 
343 struct ZIterator
344 {
345  ZIterator() {}
346  ZIterator(std::vector<ProjectivePoint>::iterator it) : it(it) {}
347  Integer& operator*() {return it->z;}
348  int operator-(ZIterator it2) {return int(it-it2.it);}
349  ZIterator operator+(int i) {return ZIterator(it+i);}
350  ZIterator& operator+=(int i) {it+=i; return *this;}
351  std::vector<ProjectivePoint>::iterator it;
352 };
353 
354 ECP::Point ECP::ScalarMultiply(const Point &P, const Integer &k) const
355 {
356  Element result;
357  if (k.BitCount() <= 5)
359  else
360  ECP::SimultaneousMultiply(&result, P, &k, 1);
361  return result;
362 }
363 
364 void ECP::SimultaneousMultiply(ECP::Point *results, const ECP::Point &P, const Integer *expBegin, unsigned int expCount) const
365 {
366  if (!GetField().IsMontgomeryRepresentation())
367  {
368  ECP ecpmr(*this, true);
369  const ModularArithmetic &mr = ecpmr.GetField();
370  ecpmr.SimultaneousMultiply(results, ToMontgomery(mr, P), expBegin, expCount);
371  for (unsigned int i=0; i<expCount; i++)
372  results[i] = FromMontgomery(mr, results[i]);
373  return;
374  }
375 
376  ProjectiveDoubling rd(GetField(), m_a, m_b, P);
377  std::vector<ProjectivePoint> bases;
378  std::vector<WindowSlider> exponents;
379  exponents.reserve(expCount);
380  std::vector<std::vector<word32> > baseIndices(expCount);
381  std::vector<std::vector<bool> > negateBase(expCount);
382  std::vector<std::vector<word32> > exponentWindows(expCount);
383  unsigned int i;
384 
385  for (i=0; i<expCount; i++)
386  {
387  assert(expBegin->NotNegative());
388  exponents.push_back(WindowSlider(*expBegin++, InversionIsFast(), 5));
389  exponents[i].FindNextWindow();
390  }
391 
392  unsigned int expBitPosition = 0;
393  bool notDone = true;
394 
395  while (notDone)
396  {
397  notDone = false;
398  bool baseAdded = false;
399  for (i=0; i<expCount; i++)
400  {
401  if (!exponents[i].finished && expBitPosition == exponents[i].windowBegin)
402  {
403  if (!baseAdded)
404  {
405  bases.push_back(rd.P);
406  baseAdded =true;
407  }
408 
409  exponentWindows[i].push_back(exponents[i].expWindow);
410  baseIndices[i].push_back((word32)bases.size()-1);
411  negateBase[i].push_back(exponents[i].negateNext);
412 
413  exponents[i].FindNextWindow();
414  }
415  notDone = notDone || !exponents[i].finished;
416  }
417 
418  if (notDone)
419  {
420  rd.Double();
421  expBitPosition++;
422  }
423  }
424 
425  // convert from projective to affine coordinates
426  ParallelInvert(GetField(), ZIterator(bases.begin()), ZIterator(bases.end()));
427  for (i=0; i<bases.size(); i++)
428  {
429  if (bases[i].z.NotZero())
430  {
431  bases[i].y = GetField().Multiply(bases[i].y, bases[i].z);
432  bases[i].z = GetField().Square(bases[i].z);
433  bases[i].x = GetField().Multiply(bases[i].x, bases[i].z);
434  bases[i].y = GetField().Multiply(bases[i].y, bases[i].z);
435  }
436  }
437 
438  std::vector<BaseAndExponent<Point, Integer> > finalCascade;
439  for (i=0; i<expCount; i++)
440  {
441  finalCascade.resize(baseIndices[i].size());
442  for (unsigned int j=0; j<baseIndices[i].size(); j++)
443  {
444  ProjectivePoint &base = bases[baseIndices[i][j]];
445  if (base.z.IsZero())
446  finalCascade[j].base.identity = true;
447  else
448  {
449  finalCascade[j].base.identity = false;
450  finalCascade[j].base.x = base.x;
451  if (negateBase[i][j])
452  finalCascade[j].base.y = GetField().Inverse(base.y);
453  else
454  finalCascade[j].base.y = base.y;
455  }
456  finalCascade[j].exponent = Integer(Integer::POSITIVE, 0, exponentWindows[i][j]);
457  }
458  results[i] = GeneralCascadeMultiplication(*this, finalCascade.begin(), finalCascade.end());
459  }
460 }
461 
462 ECP::Point ECP::CascadeScalarMultiply(const Point &P, const Integer &k1, const Point &Q, const Integer &k2) const
463 {
464  if (!GetField().IsMontgomeryRepresentation())
465  {
466  ECP ecpmr(*this, true);
467  const ModularArithmetic &mr = ecpmr.GetField();
468  return FromMontgomery(mr, ecpmr.CascadeScalarMultiply(ToMontgomery(mr, P), k1, ToMontgomery(mr, Q), k2));
469  }
470  else
471  return AbstractGroup<Point>::CascadeScalarMultiply(P, k1, Q, k2);
472 }
473 
474 NAMESPACE_END
475 
476 #endif
virtual void SimultaneousMultiply(Element *results, const Element &base, const Integer *exponents, unsigned int exponentsCount) const
Multiplies a base to multiple exponents in a group.
Definition: algebra.cpp:256
bool Equal(const Point &P, const Point &Q) const
Compare two elements for equality.
void DEREncode(BufferedTransformation &bt) const
Encodes in DER format.
Definition: integer.cpp:4258
inline::Integer operator*(const ::Integer &a, const ::Integer &b)
Definition: integer.h:577
const Integer & GetModulus() const
Retrieves the modulus.
Definition: modarith.h:81
Restricts the instantiation of a class to one static object without locks.
Definition: misc.h:264
Elliptical Curve Point.
Definition: ecp.h:20
bool GetBit(size_t i) const
Provides the i-th bit of the Integer.
Definition: integer.cpp:3043
Integer & Reduce(Integer &a, const Integer &b) const
TODO.
Definition: integer.cpp:4344
bool IsOdd() const
Determines if the Integer is odd parity.
Definition: integer.h:333
virtual const Element & Subtract(const Element &a, const Element &b) const
Subtracts elements in the group.
const Integer & Subtract(const Integer &a, const Integer &b) const
Subtracts elements in the ring.
Definition: integer.cpp:4327
Classes for Elliptic Curves over prime fields.
bool InversionIsFast() const
Determine if inversion is fast.
Definition: ecp.h:63
Elliptic Curve over GF(p), where p is prime.
Definition: ecp.h:42
const Integer & MultiplicativeIdentity() const
Retrieves the multiplicative identity.
Definition: modarith.h:164
const Point & Identity() const
Provides the Identity element.
size_type size() const
Provides the count of elements in the SecBlock.
Definition: secblock.h:521
Square block cipher.
Definition: square.h:24
const Integer & Square(const Integer &a) const
Square an element in the ring.
Definition: modarith.h:179
virtual const Element & MultiplicativeInverse(const Element &a) const =0
Calculate the multiplicative inverse of an element in the group.
Ring of congruence classes modulo n.
Definition: modarith.h:34
Interface for random number generators.
Definition: cryptlib.h:1186
SecBlock typedef.
Definition: secblock.h:728
BER Sequence Decoder.
Definition: asn.h:294
Interface for buffered transformations.
Definition: cryptlib.h:1352
void SimultaneousMultiply(Point *results, const Point &base, const Integer *exponents, unsigned int exponentsCount) const
Multiplies a base to multiple exponents in a group.
Abstract ring.
Definition: algebra.h:118
const Integer & Add(const Integer &a, const Integer &b) const
Adds elements in the ring.
Definition: integer.cpp:4287
unsigned int BitCount() const
Determines the number of bits required to represent the Integer.
Definition: integer.cpp:3277
virtual Integer ConvertIn(const Integer &a) const
Reduces an element in the congruence class.
Definition: modarith.h:97
const Point & Add(const Point &P, const Point &Q) const
Adds elements in the group.
Copy input to a memory buffer.
Definition: filters.h:1016
empty store
Definition: filters.h:1124
virtual Integer ConvertOut(const Integer &a) const
Reduces an element in the congruence class.
Definition: modarith.h:105
const Integer & Multiply(const Integer &a, const Integer &b) const
Multiplies elements in the ring.
Definition: modarith.h:172
bool IsPositive() const
Determines if the Integer is positive.
Definition: integer.h:324
lword TransferTo(BufferedTransformation &target, lword transferMax=LWORD_MAX, const std::string &channel=DEFAULT_CHANNEL)
move transferMax bytes of the buffered output to target as input
Definition: cryptlib.h:1656
size_t Put(byte inByte, bool blocking=true)
Input a byte for processing.
Definition: cryptlib.h:1378
size_t BERDecodeOctetString(BufferedTransformation &bt, SecByteBlock &str)
BER decode octet string.
Definition: asn.cpp:117
bool IsNegative() const
Determines if the Integer is negative.
Definition: integer.h:318
void swap(Integer &a)
Swaps this Integer with another Integer.
Definition: integer.cpp:3103
bool VerifyPrime(RandomNumberGenerator &rng, const Integer &p, unsigned int level=1)
Verifies a prime number.
Definition: nbtheory.cpp:249
Multiple precision integer with arithmetic operations.
Definition: integer.h:31
const Integer & Double(const Integer &a) const
Doubles an element in the ring.
Definition: modarith.h:158
OID operator+(const OID &lhs, unsigned long rhs)
Append a value to an OID.
const Integer & Inverse(const Integer &a) const
Inverts the element in the ring.
Definition: integer.cpp:4361
String-based implementation of Store interface.
Definition: filters.h:1066
Point CascadeScalarMultiply(const Point &P, const Integer &k1, const Point &Q, const Integer &k2) const
TODO.
void BERDecodeError()
Raises a BERDecodeErr.
Definition: asn.h:61
Abstract group.
Definition: algebra.h:26
const Integer & Divide(const Integer &a, const Integer &b) const
Divides elements in the ring.
Definition: modarith.h:200
Classes and functions for working with ANS.1 objects.
Implementation of BufferedTransformation's attachment interface.
Classes and functions for number theoretic operations.
const Integer & Half(const Integer &a) const
TODO.
Definition: integer.cpp:4276
DER Sequence Encoder.
Definition: asn.h:304
Performs modular arithmetic in Montgomery representation for increased speed.
Definition: modarith.h:274
Point ScalarMultiply(const Point &P, const Integer &k) const
Performs a scalar multiplication.
void Decode(const byte *input, size_t inputLen, Signedness sign=UNSIGNED)
Decode from big-endian byte array.
Definition: integer.cpp:3286
size_t DEREncodeOctetString(BufferedTransformation &bt, const byte *str, size_t strLen)
DER encode octet string.
Definition: asn.cpp:104
void DEREncodeElement(BufferedTransformation &out, const Element &a) const
Encodes element in DER format.
Definition: integer.cpp:4266
size_t BERDecodeBitString(BufferedTransformation &bt, SecByteBlock &str, unsigned int &unusedBits)
DER decode bit string.
Definition: asn.cpp:182
virtual size_t Get(byte &outByte)
Retrieve a 8-bit byte.
Definition: cryptlib.cpp:519
Class file for performing modular arithmetic.
Crypto++ library namespace.
virtual bool IsMontgomeryRepresentation() const
Retrieves the representation.
Definition: modarith.h:90
const Integer & Identity() const
Provides the Identity element.
Definition: modarith.h:122
bool Equal(const Integer &a, const Integer &b) const
Compare two elements for equality.
Definition: modarith.h:117
unsigned int MaxElementByteLength() const
Provides the maximum byte size of an element in the ring.
Definition: modarith.h:230
virtual Element CascadeScalarMultiply(const Element &x, const Integer &e1, const Element &y, const Integer &e2) const
TODO.
Definition: algebra.cpp:97
const Point & Inverse(const Point &P) const
Inverts the element in the group.
const Point & Double(const Point &P) const
Doubles an element in the group.
virtual const Element & Multiply(const Element &a, const Element &b) const =0
Multiplies elements in the group.
the value is positive or 0
Definition: integer.h:57
bool NotNegative() const
Determines if the Integer is non-negative.
Definition: integer.h:321