Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/download UI #178

Merged
merged 10 commits into from Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
822 changes: 413 additions & 409 deletions package-lock.json

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions package.json
Expand Up @@ -37,10 +37,10 @@
"sharp": "0.33.2"
},
"devDependencies": {
"@captn/joy": "^0.8.0",
"@captn/react": "^0.8.0",
"@captn/theme": "^0.8.0",
"@captn/utils": "^0.8.0",
"@captn/joy": "^0.10.0",
"@captn/react": "^0.10.2",
"@captn/theme": "^0.10.0",
"@captn/utils": "^0.10.0",
"@commitlint/cli": "^19.2.1",
"@commitlint/config-conventional": "^19.1.0",
"@dnd-kit/core": "^6.1.0",
Expand Down
6 changes: 3 additions & 3 deletions src/client/apps/live-painting/components/index.tsx
@@ -1,3 +1,4 @@
import { useRequiredDownloads } from "@captn/react/use-required-downloads";
import { useSDK } from "@captn/react/use-sdk";
import { ClickAwayListener } from "@mui/base";
import CheckIcon from "@mui/icons-material/Check";
Expand Down Expand Up @@ -31,13 +32,12 @@ import {
import type { Except } from "type-fest";
import { v4 } from "uuid";

import { APP_ID } from "../constants";
import { allRequiredDownloads, APP_ID } from "../constants";
import { StyledColorInput } from "../styled";
import type { IllustrationStyles } from "../text-to-image";
import { illustrationStyles } from "../text-to-image";

import type { Repository } from "#/types";
import { useRequiredModels } from "@/apps/live-painting/required-models-alert";
import { FlagUs } from "@/atoms/flags/us";
import { useResettableState } from "@/ions/hooks/resettable-state";
import { getContrastColor } from "@/ions/utils/color";
Expand Down Expand Up @@ -364,7 +364,7 @@ export interface RunButtonProperties {

export function RunButton({ isLoading, isRunning, onStart, onStop }: RunButtonProperties) {
const { t } = useTranslation(["common", "labels"]);
const hasModelAndVae = useRequiredModels();
const { isCompleted: hasModelAndVae } = useRequiredDownloads(allRequiredDownloads);

return isRunning ? (
<Button
Expand Down
18 changes: 18 additions & 0 deletions src/client/apps/live-painting/constants.ts
@@ -1 +1,19 @@
import type { RequiredDownload } from "@captn/utils/types";

export const APP_ID = "live-painting";
export const allRequiredDownloads: RequiredDownload[] = [
{
label: "SD Turbo",
id: "stabilityai/sd-turbo/fp16",
source: "https://pub-aea7c308ba0147b69deba50a606e7743.r2.dev/stabilityai-sd-turbo-fp16.7z",
destination: "stable-diffusion/checkpoints",
unzip: true,
},
{
label: "Taesd",
id: "madebyollin/taesd",
source: "https://pub-aea7c308ba0147b69deba50a606e7743.r2.dev/taesd.7z",
destination: "stable-diffusion/vae",
unzip: true,
},
];
2 changes: 1 addition & 1 deletion src/client/apps/live-painting/index.tsx
Expand Up @@ -90,7 +90,7 @@ export function LivePainting() {

return (
<Box sx={{ display: "flex", flexDirection: "column", minHeight: "100%" }}>
<RequiredModelsAlert appId={APP_ID} />
<RequiredModelsAlert />
<StyledStickyHeader>
{/* Left Side of the header */}
<StyledButtonWrapper>
Expand Down
180 changes: 7 additions & 173 deletions src/client/apps/live-painting/required-models-alert/index.tsx
@@ -1,181 +1,23 @@
import { useRequiredDownloads } from "@captn/react/use-required-downloads";
import WarningIcon from "@mui/icons-material/Warning";
import Box from "@mui/joy/Box";
import Button from "@mui/joy/Button";
import LinearProgress from "@mui/joy/LinearProgress";
import Snackbar from "@mui/joy/Snackbar";
import Typography from "@mui/joy/Typography";
import { getProperty } from "dot-prop";
import { useTranslation } from "next-i18next";
import { useEffect, useState } from "react";

export interface DownloadTask {
label: string;
id: string;
source: string;
destination: string;
appId?: string;
unzip?: boolean;
}

export const allRequiredDownloads: DownloadTask[] = [
{
label: "SD Turbo",
id: "stabilityai/sd-turbo/fp16",
source: "https://pub-aea7c308ba0147b69deba50a606e7743.r2.dev/stabilityai-sd-turbo-fp16.7z",
destination: "stable-diffusion/checkpoints",
unzip: true,
},
{
label: "Taesd",
id: "madebyollin/taesd",
source: "https://pub-aea7c308ba0147b69deba50a606e7743.r2.dev/taesd.7z",
destination: "stable-diffusion/vae",
unzip: true,
},
];

export function useRequiredModels() {
const [isCompleted, setIsCompleted] = useState(false);
import { allRequiredDownloads } from "../constants";

useEffect(() => {
const unsubscribeAllInventory = window.ipc.on(
"allInventory",
(inventory: Record<string, unknown>) => {
const done = allRequiredDownloads.every(item => {
const keyPath = item.destination.replaceAll("/", ".");
const inventoryCollection = getProperty<
Record<string, unknown>,
string,
{
id: string;
}[]
>(inventory, keyPath);
if (Array.isArray(inventoryCollection)) {
return inventoryCollection.some(
inventoryItem => inventoryItem.id === item.id
);
}

return false;
});
setIsCompleted(done);
}
);

return () => {
unsubscribeAllInventory();
};
}, []);

useEffect(() => {
Promise.all(
allRequiredDownloads.map(async requiredDownload => {
const keyPath = requiredDownload.destination.replaceAll("/", ".");
const value = await window.ipc.inventoryStore.get<
{
id: string;
modelPath: string;
label: string;
}[]
>(keyPath);
return value?.some(({ id }) => id === requiredDownload.id);
})
).then(results => {
setIsCompleted(results.every(Boolean));
});
for (const requiredDownload of allRequiredDownloads) {
const keyPath = requiredDownload.destination.replaceAll("/", ".");
window.ipc.inventoryStore
.get<
{
id: string;
modelPath: string;
label: string;
}[]
>(keyPath)
.then(value => {
if (value?.some(({ id }) => id === requiredDownload.id)) {
console.log(requiredDownload.id);
}
});
}
}, []);

return isCompleted;
}

export function RequiredModelsAlert({ inline, appId }: { inline?: boolean; appId: string }) {
export function RequiredModelsAlert({ inline }: { inline?: boolean }) {
const { t } = useTranslation(["common", "labels"]);
const [downloadCount, setDownloadCount] = useState(0);
const [percent, setPercent] = useState(0);
const [isDownloading, setIsDownloading] = useState(false);
const [requiredDownloads, setRequiredDownloads] = useState<DownloadTask[]>([]);
const [isCompleted, setIsCompleted] = useState(downloadCount >= requiredDownloads.length);

useEffect(() => {
const unsubscribeDownload = window.ipc.on("download", progress => {
setPercent(progress.percent);
});
const unsubscribeDownloadComplete = window.ipc.on("downloadComplete", () => {
setDownloadCount(previousState => previousState + 1);
});

return () => {
unsubscribeDownload();
unsubscribeDownloadComplete();
};
}, []);

useEffect(() => {
const unsubscribeAllDownloads = window.ipc.on("allDownloads", downloads => {
const activities = requiredDownloads
.map(downloadItem => ({
id: downloadItem.id,
state: downloads[downloadItem.id],
}))
.filter(activity => Boolean(activity.state));
if (activities.length > 0) {
setIsDownloading(true);
}
});
return () => {
unsubscribeAllDownloads();
};
}, [requiredDownloads]);

useEffect(() => {
if (downloadCount >= requiredDownloads.length) {
setIsCompleted(true);
setIsDownloading(false);
} else {
setIsCompleted(false);
}
}, [downloadCount, requiredDownloads]);

useEffect(() => {
for (const requiredDownload of allRequiredDownloads) {
const keyPath = requiredDownload.destination.replaceAll("/", ".");
window.ipc.inventoryStore
.get<
{
id: string;
modelPath: string;
label: string;
}[]
>(keyPath)
.then(value => {
if (!value || !value.some(({ id }) => id === requiredDownload.id)) {
setRequiredDownloads(previousState => [...previousState, requiredDownload]);
}
});
}
}, []);

const isCompleted_ = useRequiredModels();
const { isCompleted, downloadCount, isDownloading, percent, requiredDownloads, download } =
useRequiredDownloads(allRequiredDownloads);

return (
<Snackbar
open={!isCompleted && !isCompleted_}
open={!isCompleted}
variant="soft"
color="warning"
startDecorator={<WarningIcon />}
Expand All @@ -198,15 +40,7 @@ export function RequiredModelsAlert({ inline, appId }: { inline?: boolean; appId
size="sm"
variant="solid"
color="warning"
onClick={async () => {
setIsDownloading(true);
await window.ipc.downloadFiles(
requiredDownloads.map(requiredDownload => ({
...requiredDownload,
appId,
}))
);
}}
onClick={download}
>
{t("labels:download")}
</Button>
Expand Down
1 change: 1 addition & 0 deletions src/client/organisms/layout/core.tsx
Expand Up @@ -23,6 +23,7 @@ export function CoreLayout({ children }: { children?: ReactNode }) {
>
<TabButton href="/core/dashboard">{t("labels:dashboard")}</TabButton>
<TabButton href="/core/settings">{t("common:settings")}</TabButton>
<TabButton href="/core/downloads">{t("labels:downloads")}</TabButton>
</Box>
</TitleBar>
}
Expand Down
2 changes: 1 addition & 1 deletion src/client/pages/[locale]/core/dashboard.tsx
Expand Up @@ -17,7 +17,7 @@ export default function Page(_properties: InferGetStaticPropsType<typeof getStat
<title>{t("labels:dashboard")}</title>
</Head>
<Box sx={{ p: 2 }}>
<RequiredModelsAlert inline appId="core" />
<RequiredModelsAlert inline />
<Alert color="primary" variant="soft" sx={{ m: 4, p: 4 }}>
<Typography level="title-lg">{t("texts:howToUseCaptain")}</Typography>
</Alert>
Expand Down