Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use template dispatcher for vectortype choice #294

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
79 changes: 64 additions & 15 deletions AnnService/inc/Core/Common.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <cmath>
#include "inc/Helper/Logging.h"
#include "inc/Helper/DiskIO.h"
#include <tuple>

#ifndef _MSC_VER
#include <stdio.h>
Expand Down Expand Up @@ -152,17 +153,75 @@ enum class DistCalcMethod : std::uint8_t
};
static_assert(static_cast<std::uint8_t>(DistCalcMethod::Undefined) != 0, "Empty DistCalcMethod!");


enum class VectorValueType : std::uint8_t
{
#define DefineVectorValueType(Name, Type) Name,
#include "DefinitionList.h"
#undef DefineVectorValueType

Undefined
};
static_assert(static_cast<std::uint8_t>(VectorValueType::Undefined) != 0, "Empty VectorValueType!");

// remove_last is by Vladimir Reshetnikov, https://stackoverflow.com/a/51805324
template<class Tuple>
struct remove_last;

template<>
struct remove_last<std::tuple<>>; // Define as you wish or leave undefined

template<class... Args>
struct remove_last<std::tuple<Args...>>
{
private:
using Tuple = std::tuple<Args...>;

template<std::size_t... n>
static std::tuple<std::tuple_element_t<n, Tuple>...>
extract(std::index_sequence<n...>);

public:
using type = decltype(extract(std::make_index_sequence<sizeof...(Args) - 1>()));
};

template<class Tuple>
using remove_last_t = typename remove_last<Tuple>::type;

using VectorValueTypeTuple = remove_last_t<std::tuple<
#define DefineVectorValueType(Name, Type) Type,
#include "DefinitionList.h"
#undef DefineVectorValueType
void>>;

// Dispatcher is based on https://stackoverflow.com/a/34046180
template <typename T, typename F>
std::function<void()> call_with_default(F&& f)
{
return [f]() {f(T{}); };
}

template <typename F, std::size_t...Is>
void VectorValueTypeDispatch(VectorValueType vectorType, F&& f, std::index_sequence<Is...>)
{
std::function<void()> fs[] = {
call_with_default<std::tuple_element_t<Is, VectorValueTypeTuple>>(f)...
};
fs[static_cast<int>(vectorType)]();

}

template <typename F>
void VectorValueTypeDispatch(VectorValueType vectorType, F f)
{
constexpr auto VectorCount = std::tuple_size<VectorValueTypeTuple>::value;
if ((int)vectorType < VectorCount)
{
VectorValueTypeDispatch(vectorType, f, std::make_index_sequence<VectorCount>{});
}
else
{
throw std::exception();
}
}

