diff --git a/fizz/client/test/ClientProtocolTest.cpp b/fizz/client/test/ClientProtocolTest.cpp index ad18b18cef..1a0cc224d0 100644 --- a/fizz/client/test/ClientProtocolTest.cpp +++ b/fizz/client/test/ClientProtocolTest.cpp @@ -50,8 +50,8 @@ class ClientProtocolTest : public ProtocolTest { pskCache_ = std::make_shared(); context_->setPskCache(pskCache_); context_->setSendEarlyData(true); - mockLeaf_ = std::make_shared(); - mockClientCert_ = std::make_shared(); + mockLeaf_ = std::make_unique(); + mockClientCert_ = std::make_unique(); mockClock_ = std::make_shared(); context_->setClock(mockClock_); ON_CALL(*mockClock_, getCurrentTime()) @@ -3738,12 +3738,14 @@ TEST_F(ClientProtocolTest, TestCertificateFlow) { setupExpectingCertificate(); EXPECT_CALL( *mockHandshakeContext_, appendToTranscript(BufMatches("certencoding"))); - mockLeaf_ = std::make_shared(); - mockIntermediate_ = std::make_shared(); + auto mockLeafCert = std::make_unique(); + auto mockIntermediateCert = std::make_unique(); + auto mockLeafPtr = mockLeafCert.get(); + auto mockIntermediatePtr = mockIntermediateCert.get(); EXPECT_CALL(*factory_, _makePeerCert(CertEntryBufMatches("cert1"), true)) - .WillOnce(Return(mockLeaf_)); + .WillOnce(Return(std::move(mockLeafCert))); EXPECT_CALL(*factory_, _makePeerCert(CertEntryBufMatches("cert2"), false)) - .WillOnce(Return(mockIntermediate_)); + .WillOnce(Return(std::move(mockIntermediateCert))); auto certificate = TestMessages::certificate(); CertificateEntry entry1; @@ -3757,8 +3759,8 @@ TEST_F(ClientProtocolTest, TestCertificateFlow) { expectActions(actions); processStateMutations(actions); EXPECT_EQ(state_.unverifiedCertChain()->size(), 2); - EXPECT_EQ(state_.unverifiedCertChain()->at(0), mockLeaf_); - EXPECT_EQ(state_.unverifiedCertChain()->at(1), mockIntermediate_); + EXPECT_EQ(state_.unverifiedCertChain()->at(0).get(), mockLeafPtr); + EXPECT_EQ(state_.unverifiedCertChain()->at(1).get(), mockIntermediatePtr); EXPECT_EQ(state_.state(), StateEnum::ExpectingCertificateVerify); } @@ -3878,12 +3880,15 @@ TEST_F(ClientProtocolTest, TestCompressedCertificateFlow) { EXPECT_CALL( *mockHandshakeContext_, appendToTranscript(BufMatches("compcertencoding"))); - mockLeaf_ = std::make_shared(); - mockIntermediate_ = std::make_shared(); + + auto mockLeafCert = std::make_unique(); + auto mockIntermediateCert = std::make_unique(); + auto mockLeafPtr = mockLeafCert.get(); + auto mockIntermediatePtr = mockIntermediateCert.get(); EXPECT_CALL(*factory_, _makePeerCert(CertEntryBufMatches("cert1"), true)) - .WillOnce(Return(mockLeaf_)); + .WillOnce(Return(std::move(mockLeafCert))); EXPECT_CALL(*factory_, _makePeerCert(CertEntryBufMatches("cert2"), false)) - .WillOnce(Return(mockIntermediate_)); + .WillOnce(Return(std::move(mockIntermediateCert))); auto decompressor = std::make_shared(); decompressor->setDefaults(); @@ -3914,8 +3919,8 @@ TEST_F(ClientProtocolTest, TestCompressedCertificateFlow) { expectActions(actions); processStateMutations(actions); EXPECT_EQ(state_.unverifiedCertChain()->size(), 2); - EXPECT_EQ(state_.unverifiedCertChain()->at(0), mockLeaf_); - EXPECT_EQ(state_.unverifiedCertChain()->at(1), mockIntermediate_); + EXPECT_EQ(state_.unverifiedCertChain()->at(0).get(), mockLeafPtr); + EXPECT_EQ(state_.unverifiedCertChain()->at(1).get(), mockIntermediatePtr); EXPECT_EQ(state_.serverCertCompAlgo(), CertificateCompressionAlgorithm::zlib); EXPECT_EQ(state_.state(), StateEnum::ExpectingCertificateVerify); } diff --git a/fizz/experimental/protocol/BatchSignatureFactory.h b/fizz/experimental/protocol/BatchSignatureFactory.h index 654a4be55f..19766811a2 100644 --- a/fizz/experimental/protocol/BatchSignatureFactory.h +++ b/fizz/experimental/protocol/BatchSignatureFactory.h @@ -90,10 +90,10 @@ class BatchSignatureFactory : public Factory { * Since batch signature is only for verifying the leaf of the certificate * chain, so BatchSignaturePeerCert is turned only when @param leaf is true. */ - std::shared_ptr makePeerCert(CertificateEntry certEntry, bool leaf) + std::unique_ptr makePeerCert(CertificateEntry certEntry, bool leaf) const override { if (leaf) { - return std::make_shared( + return std::make_unique( original_->makePeerCert(std::move(certEntry), leaf)); } return original_->makePeerCert(std::move(certEntry), leaf); diff --git a/fizz/extensions/delegatedcred/DelegatedCredentialFactory.cpp b/fizz/extensions/delegatedcred/DelegatedCredentialFactory.cpp index 225ba22f54..c27308b8fc 100644 --- a/fizz/extensions/delegatedcred/DelegatedCredentialFactory.cpp +++ b/fizz/extensions/delegatedcred/DelegatedCredentialFactory.cpp @@ -14,7 +14,7 @@ namespace fizz { namespace extensions { namespace { -std::shared_ptr makeCredential( +std::unique_ptr makeCredential( DelegatedCredential&& credential, folly::ssl::X509UniquePtr cert) { VLOG(4) << "Making delegated credential"; @@ -31,19 +31,19 @@ std::shared_ptr makeCredential( switch (CertUtils::getKeyType(pubKey)) { case KeyType::RSA: - return std::make_shared>( + return std::make_unique>( std::move(cert), std::move(pubKey), std::move(credential)); case KeyType::P256: - return std::make_shared>( + return std::make_unique>( std::move(cert), std::move(pubKey), std::move(credential)); case KeyType::P384: - return std::make_shared>( + return std::make_unique>( std::move(cert), std::move(pubKey), std::move(credential)); case KeyType::P521: - return std::make_shared>( + return std::make_unique>( std::move(cert), std::move(pubKey), std::move(credential)); case KeyType::ED25519: - return std::make_shared>( + return std::make_unique>( std::move(cert), std::move(pubKey), std::move(credential)); } @@ -53,7 +53,7 @@ std::shared_ptr makeCredential( } } // namespace -std::shared_ptr DelegatedCredentialFactory::makePeerCertStatic( +std::unique_ptr DelegatedCredentialFactory::makePeerCertStatic( CertificateEntry entry, bool leaf) { if (!leaf || entry.extensions.empty()) { @@ -65,14 +65,14 @@ std::shared_ptr DelegatedCredentialFactory::makePeerCertStatic( // No credential, just leave as is if (!credential) { - return std::move(parentCert); + return parentCert; } // Create credential return makeCredential(std::move(credential.value()), std::move(parentX509)); } -std::shared_ptr DelegatedCredentialFactory::makePeerCert( +std::unique_ptr DelegatedCredentialFactory::makePeerCert( CertificateEntry entry, bool leaf) const { return makePeerCertStatic(std::move(entry), leaf); diff --git a/fizz/extensions/delegatedcred/DelegatedCredentialFactory.h b/fizz/extensions/delegatedcred/DelegatedCredentialFactory.h index 965ac11ec1..73e495dbca 100644 --- a/fizz/extensions/delegatedcred/DelegatedCredentialFactory.h +++ b/fizz/extensions/delegatedcred/DelegatedCredentialFactory.h @@ -21,10 +21,10 @@ class DelegatedCredentialFactory : public OpenSSLFactory { public: ~DelegatedCredentialFactory() override = default; - std::shared_ptr makePeerCert(CertificateEntry entry, bool leaf) + std::unique_ptr makePeerCert(CertificateEntry entry, bool leaf) const override; - static std::shared_ptr makePeerCertStatic( + static std::unique_ptr makePeerCertStatic( CertificateEntry entry, bool leaf); }; diff --git a/fizz/extensions/javacrypto/JavaCryptoFactory.h b/fizz/extensions/javacrypto/JavaCryptoFactory.h index 6297206b9e..43bdd6f72e 100644 --- a/fizz/extensions/javacrypto/JavaCryptoFactory.h +++ b/fizz/extensions/javacrypto/JavaCryptoFactory.h @@ -20,7 +20,7 @@ class JavaCryptoFactory : public OpenSSLFactory { public: ~JavaCryptoFactory() override = default; - std::shared_ptr makePeerCert( + std::unique_ptr makePeerCert( CertificateEntry certEntry, bool /*leaf*/) const override { if (certEntry.cert_data->empty()) { diff --git a/fizz/protocol/Factory.h b/fizz/protocol/Factory.h index 1cef31dc9e..69c9ef5fd4 100644 --- a/fizz/protocol/Factory.h +++ b/fizz/protocol/Factory.h @@ -71,7 +71,7 @@ class Factory { [[nodiscard]] virtual std::unique_ptr makeRandomBytes( size_t count) const = 0; - virtual std::shared_ptr makePeerCert( + virtual std::unique_ptr makePeerCert( CertificateEntry certEntry, bool /*leaf*/) const = 0; diff --git a/fizz/protocol/OpenSSLFactory.cpp b/fizz/protocol/OpenSSLFactory.cpp index 76c70c75e5..abefccaf03 100644 --- a/fizz/protocol/OpenSSLFactory.cpp +++ b/fizz/protocol/OpenSSLFactory.cpp @@ -113,7 +113,7 @@ std::unique_ptr OpenSSLFactory::makeHandshakeContext( } } -std::shared_ptr OpenSSLFactory::makePeerCert( +std::unique_ptr OpenSSLFactory::makePeerCert( CertificateEntry certEntry, bool /*leaf*/) const { return CertUtils::makePeerCert(std::move(certEntry.cert_data)); diff --git a/fizz/protocol/OpenSSLFactory.h b/fizz/protocol/OpenSSLFactory.h index 6a1ff2fe54..984b5eb91f 100644 --- a/fizz/protocol/OpenSSLFactory.h +++ b/fizz/protocol/OpenSSLFactory.h @@ -38,7 +38,7 @@ class OpenSSLFactory : public DefaultFactory { std::unique_ptr makeHandshakeContext( CipherSuite cipher) const override; - [[nodiscard]] std::shared_ptr makePeerCert( + [[nodiscard]] std::unique_ptr makePeerCert( CertificateEntry certEntry, bool /*leaf*/) const override; }; diff --git a/fizz/protocol/test/Mocks.h b/fizz/protocol/test/Mocks.h index b51f402f7b..51c15ae930 100644 --- a/fizz/protocol/test/Mocks.h +++ b/fizz/protocol/test/Mocks.h @@ -257,11 +257,11 @@ class MockFactory : public OpenSSLFactory { MOCK_METHOD(uint32_t, makeTicketAgeAdd, (), (const)); MOCK_METHOD( - std::shared_ptr, + std::unique_ptr, _makePeerCert, (CertificateEntry & entry, bool leaf), (const)); - std::shared_ptr makePeerCert(CertificateEntry entry, bool leaf) + std::unique_ptr makePeerCert(CertificateEntry entry, bool leaf) const override { return _makePeerCert(entry, leaf); } diff --git a/fizz/server/test/ServerProtocolTest.cpp b/fizz/server/test/ServerProtocolTest.cpp index 2aefabaff6..bd3549b7a6 100644 --- a/fizz/server/test/ServerProtocolTest.cpp +++ b/fizz/server/test/ServerProtocolTest.cpp @@ -6193,12 +6193,14 @@ TEST_F(ServerProtocolTest, TestCertificate) { setUpExpectingCertificate(); EXPECT_CALL( *mockHandshakeContext_, appendToTranscript(BufMatches("certencoding"))); - clientLeafCert_ = std::make_shared(); - clientIntCert_ = std::make_shared(); + auto clientLeafCert = std::make_unique(); + auto clientIntCert = std::make_unique(); + auto clientLeafPtr = clientLeafCert.get(); + auto clientIntPtr = clientIntCert.get(); EXPECT_CALL(*factory_, _makePeerCert(CertEntryBufMatches("cert1"), true)) - .WillOnce(Return(clientLeafCert_)); + .WillOnce(Return(std::move(clientLeafCert))); EXPECT_CALL(*factory_, _makePeerCert(CertEntryBufMatches("cert2"), false)) - .WillOnce(Return(clientIntCert_)); + .WillOnce(Return(std::move(clientIntCert))); auto certificate = TestMessages::certificate(); CertificateEntry entry1; @@ -6213,8 +6215,8 @@ TEST_F(ServerProtocolTest, TestCertificate) { expectActions(actions); processStateMutations(actions); EXPECT_EQ(state_.unverifiedCertChain()->size(), 2); - EXPECT_EQ(state_.unverifiedCertChain()->at(0), clientLeafCert_); - EXPECT_EQ(state_.unverifiedCertChain()->at(1), clientIntCert_); + EXPECT_EQ(state_.unverifiedCertChain()->at(0).get(), clientLeafPtr); + EXPECT_EQ(state_.unverifiedCertChain()->at(1).get(), clientIntPtr); EXPECT_EQ(state_.state(), StateEnum::ExpectingCertificateVerify); } @@ -6290,12 +6292,14 @@ TEST_F(ServerProtocolTest, TestCertificateExtensionsNotSupported) { TEST_F(ServerProtocolTest, TestCertificateExtensionsSupported) { setUpExpectingCertificate(); - clientLeafCert_ = std::make_shared(); - clientIntCert_ = std::make_shared(); + auto clientLeafCert = std::make_unique(); + auto clientLeafPtr = clientLeafCert.get(); + auto clientIntCert = std::make_unique(); + auto clientIntPtr = clientIntCert.get(); EXPECT_CALL(*factory_, _makePeerCert(CertEntryBufMatches("cert1"), true)) - .WillOnce(Return(clientLeafCert_)); + .WillOnce(Return(std::move(clientLeafCert))); EXPECT_CALL(*factory_, _makePeerCert(CertEntryBufMatches("cert2"), false)) - .WillOnce(Return(clientIntCert_)); + .WillOnce(Return(std::move(clientIntCert))); auto certificate = TestMessages::certificate(); CertificateEntry entry1; @@ -6314,8 +6318,8 @@ TEST_F(ServerProtocolTest, TestCertificateExtensionsSupported) { expectActions(actions); processStateMutations(actions); EXPECT_EQ(state_.unverifiedCertChain()->size(), 2); - EXPECT_EQ(state_.unverifiedCertChain()->at(0), clientLeafCert_); - EXPECT_EQ(state_.unverifiedCertChain()->at(1), clientIntCert_); + EXPECT_EQ(state_.unverifiedCertChain()->at(0).get(), clientLeafPtr); + EXPECT_EQ(state_.unverifiedCertChain()->at(1).get(), clientIntPtr); EXPECT_EQ(state_.state(), StateEnum::ExpectingCertificateVerify); } diff --git a/fizz/server/test/TicketCodecTest.cpp b/fizz/server/test/TicketCodecTest.cpp index 1edc776d8c..a35a932af5 100644 --- a/fizz/server/test/TicketCodecTest.cpp +++ b/fizz/server/test/TicketCodecTest.cpp @@ -127,9 +127,10 @@ TEST(TicketCodecTest, TestFactoryCert) { })); auto factory = std::make_unique(); auto certManager = std::make_unique(); - auto factoryCert = std::make_shared(); - EXPECT_CALL(*factory, _makePeerCert(_, _)).WillOnce(Return(factoryCert)); + auto factoryCert = std::make_unique(); EXPECT_CALL(*factoryCert, getIdentity()).WillOnce(Return("factory clientid")); + EXPECT_CALL(*factory, _makePeerCert(_, _)) + .WillOnce(Return(std::move(factoryCert))); EXPECT_CALL(*certManager, getCert(_)).WillOnce(Return(nullptr)); auto encoded = TicketCodec::encode(std::move(rs)); auto drs = TicketCodec::decode(