Skip to content

Commit

Permalink
[feat] Update hkv to support more evict strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
LinGeLin authored and rhdong committed May 6, 2024
1 parent 21c6235 commit 53b5ac8
Show file tree
Hide file tree
Showing 15 changed files with 959 additions and 306 deletions.
52 changes: 3 additions & 49 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -57,55 +57,9 @@ http_archive(
http_archive(
name = "hkv",
build_file = "//build_deps/toolchains/hkv:hkv.BUILD",
# TODO(LinGeLin) remove this when update hkv
patch_cmds = [
"""sed -i.bak '1772i\\'$'\\n ThrustAllocator<uint8_t> thrust_allocator_;\\n' include/merlin_hashtable.cuh""",
"""sed -i.bak '225i\\'$'\\n thrust_allocator_.set_allocator(allocator_);\\n' include/merlin_hashtable.cuh""",
"sed -i.bak 's/thrust::sort_by_key(thrust_par.on(stream)/thrust::sort_by_key(thrust_par(thrust_allocator_).on(stream)/' include/merlin_hashtable.cuh",
"sed -i.bak 's/reduce(thrust_par.on(stream)/reduce(thrust_par(thrust_allocator_).on(stream)/' include/merlin_hashtable.cuh",
"""sed -i.bak '125i\\'$'\\n template <typename T>\\n' include/merlin/allocator.cuh""",
"""sed -i.bak '126i\\'$'\\n struct ThrustAllocator : thrust::device_malloc_allocator<T> {\\n' include/merlin/allocator.cuh""",
"""sed -i.bak '127i\\'$'\\n public:\\n' include/merlin/allocator.cuh""",
"""sed -i.bak '128i\\'$'\\n typedef thrust::device_malloc_allocator<T> super_t;\\n' include/merlin/allocator.cuh""",
"""sed -i.bak '129i\\'$'\\n typedef typename super_t::pointer pointer;\\n' include/merlin/allocator.cuh""",
"""sed -i.bak '130i\\'$'\\n typedef typename super_t::size_type size_type;\\n' include/merlin/allocator.cuh""",
"""sed -i.bak '131i\\'$'\\n public:\\n' include/merlin/allocator.cuh""",
"""sed -i.bak '132i\\'$'\\n pointer allocate(size_type n) {\\n' include/merlin/allocator.cuh""",
"""sed -i.bak '133i\\'$'\\n void* ptr = nullptr;\\n' include/merlin/allocator.cuh""",
"""sed -i.bak '134i\\'$'\\n MERLIN_CHECK(\\n' include/merlin/allocator.cuh""",
"""sed -i.bak '135i\\'$'\\n allocator_ != nullptr,\\n' include/merlin/allocator.cuh""",
"""sed -i.bak '136i\\'$'\\n "[ThrustAllocator] set_allocator should be called in advance!");\\n' include/merlin/allocator.cuh""",
"""sed -i.bak '137i\\'$'\\n allocator_->alloc(MemoryType::Device, &ptr, sizeof(T) * n);\\n' include/merlin/allocator.cuh""",
"""sed -i.bak '138i\\'$'\\n return pointer(reinterpret_cast<T*>(ptr));\\n' include/merlin/allocator.cuh""",
"""sed -i.bak '139i\\'$'\\n }\\n' include/merlin/allocator.cuh""",
"""sed -i.bak '140i\\'$'\\n void deallocate(pointer p, size_type n) {\\n' include/merlin/allocator.cuh""",
"""sed -i.bak '141i\\'$'\\n MERLIN_CHECK(\\n' include/merlin/allocator.cuh""",
"""sed -i.bak '142i\\'$'\\n allocator_ != nullptr,\\n' include/merlin/allocator.cuh""",
"""sed -i.bak '143i\\'$'\\n "[ThrustAllocator] set_allocator should be called in advance!");\\n' include/merlin/allocator.cuh""",
"""sed -i.bak '144i\\'$'\\n allocator_->free(MemoryType::Device, reinterpret_cast<void*>(p.get()));\\n' include/merlin/allocator.cuh""",
"""sed -i.bak '145i\\'$'\\n }\\n' include/merlin/allocator.cuh""",
"""sed -i.bak '146i\\'$'\\n void set_allocator(BaseAllocator* allocator) { allocator_ = allocator; }\\n' include/merlin/allocator.cuh""",
"""sed -i.bak '147i\\'$'\\n public:\\n' include/merlin/allocator.cuh""",
"""sed -i.bak '148i\\'$'\\n BaseAllocator* allocator_ = nullptr;\\n' include/merlin/allocator.cuh""",
"""sed -i.bak '149i\\'$'\\n };\\n' include/merlin/allocator.cuh""",
"""sed -i.bak '20i\\'$'\\n #include <thrust/device_malloc_allocator.h>\\n' include/merlin/allocator.cuh""",
"""sed -i.bak '367i\\'$'\\n for (auto addr : (*table)->buckets_address) {\\n' include/merlin/core_kernels.cuh""",
"""sed -i.bak '368i\\'$'\\n allocator->free(MemoryType::Device, addr);\\n' include/merlin/core_kernels.cuh""",
"""sed -i.bak '369i\\'$'\\n }\\n' include/merlin/core_kernels.cuh""",
"""sed -i.bak '370i\\'$'\\n /*\\n' include/merlin/core_kernels.cuh""",
"""sed -i.bak '382i\\'$'\\n */\\n' include/merlin/core_kernels.cuh""",
"""sed -i.bak '224i\\'$'\\n uint8_t* address = nullptr;\\n' include/merlin/core_kernels.cuh""",
"""sed -i.bak '225i\\'$'\\n allocator->alloc(MemoryType::Device, (void**)&(address), bucket_memory_size * (end - start));\\n' include/merlin/core_kernels.cuh""",
"""sed -i.bak '226i\\'$'\\n (*table)->buckets_address.push_back(address);\\n' include/merlin/core_kernels.cuh""",
"""sed -i.bak '228i\\'$'\\n allocate_bucket_others<K, V, S><<<1, 1>>>((*table)->buckets, i, address + (bucket_memory_size * (i-start)), reserve_size, bucket_max_size);\\n' include/merlin/core_kernels.cuh""",
"""sed -i.bak '229i\\'$'\\n /*\\n' include/merlin/core_kernels.cuh""",
"""sed -i.bak '235i\\'$'\\n */\\n' include/merlin/core_kernels.cuh""",
"""sed -i.bak '22i\\'$'\\n#include <vector>\\n' include/merlin/types.cuh""",
"""sed -i.bak '143i\\'$'\\n std::vector<uint8_t*> buckets_address;\\n' include/merlin/types.cuh""",
],
sha256 = "f8179c445a06a558262946cda4d8ae7252d313e73f792586be9b1bc0c993b1cf",
strip_prefix = "HierarchicalKV-0.1.0-beta.6",
url = "https://github.com/NVIDIA-Merlin/HierarchicalKV/archive/refs/tags/v0.1.0-beta.6.tar.gz",
sha256 = "79c59b19c03b771cdcb6deb3c6a3213353482f4d07cb1ddb53c4b001a0f58b29",
strip_prefix = "HierarchicalKV-0.1.0-beta.10",
url = "https://github.com/NVIDIA-Merlin/HierarchicalKV/archive/refs/tags/v0.1.0-beta.10.tar.gz",
)

tf_configure(
Expand Down
3 changes: 2 additions & 1 deletion tensorflow_recommenders_addons/dynamic_embedding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
'CuckooHashTable',
'CuckooHashTableConfig',
'CuckooHashTableCreator',
'HkvEvictStrategy',
'HkvHashTable',
'HkvHashTableConfig',
'HkvHashTableCreator',
Expand Down Expand Up @@ -55,7 +56,7 @@
from tensorflow_recommenders_addons.dynamic_embedding.python.ops import data_flow_ops as data_flow
from tensorflow_recommenders_addons.dynamic_embedding.python.ops.dynamic_embedding_creator import (
KVCreator, CuckooHashTableConfig, CuckooHashTableCreator,
HkvHashTableConfig, HkvHashTableCreator, RedisTableConfig,
HkvHashTableConfig, HkvHashTableCreator, HkvEvictStrategy, RedisTableConfig,
RedisTableCreator, FileSystemSaver)
from tensorflow_recommenders_addons.dynamic_embedding.python.ops.cuckoo_hashtable_ops import (
CuckooHashTable,)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,8 @@ class HashTableInsertOp : public HashTableOpKernel {
core::ScopedUnref unref_me(table);

DataTypeVector expected_inputs = {expected_input_0_, table->key_dtype(),
table->value_dtype()};
table->value_dtype(),
DataTypeToEnum<int64>::v()};
OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {}));

const Tensor& keys = ctx->input(1);
Expand Down Expand Up @@ -750,9 +751,9 @@ class HashTableAccumOp : public HashTableOpKernel {
hkv_table::HkvHashTableOfTensors<K, V>* table_cuckoo =
(hkv_table::HkvHashTableOfTensors<K, V>*)table;

DataTypeVector expected_inputs = {expected_input_0_, table->key_dtype(),
table->value_dtype(),
DataTypeToEnum<bool>::v()};
DataTypeVector expected_inputs = {
expected_input_0_, table->key_dtype(), table->value_dtype(),
DataTypeToEnum<bool>::v(), DataTypeToEnum<int64>::v()};
OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {}));

const Tensor& keys = ctx->input(1);
Expand Down Expand Up @@ -975,8 +976,8 @@ REGISTER_KERNEL_BUILDER(
HashTableAccumOp<key_dtype, value_dtype>); \
REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(HkvHashTableFindWithExists)) \
.Device(DEVICE_CPU) \
.TypeConstraint<key_dtype>("Tin") \
.TypeConstraint<value_dtype>("Tout"), \
.TypeConstraint<key_dtype>("key_dtype") \
.TypeConstraint<value_dtype>("value_dtype"), \
HashTableFindWithExistsOp<key_dtype, value_dtype>); \
REGISTER_KERNEL_BUILDER( \
Name(PREFIX_OP_NAME(HkvHashTableSaveToFileSystem)) \
Expand Down

0 comments on commit 53b5ac8

Please sign in to comment.