Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Crash in Attention() with huge input #99

Closed
zeerd opened this issue Mar 14, 2024 · 13 comments
Closed

Crash in Attention() with huge input #99

zeerd opened this issue Mar 14, 2024 · 13 comments
Labels
type:bug Something isn't working

Comments

@zeerd
Copy link
Contributor

zeerd commented Mar 14, 2024

I am running a program to summarize a couple of texts, once a part.
(Factly, I am parse a c++ source file that has many functions, function by function.)

When the conversation is too long , the #98 occurs, and Gemma stop to call Transformer().
But, the caller do not know this. It keeping calling GenerateGemma() who will call Prefill() and Attention().
Finally lead to Memory Access Error.

If there were some notify-mechanism, let the caller know something wrong and stop calling, things may better.

AddressSanitizer:DEADLYSIGNAL
=================================================================
AddressSanitizerAddressSanitizer:DEADLYSIGNAL
:DEADLYSIGNAL
==12302==ERROR: AddressSanitizer: SEGV on unknown address 0x7f8f73625100 (pc 0x5559a8b6b1cf bp 0x7f8f997679f0 sp 0x7f8f997679a0 T30)
==12302==The signal is caused by a WRITE memory access.
AddressSanitizer:DEADLYSIGNAL
    #0 0x5559a8b6b1ce in gcpp::N_AVX2::Attention<gcpp::ConfigGemma2B, 16ul>(unsigned long, unsigned long, unsigned long, gcpp::Activations<gcpp::ConfigGemma2B, 16ul>&, gcpp::CompressedLayer<gcpp::ConfigGemma2B> const*, gcpp::KVCache&, hwy::ThreadPool&)::{lambda(unsigned long, unsigned long)#1}::operator()(unsigned long, unsigned long) const (/path/to/gemma.qt/gemma+0x541ce)
    #1 0x5559a8b44ffa in hwy::ParallelFor::WorkerRun(unsigned long, unsigned long, hwy::PoolMem&) gemma.cpp/build/_deps/highway-src/hwy/contrib/thread_pool/thread_pool.h:525
    #2 0x5559a8b451ef in hwy::ThreadPool::ThreadFunc(unsigned long, unsigned long, hwy::PoolMem*) gemma.cpp/build/_deps/highway-src/hwy/contrib/thread_pool/thread_pool.h:585
    #3 0x5559a8b4d238 in void std::__invoke_impl<void, void (*)(unsigned long, unsigned long, hwy::PoolMem*), unsigned long, unsigned long, hwy::PoolMem*>(std::__invoke_other, void (*&&)(unsigned long, unsigned long, hwy::PoolMem*), unsigned long&&, unsigned long&&, hwy::PoolMem*&&) (/path/to/gemma.qt/gemma+0x36238)
    #4 0x5559a8b4d115 in std::__invoke_result<void (*)(unsigned long, unsigned long, hwy::PoolMem*), unsigned long, unsigned long, hwy::PoolMem*>::type std::__invoke<void (*)(unsigned long, unsigned long, hwy::PoolMem*), unsigned long, unsigned long, hwy::PoolMem*>(void (*&&)(unsigned long, unsigned long, hwy::PoolMem*), unsigned long&&, unsigned long&&, hwy::PoolMem*&&) (/path/to/gemma.qt/gemma+0x36115)
    #5 0x5559a8b4cff4 in void std::thread::_Invoker<std::tuple<void (*)(unsigned long, unsigned long, hwy::PoolMem*), unsigned long, unsigned long, hwy::PoolMem*> >::_M_invoke<0ul, 1ul, 2ul, 3ul>(std::_Index_tuple<0ul, 1ul, 2ul, 3ul>) (/path/to/gemma.qt/gemma+0x35ff4)
    #6 0x5559a8b4cf73 in std::thread::_Invoker<std::tuple<void (*)(unsigned long, unsigned long, hwy::PoolMem*), unsigned long, unsigned long, hwy::PoolMem*> >::operator()() (/path/to/gemma.qt/gemma+0x35f73)
    #7 0x5559a8b4cf53 in std::thread::_State_impl<std::thread::_Invoker<std::tuple<void (*)(unsigned long, unsigned long, hwy::PoolMem*), unsigned long, unsigned long, hwy::PoolMem*> > >::_M_run() (/path/to/gemma.qt/gemma+0x35f53)
    #8 0x7f8fb8cc4df3  (/lib/x86_64-linux-gnu/libstdc++.so.6+0xd6df3)
    #9 0x7f8fb8dd8608 in start_thread /build/glibc-wuryBv/glibc-2.31/nptl/pthread_create.c:477
    #10 0x7f8fb89a5352 in __clone (/lib/x86_64-linux-gnu/libc.so.6+0x11f352)

