Spaces:
Runtime error
Runtime error
import os | |
import json | |
import argparse | |
import logging | |
from pathlib import Path | |
import torch | |
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO")) | |
log = logging.getLogger(__name__) | |
def get_args(): | |
parser = argparse.ArgumentParser(description='ESPER') | |
parser.add_argument( | |
'--init-model', type=str, default='gpt2', help='language model used for policy.') | |
parser.add_argument( | |
'--label_path', type=str, default='./data/esper_demo/labels_all.json', help='style label info file path') | |
parser.add_argument( | |
'--checkpoint', type=str, default='./data/esper_demo/ckpt/gpt2_style', help='checkpoint file path') | |
parser.add_argument( | |
'--prefix_length', type=int, default=10, help='prefix length for the visual mapper') | |
parser.add_argument( | |
'--clipcap_num_layers', type=int, default=1, help='num_layers for the visual mapper') | |
parser.add_argument( | |
'--use_transformer_mapper', action='store_true', default=False, help='use transformer mapper instead of mlp') | |
parser.add_argument( | |
'--use_label_prefix', action='store_true', default=False, help='label as prefixes') | |
parser.add_argument( | |
'--clip_model_type', type=str, default='ViT-B/32', help='clip backbone type') | |
parser.add_argument( | |
'--infer_no_repeat_size', type=int, default=2, help="no repeat ngram size for inference") | |
parser.add_argument( | |
'--response-length', type=int, default=20, help='number of tokens to generate for each prompt.') | |
parser.add_argument( | |
'--num-gpus', type=int, default=None, help='number of gpus. use all available if none') | |
parser.add_argument( | |
'--port', type=int, default=None, help="port for the demo server") | |
args = parser.parse_args() | |
args.cuda = torch.cuda.is_available() | |
if args.use_label_prefix: | |
log.info(f'using label prefix') | |
num_gpus = torch.cuda.device_count() | |
if args.num_gpus is None: | |
args.num_gpus = num_gpus | |
else: | |
args.num_gpus = min(num_gpus, args.num_gpus) | |
if args.checkpoint is not None: | |
args.checkpoint = str(Path(args.checkpoint).resolve()) | |
return args | |