Skip to content

Commit

Permalink
ComparableOKPKey work
Browse files Browse the repository at this point in the history
  • Loading branch information
atombrella committed Aug 11, 2021
1 parent 6c2a6d7 commit 590fc50
Show file tree
Hide file tree
Showing 11 changed files with 92 additions and 97 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Changelog
* Added support for Ed25519, Ed448, X25519 and X448 keys (see `RFC 8037 <https://tools.ietf.org/html/rfc8037>`_).
These are also known as Bernstein curves.
* Added support for signing with Ed25519, Ed448, X25519 and X448 keys
(see `RFC 8032 <https://datatracker.ietf.org/doc/html/rfc8032>`_).
(see `RFC 8032 <https://datatracker.ietf.org/doc/html/rfc8032>`_). See JWA.
* Minimum requirement of ``cryptography`` is now 2.6+.

1.8.0 (2021-03-15)
Expand Down
1 change: 1 addition & 0 deletions src/josepy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
ES256,
ES384,
ES512,
EdDSA,
)

from josepy.jwk import (
Expand Down
2 changes: 1 addition & 1 deletion src/josepy/json_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def register(cls, type_cls, typ=None):
def get_type_cls(cls, jobj):
"""Get the registered class for ``jobj``."""
if cls in cls.TYPES.values():
if cls.type_field_name not in jobj:
if cls.type_field_name not in jobj: # noqa
raise errors.DeserializationError(
"Missing type field ({0})".format(cls.type_field_name))
# cls is already registered type_cls, force to use it
Expand Down
34 changes: 0 additions & 34 deletions src/josepy/jwa.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,31 +224,6 @@ def _verify(self, key, msg, asn1sig):
return True


class _JWAOKP(JWASignature):
kty = jwk.JWKOKP

def __init__(self, name, hash_):
super().__init__(name)
self.hash = hash_()

@classmethod
def register(cls, signature_cls):
# might need to overwrite this, so I can get the argument in
return super().register(signature_cls)

def sign(self, key, msg: bytes):
return key.sign(msg)

def verify(self, key, msg: bytes, sig: bytes):
try:
key.verify(signature=sig, data=msg)
except cryptography.exceptions.InvalidSignature as error:
logger.debug(error, exc_info=True)
return False
else:
return True


#: HMAC using SHA-256
HS256 = JWASignature.register(_JWAHS('HS256', hashes.SHA256))
#: HMAC using SHA-384
Expand Down Expand Up @@ -276,12 +251,3 @@ def verify(self, key, msg: bytes, sig: bytes):
ES384 = JWASignature.register(_JWAEC('ES384', hashes.SHA384))
#: ECDSA using P-521 and SHA-512
ES512 = JWASignature.register(_JWAEC('ES512', hashes.SHA512))

#: Ed25519 uses SHA512
ES25519 = JWASignature.register(_JWAOKP('ES25519', hashes.SHA512))
#: Ed448 uses SHA3/SHAKE256
# ES448 = JWASignature.register(_JWAOKP('ES448', hashes.SHAKE256))
# #: X25519 uses SHA3/SHAKE256
# X22519 = JWASignature.register(_JWAOKP('X22519', hashes.SHAKE256))
# #: X448 uses SHA3/SHAKE256
# X448 = JWASignature.register(_JWAOKP('X448', hashes.SHAKE256))
18 changes: 0 additions & 18 deletions src/josepy/jwa_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@
EC_P256_KEY = test_util.load_ec_private_key('ec_p256_key.pem')
EC_P384_KEY = test_util.load_ec_private_key('ec_p384_key.pem')
EC_P521_KEY = test_util.load_ec_private_key('ec_p521_key.pem')
OKP_ED25519_KEY = test_util.load_ec_private_key('ed25519_key.pem')
OKP_ED448_KEY = test_util.load_ec_private_key('ed448_key.pem')
OKP_X25519_KEY = test_util.load_ec_private_key('x25519_key.pem')
OKP_X448_KEY = test_util.load_ec_private_key('x448_key.pem')


class JWASignatureTest(unittest.TestCase):
Expand Down Expand Up @@ -230,19 +226,5 @@ def test_signature_size(self):
self.assertEqual(len(sig), 2 * 66)


# class JWAOKPTests(JWASignatureTest):
# # look up the signature sizes in the RFC
#
# def test_sign_no_private_part(self):
# from josepy.jwa import ES25519
# self.assertRaises(errors.Error, ES25519.sign, OKP_ED25519_KEY, b'foo')
#
# # def test_can_size_ed25519(self):
# # ES25519.sign(b'foo'), OKP_ED25519_KEY,
#
# def test_signature_size(self):
# pass


if __name__ == '__main__':
unittest.main() # pragma: no cover
37 changes: 18 additions & 19 deletions src/josepy/jwk.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""JSON Web Key."""
import abc
import collections
import json
import logging
import math
Expand Down Expand Up @@ -257,7 +258,7 @@ def fields_to_partial_json(self):

@JWK.register
class JWKEC(JWK):
"""EC JWK.
"""RSA JWK.
:ivar key: :class:`~cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey`
or :class:`~cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey`
Expand Down Expand Up @@ -389,24 +390,25 @@ class JWKOKP(JWK):
or :class:`~cryptography.hazmat.primitives.asymmetric.x448.X448PublicKey`
or :class:`~cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey`
or :class:`~cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey`
wrapped in :class:`~josepy.util.ComparableOKPKey`
This class requires ``cryptography>=2.6`` to be installed.
"""
typ = 'OKP'
__slots__ = ('key', )

__slots__ = ('key',)
cryptography_key_types = (
ed25519.Ed25519PrivateKey, ed25519.Ed25519PrivateKey,
ed448.Ed448PublicKey, ed448.Ed448PrivateKey,
x25519.X25519PrivateKey, x25519.X25519PublicKey,
x448.X448PrivateKey, x448.X448PublicKey,
)
required = ('crv', JWK.type_field_name, 'x')
okp_curve = collections.namedtuple('okp_curve', 'pubkey privkey')
crv_to_pub_priv = {
"Ed25519": (ed25519.Ed25519PublicKey, ed25519.Ed25519PrivateKey),
"Ed448": (ed448.Ed448PublicKey, ed448.Ed448PrivateKey),
"X25519": (x25519.X25519PublicKey, x25519.X25519PrivateKey),
"X448": (x448.X448PublicKey, x448.X448PrivateKey),
"Ed25519": okp_curve(pubkey=ed25519.Ed25519PublicKey, privkey=ed25519.Ed25519PrivateKey),
"Ed448": okp_curve(pubkey=ed448.Ed448PublicKey, privkey=ed448.Ed448PrivateKey),
"X25519": okp_curve(pubkey=x25519.X25519PublicKey, privkey=x25519.X25519PrivateKey),
"X448": okp_curve(pubkey=x448.X448PublicKey, privkey=x448.X448PrivateKey),
}

def __init__(self, *args, **kwargs):
Expand All @@ -428,20 +430,20 @@ def _key_to_crv(self):
return "X448"
return NotImplemented

def fields_to_partial_json(self) -> Dict:
def fields_to_partial_json(self):
params = {}
if self.key.is_private():
params['d'] = json_util.encode_b64jose(self.key.private_bytes(
params['d'] = json_util.encode_b64jose(self.key._wrapped.private_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PrivateFormat.Raw,
encryption_algorithm=serialization.NoEncryption()
))
params['x'] = self.key.public_key().public_bytes(
params['x'] = self.key._wrapped.public_key().public_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw,
)
else:
params['x'] = json_util.encode_b64jose(self.key.public_bytes(
params['x'] = json_util.encode_b64jose(self.key._wrapped.public_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw,
))
Expand All @@ -460,16 +462,13 @@ def fields_from_json(cls, jobj):
except ValueError:
raise errors.DeserializationError("Key is not valid JSON")

if obj.get("kty") != "OKP":
raise errors.DeserializationError("Not an Octet Key Pair")

curve = obj.get("crv")
curve = obj["crv"]
if curve not in cls.crv_to_pub_priv:
raise errors.DeserializationError(f"Invalid curve: {curve}")

if "x" not in obj:
raise errors.DeserializationError('OKP should have "x" parameter')
x = json_util.decode_b64jose(jobj.get("x"))
x = json_util.decode_b64jose(jobj["x"])

try:
if "d" not in obj: # public key
Expand All @@ -478,16 +477,16 @@ def fields_from_json(cls, jobj):
ed448.Ed448PublicKey,
x25519.X25519PublicKey,
x448.X448PublicKey,
]] = cls.crv_to_pub_priv[curve][0]
]] = cls.crv_to_pub_priv[curve].pubkey
return cls(key=pub_class.from_public_bytes(x))
else: # private key
d = json_util.decode_b64jose(obj.get("d"))
d = json_util.decode_b64jose(obj["d"])
priv_key_class: Type[Union[
ed25519.Ed25519PrivateKey,
ed448.Ed448PrivateKey,
x25519.X25519PrivateKey,
x448.X448PrivateKey,
]] = cls.crv_to_pub_priv[curve][1]
]] = cls.crv_to_pub_priv[curve].privkey
return cls(key=priv_key_class.from_private_bytes(d))
except ValueError as err:
raise errors.DeserializationError("Invalid key parameter") from err
21 changes: 21 additions & 0 deletions src/josepy/jwk_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@ def test_fields_to_json(self):
key = JWK.load(data)
data = key.fields_to_partial_json()
self.assertEqual(data['crv'], "Ed25519")
self.assertIsInstance(data['x'], bytes)

def test_init_auto_comparable(self):
self.assertIsInstance(self.x448_key.key, util.ComparableOKPKey)
Expand All @@ -421,10 +422,30 @@ def test_unknown_crv_name(self):
}
)

def test_no_x_name(self):
from josepy.jwk import JWK
with self.assertRaises(errors.DeserializationError) as warn:
JWK.from_json(
{
'kty': 'OKP',
'crv': 'Ed448',
}
)
self.assertEqual(
warn.exception.__str__(),
'Deserialization error: OKP should have "x" parameter'
)

def test_from_json_hashable(self):
from josepy.jwk import JWK
hash(JWK.from_json(self.jwked25519json))

def test_deserialize_public_key(self):
# should target jwk.py:474-484, but those lines are still marked as missing
# in the coverage report
from josepy.jwk import JWKOKP
JWKOKP.fields_from_json(self.jwked25519json)


if __name__ == '__main__':
unittest.main() # pragma: no cover
11 changes: 10 additions & 1 deletion src/josepy/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from cryptography.hazmat.primitives import serialization

from josepy import ComparableRSAKey, ComparableX509
from josepy.util import ComparableECKey
from josepy.util import ComparableECKey, ComparableOKPKey


def vector_path(*names):
Expand Down Expand Up @@ -77,6 +77,15 @@ def load_ec_private_key(*names):
load_vector(*names), password=None, backend=default_backend()))