AddressSanitizer can not provide additional info.
SUMMARY: AddressSanitizer: SEGV (/path/to/gemma.qt/gemma+0x541ce) in gcpp::N_AVX2::Attention<gcpp::ConfigGemma2B, 16ul>(unsigned long, unsigned long, unsigned long, gcpp::Activations<gcpp::ConfigGemma2B, 16ul>&, gcpp::CompressedLayer<gcpp::ConfigGemma2B> const*, gcpp::KVCache&, hwy::ThreadPool&)::{lambda(unsigned long, unsigned long)#1}::operator()(unsigned long, unsigned long) const
Thread T30 (worker003) created by T26 (GemmaThread) here:
AddressSanitizer:DEADLYSIGNAL
AddressSanitizer: nested bug in the same thread, aborting.
@jan-wassenberg
Copy link
Member

It sounds like it could be useful to have GenerateImpl and Generate return something to indicate the error. Would an error code help, or do we just want "succeed or failed" boolean?

@zeerd
Copy link
Contributor Author

zeerd commented Mar 14, 2024

Currently, boolean works.

It's hard for caller to do recovery with error code easily. Maybe error == 1 then extend max_tokens and reboot gapp::Gemma automatically? Caller has no method to know how many tokens is enough. Maybe it's another AI work. :)

@jan-wassenberg
Copy link
Member

I agree boolean is fine. Would you like to add one?

@zeerd
Copy link
Contributor Author

zeerd commented Mar 15, 2024

I had a try, but seems, there was some other problems.
According to #88 , I had not done more tests without debug information.

$ cat gemma.cc | tr '\n' ' ' | xargs -0 echo "What does this C++ code do: " | gemma2b

[ Reading prompt ] ................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................
................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................=================================================================
==24394==ERROR: AddressSanitizer: heap-buffer-overflow on address 0x7f8346297480 at pc 0x5632a95b4246 bp 0x7f84840f8ff0 sp 0x7f84840f8fe0
WRITE of size 4 at 0x7f8346297480 thread T5 (worker004)
    #0 0x5632a95b4245 in gcpp::N_AVX2::Attention<gcpp::ConfigGemma2B, 16ul>(unsigned long, unsigned long, unsigned long, gcpp::Activations<gcpp::ConfigGemma2B, 16ul>&, gcpp::CompressedLayer<gcpp::ConfigGemma2B> const*, gcpp::KVCache&, hwy::ThreadPool&)::{lambda(unsigned long, unsigned long)#1}::operator()(unsigned long, unsigned long) const (/path/to/gemma.cpp/build/gemma+0xcd245)
    #1 0x5632a950ff88 in hwy::ParallelFor::WorkerRun(unsigned long, unsigned long, hwy::PoolMem&) (/path/to/gemma.cpp/build/gemma+0x28f88)
    #2 0x5632a95102ec in hwy::ThreadPool::ThreadFunc(unsigned long, unsigned long, hwy::PoolMem*) (/path/to/gemma.cpp/build/gemma+0x292ec)
    #3 0x7f84898e9df3  (/lib/x86_64-linux-gnu/libstdc++.so.6+0xd6df3)
    #4 0x7f8489bcf608 in start_thread /build/glibc-wuryBv/glibc-2.31/nptl/pthread_create.c:477
    #5 0x7f84895ca352 in __clone (/lib/x86_64-linux-gnu/libc.so.6+0x11f352)

0x7f8346297480 is located 4992 bytes to the left of 1407360-byte region [0x7f8346298800,0x7f83463f0180)
allocated by thread T0 here:
    #0 0x7f8489cf7808 in __interceptor_malloc ../../../../src/libsanitizer/asan/asan_malloc_linux.cc:144
    #1 0x5632a978da3f in hwy::AllocateAlignedBytes(unsigned long, void* (*)(void*, unsigned long), void*) /path/to/gemma.cpp/build/_deps/highway-src/hwy/aligned_allocator.cc:91

Thread T5 (worker004) created by T0 here:
    #0 0x7f8489c24815 in __interceptor_pthread_create ../../../../src/libsanitizer/asan/asan_interceptors.cc:208
    #1 0x7f84898ea0c9 in std::thread::_M_start_thread(std::unique_ptr<std::thread::_State, std::default_delete<std::thread::_State> >, void (*)()) (/lib/x86_64-linux-gnu/libstdc++.so.6+0xd70c9)

SUMMARY: AddressSanitizer: heap-buffer-overflow (/path/to/gemma.cpp/build/gemma+0xcd245) in gcpp::N_AVX2::Attention<gcpp::ConfigGemma2B, 16ul>(unsigned long, unsigned long, unsigned long, gcpp::Activations<gcpp::ConfigGemma2B, 16ul>&, gcpp::CompressedLayer<gcpp::ConfigGemma2B> const*, gcpp::KVCache&, hwy::ThreadPool&)::{lambda(unsigned long, unsigned long)#1}::operator()(unsigned long, unsigned long) const
Shadow bytes around the buggy address:
  0x0ff0e8c4ae40: fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa
  0x0ff0e8c4ae50: fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa
  0x0ff0e8c4ae60: fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa
  0x0ff0e8c4ae70: fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa
  0x0ff0e8c4ae80: fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa
=>0x0ff0e8c4ae90:[fa]fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa
  0x0ff0e8c4aea0: fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa
  0x0ff0e8c4aeb0: fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa
  0x0ff0e8c4aec0: fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa
  0x0ff0e8c4aed0: fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa
  0x0ff0e8c4aee0: fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa
Shadow byte legend (one shadow byte represents 8 application bytes):
  Addressable:           00
  Partially addressable: 01 02 03 04 05 06 07 
  Heap left redzone:       fa
  Freed heap region:       fd
  Stack left redzone:      f1
  Stack mid redzone:       f2
  Stack right redzone:     f3
  Stack after return:      f5
  Stack use after scope:   f8
  Global redzone:          f9
  Global init order:       f6
  Poisoned by user:        f7
  Container overflow:      fc
  Array cookie:            ac
  Intra object redzone:    bb
  ASan internal:           fe
  Left alloca redzone:     ca
  Right alloca redzone:    cb
  Shadow gap:              cc
==24394==ABORTING

@zeerd zeerd changed the title May need a callback to notify caller something wrong Crash in Attention() with huge input Mar 15, 2024
@jan-wassenberg
Copy link
Member

Thanks, running with asan is helpful to find issues.

Can you help narrow down exactly where this is happening, e.g. building with -gmlt for line numbers, or printf before the MatVecLoop / TwoOfsMatVecLoop / Rope / final MatVecLoop functions to see which is the problem?

@zeerd
Copy link
Contributor Author

zeerd commented Mar 17, 2024

330 is the line before MatVecLoop
338 is the line before TwoOfsMatVecLoop

There are 6 CPU-Core used by gemma.
So, means crashed in MatVecLoop?

BTW: I am also interesting for how to use '-gmlt'.

................330
330
330
330
330
330
338
=================================================================
338
338
338
338
338
==10292==ERROR: AddressSanitizer: heap-buffer-overflow on address 0x7f4642197880 at pc 0x55e1048da31a bp 0x7ffebf17b1b0 sp 0x7ffebf17b1a0
WRITE of size 4 at 0x7f4642197880 thread T0
    #0 0x55e1048da319 in gcpp::N_AVX2::Attention<gcpp::ConfigGemma2B, 16ul>(unsigned long, unsigned long, unsigned long, gcpp::Activations<gcpp::ConfigGemma2B, 16ul>&, gcpp::CompressedLayer<gcpp::ConfigGemma2B> const*, gcpp::KVCache&, hwy::ThreadPool&)::{lambda(unsigned long, unsigned long)#1}::operator()(unsigned long, unsigned long) const [clone .isra.0] (/home/howecn/codes/gemma.qt/gemma.cpp/build/gemma+0x11e319)
diff --git a/gemma.cc b/gemma.cc
index f2a2275..012a3f3 100644
--- a/gemma.cc
+++ b/gemma.cc
@@ -327,6 +327,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,

     const size_t batch_offset = batch_idx * kModelDim;

+std::cout << __LINE__ << std::endl;
     MatVecLoop<kQKVDim, kModelDim>(
         c_layer->c_qkv_einsum_w, q_offset,
         activations.pre_att_rms_out.data() + batch_offset, q);
@@ -334,6 +335,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
     const size_t kv_offset =
         pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim;

+std::cout << __LINE__ << std::endl;
     TwoOfsMatVecLoop<kQKVDim, kModelDim>(
         c_layer->c_qkv_einsum_w, k_offset, v_offset,
         activations.pre_att_rms_out.data() + batch_offset,
@@ -345,8 +347,11 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
                                    head * TConfig::kSeqLen +
                                    batch_idx * kHeads * kQKVDim;

@jan-wassenberg
Copy link
Member

Thanks for sharing. hm, the crash seems to be later than I thought. We know it comes after TwoOfsMatVecLoop (line 338), but before the end of Attention based on the stack trace. Unfortunately that includes several functions which could be the culprit.

Maybe more cout? Or the thing with -gmlt: add CXXFLAGS=-gmlt before your call to CMake, and remember to delete the entire build directory. This passes those flags to the C++ compiler to add a bit more debug information.

@zeerd
Copy link
Contributor Author

zeerd commented Mar 18, 2024

My mistake, I missed the 388 before the line.

Factly, I'd added cout before each function(But not post the full patch).
The next line number should be before the first Rope.

So, maybe crashed in TwoOfsMatVecLoop.

BTW: seems my g++ 's version is too low to have the '-gmlt'. I need to find if I could upgrade it.

@jan-wassenberg
Copy link
Member

Got it, thanks! We know it's a write, and Rope always reads before it writes, so it cannot be inside Rope. Agree it's likely TwoOfsMatVecLoop. Is there an easy/reliable way for us to repro this?

BTW: seems my g++ 's version is too low to have the '-gmlt'. I need to find if I could upgrade it.

Oh, I see from godbolt that this only got merged into google/main branch of GCC. It's definitely supported in clang. You might instead try -g1 on GCC?

@zeerd
Copy link
Contributor Author

zeerd commented Mar 19, 2024

I modified nothing with the main-branch for this test(except cout). So , you can repro it with the command:

$ alias gemma2b="/opt/gemma/gemma.cpp/build/gemma -- --tokenizer /opt/gemma/2b-it/tokenizer.spm --compressed_weights /opt/gemma/2b-it/2b-it.sbs --model 2b-it --verbosity 0"
$ cat gemma.cc | tr '\n' ' ' | xargs -0 echo "What does this C++ code do: " | gemma2b

I will try -g1 later.

@zeerd
Copy link
Contributor Author

zeerd commented Mar 19, 2024

With -g1, point to TwoOfsMatVecLoop.

#0 0x558d0ff27194 in void gcpp::N_AVX2::TwoOfsMatVecLoop<256ul, 2048ul, hwy::bfloat16_t, 12582912ul, float>(gcpp::CompressedArray<hwy::bfloat16_t, 12582912ul> const&, unsigned long, unsigned long, float const*, float*, float*) /opt/gemma/gemma.cpp/./ops.h:111

317
339
339
339
339
339
339
345
345
345
345
345
345
339
345
339
345
317
339
339
339
339
339
339
345
345
345
345
345
345
339
345
339
345
................317
339
=================================================================
339
339
339
339
339
==10439==ERROR: AddressSanitizer: heap-buffer-overflow on address 0x7fa0b9b19480 at pc 0x558d0ff27195 bp 0x7fa1f90fc010 sp 0x7fa1f90fc000
WRITE of size 4 at 0x7fa0b9b19480 thread T2 (worker001)
    #0 0x558d0ff27194 in void gcpp::N_AVX2::TwoOfsMatVecLoop<256ul, 2048ul, hwy::bfloat16_t, 12582912ul, float>(gcpp::CompressedArray<hwy::bfloat16_t, 12582912ul> const&, unsigned long, unsigned long, float const*, float*, float*) /opt/gemma/gemma.cpp/./ops.h:111
    #1 0x558d0ff27194 in gcpp::N_AVX2::Attention<gcpp::ConfigGemma2B, 16ul>(unsigned long, unsigned long, unsigned long, gcpp::Activations<gcpp::ConfigGemma2B, 16ul>&, gcpp::CompressedLayer<gcpp::ConfigGemma2B> const*, gcpp::KVCache&, hwy::ThreadPool&)::{lambda(unsigned long, unsigned long)#1}::operator()(unsigned long, unsigned long) const /opt/gemma/gemma.cpp/./gemma.cc:340
    #2 0x558d0fe80db8 in hwy::ParallelFor::WorkerRun(unsigned long, unsigned long, hwy::PoolMem&) /opt/gemma/gemma.cpp/build/_deps/highway-src/hwy/contrib/thread_pool/thread_pool.h:525
    #3 0x558d0fe8111c in hwy::ThreadPool::ThreadFunc(unsigned long, unsigned long, hwy::PoolMem*) /opt/gemma/gemma.cpp/build/_deps/highway-src/hwy/contrib/thread_pool/thread_pool.h:585
    #4 0x7fa1fd15b792  (/lib/x86_64-linux-gnu/libstdc++.so.6+0xe6792)
    #5 0x7fa1fd4bd608 in start_thread /build/glibc-wuryBv/glibc-2.31/nptl/pthread_create.c:477
    #6 0x7fa1fce2c352 in __clone (/lib/x86_64-linux-gnu/libc.so.6+0x11f352)

0x7fa0b9b19480 is located 2048 bytes to the right of 603980928-byte region [0x7fa095b18800,0x7fa0b9b18c80)
allocated by thread T0 here:
    #0 0x7fa1fd5e5808 in __interceptor_malloc ../../../../src/libsanitizer/asan/asan_malloc_linux.cc:144
    #1 0x558d100fed4f in hwy::AllocateAlignedBytes(unsigned long, void* (*)(void*, unsigned long), void*) /opt/gemma/gemma.cpp/build/_deps/highway-src/hwy/aligned_allocator.cc:91

Thread T2 (worker001) created by T0 here:
    #0 0x7fa1fd512815 in __interceptor_pthread_create ../../../../src/libsanitizer/asan/asan_interceptors.cc:208
    #1 0x7fa1fd15bb4b in std::thread::_M_start_thread(std::unique_ptr<std::thread::_State, std::default_delete<std::thread::_State> >, void (*)()) (/lib/x86_64-linux-gnu/libstdc++.so.6+0xe6b4b)
    #2 0x7fa1fd606611 in __sanitizer::StackDepotBase<__sanitizer::StackDepotNode, 1, 20>::Put(__sanitizer::StackTrace, bool*) ../../../../src/libsanitizer/sanitizer_common/sanitizer_stackdepotbase.h:105

SUMMARY: AddressSanitizer: heap-buffer-overflow /opt/gemma/gemma.cpp/./ops.h:111 in void gcpp::N_AVX2::TwoOfsMatVecLoop<256ul, 2048ul, hwy::bfloat16_t, 12582912ul, float>(gcpp::CompressedArray<hwy::bfloat16_t, 12582912ul> const&, unsigned long, unsigned long, float const*, float*, float*)
Shadow bytes around the buggy address:
  0x0ff49735b240: fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa
  0x0ff49735b250: fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa
  0x0ff49735b260: fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa
  0x0ff49735b270: fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa
  0x0ff49735b280: fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa
=>0x0ff49735b290:[fa]fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa
  0x0ff49735b2a0: fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa
  0x0ff49735b2b0: fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa
  0x0ff49735b2c0: fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa
  0x0ff49735b2d0: fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa
  0x0ff49735b2e0: fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa
Shadow byte legend (one shadow byte represents 8 application bytes):
  Addressable:           00
  Partially addressable: 01 02 03 04 05 06 07 
  Heap left redzone:       fa
  Freed heap region:       fd
  Stack left redzone:      f1
  Stack mid redzone:       f2
  Stack right redzone:     f3
  Stack after return:      f5
  Stack use after scope:   f8
  Global redzone:          f9
  Global init order:       f6
  Poisoned by user:        f7
  Container overflow:      fc
  Array cookie:            ac
  Intra object redzone:    bb
  ASan internal:           fe
  Left alloca redzone:     ca
  Right alloca redzone:    cb
  Shadow gap:              cc
==10439==ABORTING
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 1d2a7a0..e26e3fe 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -61,6 +61,9 @@ endif()
 
 # Executable Target
 
+add_compile_options(-fsanitize=address -g1)
+add_link_options(-fsanitize=address)
+
 add_executable(gemma run.cc)
 target_sources(gemma PRIVATE ${SOURCES})
 set_property(TARGET gemma PROPERTY CXX_STANDARD 17)
diff --git a/gemma.cc b/gemma.cc
index f2a2275..31378d3 100644
--- a/gemma.cc
+++ b/gemma.cc
@@ -314,6 +314,8 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
   static const float kQueryScale =
       static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
 
+std::cout << __LINE__ << std::endl;
+
   pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
     // linear projections to QKV
     const size_t head_offset =
@@ -334,11 +336,13 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
     const size_t kv_offset =
         pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim;
 
+std::cout << __LINE__ << std::endl;
     TwoOfsMatVecLoop<kQKVDim, kModelDim>(
         c_layer->c_qkv_einsum_w, k_offset, v_offset,
         activations.pre_att_rms_out.data() + batch_offset,
         kv_cache.key_cache.get() + kv_offset,
         kv_cache.value_cache.get() + kv_offset);
+std::cout << __LINE__ << std::endl;
 
     // Calculate scores
     float* HWY_RESTRICT head_att = activations.att.data() +

@jan-wassenberg
Copy link
Member

Thanks, with MSAN I am also able to reproduce a problem; looking into it.

copybara-service bot pushed a commit that referenced this issue Mar 27, 2024
Also remove init placeholder and move Sqrt to ops.h.

PiperOrigin-RevId: 619504447
copybara-service bot pushed a commit that referenced this issue Mar 27, 2024
Also remove init placeholder and move Sqrt to ops.h.

PiperOrigin-RevId: 619504447
copybara-service bot pushed a commit that referenced this issue Mar 27, 2024
Also remove init placeholder and move Sqrt to ops.h.

PiperOrigin-RevId: 619504447
copybara-service bot pushed a commit that referenced this issue Mar 27, 2024
Also remove init placeholder and move Sqrt to ops.h.

PiperOrigin-RevId: 619504447
copybara-service bot pushed a commit that referenced this issue Mar 27, 2024
Also remove init placeholder and move Sqrt to ops.h.

PiperOrigin-RevId: 619529202
@jan-wassenberg
Copy link
Member

I believe this is now fixed in dev :) FYI to run that with the current kaggle 2b-it weights, in configs.h please change your kVocabSize to 256128 and kKVHeads to 8.
We are working on updating the weights.

@tilakrayal tilakrayal added the type:bug Something isn't working label Apr 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
type:bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants