grg's picture
Cleaned old git history
be5548b
raw
history blame
38.7 kB
import argparse
import random
import warnings
import numpy as np
import time
import datetime
import torch
import gym_minigrid.social_ai_envs
import torch_ac
import sys
import json
import utils
from pathlib import Path
from distutils.dir_util import copy_tree
from utils.env import env_args_str_to_dict
from models import *
# Parse arguments
parser = argparse.ArgumentParser()
## General parameters
parser.add_argument("--algo", required=True,
help="algorithm to use: ppo (REQUIRED)")
parser.add_argument("--env", required=True,
help="name of the environment to train on (REQUIRED)")
parser.add_argument("--model", default=None,
help="name of the model (default: {ENV}_{ALGO}_{TIME})")
parser.add_argument("--seed", type=int, default=1,
help="random seed (default: 1)")
parser.add_argument("--log-interval", type=int, default=10,
help="number of updates between two logs (default: 10)")
parser.add_argument("--save-interval", type=int, default=10,
help="number of updates between two saves (default: 10, 0 means no saving)")
parser.add_argument("--procs", type=int, default=16,
help="number of processes (default: 16)")
parser.add_argument("--frames", type=int, default=10**7,
help="number of frames of training (default: 1e7)")
parser.add_argument("--continue-train", default=None,
help="path to the model to finetune", type=str)
parser.add_argument("--finetune-train", default=None,
help="path to the model to finetune", type=str)
parser.add_argument("--compact-save", "-cs", action="store_true", default=False,
help="Keep only last model save")
parser.add_argument("--lr-schedule-end-frames", type=int, default=0,
help="Learning rate will be diminished from --lr to 0 linearly over the period of --lr-schedule-end-frames (default: 0 - no diminsh)")
parser.add_argument("--lr-end", type=float, default=0,
help="the final lr that will be reached at 'lr-schedule-end-frames' (default = 0)")
## Periodic test parameters
parser.add_argument("--test-set-name", required=False,
help="name of the environment to test on, default use the train env", default="SocialAITestSet")
# parser.add_argument("--test-env", required=False,
# help="name of the environment to test on, default use the train env")
# parser.add_argument("--no-test", "-nt", action="store_true", default=False,
# help="don't perform periodic testing")
parser.add_argument("--test-seed", type=int, default=0,
help="random seed (default: 0)")
parser.add_argument("--test-episodes", type=int, default=50,
help="number of episodes to test")
parser.add_argument("--test-interval", type=int, default=-1,
help="number of updates between two tests (default: -1, no testing)")
parser.add_argument("--test-env-args", nargs='*', default="like_train_no_acl")
## Parameters for main algorithm
parser.add_argument("--acl", action="store_true", default=False,
help="use acl")
parser.add_argument("--acl-type", type=str, default=None,
help="acl type")
parser.add_argument("--acl-thresholds", nargs="+", type=float, default=(0.75, 0.75),
help="per phase thresholds for expert CL")
parser.add_argument("--acl-minimum-episodes", type=int, default=1000,
help="Never go to second phase before this.")
parser.add_argument("--acl-average-interval", type=int, default=500,
help="Average the perfromance estimate over this many last episodes")
parser.add_argument("--epochs", type=int, default=4,
help="number of epochs for PPO (default: 4)")
parser.add_argument("--exploration-bonus", action="store_true", default=False,
help="Use a count based exploration bonus")
parser.add_argument("--exploration-bonus-type", nargs="+", default=["lang"],
help="modality on which to use the bonus (lang/grid)")
parser.add_argument("--exploration-bonus-params", nargs="+", type=float, default=(30., 50.),
help="parameters for a count based exploration bonus (C, M)")
parser.add_argument("--exploration-bonus-tanh", nargs="+", type=float, default=None,
help="tanh expl bonus scale, None means no tanh")
parser.add_argument("--expert-exploration-bonus", action="store_true", default=False,
help="Use an expert exploration bonus")
parser.add_argument("--episodic-exploration-bonus", action="store_true", default=False,
help="Use the exploration bonus in a episodic setting")
parser.add_argument("--batch-size", type=int, default=256,
help="batch size for PPO (default: 256)")
parser.add_argument("--frames-per-proc", type=int, default=None,
help="number of frames per process before update (default: 5 for A2C and 128 for PPO)")
parser.add_argument("--discount", type=float, default=0.99,
help="discount factor (default: 0.99)")
parser.add_argument("--lr", type=float, default=0.001,
help="learning rate (default: 0.001)")
parser.add_argument("--gae-lambda", type=float, default=0.99,
help="lambda coefficient in GAE formula (default: 0.99, 1 means no gae)")
parser.add_argument("--entropy-coef", type=float, default=0.01,
help="entropy term coefficient (default: 0.01)")
parser.add_argument("--value-loss-coef", type=float, default=0.5,
help="value loss term coefficient (default: 0.5)")
parser.add_argument("--max-grad-norm", type=float, default=0.5,
help="maximum norm of gradient (default: 0.5)")
parser.add_argument("--optim-eps", type=float, default=1e-8,
help="Adam and RMSprop optimizer epsilon (default: 1e-8)")
parser.add_argument("--optim-alpha", type=float, default=0.99,
help="RMSprop optimizer alpha (default: 0.99)")
parser.add_argument("--clip-eps", type=float, default=0.2,
help="clipping epsilon for PPO (default: 0.2)")
parser.add_argument("--recurrence", type=int, default=1,
help="number of time-steps gradient is backpropagated (default: 1). If > 1, a LSTM is added to the model to have memory.")
parser.add_argument("--text", action="store_true", default=False,
help="add a GRU to the model to handle text input")
parser.add_argument("--dialogue", action="store_true", default=False,
help="add a GRU to the model to handle the history of dialogue input")
parser.add_argument("--current-dialogue-only", action="store_true", default=False,
help="add a GRU to the model to handle only the current dialogue input")
parser.add_argument("--multi-headed-agent", action="store_true", default=False,
help="add a talking head")
parser.add_argument("--babyai11_agent", action="store_true", default=False,
help="use the babyAI 1.1 agent architecture")
parser.add_argument("--multi-headed-babyai11-agent", action="store_true", default=False,
help="use the multi headed babyAI 1.1 agent architecture")
parser.add_argument("--custom-ppo", action="store_true", default=False,
help="use BabyAI original PPO hyperparameters")
parser.add_argument("--custom-ppo-2", action="store_true", default=False,
help="use BabyAI original PPO hyperparameters but with smaller memory")
parser.add_argument("--custom-ppo-3", action="store_true", default=False,
help="use BabyAI original PPO hyperparameters but with no memory")
parser.add_argument("--custom-ppo-rnd", action="store_true", default=False,
help="rnd reconstruct")
parser.add_argument("--custom-ppo-rnd-reference", action="store_true", default=False,
help="rnd reconstruct")
parser.add_argument("--custom-ppo-ride", action="store_true", default=False,
help="rnd reconstruct")
parser.add_argument("--custom-ppo-ride-reference", action="store_true", default=False,
help="rnd reconstruct")
parser.add_argument("--ppo-hp-tuning", action="store_true", default=False,
help="use PPO hyperparameters selected from our HP tuning")
parser.add_argument("--multi-modal-babyai11-agent", action="store_true", default=False,
help="use the multi headed babyAI 1.1 agent architecture")
# ride ref
parser.add_argument("--ride-ref-agent", action="store_true", default=False,
help="Model from the ride paper")
parser.add_argument("--ride-ref-preprocessor", action="store_true", default=False,
help="use ride reference preprocessor (3D images)")
parser.add_argument("--bAI-lang-model", help="lang model type for babyAI models", default="gru")
parser.add_argument("--memory-dim", type=int, help="memory dim (128 is small 2048 is big", default=128)
parser.add_argument("--clipped-rewards", action="store_true", default=False,
help="add a talking head")
parser.add_argument("--intrinsic-reward-epochs", type=int, default=0,
help="")
parser.add_argument("--balance-moa-training", action="store_true", default=False,
help="balance moa training to handle class imbalance.")
parser.add_argument("--moa-memory-dim", type=int, help="memory dim (default=128)", default=128)
# rnd + ride
parser.add_argument("--intrinsic-reward-coef", type=float, default=0.1,
help="")
parser.add_argument("--intrinsic-reward-learning-rate", type=float, default=0.0001,
help="")
parser.add_argument("--intrinsic-reward-momentum", type=float, default=0,
help="")
parser.add_argument("--intrinsic-reward-epsilon", type=float, default=0.01,
help="")
parser.add_argument("--intrinsic-reward-alpha", type=float, default=0.99,
help="")
parser.add_argument("--intrinsic-reward-max-grad-norm", type=float, default=40,
help="")
# rnd + soc_inf
parser.add_argument("--intrinsic-reward-loss-coef", type=float, default=0.1,
help="")
# ride
parser.add_argument("--intrinsic-reward-forward-loss-coef", type=float, default=10,
help="")
parser.add_argument("--intrinsic-reward-inverse-loss-coef", type=float, default=0.1,
help="")
parser.add_argument("--reset-rnd-ride-at-phase", action="store_true", default=False,
help="expert knowledge resets rnd ride at acl phase change")
# babyAI1.1 related
parser.add_argument("--arch", default="original_endpool_res",
help="image embedding architecture")
parser.add_argument("--num-films", type=int, default=2,
help="")
# Put all env related arguments after --env_args, e.g. --env_args nb_foo 1 is_bar True
parser.add_argument("--env-args", nargs='*', default=None)
args = parser.parse_args()
if args.compact_save:
print("Compact save is deprecated. Don't use it. It doesn't do anything now.")
if args.save_interval != args.log_interval:
print(f"save_interval ({args.save_interval}) and log_interval ({args.log_interval}) are not the same. This is not ideal for train continuation.")
if args.seed == -1:
args.seed = np.random.randint(424242)
if args.custom_ppo:
print("babyAI's ppo config")
assert not args.custom_ppo_2
assert not args.custom_ppo_3
args.frames_per_proc = 40
args.lr = 1e-4
args.gae_lambda = 0.99
args.recurrence = 20
args.optim_eps = 1e-05
args.clip_eps = 0.2
args.batch_size = 1280
elif args.custom_ppo_2:
print("babyAI's ppo config with smaller memory")
assert not args.custom_ppo
assert not args.custom_ppo_3
args.frames_per_proc = 40
args.lr = 1e-4
args.gae_lambda = 0.99
args.recurrence = 10
args.optim_eps = 1e-05
args.clip_eps = 0.2
args.batch_size = 1280
elif args.custom_ppo_3:
print("babyAI's ppo config with no memory")
assert not args.custom_ppo
assert not args.custom_ppo_2
args.frames_per_proc = 40
args.lr = 1e-4
args.gae_lambda = 0.99
args.recurrence = 1
args.optim_eps = 1e-05
args.clip_eps = 0.2
args.batch_size = 1280
elif args.custom_ppo_rnd:
print("RND reconstruct")
assert not args.custom_ppo
assert not args.custom_ppo_2
assert not args.custom_ppo_3
args.frames_per_proc = 40
args.lr = 1e-4
args.recurrence = 1
# args.recurrence = 5 # use 5 for SocialAI envs
args.batch_size = 640
args.epochs = 4
# args.optim_eps = 1e-05
# args.entropy_coef = 0.0001
args.clipped_rewards = True
elif args.custom_ppo_ride:
print("RIDE reconstruct")
assert not args.custom_ppo
assert not args.custom_ppo_2
assert not args.custom_ppo_3
assert not args.custom_ppo_rnd
args.frames_per_proc = 40
args.lr = 1e-4
args.recurrence = 1
# args.recurrence = 5 # use 5 for SocialAI envs
args.batch_size = 640
args.epochs = 4
# args.optim_eps = 1e-05
# args.entropy_coef = 0.0005
args.clipped_rewards = True
elif args.custom_ppo_rnd_reference:
print("RND reconstruct")
assert not args.custom_ppo
assert not args.custom_ppo_2
assert not args.custom_ppo_3
args.frames_per_proc = 128 # 128 for PPO
args.lr = 1e-4
args.recurrence = 64
args.gae_lambda = 0.99
args.batch_size = 1280
args.epochs = 4
args.optim_eps = 1e-05
args.clip_eps = 0.2
args.entropy_coef = 0.0001
args.clipped_rewards = True
elif args.custom_ppo_ride_reference:
print("RIDE reference")
assert not args.custom_ppo
assert not args.custom_ppo_2
assert not args.custom_ppo_3
assert not args.custom_ppo_rnd
args.frames_per_proc = 128 # 128 for PPO
args.lr = 1e-4
args.recurrence = 64
args.gae_lambda = 0.99
args.batch_size = 1280
args.epochs = 4
args.optim_eps = 1e-05
args.clip_eps = 0.2
args.entropy_coef = 0.0005
args.clipped_rewards = True
elif args.ppo_hp_tuning:
args.frames_per_proc = 40
args.lr = 1e-4
args.recurrence = 5
args.batch_size = 640
args.epochs = 4
if args.env not in [
"MiniGrid-KeyCorridorS3R3-v0",
"MiniGrid-MultiRoom-N2-S4-v0",
"MiniGrid-MultiRoom-N4-S5-v0",
"MiniGrid-MultiRoom-N7-S4-v0",
"MiniGrid-MultiRoomNoisyTV-N7-S4-v0"
]:
if args.recurrence <= 1:
print("You are using recurrence {} with {} env. This is probably unintentional.".format(args.recurrence, args.env))
# warnings.warn("You are using recurrence {} with {} env. This is probably unintentional.".format(args.recurrence, args.env))
args.mem = args.recurrence > 1
# Set run dir
date = datetime.datetime.now().strftime("%y-%m-%d-%H-%M-%S")
default_model_name = f"{args.env}_{args.algo}_seed{args.seed}_{date}"
model_name = args.model or default_model_name
model_dir = utils.get_model_dir(model_name)
if Path(model_dir).exists() and args.continue_train is None:
raise ValueError(f"Dir {model_dir} already exists and continue train is None.")
# Load loggers and Tensorboard writer
txt_logger = utils.get_txt_logger(model_dir)
csv_file, csv_logger = utils.get_csv_logger(model_dir)
# Log command and all script arguments
txt_logger.info("{}\n".format(" ".join(sys.argv)))
txt_logger.info("{}\n".format(args))
# Set seed for all randomness sources
utils.seed(args.seed)
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
txt_logger.info(f"Device: {device}\n")
# Create env_args dict
env_args = env_args_str_to_dict(args.env_args)
if args.acl:
# expert_acl = "three_stage_expert"
expert_acl = args.acl_type
print(f"Using curriculum: {expert_acl}.")
else:
expert_acl = None
env_args_no_acl = env_args.copy()
env_args["curriculum"] = expert_acl
env_args["expert_curriculum_thresholds"] = args.acl_thresholds
env_args["expert_curriculum_average_interval"] = args.acl_average_interval
env_args["expert_curriculum_minimum_episodes"] = args.acl_minimum_episodes
env_args["egocentric_observation"] = True
# test env args
if not args.test_env_args:
test_env_args = {}
elif args.test_env_args == "like_train_no_acl":
test_env_args = env_args_no_acl
elif args.test_env_args == "like_train":
test_env_args = env_args
else:
test_env_args = env_args_str_to_dict(args.test_env_args)
if "SocialAI-" not in args.env:
env_args = {}
test_env_args = {}
print("train_env_args:", env_args)
print("test_env_args:", test_env_args)
# Load train environments
envs = []
for i in range(args.procs):
envs.append(utils.make_env(args.env, args.seed + 10000 * i, env_args=env_args))
txt_logger.info("Environments loaded\n")
if args.continue_train and args.finetune_train:
raise ValueError(f"Continue path ({args.continue_train}) and finetune path ({args.finetune_train}) can't both be set.")
# Load training status
if args.continue_train:
if args.continue_train == "auto":
status_continue_path = Path(model_dir)
args.continue_train = status_continue_path # just in case
else:
status_continue_path = Path(args.continue_train)
if status_continue_path.is_dir():
# if dir, assume experiment dir so append the seed
# status_continue_path = Path(status_continue_path) / str(args.seed)
status_continue_path = utils.get_status_path(status_continue_path)
else:
if not status_continue_path.is_file():
raise ValueError(f"{status_continue_path} is not a file")
if "status" not in status_continue_path.name:
raise UserWarning(f"{status_continue_path} is does not contain status, is this the correct file? ")
status = utils.load_status(status_continue_path)
txt_logger.info("Training status loaded\n")
txt_logger.info(f"{model_name} continued from {status_continue_path}")
# copy everything from model_dir to backup_dir
assert Path(status_continue_path).is_file()
elif args.finetune_train:
status_finetune_path = Path(args.finetune_train)
if status_finetune_path.is_dir():
# if dir, assume experiment dir so append the seed
status_finetune_seed_path = Path(status_finetune_path) / str(args.seed)
if status_finetune_seed_path.exists():
# if a seed folder exists assume that you use that one
status_finetune_path = utils.get_status_path(status_finetune_seed_path)
else:
# if not assume that no seed folder exists
status_finetune_path = utils.get_status_path(status_finetune_path)
else:
if not status_finetune_path.is_file():
raise ValueError(f"{status_finetune_path} is not dir or a file")
if "status" not in status_finetune_path.name:
raise UserWarning(f"{status_finetune_path} is does not contain status, is this the correct file? ")
status = utils.load_status(status_finetune_path)
txt_logger.info("Training status loaded\n")
txt_logger.info(f"{model_name} finetuning from {status_finetune_path}")
# copy everything from model_dir to backup_dir
assert Path(status_finetune_path).is_file()
# reset parameters for finetuning
status["num_frames"] = 0
status["update"] = 0
del status["optimizer_state"]
del status["lr_scheduler_state"]
del status["env_args"]
else:
status = {"num_frames": 0, "update": 0}
# Parameter sanity checks
if args.dialogue and args.current_dialogue_only:
raise ValueError("Either use dialogue or current-dialogue-only")
if not args.dialogue and not args.current_dialogue_only:
warnings.warn("Not using dialogue")
if args.text:
raise ValueError("Text should not be used. Use dialogue instead.")
# Load observations preprocessor
obs_space, preprocess_obss = utils.get_obss_preprocessor(
obs_space=envs[0].observation_space,
text=args.text,
dialogue_current=args.current_dialogue_only,
dialogue_history=args.dialogue,
custom_image_preprocessor=utils.ride_ref_image_preprocessor if args.ride_ref_preprocessor else None,
custom_image_space_preprocessor=utils.ride_ref_image_space_preprocessor if args.ride_ref_preprocessor else None,
)
if args.continue_train is not None or args.finetune_train is not None:
assert "vocab" in status
preprocess_obss.vocab.load_vocab(status["vocab"])
txt_logger.info("Observations preprocessor loaded")
if args.exploration_bonus:
if args.expert_exploration_bonus:
warnings.warn("You are using expert exploration bonus.")
# Load model
assert sum(map(int, [
args.multi_modal_babyai11_agent,
args.multi_headed_babyai11_agent,
args.babyai11_agent,
args.multi_headed_agent,
])) <= 1
if args.multi_modal_babyai11_agent:
acmodel = MultiModalBaby11ACModel(
obs_space=obs_space,
action_space=envs[0].action_space,
arch=args.arch,
use_text=args.text,
use_dialogue=args.dialogue,
use_current_dialogue_only=args.current_dialogue_only,
use_memory=args.mem,
lang_model=args.bAI_lang_model,
memory_dim=args.memory_dim,
num_films=args.num_films
)
elif args.ride_ref_agent:
assert args.mem
assert not args.text
assert not args.dialogue
acmodel = RefACModel(
obs_space=obs_space,
action_space=envs[0].action_space,
use_memory=args.mem,
use_text=args.text,
use_dialogue=args.dialogue,
input_size=obs_space['image'][-1],
)
if args.current_dialogue_only: raise NotImplementedError("current dialogue only")
else:
acmodel = ACModel(
obs_space=obs_space,
action_space=envs[0].action_space,
use_memory=args.mem,
use_text=args.text,
use_dialogue=args.dialogue,
input_size=obs_space['image'][-1],
)
if args.current_dialogue_only: raise NotImplementedError("current dialogue only")
# if args.continue_train is not None:
# assert "model_state" in status
# acmodel.load_state_dict(status["model_state"])
acmodel.to(device)
txt_logger.info("Model loaded\n")
txt_logger.info("{}\n".format(acmodel))
# Load algo
assert args.algo == "ppo"
algo = torch_ac.PPOAlgo(
envs=envs,
acmodel=acmodel,
device=device,
num_frames_per_proc=args.frames_per_proc,
discount=args.discount,
lr=args.lr,
gae_lambda=args.gae_lambda,
entropy_coef=args.entropy_coef,
value_loss_coef=args.value_loss_coef,
max_grad_norm=args.max_grad_norm,
recurrence=args.recurrence,
adam_eps=args.optim_eps,
clip_eps=args.clip_eps,
epochs=args.epochs,
batch_size=args.batch_size,
preprocess_obss=preprocess_obss,
exploration_bonus=args.exploration_bonus,
exploration_bonus_tanh=args.exploration_bonus_tanh,
exploration_bonus_type=args.exploration_bonus_type,
exploration_bonus_params=args.exploration_bonus_params,
expert_exploration_bonus=args.expert_exploration_bonus,
episodic_exploration_bonus=args.episodic_exploration_bonus,
clipped_rewards=args.clipped_rewards,
# for rnd, ride, and social influence
intrinsic_reward_coef=args.intrinsic_reward_coef,
# for rnd and ride
intrinsic_reward_epochs=args.intrinsic_reward_epochs,
intrinsic_reward_learning_rate=args.intrinsic_reward_learning_rate,
intrinsic_reward_momentum=args.intrinsic_reward_momentum,
intrinsic_reward_epsilon=args.intrinsic_reward_epsilon,
intrinsic_reward_alpha=args.intrinsic_reward_alpha,
intrinsic_reward_max_grad_norm=args.intrinsic_reward_max_grad_norm,
# for rnd and social influence
intrinsic_reward_loss_coef=args.intrinsic_reward_loss_coef,
# for ride
intrinsic_reward_forward_loss_coef=args.intrinsic_reward_forward_loss_coef,
intrinsic_reward_inverse_loss_coef=args.intrinsic_reward_inverse_loss_coef,
# for social influence
balance_moa_training=args.balance_moa_training,
moa_memory_dim=args.moa_memory_dim,
lr_schedule_end_frames=args.lr_schedule_end_frames,
end_lr=args.lr_end,
reset_rnd_ride_at_phase=args.reset_rnd_ride_at_phase,
)
if args.continue_train or args.finetune_train:
algo.load_status_dict(status)
# txt_logger.info(f"Model + Algo loaded from {args.continue_train or args.finetune_train}\n")
if args.continue_train:
txt_logger.info(f"Model + Algo loaded from {status_continue_path} \n")
elif args.finetune_train:
txt_logger.info(f"Model + Algo loaded from {status_finetune_path} \n")
# todo: make nicer
# Set and load test environment
if args.test_set_name:
if args.test_set_name == "SocialAITestSet":
# "SocialAI-AskEyeContactLanguageBoxesInformationSeekingParamEnv-v1",
# "SocialAI-NoIntroPointingBoxesInformationSeekingParamEnv-v1"
test_env_names = [
"SocialAI-TestLanguageColorBoxesInformationSeekingEnv-v1",
"SocialAI-TestLanguageFeedbackBoxesInformationSeekingEnv-v1",
"SocialAI-TestPointingBoxesInformationSeekingEnv-v1",
"SocialAI-TestEmulationBoxesInformationSeekingEnv-v1",
"SocialAI-TestLanguageColorSwitchesInformationSeekingEnv-v1",
"SocialAI-TestLanguageFeedbackSwitchesInformationSeekingEnv-v1",
"SocialAI-TestPointingSwitchesInformationSeekingEnv-v1",
"SocialAI-TestEmulationSwitchesInformationSeekingEnv-v1",
"SocialAI-TestLanguageColorMarbleInformationSeekingEnv-v1",
"SocialAI-TestLanguageFeedbackMarbleInformationSeekingEnv-v1",
"SocialAI-TestPointingMarbleInformationSeekingEnv-v1",
"SocialAI-TestEmulationMarbleInformationSeekingEnv-v1",
"SocialAI-TestLanguageColorGeneratorsInformationSeekingEnv-v1",
"SocialAI-TestLanguageFeedbackGeneratorsInformationSeekingEnv-v1",
"SocialAI-TestPointingGeneratorsInformationSeekingEnv-v1",
"SocialAI-TestEmulationGeneratorsInformationSeekingEnv-v1",
"SocialAI-TestLanguageColorLeversInformationSeekingEnv-v1",
"SocialAI-TestLanguageFeedbackLeversInformationSeekingEnv-v1",
"SocialAI-TestPointingLeversInformationSeekingEnv-v1",
"SocialAI-TestEmulationLeversInformationSeekingEnv-v1",
"SocialAI-TestLanguageColorDoorsInformationSeekingEnv-v1",
"SocialAI-TestLanguageFeedbackDoorsInformationSeekingEnv-v1",
"SocialAI-TestPointingDoorsInformationSeekingEnv-v1",
"SocialAI-TestEmulationDoorsInformationSeekingEnv-v1",
"SocialAI-TestLeverDoorCollaborationEnv-v1",
"SocialAI-TestMarblePushCollaborationEnv-v1",
"SocialAI-TestMarblePassCollaborationEnv-v1",
"SocialAI-TestAppleStealingPerspectiveTakingEnv-v1"
]
elif args.test_set_name == "SocialAIGSTestSet":
test_env_names = [
"SocialAI-GridSearchParamEnv-v1",
"SocialAI-GridSearchPointingParamEnv-v1",
"SocialAI-GridSearchLangColorParamEnv-v1",
"SocialAI-GridSearchLangFeedbackParamEnv-v1",
]
elif args.test_set_name == "SocialAICuesGSTestSet":
test_env_names = [
"SocialAI-CuesGridSearchParamEnv-v1",
"SocialAI-CuesGridSearchPointingParamEnv-v1",
"SocialAI-CuesGridSearchLangColorParamEnv-v1",
"SocialAI-CuesGridSearchLangFeedbackParamEnv-v1",
]
elif args.test_set_name == "BoxesPointingTestSet":
test_env_names = [
"SocialAI-TestPointingBoxesInformationSeekingParamEnv-v1",
]
elif args.test_set_name == "PointingTestSet":
test_env_names = gym_minigrid.social_ai_envs.pointing_test_set
elif args.test_set_name == "LangColorTestSet":
test_env_names = gym_minigrid.social_ai_envs.language_color_test_set
elif args.test_set_name == "LangFeedbackTestSet":
test_env_names = gym_minigrid.social_ai_envs.language_feedback_test_set
# joint attention
elif args.test_set_name == "JAPointingTestSet":
test_env_names = gym_minigrid.social_ai_envs.ja_pointing_test_set
elif args.test_set_name == "JALangColorTestSet":
test_env_names = gym_minigrid.social_ai_envs.ja_language_color_test_set
elif args.test_set_name == "JALangFeedbackTestSet":
test_env_names = gym_minigrid.social_ai_envs.ja_language_feedback_test_set
# emulation
elif args.test_set_name == "DistrEmulationTestSet":
test_env_names = gym_minigrid.social_ai_envs.distr_emulation_test_set
elif args.test_set_name == "NoDistrEmulationTestSet":
test_env_names = gym_minigrid.social_ai_envs.no_distr_emulation_test_set
# formats
elif args.test_set_name == "NFormatsTestSet":
test_env_names = gym_minigrid.social_ai_envs.N_formats_test_set
elif args.test_set_name == "EFormatsTestSet":
test_env_names = gym_minigrid.social_ai_envs.E_formats_test_set
elif args.test_set_name == "AFormatsTestSet":
test_env_names = gym_minigrid.social_ai_envs.A_formats_test_set
elif args.test_set_name == "AEFormatsTestSet":
test_env_names = gym_minigrid.social_ai_envs.AE_formats_test_set
elif args.test_set_name == "RoleReversalTestSet":
test_env_names = gym_minigrid.social_ai_envs.role_reversal_test_set
else:
raise ValueError("Undefined test set name.")
else:
test_env_names = [args.env]
# test_envs = []
testers = []
if args.test_interval > 0:
for test_env_name in test_env_names:
make_env_args = {
"env_key": test_env_name,
"seed": args.test_seed,
"env_args": test_env_args,
}
testers.append(utils.Tester(
make_env_args, args.test_seed, args.test_episodes, model_dir, acmodel, preprocess_obss, device)
)
# test_env = utils.make_env(test_env_name, args.test_seed, env_args=test_env_args)
# test_envs.append(test_env)
# init tester
# testers.append(utils.Tester(test_env, args.test_seed, args.test_episodes, model_dir, acmodel, preprocess_obss, device))
if args.continue_train:
for tester in testers:
tester.load()
# Save config
env_args_ = {k: v.__repr__() if k == "curriculum" else v for k, v in env_args.items()}
test_env_args_ = {k: v.__repr__() if k == "curriculum" else v for k, v in test_env_args.items()}
config_dict = {
"seed": args.seed,
"env": args.env,
"env_args": env_args_,
"test_seed": args.test_seed,
"test_env": args.test_set_name,
"test_env_args": test_env_args_
}
config_dict.update(algo.get_config_dict())
config_dict.update(acmodel.get_config_dict())
with open(model_dir+'/config.json', 'w') as fp:
json.dump(config_dict, fp)
# Train model
num_frames = status["num_frames"]
update = status["update"]
start_time = time.time()
log_add_headers = num_frames == 0 or not args.continue_train
long_term_save_interval = 5000000
if args.continue_train:
# set next long term save interval
next_long_term_save = (1 + num_frames // long_term_save_interval) * long_term_save_interval
else:
next_long_term_save = 0 # for long term logging
while num_frames < args.frames:
# Update model parameters
update_start_time = time.time()
# print("current_seed_pre_train:", np.random.get_state()[1][0])
exps, logs1 = algo.collect_experiences()
logs2 = algo.update_parameters(exps)
logs = {**logs1, **logs2}
update_end_time = time.time()
num_frames += logs["num_frames"]
update += 1
NPC_intro = np.mean(logs["NPC_introduced_to"])
# Print logs
if update % args.log_interval == 0:
fps = logs["num_frames"]/(update_end_time - update_start_time)
duration = int(time.time() - start_time)
return_per_episode = utils.synthesize(logs["return_per_episode"])
extrinsic_return_per_episode = utils.synthesize(logs["extrinsic_return_per_episode"])
exploration_bonus_per_episode = utils.synthesize(logs["exploration_bonus_per_episode"])
success_rate = utils.synthesize(logs["success_rate_per_episode"])
curriculum_max_success_rate = utils.synthesize(logs["curriculum_max_mean_perf_per_episode"])
curriculum_param = utils.synthesize(logs["curriculum_param_per_episode"])
rreturn_per_episode = utils.synthesize(logs["reshaped_return_per_episode"])
num_frames_per_episode = utils.synthesize(logs["num_frames_per_episode"])
# intrinsic_reward_perf = utils.synthesize(logs["intr_reward_perf"])
# intrinsic_reward_perf_ = utils.synthesize(logs["intr_reward_perf_"])
intrinsic_reward_perf = logs["intr_reward_perf"]
intrinsic_reward_perf_ = logs["intr_reward_perf_"]
lr_ = logs["lr"]
time_now = int(datetime.datetime.now().strftime("%d%m%Y%H%M%S"))
header = ["update", "frames", "FPS", "duration", "time"]
data = [update, num_frames, fps, duration, time_now]
data_to_print = [update, num_frames, fps, duration, time_now]
header += ["success_rate_" + key for key in success_rate.keys()]
data += success_rate.values()
data_to_print += success_rate.values()
header += ["curriculum_max_success_rate_" + key for key in curriculum_max_success_rate.keys()]
data += curriculum_max_success_rate.values()
if args.acl:
data_to_print += curriculum_max_success_rate.values()
header += ["curriculum_param_" + key for key in curriculum_param.keys()]
data += curriculum_param.values()
if args.acl:
data_to_print += curriculum_param.values()
header += ["extrinsic_return_" + key for key in extrinsic_return_per_episode.keys()]
data += extrinsic_return_per_episode.values()
data_to_print += extrinsic_return_per_episode.values()
# turn on
header += ["exploration_bonus_" + key for key in exploration_bonus_per_episode.keys()]
data += exploration_bonus_per_episode.values()
data_to_print += exploration_bonus_per_episode.values()
header += ["rreturn_" + key for key in rreturn_per_episode.keys()]
data += rreturn_per_episode.values()
data_to_print += rreturn_per_episode.values()
header += ["intrinsic_reward_perf_"]
data += [intrinsic_reward_perf]
# data_to_print += [intrinsic_reward_perf]
header += ["intrinsic_reward_perf2_"]
data += [intrinsic_reward_perf_]
# data_to_print += [intrinsic_reward_perf_]
# header += ["num_frames_" + key for key in num_frames_per_episode.keys()]
# data += num_frames_per_episode.values()
header += ["NPC_intro"]
data += [NPC_intro]
data_to_print += [NPC_intro]
header += ["lr"]
data += [lr_]
data_to_print += [lr_]
# header += ["entropy", "value", "policy_loss", "value_loss", "grad_norm"]
# data += [logs["entropy"], logs["value"], logs["policy_loss"], logs["value_loss"], logs["grad_norm"]]
# curr_history_len = len(algo.env.envs[0].curriculum.performance_history)
# header += ["curr_history_len"]
# data += [curr_history_len]
txt_logger.info("".join([
"U {} | F {:06} | FPS {:04.0f} | D {} | T {} ",
"| SR:μσmM {:.2f} {:.1f} {:.1f} {:.1f} ",
"| CurMaxSR:μσmM {:.2f} {:.1f} {:.1f} {:.1f} " if args.acl else "",
"| CurPhase:μσmM {:.2f} {:.1f} {:.1f} {:.1f} " if args.acl else "",
"| ExR:μσmM {:.2f} {:.1f} {:.1f} {:.1f} ",
"| InR:μσmM {:.2f} {:.1f} {:.1f} {:.1f} ",
"| rR:μσmM {:.6f} {:.1f} {:.1f} {:.1f} ",
# "| irp:μσmM {:.6f} {:.2f} {:.2f} {:.2f} ",
# "| irp_:μσmM {:.6f} {:.2f} {:.2f} {:.2f} ",
# "| F:μσmM {:.1f} {:.1f} {} {} ",
"| NPC_intro: {:.3f}",
"| lr: {:.5f}",
# "| cur_his_len: {:.5f}" if args.acl else "",
# "| H {:.3f} | V {:.3f} | pL {:.3f} | vL {:.3f} | ∇ {:.3f}"
]).format(*data_to_print))
header += ["return_" + key for key in return_per_episode.keys()]
data += return_per_episode.values()
if log_add_headers:
csv_logger.writerow(header)
log_add_headers = False
csv_logger.writerow(data)
csv_file.flush()
# Save status
long_term_save = False
if num_frames >= next_long_term_save:
next_long_term_save += long_term_save_interval
long_term_save = True
if (args.save_interval > 0 and update % args.save_interval == 0) or long_term_save:
# continuing train works best when save_interval == log_interval, the csv is cleaner wo redundancies
status = {"num_frames": num_frames, "update": update}
algo_status = algo.get_status_dict()
status = {**status, **algo_status}
if hasattr(preprocess_obss, "vocab"):
status["vocab"] = preprocess_obss.vocab.vocab
status["env_args"] = env_args
if long_term_save:
utils.save_status(status, model_dir, num_frames=num_frames)
utils.save_model(acmodel, model_dir, num_frames=num_frames)
txt_logger.info("Status and Model saved for {} frames".format(num_frames))
else:
utils.save_status(status, model_dir)
utils.save_model(acmodel, model_dir)
txt_logger.info("Status and Model saved")
if args.test_interval > 0 and (update % args.test_interval == 0 or update == 1):
txt_logger.info(f"Testing at update {update}.")
test_success_rates = []
for tester in testers:
mean_success_rate, mean_rewards = tester.test_agent(num_frames)
test_success_rates.append(mean_success_rate)
txt_logger.info(f"\t{tester.envs[0].spec.id} -> {mean_success_rate} (SR)")
tester.dump()
if len(testers):
txt_logger.info(f"Test set SR: {np.array(test_success_rates).mean()}")
# save at the end
status = {"num_frames": num_frames, "update": update}
algo_status = algo.get_status_dict()
status = {**status, **algo_status}
if hasattr(preprocess_obss, "vocab"):
status["vocab"] = preprocess_obss.vocab.vocab
status["env_args"] = env_args
utils.save_status(status, model_dir)
utils.save_model(acmodel, model_dir)
txt_logger.info("Status and Model saved at the end")