enum class IndexAlgoType : std::uint8_t
{
Expand Down Expand Up @@ -214,20 +273,10 @@ constexpr VectorValueType GetEnumValueType<Type>() \

inline std::size_t GetValueTypeSize(VectorValueType p_valueType)
{
switch (p_valueType)
{
#define DefineVectorValueType(Name, Type) \
case VectorValueType::Name: \
return sizeof(Type); \

#include "DefinitionList.h"
#undef DefineVectorValueType

default:
break;
}
std::size_t out = 0;
VectorValueTypeDispatch(p_valueType, [&](auto t) { out = sizeof(decltype(t)); });

return 0;
return out;
}

enum class QuantizerType : std::uint8_t
Expand Down
30 changes: 6 additions & 24 deletions AnnService/inc/Core/Common/BKTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -421,18 +421,7 @@ namespace SPTAG
float CountStd;
if (args.m_pQuantizer)
{
switch (args.m_pQuantizer->GetReconstructType())
{
#define DefineVectorValueType(Name, Type) \
case VectorValueType::Name: \
CountStd = TryClustering<T, Type>(data, indices, first, last, args, samples, lambdaFactor, true); \
break;

#include "inc/Core/DefinitionList.h"
#undef DefineVectorValueType

default: break;
}
VectorValueTypeDispatch(args.m_pQuantizer->GetReconstructType(), [&](auto t) { CountStd = TryClustering<T, decltype(t)>(data, indices, first, last, args, samples, lambdaFactor, true); });
}
else
{
Expand Down Expand Up @@ -469,18 +458,11 @@ break;

if (args.m_pQuantizer)
{
switch (args.m_pQuantizer->GetReconstructType())
{
#define DefineVectorValueType(Name, Type) \
case VectorValueType::Name: \
TryClustering<T, Type>(data, indices, first, last, args, samples, lambdaFactor, debug, abort); \
break;

#include "inc/Core/DefinitionList.h"
#undef DefineVectorValueType

default: break;
}
VectorValueTypeDispatch(args.m_pQuantizer->GetReconstructType(), [&](auto t)
{
using Type = decltype(t);
TryClustering<T, Type>(data, indices, first, last, args, samples, lambdaFactor, debug, abort);
});
}
else
{
Expand Down
33 changes: 10 additions & 23 deletions AnnService/inc/Core/Common/KDTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,11 @@ namespace SPTAG
{
if (m_pQuantizer)
{
switch (m_pQuantizer->GetReconstructType())
{
#define DefineVectorValueType(Name, Type) \
case VectorValueType::Name: \
BuildTreesCore<T, Type>(data, numOfThreads, indices, abort); \
break;

#include "inc/Core/DefinitionList.h"
#undef DefineVectorValueType

default: break;
}
VectorValueTypeDispatch(m_pQuantizer->GetReconstructType(), [&](auto t)
{
using Type = decltype(t);
BuildTreesCore<T, Type>(data, numOfThreads, indices, abort);
});
}
else
{
Expand Down Expand Up @@ -236,17 +229,11 @@ break;
{
if (m_pQuantizer)
{
switch (m_pQuantizer->GetReconstructType())
{
#define DefineVectorValueType(Name, Type) \
case VectorValueType::Name: \
return KDTSearchCore<T, Type>(p_data, fComputeDistance, p_query, p_space, node, distBound);

#include "inc/Core/DefinitionList.h"
#undef DefineVectorValueType

default: break;
}
return VectorValueTypeDispatch(m_pQuantizer->GetReconstructType(), [&](auto t)
{
using Type = decltype(t);
KDTSearchCore<T, Type>(p_data, fComputeDistance, p_query, p_space, node, distBound);
});
}
else
{
Expand Down
17 changes: 5 additions & 12 deletions AnnService/inc/Core/Common/NeighborhoodGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,18 +129,11 @@ namespace SPTAG
{
if (index->m_pQuantizer)
{
switch (index->m_pQuantizer->GetReconstructType())
{
#define DefineVectorValueType(Name, Type) \
case VectorValueType::Name: \
PartitionByTptreeCore<T, Type>(index, indices, first, last, leaves); \
break;

#include "inc/Core/DefinitionList.h"
#undef DefineVectorValueType

default: break;
}
return VectorValueTypeDispatch(index->m_pQuantizer->GetReconstructType(), [&](auto t)
{
using Type = decltype(t);
PartitionByTptreeCore<T, Type>(index, indices, first, last, leaves);
});
}
else
{
Expand Down
40 changes: 17 additions & 23 deletions AnnService/src/Aggregator/AggregatorService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,29 +222,23 @@ AggregatorService::SearchRequestHanlder(Socket::ConnectionID p_localConnectionID
size_t vectorSize;
SizeType vectorDimension = 0;
std::vector<BasicResult> servers;
switch (context->GetSettings()->m_valueType)
{
#define DefineVectorValueType(Name, Type) \
case VectorValueType::Name: \
if (!queryParser.GetVectorElements().empty()) { \
Service::ConvertVectorFromString<Type>(queryParser.GetVectorElements(), vector, vectorDimension); \
} else if (queryParser.GetVectorBase64() != nullptr && queryParser.GetVectorBase64Length() != 0) { \
vector = ByteArray::Alloc(Helper::Base64::CapacityForDecode(queryParser.GetVectorBase64Length())); \
Helper::Base64::Decode(queryParser.GetVectorBase64(), queryParser.GetVectorBase64Length(), vector.Data(), vectorSize); \
vectorDimension = (SizeType)(vectorSize / GetValueTypeSize(context->GetSettings()->m_valueType)); \
} \
for (int i = 0; i < context->GetCenters()->Count(); i++) { \
servers.push_back(BasicResult(i, COMMON::DistanceUtils::ComputeDistance((Type*)vector.Data(), \
(Type*)context->GetCenters()->GetVector(i), vectorDimension, context->GetSettings()->m_distMethod))); \
} \
break; \

#include "inc/Core/DefinitionList.h"
#undef DefineVectorValueType

default:
break;
}
VectorValueTypeDispatch(context->GetSettings()->m_valueType, [&](auto t)
{
using Type = decltype(t);
if (!queryParser.GetVectorElements().empty()) {
Service::ConvertVectorFromString<Type>(queryParser.GetVectorElements(), vector, vectorDimension);
}
else if (queryParser.GetVectorBase64() != nullptr && queryParser.GetVectorBase64Length() != 0) {
vector = ByteArray::Alloc(Helper::Base64::CapacityForDecode(queryParser.GetVectorBase64Length()));
Helper::Base64::Decode(queryParser.GetVectorBase64(), queryParser.GetVectorBase64Length(), vector.Data(), vectorSize);
vectorDimension = (SizeType)(vectorSize / GetValueTypeSize(context->GetSettings()->m_valueType));
}
for (int i = 0; i < context->GetCenters()->Count(); i++) {
servers.push_back(BasicResult(i, COMMON::DistanceUtils::ComputeDistance((Type*)vector.Data(),
(Type*)context->GetCenters()->GetVector(i), vectorDimension, context->GetSettings()->m_distMethod)));
}
});

std::sort(servers.begin(), servers.end(), [](const BasicResult& a, const BasicResult& b) { return a.Dist < b.Dist; });
for (int i = 0; i < context->GetSettings()->m_topK; i++) {
auto& server = context->GetRemoteServers().at(servers[i].VID);
Expand Down
48 changes: 4 additions & 44 deletions AnnService/src/Core/Common/IQuantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,31 +20,11 @@ namespace SPTAG
case QuantizerType::Undefined:
break;
case QuantizerType::PQQuantizer:
switch (reconstructType) {
#define DefineVectorValueType(Name, Type) \
case VectorValueType::Name: \
ret.reset(new PQQuantizer<Type>()); \
break;

#include "inc/Core/DefinitionList.h"
#undef DefineVectorValueType

default: break;
}

VectorValueTypeDispatch(reconstructType, [&](auto t) {ret.reset(new PQQuantizer<decltype(t)>()); });
if (ret->LoadQuantizer(p_in) != ErrorCode::Success) ret.reset();
return ret;
case QuantizerType::OPQQuantizer:
switch (reconstructType) {
#define DefineVectorValueType(Name, Type) \
case VectorValueType::Name: \
ret.reset(new OPQQuantizer<Type>()); \
break;

#include "inc/Core/DefinitionList.h"
#undef DefineVectorValueType
default: break;
}
VectorValueTypeDispatch(reconstructType, [&](auto t) {ret.reset(new OPQQuantizer<decltype(t)>()); });
if (ret->LoadQuantizer(p_in) != ErrorCode::Success) ret.reset();
return ret;
}
Expand All @@ -68,31 +48,11 @@ namespace SPTAG
case QuantizerType::Undefined:
return ret;
case QuantizerType::PQQuantizer:
switch (reconstructType) {
#define DefineVectorValueType(Name, Type) \
case VectorValueType::Name: \
ret.reset(new PQQuantizer<Type>()); \
break;

#include "inc/Core/DefinitionList.h"
#undef DefineVectorValueType
default: break;
}

VectorValueTypeDispatch(reconstructType, [&](auto t) {ret.reset(new PQQuantizer<decltype(t)>()); });
if (ret->LoadQuantizer(raw_bytes) != ErrorCode::Success) ret.reset();
return ret;
case QuantizerType::OPQQuantizer:
switch (reconstructType) {
#define DefineVectorValueType(Name, Type) \
case VectorValueType::Name: \
ret.reset(new OPQQuantizer<Type>()); \
break;

#include "inc/Core/DefinitionList.h"
#undef DefineVectorValueType
default: break;
}

VectorValueTypeDispatch(reconstructType, [&](auto t) {ret.reset(new PQQuantizer<decltype(t)>()); });
if (ret->LoadQuantizer(raw_bytes) != ErrorCode::Success) ret.reset();
return ret;
}
Expand Down