Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add sdxl-turbo #82

Draft
wants to merge 1 commit into
base: alpha
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
99 changes: 85 additions & 14 deletions resources/python/live-painting/main.py
Expand Up @@ -29,6 +29,8 @@ def suppress_print(debug=False):
with suppress_print():
from diffusers import (
StableDiffusionImg2ImgPipeline,
StableDiffusionXLImg2ImgPipeline,
StableDiffusionPipeline,
AutoencoderTiny,
EulerAncestralDiscreteScheduler,
)
Expand All @@ -44,13 +46,46 @@ def suppress_print(debug=False):
from diffusers.utils import load_image
from sfast.compilers.diffusion_pipeline_compiler import compile, CompilationConfig
import argparse
from collections import OrderedDict


# Torch optimizations
torch.set_grad_enabled(False)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

IMG2IMG_PIPELINES_MAPPING = OrderedDict([
("stable-diffusion", StableDiffusionImg2ImgPipeline),
("stable-diffusion-xl", StableDiffusionXLImg2ImgPipeline),
])

def get_pipeline(model_path, model_type):
# Determine the pipeline class based on the model_type
print(f">>>>>>>>> Using model {model_path} as type {model_type}.")

PipelineClass = IMG2IMG_PIPELINES_MAPPING.get(model_type)
if not PipelineClass:
raise ValueError(f"Unsupported model type: {model_type}")

# Determine the loading method based on the model_path
if model_path.endswith(".safetensors"):
print("------ received safetensors")
pipeline = PipelineClass.from_single_file(
model_path,
torch_dtype=torch.float16,
)
else:
print("------ received pretrained")
pipeline = PipelineClass.from_pretrained(
model_path,
torch_dtype=torch.float16,
variant="fp16",
safety_checker=None,
requires_safety_checker=False,
)

return pipeline


def parse_args():
parser = argparse.ArgumentParser(description="Control model and paths via CLI.")
Expand All @@ -60,6 +95,12 @@ def parse_args():
help="Path to the model directory or identifier.",
required=True,
)
parser.add_argument(
"--model_type",
type=str,
help="Type of model to use for diffusion.",
required=True,
)
parser.add_argument(
"--vae_path",
type=str,
Expand Down Expand Up @@ -154,14 +195,9 @@ def calculate_min_inference_steps(steps, strength):
return math.ceil(steps / strength)


def prepare_pipeline(model_path, vae_path, disable_stablefast=False):
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
model_path,
torch_dtype=torch.float16,
variant="fp16",
safety_checker=None,
requires_safety_checker=False,
)
def prepare_pipeline(model_path, model_type, vae_path, disable_stablefast=False):
print(f"+++++++ Using model {model_path} as type {model_type}.")
pipe = get_pipeline(model_path, model_type)

pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
pipe.safety_checker = None
Expand Down Expand Up @@ -236,10 +272,21 @@ def main(pipe, input_image_path, output_image_path, shutdown_event):
# Initial/default values for parameters
prompt = "a captain with white beard, teal hat and uniform"
seed = 1
strength = 0.95
steps = 3
guidance_scale = 1.5

strength = 1
steps = 1
guidance_scale = 0.0
guidance_scale = 0.0
# Attempted model switch
# model_path = None
# model_type = None
# vae_path = None
# Attempted model switch
# last_model_type = None
# last_model_path = None
# last_vae_path = None
last_strength = None
last_steps = None
last_guidance_scale = None
last_prompt = None
last_seed = None
last_input_image = None
Expand All @@ -259,6 +306,9 @@ def main(pipe, input_image_path, output_image_path, shutdown_event):
while not params_queue.empty():
parameters = params_queue.get_nowait()
prompt = parameters.get("prompt", prompt)
# model_path = parameters.get("model_path", model_path)
# model_type = parameters.get("model_type", model_type)
# vae_path = parameters.get("vae_path", vae_path)
seed = parameters.get("seed", seed)
strength = parameters.get("strength", strength)
guidance_scale = parameters.get("guidance_scale", guidance_scale)
Expand All @@ -270,14 +320,35 @@ def main(pipe, input_image_path, output_image_path, shutdown_event):
# Get the current modified time of the input image
current_input_image = os.path.getmtime(input_image_path)

# Attempted model switch
# reload_pipeline = (
# model_type != last_model_type
# or model_path != last_model_path
# or vae_path != last_vae_path
# )

# if reload_pipeline:
# last_model_type = model_type
# last_model_path = model_path
# last_vae_path = vae_path
# pipe = prepare_pipeline(model_path, model_type, vae_path, True)
# warmup(pipe, input_image_path)


# Determine if image generation should be triggered
trigger_generation = (
prompt != last_prompt
strength != last_strength
or steps != last_steps
or guidance_scale != last_guidance_scale
or prompt != last_prompt
or seed != last_seed
or current_input_image != last_input_image
)

if trigger_generation:
last_guidance_scale = guidance_scale
last_steps = steps
last_strength = strength
last_prompt = prompt
last_seed = seed
last_input_image = current_input_image
Expand Down Expand Up @@ -333,7 +404,7 @@ def main(pipe, input_image_path, output_image_path, shutdown_event):

with suppress_print(args.debug):
pipe = prepare_pipeline(
args.model_path, args.vae_path, args.disable_stablefast
args.model_path, args.model_type, args.vae_path, args.disable_stablefast
)
warmup(pipe, args.input_image_path)

Expand Down
132 changes: 124 additions & 8 deletions src/client/apps/live-painting/index.tsx
Expand Up @@ -49,9 +49,21 @@ export function LivePainting() {
const [brushSizeOpen, setBrushSizeOpen] = useState(false);
const [prompt, setPrompt] = useState("");
const [illustrationStyle, setIllustrationStyle] = useState<IllustrationStyles>("childrensBook");
const [seed, setSeed] = useState(randomSeed());
const [settings, setSettings] = useState({
seed: randomSeed(),
strength: 0.95,
steps: 3,
guidance_scale: 1,
});
const [modelSettings, setModelSettings] = useState({
id: "sd-turbo",
model_type: "stable-diffusion",
model_path: "stabilityai/sd-turbo",
vae_path: "madebyollin/taesd",
});
const [image] = useAtom(imageAtom);
const [running, setRunning] = useState(false);
const [shouldRestart, setShouldRestart] = useState(false);
const [isLoading, setIsLoading] = useState(false);
const [saved, setSaved] = useResettableState(false, 3000);

Expand Down Expand Up @@ -144,11 +156,22 @@ export function LivePainting() {
action: "livePainting:settings",
payload: {
prompt: [prompt, illustrationStyles[illustrationStyle]].join(", "),
seed,
...settings,
// Attempted model switch
// model_type: modelSettings.model_type,
// model_path: modelSettings.model_path,
// vae_path: modelSettings.vae_path,
},
});
}
}, [send, prompt, seed, running, illustrationStyle]);
}, [send, prompt, settings, running, illustrationStyle]);

useEffect(() => {
if (shouldRestart) {
setIsLoading(true);
send({ action: "livePainting:start", payload: modelSettings });
}
}, [shouldRestart, send, modelSettings]);

