diff --git a/google/auth/jwt.py b/google/auth/jwt.py index a4f04f529..8165ddad7 100644 --- a/google/auth/jwt.py +++ b/google/auth/jwt.py @@ -95,10 +95,11 @@ def encode(signer, payload, header=None, key_id=None): header.update({"typ": "JWT"}) - if es256 is not None and isinstance(signer, es256.ES256Signer): - header.update({"alg": "ES256"}) - else: - header.update({"alg": "RS256"}) + if "alg" not in header: + if es256 is not None and isinstance(signer, es256.ES256Signer): + header.update({"alg": "ES256"}) + else: + header.update({"alg": "RS256"}) if key_id is not None: header["kid"] = key_id diff --git a/tests/test_jwt.py b/tests/test_jwt.py index 7aa031ec5..7b5ba5cdc 100644 --- a/tests/test_jwt.py +++ b/tests/test_jwt.py @@ -73,6 +73,12 @@ def test_encode_extra_headers(signer): } +def test_encode_custom_alg_in_headers(signer): + encoded = jwt.encode(signer, {}, header={"alg": "foo"}) + header = jwt.decode_header(encoded) + assert header == {"typ": "JWT", "alg": "foo", "kid": signer.key_id} + + @pytest.fixture def es256_signer(): return crypt.ES256Signer.from_string(EC_PRIVATE_KEY_BYTES, "1")