494 lines
16 KiB
Python
494 lines
16 KiB
Python
# This file is licensed under the BSD 2-Clause License.
|
|
# See https://opensource.org/licenses/BSD-2-Clause for details.
|
|
|
|
import threading
|
|
|
|
from Crypto.Util.number import bytes_to_long, long_to_bytes
|
|
from Crypto.Util._raw_api import (VoidPointer, null_pointer,
|
|
SmartPointer, c_size_t, c_uint8_ptr,
|
|
c_ulonglong)
|
|
from Crypto.Math.Numbers import Integer
|
|
from Crypto.Random.random import getrandbits
|
|
|
|
|
|
class CurveID(object):
|
|
P192 = 1
|
|
P224 = 2
|
|
P256 = 3
|
|
P384 = 4
|
|
P521 = 5
|
|
ED25519 = 6
|
|
ED448 = 7
|
|
CURVE25519 = 8
|
|
CURVE448 = 9
|
|
|
|
|
|
class _Curves(object):
|
|
|
|
curves = {}
|
|
curves_lock = threading.RLock()
|
|
|
|
p192_names = ["p192", "NIST P-192", "P-192", "prime192v1", "secp192r1",
|
|
"nistp192"]
|
|
p224_names = ["p224", "NIST P-224", "P-224", "prime224v1", "secp224r1",
|
|
"nistp224"]
|
|
p256_names = ["p256", "NIST P-256", "P-256", "prime256v1", "secp256r1",
|
|
"nistp256"]
|
|
p384_names = ["p384", "NIST P-384", "P-384", "prime384v1", "secp384r1",
|
|
"nistp384"]
|
|
p521_names = ["p521", "NIST P-521", "P-521", "prime521v1", "secp521r1",
|
|
"nistp521"]
|
|
ed25519_names = ["ed25519", "Ed25519"]
|
|
ed448_names = ["ed448", "Ed448"]
|
|
curve25519_names = ["curve25519", "Curve25519", "X25519"]
|
|
curve448_names = ["curve448", "Curve448", "X448"]
|
|
|
|
all_names = p192_names + p224_names + p256_names + p384_names + p521_names + \
|
|
ed25519_names + ed448_names + curve25519_names + curve448_names
|
|
|
|
def __contains__(self, item):
|
|
return item in self.all_names
|
|
|
|
def __dir__(self):
|
|
return self.all_names
|
|
|
|
def load(self, name):
|
|
if name in self.p192_names:
|
|
from . import _nist_ecc
|
|
p192 = _nist_ecc.p192_curve()
|
|
p192.id = CurveID.P192
|
|
self.curves.update(dict.fromkeys(self.p192_names, p192))
|
|
elif name in self.p224_names:
|
|
from . import _nist_ecc
|
|
p224 = _nist_ecc.p224_curve()
|
|
p224.id = CurveID.P224
|
|
self.curves.update(dict.fromkeys(self.p224_names, p224))
|
|
elif name in self.p256_names:
|
|
from . import _nist_ecc
|
|
p256 = _nist_ecc.p256_curve()
|
|
p256.id = CurveID.P256
|
|
self.curves.update(dict.fromkeys(self.p256_names, p256))
|
|
elif name in self.p384_names:
|
|
from . import _nist_ecc
|
|
p384 = _nist_ecc.p384_curve()
|
|
p384.id = CurveID.P384
|
|
self.curves.update(dict.fromkeys(self.p384_names, p384))
|
|
elif name in self.p521_names:
|
|
from . import _nist_ecc
|
|
p521 = _nist_ecc.p521_curve()
|
|
p521.id = CurveID.P521
|
|
self.curves.update(dict.fromkeys(self.p521_names, p521))
|
|
elif name in self.ed25519_names:
|
|
from . import _edwards
|
|
ed25519 = _edwards.ed25519_curve()
|
|
ed25519.id = CurveID.ED25519
|
|
self.curves.update(dict.fromkeys(self.ed25519_names, ed25519))
|
|
elif name in self.ed448_names:
|
|
from . import _edwards
|
|
ed448 = _edwards.ed448_curve()
|
|
ed448.id = CurveID.ED448
|
|
self.curves.update(dict.fromkeys(self.ed448_names, ed448))
|
|
elif name in self.curve25519_names:
|
|
from . import _montgomery
|
|
curve25519 = _montgomery.curve25519_curve()
|
|
curve25519.id = CurveID.CURVE25519
|
|
self.curves.update(dict.fromkeys(self.curve25519_names, curve25519))
|
|
elif name in self.curve448_names:
|
|
from . import _montgomery
|
|
curve448 = _montgomery.curve448_curve()
|
|
curve448.id = CurveID.CURVE448
|
|
self.curves.update(dict.fromkeys(self.curve448_names, curve448))
|
|
else:
|
|
raise ValueError("Unsupported curve '%s'" % name)
|
|
return self.curves[name]
|
|
|
|
def __getitem__(self, name):
|
|
with self.curves_lock:
|
|
curve = self.curves.get(name)
|
|
if curve is None:
|
|
curve = self.load(name)
|
|
if name in self.curve25519_names or name in self.curve448_names:
|
|
curve.G = EccXPoint(curve.Gx, name)
|
|
else:
|
|
curve.G = EccPoint(curve.Gx, curve.Gy, name)
|
|
curve.is_edwards = curve.id in (CurveID.ED25519, CurveID.ED448)
|
|
curve.is_montgomery = curve.id in (CurveID.CURVE25519,
|
|
CurveID.CURVE448)
|
|
curve.is_weierstrass = not (curve.is_edwards or
|
|
curve.is_montgomery)
|
|
return curve
|
|
|
|
def items(self):
|
|
# Load all curves
|
|
for name in self.all_names:
|
|
_ = self[name]
|
|
return self.curves.items()
|
|
|
|
|
|
_curves = _Curves()
|
|
|
|
|
|
class EccPoint(object):
|
|
"""A class to model a point on an Elliptic Curve.
|
|
|
|
The class supports operators for:
|
|
|
|
* Adding two points: ``R = S + T``
|
|
* In-place addition: ``S += T``
|
|
* Negating a point: ``R = -T``
|
|
* Comparing two points: ``if S == T: ...`` or ``if S != T: ...``
|
|
* Multiplying a point by a scalar: ``R = S*k``
|
|
* In-place multiplication by a scalar: ``T *= k``
|
|
|
|
:ivar curve: The **canonical** name of the curve as defined in the `ECC table`_.
|
|
:vartype curve: string
|
|
|
|
:ivar x: The affine X-coordinate of the ECC point
|
|
:vartype x: integer
|
|
|
|
:ivar y: The affine Y-coordinate of the ECC point
|
|
:vartype y: integer
|
|
|
|
:ivar xy: The tuple with affine X- and Y- coordinates
|
|
"""
|
|
|
|
def __init__(self, x, y, curve="p256"):
|
|
|
|
try:
|
|
self._curve = _curves[curve]
|
|
except KeyError:
|
|
raise ValueError("Unknown curve name %s" % str(curve))
|
|
self.curve = self._curve.canonical
|
|
|
|
if self._curve.id == CurveID.CURVE25519:
|
|
raise ValueError("EccPoint cannot be created for Curve25519")
|
|
|
|
modulus_bytes = self.size_in_bytes()
|
|
|
|
xb = long_to_bytes(x, modulus_bytes)
|
|
yb = long_to_bytes(y, modulus_bytes)
|
|
if len(xb) != modulus_bytes or len(yb) != modulus_bytes:
|
|
raise ValueError("Incorrect coordinate length")
|
|
|
|
new_point = self._curve.rawlib.new_point
|
|
free_func = self._curve.rawlib.free_point
|
|
|
|
self._point = VoidPointer()
|
|
try:
|
|
context = self._curve.context.get()
|
|
except AttributeError:
|
|
context = null_pointer
|
|
result = new_point(self._point.address_of(),
|
|
c_uint8_ptr(xb),
|
|
c_uint8_ptr(yb),
|
|
c_size_t(modulus_bytes),
|
|
context)
|
|
|
|
if result:
|
|
if result == 15:
|
|
raise ValueError("The EC point does not belong to the curve")
|
|
raise ValueError("Error %d while instantiating an EC point" % result)
|
|
|
|
# Ensure that object disposal of this Python object will (eventually)
|
|
# free the memory allocated by the raw library for the EC point
|
|
self._point = SmartPointer(self._point.get(), free_func)
|
|
|
|
def set(self, point):
|
|
clone = self._curve.rawlib.clone
|
|
free_func = self._curve.rawlib.free_point
|
|
|
|
self._point = VoidPointer()
|
|
result = clone(self._point.address_of(),
|
|
point._point.get())
|
|
|
|
if result:
|
|
raise ValueError("Error %d while cloning an EC point" % result)
|
|
|
|
self._point = SmartPointer(self._point.get(), free_func)
|
|
return self
|
|
|
|
def __eq__(self, point):
|
|
if not isinstance(point, EccPoint):
|
|
return False
|
|
|
|
cmp_func = self._curve.rawlib.cmp
|
|
return 0 == cmp_func(self._point.get(), point._point.get())
|
|
|
|
# Only needed for Python 2
|
|
def __ne__(self, point):
|
|
return not self == point
|
|
|
|
def __neg__(self):
|
|
neg_func = self._curve.rawlib.neg
|
|
np = self.copy()
|
|
result = neg_func(np._point.get())
|
|
if result:
|
|
raise ValueError("Error %d while inverting an EC point" % result)
|
|
return np
|
|
|
|
def copy(self):
|
|
"""Return a copy of this point."""
|
|
x, y = self.xy
|
|
np = EccPoint(x, y, self.curve)
|
|
return np
|
|
|
|
def is_point_at_infinity(self):
|
|
"""``True`` if this is the *point-at-infinity*."""
|
|
|
|
if self._curve.is_edwards:
|
|
return self.x == 0
|
|
else:
|
|
return self.xy == (0, 0)
|
|
|
|
def point_at_infinity(self):
|
|
"""Return the *point-at-infinity* for the curve."""
|
|
|
|
if self._curve.is_edwards:
|
|
return EccPoint(0, 1, self.curve)
|
|
else:
|
|
return EccPoint(0, 0, self.curve)
|
|
|
|
@property
|
|
def x(self):
|
|
return self.xy[0]
|
|
|
|
@property
|
|
def y(self):
|
|
return self.xy[1]
|
|
|
|
@property
|
|
def xy(self):
|
|
modulus_bytes = self.size_in_bytes()
|
|
xb = bytearray(modulus_bytes)
|
|
yb = bytearray(modulus_bytes)
|
|
get_xy = self._curve.rawlib.get_xy
|
|
result = get_xy(c_uint8_ptr(xb),
|
|
c_uint8_ptr(yb),
|
|
c_size_t(modulus_bytes),
|
|
self._point.get())
|
|
if result:
|
|
raise ValueError("Error %d while encoding an EC point" % result)
|
|
|
|
return (Integer(bytes_to_long(xb)), Integer(bytes_to_long(yb)))
|
|
|
|
def size_in_bytes(self):
|
|
"""Size of each coordinate, in bytes."""
|
|
return (self.size_in_bits() + 7) // 8
|
|
|
|
def size_in_bits(self):
|
|
"""Size of each coordinate, in bits."""
|
|
return self._curve.modulus_bits
|
|
|
|
def double(self):
|
|
"""Double this point (in-place operation).
|
|
|
|
Returns:
|
|
This same object (to enable chaining).
|
|
"""
|
|
|
|
double_func = self._curve.rawlib.double
|
|
result = double_func(self._point.get())
|
|
if result:
|
|
raise ValueError("Error %d while doubling an EC point" % result)
|
|
return self
|
|
|
|
def __iadd__(self, point):
|
|
"""Add a second point to this one"""
|
|
|
|
add_func = self._curve.rawlib.add
|
|
result = add_func(self._point.get(), point._point.get())
|
|
if result:
|
|
if result == 16:
|
|
raise ValueError("EC points are not on the same curve")
|
|
raise ValueError("Error %d while adding two EC points" % result)
|
|
return self
|
|
|
|
def __add__(self, point):
|
|
"""Return a new point, the addition of this one and another"""
|
|
|
|
np = self.copy()
|
|
np += point
|
|
return np
|
|
|
|
def __imul__(self, scalar):
|
|
"""Multiply this point by a scalar"""
|
|
|
|
scalar_func = self._curve.rawlib.scalar
|
|
if scalar < 0:
|
|
raise ValueError("Scalar multiplication is only defined for non-negative integers")
|
|
sb = long_to_bytes(scalar)
|
|
result = scalar_func(self._point.get(),
|
|
c_uint8_ptr(sb),
|
|
c_size_t(len(sb)),
|
|
c_ulonglong(getrandbits(64)))
|
|
if result:
|
|
raise ValueError("Error %d during scalar multiplication" % result)
|
|
return self
|
|
|
|
def __mul__(self, scalar):
|
|
"""Return a new point, the scalar product of this one"""
|
|
|
|
np = self.copy()
|
|
np *= scalar
|
|
return np
|
|
|
|
def __rmul__(self, left_hand):
|
|
return self.__mul__(left_hand)
|
|
|
|
|
|
class EccXPoint(object):
|
|
"""A class to model a point on an Elliptic Curve,
|
|
where only the X-coordinate is exposed.
|
|
|
|
The class supports operators for:
|
|
|
|
* Multiplying a point by a scalar: ``R = S*k``
|
|
* In-place multiplication by a scalar: ``T *= k``
|
|
|
|
:ivar curve: The **canonical** name of the curve as defined in the `ECC table`_.
|
|
:vartype curve: string
|
|
|
|
:ivar x: The affine X-coordinate of the ECC point
|
|
:vartype x: integer
|
|
"""
|
|
|
|
def __init__(self, x, curve):
|
|
# Once encoded, x must not exceed the length of the modulus,
|
|
# but its value may match or exceed the modulus itself
|
|
# (i.e., non-canonical value)
|
|
|
|
try:
|
|
self._curve = _curves[curve]
|
|
except KeyError:
|
|
raise ValueError("Unknown curve name %s" % str(curve))
|
|
self.curve = self._curve.canonical
|
|
|
|
if self._curve.id not in (CurveID.CURVE25519, CurveID.CURVE448):
|
|
raise ValueError("EccXPoint can only be created for Curve25519/Curve448")
|
|
|
|
new_point = self._curve.rawlib.new_point
|
|
free_func = self._curve.rawlib.free_point
|
|
|
|
self._point = VoidPointer()
|
|
try:
|
|
context = self._curve.context.get()
|
|
except AttributeError:
|
|
context = null_pointer
|
|
|
|
modulus_bytes = self.size_in_bytes()
|
|
|
|
if x is None:
|
|
xb = null_pointer
|
|
else:
|
|
xb = c_uint8_ptr(long_to_bytes(x, modulus_bytes))
|
|
if len(xb) != modulus_bytes:
|
|
raise ValueError("Incorrect coordinate length")
|
|
|
|
self._point = VoidPointer()
|
|
result = new_point(self._point.address_of(),
|
|
xb,
|
|
c_size_t(modulus_bytes),
|
|
context)
|
|
|
|
if result == 15:
|
|
raise ValueError("The EC point does not belong to the curve")
|
|
if result:
|
|
raise ValueError("Error %d while instantiating an EC point" % result)
|
|
|
|
# Ensure that object disposal of this Python object will (eventually)
|
|
# free the memory allocated by the raw library for the EC point
|
|
self._point = SmartPointer(self._point.get(), free_func)
|
|
|
|
def set(self, point):
|
|
clone = self._curve.rawlib.clone
|
|
free_func = self._curve.rawlib.free_point
|
|
|
|
self._point = VoidPointer()
|
|
result = clone(self._point.address_of(),
|
|
point._point.get())
|
|
if result:
|
|
raise ValueError("Error %d while cloning an EC point" % result)
|
|
|
|
self._point = SmartPointer(self._point.get(), free_func)
|
|
return self
|
|
|
|
def __eq__(self, point):
|
|
if not isinstance(point, EccXPoint):
|
|
return False
|
|
|
|
cmp_func = self._curve.rawlib.cmp
|
|
p1 = self._point.get()
|
|
p2 = point._point.get()
|
|
res = cmp_func(p1, p2)
|
|
return 0 == res
|
|
|
|
def copy(self):
|
|
"""Return a copy of this point."""
|
|
|
|
try:
|
|
x = self.x
|
|
except ValueError:
|
|
return self.point_at_infinity()
|
|
return EccXPoint(x, self.curve)
|
|
|
|
def is_point_at_infinity(self):
|
|
"""``True`` if this is the *point-at-infinity*."""
|
|
|
|
try:
|
|
_ = self.x
|
|
except ValueError:
|
|
return True
|
|
return False
|
|
|
|
def point_at_infinity(self):
|
|
"""Return the *point-at-infinity* for the curve."""
|
|
|
|
return EccXPoint(None, self.curve)
|
|
|
|
@property
|
|
def x(self):
|
|
modulus_bytes = self.size_in_bytes()
|
|
xb = bytearray(modulus_bytes)
|
|
get_x = self._curve.rawlib.get_x
|
|
result = get_x(c_uint8_ptr(xb),
|
|
c_size_t(modulus_bytes),
|
|
self._point.get())
|
|
if result == 19: # ERR_ECC_PAI
|
|
raise ValueError("No X coordinate for the point at infinity")
|
|
if result:
|
|
raise ValueError("Error %d while getting X of an EC point" % result)
|
|
return Integer(bytes_to_long(xb))
|
|
|
|
def size_in_bytes(self):
|
|
"""Size of each coordinate, in bytes."""
|
|
return (self.size_in_bits() + 7) // 8
|
|
|
|
def size_in_bits(self):
|
|
"""Size of each coordinate, in bits."""
|
|
return self._curve.modulus_bits
|
|
|
|
def __imul__(self, scalar):
|
|
"""Multiply this point by a scalar"""
|
|
|
|
scalar_func = self._curve.rawlib.scalar
|
|
if scalar < 0:
|
|
raise ValueError("Scalar multiplication is only defined for non-negative integers")
|
|
sb = long_to_bytes(scalar)
|
|
result = scalar_func(self._point.get(),
|
|
c_uint8_ptr(sb),
|
|
c_size_t(len(sb)),
|
|
c_ulonglong(getrandbits(64)))
|
|
if result:
|
|
raise ValueError("Error %d during scalar multiplication" % result)
|
|
return self
|
|
|
|
def __mul__(self, scalar):
|
|
"""Return a new point, the scalar product of this one"""
|
|
|
|
np = self.copy()
|
|
np *= scalar
|
|
return np
|
|
|
|
def __rmul__(self, left_hand):
|
|
return self.__mul__(left_hand)
|