useEffect(() => {
function beforeUnload() {
Expand Down Expand Up @@ -205,7 +228,7 @@ export function LivePainting() {
startDecorator={isLoading ? <CircularProgress /> : <PlayIcon />}
onClick={() => {
setIsLoading(true);
send({ action: "livePainting:start", payload: APP_ID });
send({ action: "livePainting:start", payload: modelSettings });
}}
>
{t("labels:start")}
Expand Down Expand Up @@ -335,7 +358,10 @@ export function LivePainting() {
aria-label={t("labels:randomize")}
sx={{ flexShrink: 0 }}
onClick={() => {
setSeed(randomSeed());
setSettings(previousState => ({
...previousState,
seed: randomSeed(),
}));
}}
>
<CasinoIcon />
Expand Down Expand Up @@ -367,12 +393,50 @@ export function LivePainting() {
}}
>
{hasModelAndVae ? (
<Select variant="plain" defaultValue={checkpoints[0].id}>
{checkpoints.map(checkpoint => (
<Select
variant="plain"
value={modelSettings.id}
onChange={(_event, value) => {
send({ action: "livePainting:stop", payload: APP_ID });
setIsLoading(true);
setShouldRestart(true);
switch (value) {
case "sd-turbo": {
setModelSettings({
id: "sd-turbo",
model_path: "stabilityai/sd-turbo",
model_type: "stable-diffusion",
vae_path: "madebyollin/taesd",
});
break;
}

case "sdxl-turbo": {
setModelSettings({
id: "sdxl-turbo",
model_path:
"stabilityai/sdxl-turbo/sd_xl_turbo_1.0_fp16.safetensors",
model_type: "stable-diffusion-xl",
vae_path: "madebyollin/taesdxl",
});
break;
}

default: {
break;
}
}
}}
>
{/*
checkpoints.map(checkpoint => (
<Option key={checkpoint.id} value={checkpoint.id}>
{checkpoint.label}
</Option>
))}
))
*/}
<Option value="sdxl-turbo">SDXL Turbo</Option>
<Option value="sd-turbo">SD Turbo</Option>
</Select>
) : (
<Button
Expand Down Expand Up @@ -400,6 +464,58 @@ export function LivePainting() {
{t("labels:download")}
</Button>
)}
<Button
sx={{ flexShrink: 0 }}
onClick={() => {
setSettings(previousState => ({
...previousState,
strength: 1,
steps: 2,
guidance_scale: 0.5,
}));
}}
>
Fast
</Button>
<Button
sx={{ flexShrink: 0 }}
onClick={() => {
setSettings(previousState => ({
...previousState,
strength: 0.95,
steps: 3,
guidance_scale: 1,
}));
}}
>
Default
</Button>
<Button
sx={{ flexShrink: 0 }}
onClick={() => {
setSettings(previousState => ({
...previousState,
strength: 0.9,
steps: 4,
guidance_scale: 1.25,
}));
}}
>
Precise
</Button>
<Button
sx={{ flexShrink: 0 }}
onClick={() => {
setSettings(previousState => ({
...previousState,
strength: 0.85,
steps: 5,
guidance_scale: 1.5,
}));
}}
>
Extreme
</Button>

<Box sx={{ flex: 1 }} />

Expand Down
18 changes: 15 additions & 3 deletions src/electron/future/ipc/sdk.ts
Expand Up @@ -54,7 +54,16 @@ let cache = "";

ipcMain.on(
APP_MESSAGE_KEY,
(event, { message, appId }: { message: SDKMessage<string>; appId: string }) => {
(
event,
{
message,
appId,
}: {
message: SDKMessage<{ model_type: string; model_path: string; vae_path: string }>;
appId: string;
}
) => {
if (message.action !== "livePainting:start") {
return;
}
Expand All @@ -67,11 +76,14 @@ ipcMain.on(

const pythonBinaryPath = getCaptainData("python-embedded/python.exe");
const scriptPath = getDirectory("python/live-painting/main.py");
const { model_type, model_path, vae_path } = message.payload;
const scriptArguments = [
"--model_path",
getCaptainDownloads("stable-diffusion/checkpoints/stabilityai/sd-turbo"),
getCaptainDownloads("stable-diffusion/checkpoints", model_path),
"--model_type",
model_type,
"--vae_path",
getCaptainDownloads("stable-diffusion/vae/madebyollin/taesd"),
getCaptainDownloads("stable-diffusion/vae", vae_path),
"--input_image_path",
getCaptainTemporary("live-painting/input.png"),
"--output_image_path",
Expand Down