def load_okp_private_key(*names):
"""Load OKP private key."""
loader = _guess_loader(
names[-1], serialization.load_pem_private_key,
serialization.load_der_private_key,
)
return ComparableOKPKey(loader(load_vector(*names), password=None, backend=default_backend()))


def load_pyopenssl_private_key(*names):
"""Load pyOpenSSL private key."""
loader = _guess_loader(
Expand Down
35 changes: 14 additions & 21 deletions src/josepy/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,7 @@
import OpenSSL
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import (
ec,
ed25519, ed448,
rsa,
x25519, x448,
)
from cryptography.hazmat.primitives.asymmetric import ec, rsa


class abstractclassmethod(classmethod):
Expand Down Expand Up @@ -167,7 +162,7 @@ def public_key(self):
class ComparableOKPKey(ComparableKey):
"""Wrapper for ``cryptography`` OKP keys.
Wraps around:
Wraps around any of these available with the compilation
- :class:`~cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey`
- :class:`~cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey`
- :class:`~cryptography.hazmat.primitives.asymmetric.ed448.Ed448PublicKey`
Expand All @@ -179,24 +174,22 @@ class ComparableOKPKey(ComparableKey):
"""

def __hash__(self):
# Computed using the thumbprint
# https://datatracker.ietf.org/doc/html/rfc7638#section-3
if self.is_private():
priv = self._wrapped.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
)
pub = priv.public_key
return hash((self.__class__, pub.curve.name, priv))
else:
pub = self._wrapped.public_key()
return hash((self.__class__, pub.curve.name, pub))
else:
pub = self._wrapped
return hash(pub.public_bytes(
format=serialization.PublicFormat.Raw,
encoding=serialization.Encoding.Raw,
)[:32])

def is_private(self) -> bool:
return isinstance(
self._wrapped, (
ed25519.Ed25519PrivateKey, ed448.Ed448PrivateKey,
x25519.X25519PrivateKey, x448.X448PrivateKey
)
)
# Not all of the curves may be available with OpenSSL,
# so instead of doing instance checks against the private
# key classes, we do this
return hasattr(self._wrapped, "private_bytes")


class ImmutableMap(Mapping, Hashable):
Expand Down
26 changes: 25 additions & 1 deletion src/josepy/util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import functools
import unittest


from josepy import test_util


Expand Down Expand Up @@ -136,6 +135,31 @@ def test_public_key(self):
self.assertIsInstance(self.p256_key.public_key(), ComparableECKey)


class ComparableOKPKeyTests(unittest.TestCase):
def setUp(self):
# test_utl.load_ec_private_key return ComparableECKey
self.ed25519_key = test_util.load_okp_private_key('ed25519_key.pem')
self.ed25519_key_same = test_util.load_okp_private_key('ed25519_key.pem')
self.ed448_key = test_util.load_okp_private_key('ed448_key.pem')
self.x25519_key = test_util.load_okp_private_key('x25519_key.pem')
self.x448_key = test_util.load_okp_private_key('x448_key.pem')

def test_repr(self):
self.assertIs(repr(self.ed25519_key).startswith(
'<ComparableOKPKey(<cryptography.hazmat.'), True)

def test_public_key(self):
from josepy.util import ComparableOKPKey
self.assertIsInstance(self.ed25519_key.public_key(), ComparableOKPKey)

def test_hash(self):
self.assertIsInstance(hash(self.ed25519_key), int)
self.assertEqual(hash(self.ed25519_key), hash(self.ed25519_key_same))
self.assertNotEqual(hash(self.ed25519_key), hash(self.ed448_key))
self.assertNotEqual(hash(self.ed25519_key), hash(self.x25519_key))
self.assertNotEqual(hash(self.x25519_key), hash(self.ed448_key))


class ImmutableMapTest(unittest.TestCase):
"""Tests for josepy.util.ImmutableMap."""

Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ envlist =

[testenv]
commands =
py.test {posargs}
py.test -s {posargs}
deps =
-cconstraints.txt
-e .[tests]
Expand Down

0 comments on commit 590fc50

Please sign in to comment.