Skip to content

Commit

Permalink
feat(core): add state management, closes #1655
Browse files Browse the repository at this point in the history
  • Loading branch information
lucasfernog committed Apr 30, 2021
1 parent 894643c commit 94624ee
Show file tree
Hide file tree
Showing 8 changed files with 215 additions and 60 deletions.
5 changes: 5 additions & 0 deletions .changes/app-state.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"tauri": patch
---

Adds `manage` API to the `Builder` struct, which manages app state.
5 changes: 5 additions & 0 deletions .changes/command-state.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"tauri-macros": patch
---

Adds support to command state, triggered when a command argument is `arg: State<'_, StateType>`.
112 changes: 73 additions & 39 deletions core/tauri-macros/src/command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,29 @@
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::{
parse::Parser, punctuated::Punctuated, FnArg, Ident, ItemFn, Meta, NestedMeta, Pat, Path,
ReturnType, Token, Type,
parse::Parser, punctuated::Punctuated, FnArg, GenericArgument, Ident, ItemFn, Meta, NestedMeta,
Pat, Path, PathArguments, ReturnType, Token, Type, Visibility,
};

fn fn_wrapper(function: &ItemFn) -> (&Visibility, Ident) {
(
&function.vis,
format_ident!("{}_wrapper", function.sig.ident),
)
}

fn err(function: ItemFn, error_message: &str) -> TokenStream {
let (vis, wrap) = fn_wrapper(&function);
quote! {
#function

#vis fn #wrap<P: ::tauri::Params>(_message: ::tauri::InvokeMessage<P>) {
compile_error!(#error_message);
unimplemented!()
}
}
}

pub fn generate_command(attrs: Vec<NestedMeta>, function: ItemFn) -> TokenStream {
// Check if "with_window" attr was passed to macro
let with_window = attrs.iter().any(|a| {
Expand Down Expand Up @@ -40,40 +59,55 @@ pub fn generate_command(attrs: Vec<NestedMeta>, function: ItemFn) -> TokenStream
ReturnType::Default => false,
};

// Split function args into names and types
let (mut names, mut types): (Vec<Ident>, Vec<Path>) = function
.sig
.inputs
.iter()
.map(|param| {
let mut arg_name = None;
let mut arg_type = None;
if let FnArg::Typed(arg) = param {
if let Pat::Ident(ident) = arg.pat.as_ref() {
arg_name = Some(ident.ident.clone());
}
if let Type::Path(path) = arg.ty.as_ref() {
arg_type = Some(path.path.clone());
let mut invoke_arg_names: Vec<Ident> = Default::default();
let mut invoke_arg_types: Vec<Path> = Default::default();
let mut call_arguments = Vec::new();

for (i, param) in function.sig.inputs.clone().into_iter().enumerate() {
let mut arg_name = None;
let mut arg_type = None;
if let FnArg::Typed(arg) = param {
if let Pat::Ident(ident) = arg.pat.as_ref() {
arg_name = Some(ident.ident.clone());
}
if let Type::Path(path) = arg.ty.as_ref() {
arg_type = Some(path.path.clone());
}
}

if i == 0 && with_window {
call_arguments.push(quote!(_window));
continue;
}

let arg_name_ = arg_name.clone().unwrap();
let arg_type =
arg_type.unwrap_or_else(|| panic!("Invalid type for arg \"{}\"", arg_name.unwrap()));

let mut path_as_string = String::new();
for segment in &arg_type.segments {
path_as_string.push_str(&segment.ident.to_string());
path_as_string.push_str("::");
}

if ["State::", "tauri::State::"].contains(&path_as_string.as_str()) {
let last_segment = match arg_type.segments.last() {
Some(last) => last,
None => return err(function, "found a type path without any segments (how?)"),
};
if let PathArguments::AngleBracketed(angle) = &last_segment.arguments {
if let Some(GenericArgument::Type(ty)) = angle.args.last() {
call_arguments.push(quote!(state_manager.get::<#ty>()));
continue;
}
}
(
arg_name.clone().unwrap(),
arg_type.unwrap_or_else(|| panic!("Invalid type for arg \"{}\"", arg_name.unwrap())),
)
})
.unzip();

let window_arg_maybe = match types.first() {
Some(_) if with_window => {
// Remove window arg from list so it isn't expected as arg from JS
types.drain(0..1);
names.drain(0..1);
// Tell wrapper to pass `window` to original function
quote!(_window,)
}
// Tell wrapper not to pass `window` to original function
_ => quote!(),
};

invoke_arg_names.push(arg_name_.clone());
invoke_arg_types.push(arg_type);
call_arguments.push(quote!(parsed_args.#arg_name_));
}

let await_maybe = if function.sig.asyncness.is_some() {
quote!(.await)
} else {
Expand All @@ -86,22 +120,22 @@ pub fn generate_command(attrs: Vec<NestedMeta>, function: ItemFn) -> TokenStream
// note that all types must implement `serde::Serialize`.
let return_value = if returns_result {
quote! {
match #fn_name(#window_arg_maybe #(parsed_args.#names),*)#await_maybe {
match #fn_name(#(#call_arguments),*)#await_maybe {
Ok(value) => ::core::result::Result::Ok(value),
Err(e) => ::core::result::Result::Err(e),
}
}
} else {
quote! { ::core::result::Result::<_, ()>::Ok(#fn_name(#window_arg_maybe #(parsed_args.#names),*)#await_maybe) }
quote! { ::core::result::Result::<_, ()>::Ok(#fn_name(#(#call_arguments),*)#await_maybe) }
};

quote! {
#function
pub fn #fn_wrapper<P: ::tauri::Params>(message: ::tauri::InvokeMessage<P>) {
pub fn #fn_wrapper<P: ::tauri::Params>(message: ::tauri::InvokeMessage<P>, state_manager: ::std::sync::Arc<::tauri::StateManager>) {
#[derive(::serde::Deserialize)]
#[serde(rename_all = "camelCase")]
struct ParsedArgs {
#(#names: #types),*
#(#invoke_arg_names: #invoke_arg_types),*
}
let _window = message.window();
match ::serde_json::from_value::<ParsedArgs>(message.payload()) {
Expand Down Expand Up @@ -134,10 +168,10 @@ pub fn generate_handler(item: proc_macro::TokenStream) -> TokenStream {
});

quote! {
move |message| {
move |message, state_manager| {
let cmd = message.command().to_string();
match cmd.as_str() {
#(stringify!(#fn_names) => #fn_wrappers(message),)*
#(stringify!(#fn_names) => #fn_wrappers(message, state_manager),)*
_ => {
message.reject(format!("command {} not found", cmd))
},
Expand Down
1 change: 1 addition & 0 deletions core/tauri/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ shared_child = "0.3"
os_pipe = "0.9"
minisign-verify = "0.1.8"
image = "0.23"
state = "0.4"

[build-dependencies]
cfg_aliases = "0.1.1"
Expand Down
6 changes: 3 additions & 3 deletions core/tauri/src/hooks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@
use crate::{
api::rpc::{format_callback, format_callback_result},
runtime::app::App,
Params, Window,
Params, StateManager, Window,
};
use serde::{Deserialize, Serialize};
use std::future::Future;
use std::{future::Future, sync::Arc};

/// A closure that is run when the Tauri application is setting up.
pub type SetupHook<M> = Box<dyn Fn(&mut App<M>) -> Result<(), Box<dyn std::error::Error>> + Send>;

/// A closure that is run everytime Tauri receives a message it doesn't explicitly handle.
pub type InvokeHandler<M> = dyn Fn(InvokeMessage<M>) + Send + Sync + 'static;
pub type InvokeHandler<M> = dyn Fn(InvokeMessage<M>, Arc<StateManager>) + Send + Sync + 'static;

/// A closure that is run once every time a window is created and loaded.
pub type OnPageLoad<M> = dyn Fn(Window<M>, PageLoadPayload) + Send + Sync + 'static;
Expand Down
66 changes: 57 additions & 9 deletions core/tauri/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

/// The Tauri error enum.
pub use error::Error;
use state::Container;
pub use tauri_macros::{command, generate_handler};

/// Core API.
Expand All @@ -36,15 +37,17 @@ pub type Result<T> = std::result::Result<T, Error>;
/// A task to run on the main thread.
pub type SyncTask = Box<dyn FnOnce() + Send>;

use crate::api::assets::Assets;
use crate::api::config::Config;
use crate::event::{Event, EventHandler};
use crate::runtime::tag::{Tag, TagRef};
use crate::runtime::window::PendingWindow;
use crate::runtime::{Dispatch, Runtime};
use crate::{
api::{assets::Assets, config::Config},
event::{Event, EventHandler},
runtime::{
tag::{Tag, TagRef},
window::PendingWindow,
Dispatch, Runtime,
},
};
use serde::Serialize;
use std::collections::HashMap;
use std::path::PathBuf;
use std::{borrow::Borrow, collections::HashMap, path::PathBuf};

// Export types likely to be used by the application.
pub use {
Expand All @@ -56,7 +59,52 @@ pub use {
runtime::window::export::Window,
};

use std::borrow::Borrow;
/// A guard for a state value.
pub struct State<'r, T>(&'r T);

impl<'r, T: Send + Sync + 'static> State<'r, T> {
/// Retrieve a borrow to the underlying value with a lifetime of `'r`.
/// Using this method is typically unnecessary as `State` implements
/// [`Deref`] with a [`Deref::Target`] of `T`.
#[inline(always)]
pub fn inner(&self) -> &'r T {
self.0
}
}

impl<T: Send + Sync + 'static> std::ops::Deref for State<'_, T> {
type Target = T;

#[inline(always)]
fn deref(&self) -> &T {
self.0
}
}

impl<T: Send + Sync + 'static> Clone for State<'_, T> {
fn clone(&self) -> Self {
State(self.0)
}
}

/// The Tauri state manager.
pub struct StateManager(pub(crate) Container);

impl StateManager {
pub(crate) fn new() -> Self {
Self(Container::new())
}

pub(crate) fn set<T: Send + Sync + 'static>(&self, state: T) -> bool {
self.0.set(state)
}

/// Gets the state associated with the specified type.
pub fn get<T: Send + Sync + 'static>(&self) -> State<'_, T> {
State(self.0.get())
}
}

/// Reads the config file at compile time and generates a [`Context`] based on its content.
///
/// The default config file path is a `tauri.conf.json` file inside the Cargo manifest directory of
Expand Down

0 comments on commit 94624ee

Please sign in to comment.