Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

possibility to update state via thread and have dependencies executed #1317

Open
alefminus opened this issue May 4, 2024 · 7 comments
Open

Comments

@alefminus
Copy link

Description

I want to create state (mo.state) and update it via a thread and have it work as usual (i.e. as it works from within a cell), i.e.

import marimo as mo


app = mo.App()


@app.cell
def __():
    getter, setter = mo.state(0)
    return getter, setter

@app.cell
def __():
  from threading import Thread
  from time import sleep
  def update():
    while True:
      sleep(1)
      setter(getter() + 1)

  thread = Thread(start=update, name='update_at_1_hz')
  thread.start()

@app.cell
def __():
  # This should update at 1 Hz. It does not.
  getter()

Suggested solution

I'm not sure this is even a wanted feature. It is related to the open bug about parallelism but only slightly - I want this so I can have long background running SQL server access and be able to change the running SQL queries in a separate thread and communicating partial results over to the marimo notebook via some IPC (i.e. queue.Queue) and using the mo.state as a signal for a cell to rerun (i.e. some plot or table).

Alternative

No response

Additional context

No response

@dmadisetti
Copy link
Contributor

You can do this with mo.ui.refresh: https://marimo.app/l/z7cb7q

It's not multi-threaded, but will let you poll

@alefminus
Copy link
Author

Right, that would work. But then I have to forgo having the logic in sequence. An async cell running concurrently would be nicer.

@dmadisetti
Copy link
Contributor

dmadisetti commented May 4, 2024

Not sure I understand, if you have:

graph TD
  root --> A[expensive procedure]
  A --> output
  root --> refresh
  refresh --> output

"Expensive procedure" isn't rerun on "refresh" changes- only if "root" changes, it is already async in that sense.
But to bring it to your question, is "Expensive procedure" just a SQL server connection in the foreground? What if you daemonized it, and used refresh as a poll?

if daemon.has_updates():
    set_changed(True)
refresh
get_changed()
daemon.run_query()

I think event-listeners might be a better async pattern, but maybe that would work? https://marimo.app/l/eh7a67

@alefminus
Copy link
Author

alefminus commented May 5, 2024

This is what I wanted to achieve, it works, although clunky

Screencast:
Screencast from 2024-05-05 10-25-40.webm

Code:

import marimo

__generated_with = "0.4.11"
app = marimo.App()


@app.cell
def imports():
    import marimo as mo
    import sqlalchemy
    import polars as pl
    import pandas as pd
    import threading
    from time import sleep
    from sqlalchemy.exc import ProgrammingError
    from threading import Thread
    from queue import Queue, Empty
    return (
        Empty,
        ProgrammingError,
        Queue,
        Thread,
        mo,
        pd,
        pl,
        sleep,
        sqlalchemy,
        threading,
    )


@app.cell
def connect_to_db(sqlalchemy):
    con = sqlalchemy.create_engine('postgresql:///backup')
    return con,


@app.cell
def __():
    history = []
    return history,


@app.cell
def __(sql):
    sql
    return


@app.cell
def __(pd, result):
    pd.concat([df.to_pandas() for df in result()]) if result() is not None else None
    return


@app.cell
def __(ProgrammingError, Queue, Thread, con, mo, pl, sleep, threading):
    sql = mo.ui.text_area()
    sql_result_queue = Queue()

    def sql_main(batch_size=5):
        """
        Read from sql (mo.state), whenever it changes start executing, reading results
        in batches of 100 every 0.1 seconds
        Write to sql_result_queue
        """
        prev_value = None
        i = 0
        it = None
        while True:
            if sql.value != prev_value:
                # change the iterator; signify to refresh consumer via None - only do it if
                # we actually produced a result before (use "i" to attest to that)
                if i > 0:
                    sql_result_queue.put((prev_value, -1, None))
                try:
                    it = pl.read_database(sql.value, con, iter_batches=True, batch_size=batch_size)
                except ProgrammingError:
                    it = None
                prev_value = sql.value
                i = 0
            if it is not None:
                try:
                    df = next(it)
                    sql_result_queue.put((prev_value, i, df))
                    i += 1
                except StopIteration:
                    it = None
            sleep(0.1)

    SQL_THREAD_NAME = 'sql'
    _threads = threading.enumerate()
    _sql_threads = [x for x in _threads if x.name == SQL_THREAD_NAME ]
    if len(_sql_threads) > 0:
        sql_thread = _sql_threads[0]
    else:
        sql_thread = Thread(name=SQL_THREAD_NAME, target=sql_main)
        sql_thread.start()
    sql_thread
    return SQL_THREAD_NAME, sql, sql_main, sql_result_queue, sql_thread


@app.cell
def __(history):
    history
    return


@app.cell
def cell_refresh(history, pl, refresh, result, result_set, sql_result_queue):
    refresh

    if not sql_result_queue.empty():
        _query, _count, _df = sql_result_queue.get()
        if _df is None:
            # clear result
            result_set(None)
        else:
            assert isinstance(_df, pl.DataFrame)
            _last = result()
            if _last is None:
                _last = [_df]
            else:
                # check if columns changed - if so reset
                _last_df = _last[-1]
                assert isinstance(_last_df, pl.DataFrame)
                if _last_df.columns != _df.columns:
                    _last = [_df]
                else:
                    _last.append(_df)
            history.append(_last.copy())
            result_set(_last)
    #_last
    return


@app.cell
def __(mo):
    result, result_set = mo.state(None)
    return result, result_set


@app.cell
def __(mo):
    refresh = mo.ui.refresh(default_interval='0.1s')
    refresh
    return refresh,


@app.cell
def __(mo):
    mo.md('''
    SQL query changes causes the sql iterator to reset
    SQL iterator is read on the refresh.
    Since read_database does not stop we instead have it running in a thread
    ''')
    return


if __name__ == "__main__":
    app.run()

Additionally, I saw two bugs, just letting you know, I'll try to report as a separate issue once I have reproductions:

  1. I saw another bug where the refresh cell (cell called cell_refresh) would generate a cannot find cell "cell_refresh_last" error
  2. I got the following error "Invalid session id" a number of times, the last one when quiting the marimo edit server via "Ctrl-C":
❯ ./marimo.sh -p 3010

        Create or edit notebooks in your browser 📝

        URL: http://0.0.0.0:3010

        Are you sure you want to quit? (y/n): y
ERROR:    Exception in ASGI application
Traceback (most recent call last):
  File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/uvicorn/protocols/http/h11_impl.py", line 407, in run_asgi
    result = await app(  # type: ignore[func-returns-value]
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/uvicorn/middleware/proxy_headers.py", line 69, in __call__
    return await self.app(scope, receive, send)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/starlette/applications.py", line 123, in __call__
    await self.middleware_stack(scope, receive, send)
  File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/starlette/middleware/errors.py", line 186, in __call__
    raise exc
  File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/starlette/middleware/errors.py", line 164, in __call__
    await self.app(scope, receive, _send)
  File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/starlette/middleware/cors.py", line 93, in __call__
    await self.simple_response(scope, receive, send, request_headers=headers)
  File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/starlette/middleware/cors.py", line 148, in simple_response
    await self.app(scope, receive, send)
  File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/starlette/middleware/authentication.py", line 49, in __call__
    await self.app(scope, receive, send)
  File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/marimo/_server/api/middleware.py", line 68, in __call__
    return await self.app(scope, receive, send)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/starlette/middleware/exceptions.py", line 65, in __call__
    await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
  File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/starlette/_exception_handler.py", line 64, in wrapped_app
    raise exc
  File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/starlette/_exception_handler.py", line 53, in wrapped_app
    await app(scope, receive, sender)
  File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/starlette/routing.py", line 756, in __call__
    await self.middleware_stack(scope, receive, send)
  File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/starlette/routing.py", line 776, in app
    await route.handle(scope, receive, send)
  File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/starlette/routing.py", line 485, in handle
    await self.app(scope, receive, send)
  File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/starlette/routing.py", line 756, in __call__
    await self.middleware_stack(scope, receive, send)
  File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/starlette/routing.py", line 776, in app
    await route.handle(scope, receive, send)
  File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/starlette/routing.py", line 297, in handle
    await self.app(scope, receive, send)
  File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/starlette/routing.py", line 77, in app
    await wrap_app_handling_exceptions(app, request)(scope, receive, send)
  File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/starlette/_exception_handler.py", line 64, in wrapped_app
    raise exc
  File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/starlette/_exception_handler.py", line 53, in wrapped_app
    await app(scope, receive, sender)
  File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/starlette/routing.py", line 72, in app
    response = await func(request)
               ^^^^^^^^^^^^^^^^^^^
  File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/marimo/_server/router.py", line 53, in wrapper_func
    response = await func(request=request)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/marimo/_server/api/endpoints/execution.py", line 44, in set_ui_element_values
    app_state.require_current_session().put_control_request(
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/marimo/_server/api/deps.py", line 71, in require_current_session
    raise ValueError(f"Invalid session id: {session_id}")
ValueError: Invalid session id: s_qhku1x

        Thanks for using marimo! 🌊🍃

(marimo.sh just runs marimo run with a few canned switches)

@alefminus
Copy link
Author

@dmadisetti yes, you got it right. I reached the same solution you suggested, the thread is just a result of the API I'm wrapping not having a poll option (polars.read_database).

@akshayka
Copy link
Contributor

akshayka commented May 6, 2024

I want to create state (mo.state) and update it via a thread and have it work as usual (i.e. as it works from within a cell), i.e.

The reason this doesn't work today is due to an implementation detail: state setters need to reach into global state, but that global state is a Python thread-local object -- in run-mode we don't want different sessions (each of which runs in the same process, but on its own thread) to share kernels.

So we'd need a way for user spawned threads to inherit the global state of their parent thread.

One way we could do this is to have an API that subclasses the Thread class but passes in the parent's global state, and expose this as mo.Thread. Seems a little complex though. Open to other suggestions.

@akshayka
Copy link
Contributor

akshayka commented May 6, 2024

One way we could do this is to have an API that subclasses the Thread class but passes in the parent's global state, and expose this as mo.Thread

Tried this and realized it's not that simple. We don't have a way for background threads to trigger execution of cells, so we'd need to add that as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants