forked from replicate/cog-sdxl
/
test_lora.py
172 lines (140 loc) · 7.79 KB
/
test_lora.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import torch
# Bootstrapped from Huggingface diffuser's code.
import fnmatch
import json
import math
import os
import sys
import random
import time
import shutil
import numpy as np
import torch
import torch.utils.checkpoint
import torch.nn.functional as F
from diffusers.models.attention_processor import LoRAAttnProcessor, LoRAAttnProcessor2_0
from diffusers.optimization import get_scheduler
from diffusers import EulerDiscreteScheduler
from tqdm import tqdm
from diffusers import StableDiffusionPipeline
from safetensors.torch import load_file
from dataset_and_utils import (
PreprocessedDataset,
TokenEmbeddingsHandler,
load_models,
unet_attn_processors_state_dict
)
def patch_pipe_with_lora(pipe, lora_path):
print('xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx')
print('xxxxxxxxxxxxxxxxxxxxxxxx PATCHING PIPE WITH LORA xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx')
lora_rank = 8
concept_name = "eden_concept_lora"
# Make sure all weird delimiter characters are removed from concept_name before using it as a filepath:
concept_name = concept_name.replace(" ", "_").replace("/", "_").replace("\\", "_").replace(":", "_").replace("*", "_").replace("?", "_").replace("\"", "_").replace("<", "_").replace(">", "_").replace("|", "_")
unet = pipe.unet
lora_safetensors_path = os.path.join(lora_path, f"{concept_name}_lora.safetensors")
lora_safetensors_path = "/data/xander/Projects/cog/GitHub_repos/cog-sdxl/LoRA-3D-Light.safetensors"
if os.path.exists(lora_safetensors_path):
tensors = load_file(lora_safetensors_path)
print("Loaded lora tensors from", lora_safetensors_path)
unet_lora_attn_procs = {}
for name, attn_processor in unet.attn_processors.items():
cross_attention_dim = (
None
if name.endswith("attn1.processor")
else unet.config.cross_attention_dim
)
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[
block_id
]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
module = LoRAAttnProcessor2_0(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
rank=lora_rank,
)
unet_lora_attn_procs[name] = module.to("cuda")
unet.set_attn_processor(unet_lora_attn_procs)
else:
unet_path = os.path.join(lora_path, "unet.safetensors")
tensors = load_file(unet_path)
unet.load_state_dict(tensors, strict=False)
try: #SDXL
handler = TokenEmbeddingsHandler([pipe.text_encoder, pipe.text_encoder_2], [pipe.tokenizer, pipe.tokenizer_2])
except: #SD15
handler = TokenEmbeddingsHandler([pipe.text_encoder, None], [pipe.tokenizer, None])
#handler.load_embeddings(os.path.join(lora_path, f"{concept_name}_embeddings.safetensors"))
print('xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx')
return pipe
@torch.no_grad()
def render_images(lora_path, train_step, seed, is_lora, pretrained_model, lora_scale = 0.7, n_imgs = 4, debug = False, device = "cuda:0"):
random.seed(seed)
validation_prompts = [
'a beautiful mountainous landscape, boulders, fresh water stream, setting sun',
'the stunning skyline of New York City',
'fruit hanging from a tree, highly detailed texture, soil, rain, drops, photo realistic, surrealism, highly detailed, 8k macrophotography',
'the Taj Mahal, stunning wallpaper',
'A majestic tree rooted in circuits, leaves shimmering with data streams, stands as a beacon where the digital dawn caresses the fog-laden, binary soil—a symphony of pixels and chlorophyll.',
'A beautiful octopus, with swirling tendrils and a pulsating heart of fiery opal hues, hovers ethereally against a starry void, sculpted through a meticulous flame-working technique.',
'a stunning image of an aston martin sportscar',
'the streets of new york city, traffic lights, skyline, setting sun',
'An ethereal, levitating monolith backlit by a supernova sky, casting iridescent light on the ice-spiked Martian terrain. Neo-futurism, Dali surrealism, wide-angle lens, chiaroscuro lighting.',
'a portrait of a beautiful young woman',
'Binary Love: A heart-shaped composition made up of glowing binary code, symbolizing the merging of human emotion and technology, incredible digital art, cyberpunk, neon colors, glitch effects, 3D octane render, HD',
'A labyrinthine maze representing the search for answers and understanding, Abstract expressionism, muted color palette, heavy brushstrokes, textured surfaces, somber atmosphere, symbolic elements',
'A solitary tree standing tall amidst a sea of buildings, Urban nature photography, vibrant colors, juxtaposition of natural elements with urban landscapes, play of light and shadow, storytelling through compositions',
]
#validation_prompts = random.sample(validation_prompts, n_imgs)
validation_prompts[0] = ''
torch.cuda.empty_cache()
print(f"Loading inference pipeline from {pretrained_model['path']}...")
if pretrained_model['path'].endswith('.safetensors'):
pipeline = StableDiffusionPipeline.from_single_file(
pretrained_model['path'], torch_dtype=torch.float16, use_safetensors=True)
else:
pipeline = StableDiffusionPipeline.from_pretrained(
pretrained_model['path'], torch_dtype=torch.float16, use_safetensors=True)
pipeline = pipeline.to(device)
test_lora_dir = "test_lora"
pipeline.load_lora_weights(test_lora_dir)
pipeline.fuse_lora(lora_scale=0.7)
#pipeline = patch_pipe_with_lora(pipeline, lora_path)
pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
validation_prompts_raw = validation_prompts
validation_prompts = [prompt for prompt in validation_prompts]
generator = torch.Generator(device=device).manual_seed(0)
pipeline_args = {
"negative_prompt": "nude, naked, poorly drawn face, ugly, tiling, out of frame, extra limbs, disfigured, deformed body, blurry, blurred, watermark, text, grainy, signature, cut off, draft",
"num_inference_steps": 35,
"guidance_scale": 8,
}
if is_lora > 0:
cross_attention_kwargs = {"scale": lora_scale}
else:
cross_attention_kwargs = None
#with torch.cuda.amp.autocast():
os.makedirs(lora_path, exist_ok=True)
for i in range(n_imgs):
pipeline_args["prompt"] = validation_prompts[i]
print(f"Rendering validation img with prompt: {validation_prompts[i]}")
image = pipeline(**pipeline_args, generator=generator, cross_attention_kwargs = cross_attention_kwargs).images[0]
image.save(os.path.join(lora_path, f"img_{train_step:04d}_{i}.jpg"), format="JPEG", quality=95)
# create img_grid:
img_grid_path = make_validation_img_grid(lora_path)
del pipeline
gc.collect()
torch.cuda.empty_cache()
return validation_prompts_raw
lora_path = "/data/xander/Projects/cog/GitHub_repos/cog-sdxl/LoRA-3D-Light.safetensors"
train_step = 0
seed = 0
is_lora = 1
pretrained_model = {"path": "/data/xander/Projects/cog/GitHub_repos/cog-sdxl/models/juggernaut_reborn.safetensors"}
render_images(lora_path, train_step, seed, is_lora, pretrained_model,
lora_scale = 0.7, n_imgs = 4, debug = True, device = "cuda:0")