import os import logging import json from pathlib import Path import yaml import torch from policy import Policy logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO")) log = logging.getLogger(__name__) def load_model_args(args): checkpoint = Path(args.checkpoint + '.ckpt') assert checkpoint.is_file(), f"no checkpoint file: {args.checkpoint}" args_path = Path(args.checkpoint + '.json') if args_path.is_file(): with open(args_path) as f: hparams = json.load(f) else: args_path = Path(args.checkpoint + '.yaml') with open(args_path) as f: hparams = yaml.safe_load(f) for key in ['init_model', 'clip_model_type', 'use_caption', 'use_style_reward', 'use_transformer_mapper', 'prefix_length', 'clipcap_num_layers', 'use_ptuning_v2']: if key in hparams: setattr(args, key, hparams[key]) args.loaded_init_model = True return args def load_model(args, device, finetune=False): log.info('loading model') policy = Policy(model_name=args.init_model, temperature=1.0, device=device, clipcap_path='None', fix_gpt=True, label_path=args.label_path, prefix_length=args.prefix_length, clipcap_num_layers=args.clipcap_num_layers, use_transformer_mapper=args.use_transformer_mapper, model_weight='None', use_label_prefix=args.use_label_prefix) ckpt = args.checkpoint + '.ckpt' state = torch.load(ckpt) policy_key = 'policy_model' if policy_key in state: policy.model.load_state_dict(state[policy_key]) else: weights = state['state_dict'] key = 'policy.model.' if not any(k for k in weights.keys() if k.startswith(key)): key = 'model.model.' weights = {k[len(key):]: v for k, v in weights.items() if k.startswith(key)} # weights = {k: v for k, v in weights.items() if k.startswith('clip_project.')} policy.model.load_state_dict(weights, strict=False) model = policy model = model.to(device) return model