File size: 1,879 Bytes
0bf81ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import os
import json
import argparse
import logging
from pathlib import Path

import torch


logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
log = logging.getLogger(__name__)


def get_args():
    parser = argparse.ArgumentParser(description='ESPER')


    parser.add_argument(
        '--init-model', type=str, default='gpt2', help='language model used for policy.')
    parser.add_argument(
        '--label_path', type=str, default='./data/esper_demo/labels_all.json', help='style label info file path')
    parser.add_argument(
        '--checkpoint', type=str, default='./data/esper_demo/ckpt/gpt2_style', help='checkpoint file path')

    parser.add_argument(
        '--prefix_length', type=int, default=10, help='prefix length for the visual mapper')
    parser.add_argument(
        '--clipcap_num_layers', type=int, default=1, help='num_layers for the visual mapper')
    parser.add_argument(
        '--use_transformer_mapper', action='store_true', default=False, help='use transformer mapper instead of mlp')
    parser.add_argument(
        '--use_label_prefix', action='store_true', default=False, help='label as prefixes')
    parser.add_argument(
        '--clip_model_type', type=str, default='ViT-B/32', help='clip backbone type')

    parser.add_argument(
        '--infer_no_repeat_size', type=int, default=2, help="no repeat ngram size for inference")
    parser.add_argument(
        '--response-length', type=int, default=20, help='number of tokens to generate for each prompt.')
    parser.add_argument(
        '--port', type=int, default=None, help="port for the demo server")

    args = parser.parse_args()
    args.cuda = torch.cuda.is_available()

    if args.use_label_prefix:
        log.info(f'using label prefix')

    if args.checkpoint is not None:
        args.checkpoint = str(Path(args.checkpoint).resolve())
    return args