Spaces:
Running
on
Zero
Running
on
Zero
Update OmniGen/model.py
Browse files- OmniGen/model.py +7 -2
OmniGen/model.py
CHANGED
@@ -9,6 +9,7 @@ from typing import Dict
|
|
9 |
from diffusers.loaders import PeftAdapterMixin
|
10 |
from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
|
11 |
from huggingface_hub import snapshot_download
|
|
|
12 |
|
13 |
from OmniGen.transformer import Phi3Config, Phi3Transformer
|
14 |
|
@@ -187,14 +188,18 @@ class OmniGen(nn.Module, PeftAdapterMixin):
|
|
187 |
|
188 |
@classmethod
|
189 |
def from_pretrained(cls, model_name):
|
190 |
-
if not os.path.exists(
|
191 |
cache_folder = os.getenv('HF_HUB_CACHE')
|
192 |
model_name = snapshot_download(repo_id=model_name,
|
193 |
cache_dir=cache_folder,
|
194 |
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
|
195 |
config = Phi3Config.from_pretrained(model_name)
|
196 |
model = cls(config)
|
197 |
-
|
|
|
|
|
|
|
|
|
198 |
model.load_state_dict(ckpt)
|
199 |
return model
|
200 |
|
|
|
9 |
from diffusers.loaders import PeftAdapterMixin
|
10 |
from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
|
11 |
from huggingface_hub import snapshot_download
|
12 |
+
from safetensors.torch import load_file
|
13 |
|
14 |
from OmniGen.transformer import Phi3Config, Phi3Transformer
|
15 |
|
|
|
188 |
|
189 |
@classmethod
|
190 |
def from_pretrained(cls, model_name):
|
191 |
+
if not os.path.exists(model_name):
|
192 |
cache_folder = os.getenv('HF_HUB_CACHE')
|
193 |
model_name = snapshot_download(repo_id=model_name,
|
194 |
cache_dir=cache_folder,
|
195 |
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
|
196 |
config = Phi3Config.from_pretrained(model_name)
|
197 |
model = cls(config)
|
198 |
+
if os.path.exists(os.path.join(model_name, 'model.safetensors')):
|
199 |
+
print("Loading safetensors")
|
200 |
+
ckpt = load_file(os.path.join(model_name, 'model.safetensors'))
|
201 |
+
else:
|
202 |
+
ckpt = torch.load(os.path.join(model_name, 'model.pt'), map_location='cpu')
|
203 |
model.load_state_dict(ckpt)
|
204 |
return model
|
205 |
|