fpramunno commited on
Commit
f4dd292
1 Parent(s): 2abfe88

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -31
app.py CHANGED
@@ -6,36 +6,6 @@ import torch
6
  import torchvision.transforms as transforms
7
  import requests
8
 
9
- # Function to download the model from Google Drive
10
- def download_file_from_google_drive(id, destination):
11
- URL = "https://drive.google.com/uc?export=download"
12
- session = requests.Session()
13
- response = session.get(URL, params={'id': id}, stream=True)
14
- token = get_confirm_token(response)
15
-
16
- if token:
17
- params = {'id': id, 'confirm': token}
18
- response = session.get(URL, params=params, stream=True)
19
-
20
- save_response_content(response, destination)
21
-
22
- def get_confirm_token(response):
23
- for key, value in response.cookies.items():
24
- if key.startswith('download_warning'):
25
- return value
26
- return None
27
-
28
- def save_response_content(response, destination):
29
- CHUNK_SIZE = 32768
30
- with open(destination, "wb") as f:
31
- for chunk in response.iter_content(CHUNK_SIZE):
32
- if chunk: # filter out keep-alive new chunks
33
- f.write(chunk)
34
-
35
- # Replace 'YOUR_FILE_ID' with your actual file ID from Google Drive
36
- file_id = '1WJ33nys02XpPDsMO5uIZFiLqTuAT_iuV'
37
- destination = 'ema_ckpt_cond.pt'
38
- download_file_from_google_drive(file_id, destination)
39
 
40
  # Preprocessing
41
  from modules import PaletteModelV2
@@ -44,7 +14,7 @@ from diffusion import Diffusion_cond
44
  device = 'cuda'
45
 
46
  model = PaletteModelV2(c_in=2, c_out=1, num_classes=5, image_size=256, true_img_size=64).to(device)
47
- ckpt = torch.load(destination, map_location=device)
48
  model.load_state_dict(ckpt)
49
 
50
  diffusion = Diffusion_cond(noise_steps=1000, img_size=256, device=device)
 
6
  import torchvision.transforms as transforms
7
  import requests
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  # Preprocessing
11
  from modules import PaletteModelV2
 
14
  device = 'cuda'
15
 
16
  model = PaletteModelV2(c_in=2, c_out=1, num_classes=5, image_size=256, true_img_size=64).to(device)
17
+ ckpt = torch.load('ema_ckpt_cond.pt')
18
  model.load_state_dict(ckpt)
19
 
20
  diffusion = Diffusion_cond(noise_steps=1000, img_size=256, device=device)