Skip to content

Commit

Permalink
Change api for makePeerCert to return a unique ptr
Browse files Browse the repository at this point in the history
Summary: Given there is no shared ownership, having the factory return a unique ptr provides strictly more flexibility than the shared ptr it returns now.

Reviewed By: frqiu

Differential Revision: D55444127

fbshipit-source-id: 3cd1503263eb5c312e9df51ceee8144e8d821a6a
  • Loading branch information
Ajanthan Asogamoorthy authored and facebook-github-bot committed Apr 25, 2024
1 parent b3c1d80 commit d19f56d
Show file tree
Hide file tree
Showing 11 changed files with 57 additions and 47 deletions.
33 changes: 19 additions & 14 deletions fizz/client/test/ClientProtocolTest.cpp
Expand Up @@ -50,8 +50,8 @@ class ClientProtocolTest : public ProtocolTest<ClientTypes, Actions> {
pskCache_ = std::make_shared<MockPskCache>();
context_->setPskCache(pskCache_);
context_->setSendEarlyData(true);
mockLeaf_ = std::make_shared<MockPeerCert>();
mockClientCert_ = std::make_shared<MockSelfCert>();
mockLeaf_ = std::make_unique<MockPeerCert>();
mockClientCert_ = std::make_unique<MockSelfCert>();
mockClock_ = std::make_shared<MockClock>();
context_->setClock(mockClock_);
ON_CALL(*mockClock_, getCurrentTime())
Expand Down Expand Up @@ -3738,12 +3738,14 @@ TEST_F(ClientProtocolTest, TestCertificateFlow) {
setupExpectingCertificate();
EXPECT_CALL(
*mockHandshakeContext_, appendToTranscript(BufMatches("certencoding")));
mockLeaf_ = std::make_shared<MockPeerCert>();
mockIntermediate_ = std::make_shared<MockPeerCert>();
auto mockLeafCert = std::make_unique<MockPeerCert>();
auto mockIntermediateCert = std::make_unique<MockPeerCert>();
auto mockLeafPtr = mockLeafCert.get();
auto mockIntermediatePtr = mockIntermediateCert.get();
EXPECT_CALL(*factory_, _makePeerCert(CertEntryBufMatches("cert1"), true))
.WillOnce(Return(mockLeaf_));
.WillOnce(Return(std::move(mockLeafCert)));
EXPECT_CALL(*factory_, _makePeerCert(CertEntryBufMatches("cert2"), false))
.WillOnce(Return(mockIntermediate_));
.WillOnce(Return(std::move(mockIntermediateCert)));

auto certificate = TestMessages::certificate();
CertificateEntry entry1;
Expand All @@ -3757,8 +3759,8 @@ TEST_F(ClientProtocolTest, TestCertificateFlow) {
expectActions<MutateState>(actions);
processStateMutations(actions);
EXPECT_EQ(state_.unverifiedCertChain()->size(), 2);
EXPECT_EQ(state_.unverifiedCertChain()->at(0), mockLeaf_);
EXPECT_EQ(state_.unverifiedCertChain()->at(1), mockIntermediate_);
EXPECT_EQ(state_.unverifiedCertChain()->at(0).get(), mockLeafPtr);
EXPECT_EQ(state_.unverifiedCertChain()->at(1).get(), mockIntermediatePtr);
EXPECT_EQ(state_.state(), StateEnum::ExpectingCertificateVerify);
}

Expand Down Expand Up @@ -3878,12 +3880,15 @@ TEST_F(ClientProtocolTest, TestCompressedCertificateFlow) {
EXPECT_CALL(
*mockHandshakeContext_,
appendToTranscript(BufMatches("compcertencoding")));
mockLeaf_ = std::make_shared<MockPeerCert>();
mockIntermediate_ = std::make_shared<MockPeerCert>();

auto mockLeafCert = std::make_unique<MockPeerCert>();
auto mockIntermediateCert = std::make_unique<MockPeerCert>();
auto mockLeafPtr = mockLeafCert.get();
auto mockIntermediatePtr = mockIntermediateCert.get();
EXPECT_CALL(*factory_, _makePeerCert(CertEntryBufMatches("cert1"), true))
.WillOnce(Return(mockLeaf_));
.WillOnce(Return(std::move(mockLeafCert)));
EXPECT_CALL(*factory_, _makePeerCert(CertEntryBufMatches("cert2"), false))
.WillOnce(Return(mockIntermediate_));
.WillOnce(Return(std::move(mockIntermediateCert)));

auto decompressor = std::make_shared<MockCertificateDecompressor>();
decompressor->setDefaults();
Expand Down Expand Up @@ -3914,8 +3919,8 @@ TEST_F(ClientProtocolTest, TestCompressedCertificateFlow) {
expectActions<MutateState>(actions);
processStateMutations(actions);
EXPECT_EQ(state_.unverifiedCertChain()->size(), 2);
EXPECT_EQ(state_.unverifiedCertChain()->at(0), mockLeaf_);
EXPECT_EQ(state_.unverifiedCertChain()->at(1), mockIntermediate_);
EXPECT_EQ(state_.unverifiedCertChain()->at(0).get(), mockLeafPtr);
EXPECT_EQ(state_.unverifiedCertChain()->at(1).get(), mockIntermediatePtr);
EXPECT_EQ(state_.serverCertCompAlgo(), CertificateCompressionAlgorithm::zlib);
EXPECT_EQ(state_.state(), StateEnum::ExpectingCertificateVerify);
}
Expand Down
4 changes: 2 additions & 2 deletions fizz/experimental/protocol/BatchSignatureFactory.h
Expand Up @@ -90,10 +90,10 @@ class BatchSignatureFactory : public Factory {
* Since batch signature is only for verifying the leaf of the certificate
* chain, so BatchSignaturePeerCert is turned only when @param leaf is true.
*/
std::shared_ptr<PeerCert> makePeerCert(CertificateEntry certEntry, bool leaf)
std::unique_ptr<PeerCert> makePeerCert(CertificateEntry certEntry, bool leaf)
const override {
if (leaf) {
return std::make_shared<BatchSignaturePeerCert>(
return std::make_unique<BatchSignaturePeerCert>(
original_->makePeerCert(std::move(certEntry), leaf));
}
return original_->makePeerCert(std::move(certEntry), leaf);
Expand Down
18 changes: 9 additions & 9 deletions fizz/extensions/delegatedcred/DelegatedCredentialFactory.cpp
Expand Up @@ -14,7 +14,7 @@ namespace fizz {
namespace extensions {

namespace {
std::shared_ptr<PeerCert> makeCredential(
std::unique_ptr<PeerCert> makeCredential(
DelegatedCredential&& credential,
folly::ssl::X509UniquePtr cert) {
VLOG(4) << "Making delegated credential";
Expand All @@ -31,19 +31,19 @@ std::shared_ptr<PeerCert> makeCredential(

switch (CertUtils::getKeyType(pubKey)) {
case KeyType::RSA:
return std::make_shared<PeerDelegatedCredentialImpl<KeyType::RSA>>(
return std::make_unique<PeerDelegatedCredentialImpl<KeyType::RSA>>(
std::move(cert), std::move(pubKey), std::move(credential));
case KeyType::P256:
return std::make_shared<PeerDelegatedCredentialImpl<KeyType::P256>>(
return std::make_unique<PeerDelegatedCredentialImpl<KeyType::P256>>(
std::move(cert), std::move(pubKey), std::move(credential));
case KeyType::P384:
return std::make_shared<PeerDelegatedCredentialImpl<KeyType::P384>>(
return std::make_unique<PeerDelegatedCredentialImpl<KeyType::P384>>(
std::move(cert), std::move(pubKey), std::move(credential));
case KeyType::P521:
return std::make_shared<PeerDelegatedCredentialImpl<KeyType::P521>>(
return std::make_unique<PeerDelegatedCredentialImpl<KeyType::P521>>(
std::move(cert), std::move(pubKey), std::move(credential));
case KeyType::ED25519:
return std::make_shared<PeerDelegatedCredentialImpl<KeyType::ED25519>>(
return std::make_unique<PeerDelegatedCredentialImpl<KeyType::ED25519>>(
std::move(cert), std::move(pubKey), std::move(credential));
}

Expand All @@ -53,7 +53,7 @@ std::shared_ptr<PeerCert> makeCredential(
}
} // namespace

std::shared_ptr<PeerCert> DelegatedCredentialFactory::makePeerCertStatic(
std::unique_ptr<PeerCert> DelegatedCredentialFactory::makePeerCertStatic(
CertificateEntry entry,
bool leaf) {
if (!leaf || entry.extensions.empty()) {
Expand All @@ -65,14 +65,14 @@ std::shared_ptr<PeerCert> DelegatedCredentialFactory::makePeerCertStatic(

// No credential, just leave as is
if (!credential) {
return std::move(parentCert);
return parentCert;
}

// Create credential
return makeCredential(std::move(credential.value()), std::move(parentX509));
}

std::shared_ptr<PeerCert> DelegatedCredentialFactory::makePeerCert(
std::unique_ptr<PeerCert> DelegatedCredentialFactory::makePeerCert(
CertificateEntry entry,
bool leaf) const {
return makePeerCertStatic(std::move(entry), leaf);
Expand Down
4 changes: 2 additions & 2 deletions fizz/extensions/delegatedcred/DelegatedCredentialFactory.h
Expand Up @@ -21,10 +21,10 @@ class DelegatedCredentialFactory : public OpenSSLFactory {
public:
~DelegatedCredentialFactory() override = default;

std::shared_ptr<PeerCert> makePeerCert(CertificateEntry entry, bool leaf)
std::unique_ptr<PeerCert> makePeerCert(CertificateEntry entry, bool leaf)
const override;

static std::shared_ptr<PeerCert> makePeerCertStatic(
static std::unique_ptr<PeerCert> makePeerCertStatic(
CertificateEntry entry,
bool leaf);
};
Expand Down
2 changes: 1 addition & 1 deletion fizz/extensions/javacrypto/JavaCryptoFactory.h
Expand Up @@ -20,7 +20,7 @@ class JavaCryptoFactory : public OpenSSLFactory {
public:
~JavaCryptoFactory() override = default;

std::shared_ptr<PeerCert> makePeerCert(
std::unique_ptr<PeerCert> makePeerCert(
CertificateEntry certEntry,
bool /*leaf*/) const override {
if (certEntry.cert_data->empty()) {
Expand Down
2 changes: 1 addition & 1 deletion fizz/protocol/Factory.h
Expand Up @@ -71,7 +71,7 @@ class Factory {
[[nodiscard]] virtual std::unique_ptr<folly::IOBuf> makeRandomBytes(
size_t count) const = 0;

virtual std::shared_ptr<PeerCert> makePeerCert(
virtual std::unique_ptr<PeerCert> makePeerCert(
CertificateEntry certEntry,
bool /*leaf*/) const = 0;

Expand Down
2 changes: 1 addition & 1 deletion fizz/protocol/OpenSSLFactory.cpp
Expand Up @@ -113,7 +113,7 @@ std::unique_ptr<HandshakeContext> OpenSSLFactory::makeHandshakeContext(
}
}

std::shared_ptr<PeerCert> OpenSSLFactory::makePeerCert(
std::unique_ptr<PeerCert> OpenSSLFactory::makePeerCert(
CertificateEntry certEntry,
bool /*leaf*/) const {
return CertUtils::makePeerCert(std::move(certEntry.cert_data));
Expand Down
2 changes: 1 addition & 1 deletion fizz/protocol/OpenSSLFactory.h
Expand Up @@ -38,7 +38,7 @@ class OpenSSLFactory : public DefaultFactory {
std::unique_ptr<HandshakeContext> makeHandshakeContext(
CipherSuite cipher) const override;

[[nodiscard]] std::shared_ptr<PeerCert> makePeerCert(
[[nodiscard]] std::unique_ptr<PeerCert> makePeerCert(
CertificateEntry certEntry,
bool /*leaf*/) const override;
};
Expand Down
4 changes: 2 additions & 2 deletions fizz/protocol/test/Mocks.h
Expand Up @@ -257,11 +257,11 @@ class MockFactory : public OpenSSLFactory {
MOCK_METHOD(uint32_t, makeTicketAgeAdd, (), (const));

MOCK_METHOD(
std::shared_ptr<PeerCert>,
std::unique_ptr<PeerCert>,
_makePeerCert,
(CertificateEntry & entry, bool leaf),
(const));
std::shared_ptr<PeerCert> makePeerCert(CertificateEntry entry, bool leaf)
std::unique_ptr<PeerCert> makePeerCert(CertificateEntry entry, bool leaf)
const override {
return _makePeerCert(entry, leaf);
}
Expand Down
28 changes: 16 additions & 12 deletions fizz/server/test/ServerProtocolTest.cpp
Expand Up @@ -6193,12 +6193,14 @@ TEST_F(ServerProtocolTest, TestCertificate) {
setUpExpectingCertificate();
EXPECT_CALL(
*mockHandshakeContext_, appendToTranscript(BufMatches("certencoding")));
clientLeafCert_ = std::make_shared<MockPeerCert>();
clientIntCert_ = std::make_shared<MockPeerCert>();
auto clientLeafCert = std::make_unique<MockPeerCert>();
auto clientIntCert = std::make_unique<MockPeerCert>();
auto clientLeafPtr = clientLeafCert.get();
auto clientIntPtr = clientIntCert.get();
EXPECT_CALL(*factory_, _makePeerCert(CertEntryBufMatches("cert1"), true))
.WillOnce(Return(clientLeafCert_));
.WillOnce(Return(std::move(clientLeafCert)));
EXPECT_CALL(*factory_, _makePeerCert(CertEntryBufMatches("cert2"), false))
.WillOnce(Return(clientIntCert_));
.WillOnce(Return(std::move(clientIntCert)));

auto certificate = TestMessages::certificate();
CertificateEntry entry1;
Expand All @@ -6213,8 +6215,8 @@ TEST_F(ServerProtocolTest, TestCertificate) {
expectActions<MutateState>(actions);
processStateMutations(actions);
EXPECT_EQ(state_.unverifiedCertChain()->size(), 2);
EXPECT_EQ(state_.unverifiedCertChain()->at(0), clientLeafCert_);
EXPECT_EQ(state_.unverifiedCertChain()->at(1), clientIntCert_);
EXPECT_EQ(state_.unverifiedCertChain()->at(0).get(), clientLeafPtr);
EXPECT_EQ(state_.unverifiedCertChain()->at(1).get(), clientIntPtr);
EXPECT_EQ(state_.state(), StateEnum::ExpectingCertificateVerify);
}

Expand Down Expand Up @@ -6290,12 +6292,14 @@ TEST_F(ServerProtocolTest, TestCertificateExtensionsNotSupported) {

TEST_F(ServerProtocolTest, TestCertificateExtensionsSupported) {
setUpExpectingCertificate();
clientLeafCert_ = std::make_shared<MockPeerCert>();
clientIntCert_ = std::make_shared<MockPeerCert>();
auto clientLeafCert = std::make_unique<MockPeerCert>();
auto clientLeafPtr = clientLeafCert.get();
auto clientIntCert = std::make_unique<MockPeerCert>();
auto clientIntPtr = clientIntCert.get();
EXPECT_CALL(*factory_, _makePeerCert(CertEntryBufMatches("cert1"), true))
.WillOnce(Return(clientLeafCert_));
.WillOnce(Return(std::move(clientLeafCert)));
EXPECT_CALL(*factory_, _makePeerCert(CertEntryBufMatches("cert2"), false))
.WillOnce(Return(clientIntCert_));
.WillOnce(Return(std::move(clientIntCert)));

auto certificate = TestMessages::certificate();
CertificateEntry entry1;
Expand All @@ -6314,8 +6318,8 @@ TEST_F(ServerProtocolTest, TestCertificateExtensionsSupported) {
expectActions<MutateState>(actions);
processStateMutations(actions);
EXPECT_EQ(state_.unverifiedCertChain()->size(), 2);
EXPECT_EQ(state_.unverifiedCertChain()->at(0), clientLeafCert_);
EXPECT_EQ(state_.unverifiedCertChain()->at(1), clientIntCert_);
EXPECT_EQ(state_.unverifiedCertChain()->at(0).get(), clientLeafPtr);
EXPECT_EQ(state_.unverifiedCertChain()->at(1).get(), clientIntPtr);
EXPECT_EQ(state_.state(), StateEnum::ExpectingCertificateVerify);
}

Expand Down
5 changes: 3 additions & 2 deletions fizz/server/test/TicketCodecTest.cpp
Expand Up @@ -127,9 +127,10 @@ TEST(TicketCodecTest, TestFactoryCert) {
}));
auto factory = std::make_unique<MockFactory>();
auto certManager = std::make_unique<MockCertManager>();
auto factoryCert = std::make_shared<MockPeerCert>();
EXPECT_CALL(*factory, _makePeerCert(_, _)).WillOnce(Return(factoryCert));
auto factoryCert = std::make_unique<MockPeerCert>();
EXPECT_CALL(*factoryCert, getIdentity()).WillOnce(Return("factory clientid"));
EXPECT_CALL(*factory, _makePeerCert(_, _))
.WillOnce(Return(std::move(factoryCert)));
EXPECT_CALL(*certManager, getCert(_)).WillOnce(Return(nullptr));
auto encoded = TicketCodec<CertificateStorage::X509>::encode(std::move(rs));
auto drs = TicketCodec<CertificateStorage::X509>::decode(
Expand Down

0 comments on commit d19f56d

Please sign in to comment.