Spaces:
Runtime error
Runtime error
File size: 1,879 Bytes
0bf81ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import os
import json
import argparse
import logging
from pathlib import Path
import torch
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
log = logging.getLogger(__name__)
def get_args():
parser = argparse.ArgumentParser(description='ESPER')
parser.add_argument(
'--init-model', type=str, default='gpt2', help='language model used for policy.')
parser.add_argument(
'--label_path', type=str, default='./data/esper_demo/labels_all.json', help='style label info file path')
parser.add_argument(
'--checkpoint', type=str, default='./data/esper_demo/ckpt/gpt2_style', help='checkpoint file path')
parser.add_argument(
'--prefix_length', type=int, default=10, help='prefix length for the visual mapper')
parser.add_argument(
'--clipcap_num_layers', type=int, default=1, help='num_layers for the visual mapper')
parser.add_argument(
'--use_transformer_mapper', action='store_true', default=False, help='use transformer mapper instead of mlp')
parser.add_argument(
'--use_label_prefix', action='store_true', default=False, help='label as prefixes')
parser.add_argument(
'--clip_model_type', type=str, default='ViT-B/32', help='clip backbone type')
parser.add_argument(
'--infer_no_repeat_size', type=int, default=2, help="no repeat ngram size for inference")
parser.add_argument(
'--response-length', type=int, default=20, help='number of tokens to generate for each prompt.')
parser.add_argument(
'--port', type=int, default=None, help="port for the demo server")
args = parser.parse_args()
args.cuda = torch.cuda.is_available()
if args.use_label_prefix:
log.info(f'using label prefix')
if args.checkpoint is not None:
args.checkpoint = str(Path(args.checkpoint).resolve())
return args
|