kasper-boy commited on
Commit
062f53d
1 Parent(s): 208fcb6

Update main_app.py

Browse files
Files changed (1) hide show
  1. main_app.py +19 -2
main_app.py CHANGED
@@ -75,6 +75,7 @@ def apply_style(style_name: str, positive: str, negative: str = "") -> Tuple[str
75
  return p.replace("{prompt}", positive), n + negative
76
 
77
  DESCRIPTIONs = """ㅤㅤㅤ """
 
78
  DESCRIPTION = """ㅤㅤㅤ """
79
 
80
  if not torch.cuda.is_available():
@@ -109,14 +110,30 @@ if torch.cuda.is_available():
109
  pipe.enable_model_cpu_offload()
110
  pipe2.enable_model_cpu_offload()
111
  else:
112
- pipe.to(device)
113
- pipe2.to(device)
114
  print("Loaded on Device!")
115
 
116
  if USE_TORCH_COMPILE:
117
  pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
118
  pipe2.unet = torch.compile(pipe2.unet, mode="reduce-overhead", fullgraph=True)
119
  print("Model Compiled!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  def save_image(img):
122
  unique_name = str(uuid.uuid4()) + ".png"
 
75
  return p.replace("{prompt}", positive), n + negative
76
 
77
  DESCRIPTIONs = """ㅤㅤㅤ """
78
+
79
  DESCRIPTION = """ㅤㅤㅤ """
80
 
81
  if not torch.cuda.is_available():
 
110
  pipe.enable_model_cpu_offload()
111
  pipe2.enable_model_cpu_offload()
112
  else:
113
+ pipe.to(device)
114
+ pipe2.to(device)
115
  print("Loaded on Device!")
116
 
117
  if USE_TORCH_COMPILE:
118
  pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
119
  pipe2.unet = torch.compile(pipe2.unet, mode="reduce-overhead", fullgraph=True)
120
  print("Model Compiled!")
121
+ else:
122
+ pipe = DiffusionPipeline.from_pretrained(
123
+ "SG161222/RealVisXL_V4.0",
124
+ torch_dtype=torch.float32,
125
+ use_safetensors=True,
126
+ add_watermarker=False
127
+ )
128
+ pipe2 = DiffusionPipeline.from_pretrained(
129
+ "SG161222/RealVisXL_V3.0",
130
+ torch_dtype=torch.float32,
131
+ use_safetensors=True,
132
+ add_watermarker=False
133
+ )
134
+ pipe.to(device)
135
+ pipe2.to(device)
136
+ print("Loaded on Device!")
137
 
138
  def save_image(img):
139
  unique_name = str(uuid.uuid4()) + ".png"