import numpy import torch import torch.nn.functional as F from torch_ac.intrinsic_reward_models import compute_forward_dynamics_loss, compute_inverse_dynamics_loss from sklearn.metrics import f1_score from torch_ac.algos.base import BaseAlgo def compute_balance_mask(target, n_classes): if target.float().var() == 0: # all the same class, don't train at all return torch.zeros_like(target).detach() # compute the balance mask per_class_n = torch.bincount(target, minlength=n_classes) # number of times the least common class (that appeared) appeared n_for_each_class = per_class_n[torch.nonzero(per_class_n)].min() # undersample other classes per_class_n = n_for_each_class # sample each class that many times balanced_indexes_ = [] for c in range(n_classes): c_indexes = torch.where(target == c)[0] if len(c_indexes) == 0: continue # c_sampled_indexes = c_indexes[torch.randint(len(c_indexes), (per_class_n,))] c_sampled_indexes = c_indexes[torch.randperm(len(c_indexes))[:per_class_n]] balanced_indexes_.append(c_sampled_indexes) balanced_indexes = torch.concat(balanced_indexes_) balance_mask = torch.zeros_like(target) balance_mask[balanced_indexes] = 1.0 return balance_mask.detach() class PPOAlgo(BaseAlgo): """The Proximal Policy Optimization algorithm ([Schulman et al., 2015](https://arxiv.org/abs/1707.06347)).""" def __init__(self, envs, acmodel, device=None, num_frames_per_proc=None, discount=0.99, lr=0.001, gae_lambda=0.95, entropy_coef=0.01, value_loss_coef=0.5, max_grad_norm=0.5, recurrence=4, adam_eps=1e-5, clip_eps=0.2, epochs=4, batch_size=256, preprocess_obss=None, reshape_reward=None, exploration_bonus=False, exploration_bonus_params=None, expert_exploration_bonus=False, episodic_exploration_bonus=True, exploration_bonus_type="lang", exploration_bonus_tanh=None, clipped_rewards=False, intrinsic_reward_epochs=0, # default is set to fit RND intrinsic_reward_coef=0.1, intrinsic_reward_learning_rate=0.0001, intrinsic_reward_momentum=0, intrinsic_reward_epsilon=0.01, intrinsic_reward_alpha=0.99, intrinsic_reward_max_grad_norm=40, intrinsic_reward_loss_coef=0.1, intrinsic_reward_forward_loss_coef=10, intrinsic_reward_inverse_loss_coef=0.1, reset_rnd_ride_at_phase=False, balance_moa_training=False, moa_memory_dim=128, schedule_lr=False, lr_schedule_end_frames=0, end_lr=0.0, ): num_frames_per_proc = num_frames_per_proc or 128 # save config self.config = locals() super().__init__( envs=envs, acmodel=acmodel, device=device, num_frames_per_proc=num_frames_per_proc, discount=discount, lr=lr, gae_lambda=gae_lambda, entropy_coef=entropy_coef, value_loss_coef=value_loss_coef, max_grad_norm=max_grad_norm, recurrence=recurrence, preprocess_obss=preprocess_obss, reshape_reward=reshape_reward, exploration_bonus=exploration_bonus, expert_exploration_bonus=expert_exploration_bonus, episodic_exploration_bonus=episodic_exploration_bonus, exploration_bonus_params=exploration_bonus_params, exploration_bonus_tanh=exploration_bonus_tanh, exploration_bonus_type=exploration_bonus_type, clipped_rewards=clipped_rewards, intrinsic_reward_loss_coef=intrinsic_reward_loss_coef, intrinsic_reward_coef=intrinsic_reward_coef, intrinsic_reward_learning_rate=intrinsic_reward_learning_rate, intrinsic_reward_momentum=intrinsic_reward_momentum, intrinsic_reward_epsilon=intrinsic_reward_epsilon, intrinsic_reward_alpha=intrinsic_reward_alpha, intrinsic_reward_max_grad_norm=intrinsic_reward_max_grad_norm, intrinsic_reward_forward_loss_coef=intrinsic_reward_forward_loss_coef, intrinsic_reward_inverse_loss_coef=intrinsic_reward_inverse_loss_coef, balance_moa_training=balance_moa_training, moa_memory_dim=moa_memory_dim, reset_rnd_ride_at_phase=reset_rnd_ride_at_phase, ) self.clip_eps = clip_eps self.epochs = epochs self.intrinsic_reward_epochs = intrinsic_reward_epochs self.batch_size = batch_size assert self.batch_size % self.recurrence == 0 if self.exploration_bonus and "soc_inf" in self.exploration_bonus_type: adam_params = list(dict.fromkeys(list(self.acmodel.parameters()) + list(self.moa_net.parameters()))) self.optimizer = torch.optim.Adam(adam_params, lr, eps=adam_eps) else: self.optimizer = torch.optim.Adam(self.acmodel.parameters(), lr, eps=adam_eps) self.schedule_lr = schedule_lr self.lr_schedule_end_frames = lr_schedule_end_frames assert end_lr <= lr def lr_lambda(step): if self.lr_schedule_end_frames == 0: # no schedule return 1 end_factor = end_lr/lr final_diminished_factor = 1-end_factor n_frames = self.step_to_n_frames(step) return 1 - (min(n_frames, self.lr_schedule_end_frames) / self.lr_schedule_end_frames) * final_diminished_factor self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda) self.batch_num = 0 def load_status_dict(self, status): super().load_status_dict(status) if "optimizer_state" in status: self.optimizer.load_state_dict(status["optimizer_state"]) if "lr_scheduler_state" in status: self.lr_scheduler.load_state_dict(status["lr_scheduler_state"]) def get_status_dict(self): status_dict = super().get_status_dict() status_dict["optimizer_state"] = self.optimizer.state_dict() status_dict["lr_scheduler_state"] = self.lr_scheduler.state_dict() return status_dict def update_parameters(self, exps): # Collect experiences self.acmodel.train() self.update_epoch += 1 intr_rew_perf = torch.tensor(0.0) intr_rew_perf_ = 0.0 social_influence = False if self.exploration_bonus: if "rnd" in self.exploration_bonus_type: imgs = exps.obs.image.reshape( self.num_procs, self.num_frames_per_proc, *exps.obs.image.shape[1:] ).transpose(0, 1) mask = exps.mask.reshape( self.num_procs, self.num_frames_per_proc, 1, ).transpose(0, 1) self.random_target_network.train() self.predictor_network.train() random_embedding = self.random_target_network(imgs).reshape(self.num_frames_per_proc, self.num_procs, 128) predicted_embedding = self.predictor_network(imgs).reshape(self.num_frames_per_proc, self.num_procs, 128) intr_rew_loss = self.intrinsic_reward_loss_coef * compute_forward_dynamics_loss(mask*predicted_embedding, mask*random_embedding.detach()) # update the intr rew models self.intrinsic_reward_optimizer.zero_grad() intr_rew_loss.backward() torch.nn.utils.clip_grad_norm_(self.predictor_network.parameters(), self.intrinsic_reward_max_grad_norm) self.intrinsic_reward_optimizer.step() intr_rew_perf = intr_rew_loss elif "ride" in self.exploration_bonus_type: imgs = exps.obs.image.reshape( self.num_procs, self.num_frames_per_proc, *exps.obs.image.shape[1:] ).transpose(0, 1) mask = exps.mask.reshape( self.num_procs, self.num_frames_per_proc ).transpose(0, 1).to(torch.int64) # we only take the first (primitive) action action = exps.action[:, 0].reshape( self.num_procs, self.num_frames_per_proc ).transpose(0, 1).to(torch.int64) _mask = mask[:-1] _obs = imgs[:-1] _actions = action[:-1] _next_obs = imgs[1:] self.state_embedding_model.train() self.forward_dynamics_model.train() self.inverse_dynamics_model.train() state_emb = self.state_embedding_model(_obs.to(device=self.device)) next_state_emb = self.state_embedding_model(_next_obs.to(device=self.device)) pred_next_state_emb = self.forward_dynamics_model(state_emb, _actions.to(device=self.device)) pred_actions = self.inverse_dynamics_model(state_emb, next_state_emb) forward_dynamics_loss = self.intrinsic_reward_forward_loss_coef * \ compute_forward_dynamics_loss(_mask[:,:,None]*pred_next_state_emb, _mask[:,:,None]*next_state_emb) inverse_dynamics_loss = self.intrinsic_reward_inverse_loss_coef * \ compute_inverse_dynamics_loss(_mask[:,:,None]*pred_actions, _mask*_actions) # update the intr rew models self.state_embedding_optimizer.zero_grad() self.forward_dynamics_optimizer.zero_grad() self.inverse_dynamics_optimizer.zero_grad() intr_rew_loss = forward_dynamics_loss + inverse_dynamics_loss intr_rew_loss.backward() torch.nn.utils.clip_grad_norm_(self.state_embedding_model.parameters(), self.intrinsic_reward_max_grad_norm) torch.nn.utils.clip_grad_norm_(self.forward_dynamics_model.parameters(), self.intrinsic_reward_max_grad_norm) torch.nn.utils.clip_grad_norm_(self.inverse_dynamics_model.parameters(), self.intrinsic_reward_max_grad_norm) self.state_embedding_optimizer.step() self.forward_dynamics_optimizer.step() self.inverse_dynamics_optimizer.step() intr_rew_perf = intr_rew_loss elif "soc_inf" in self.exploration_bonus_type: # trained together with the policy social_influence = True self.moa_net.train() if self.intrinsic_reward_epochs > 0: raise DeprecationWarning(f"Moa must be trained with the agent. intrinsic_reward_epochs must be 0 but is {self.intrinsic_reward_epochs}") for _ in range(self.epochs): # Initialize log values log_entropies = [] log_values = [] log_policy_losses = [] log_value_losses = [] log_grad_norms = [] log_lrs = [] for inds in self._get_batches_starting_indexes(): # Initialize batch values batch_entropy = 0 batch_value = 0 batch_policy_loss = 0 batch_value_loss = 0 batch_loss = 0 # intr reward metrics batch_intr_rew_loss = 0 batch_intr_rew_acc = 0 batch_intr_rew_f1 = 0 # Initialize memory if self.acmodel.recurrent: memory = exps.memory[inds] if social_influence: # Initialize moa memory moa_memory = exps.moa_memory[inds] prev_npc_prim_action = None for i in range(self.recurrence): # Create a sub-batch of experience sb = exps[inds + i] # Compute loss if self.acmodel.recurrent: dist, value, memory, policy_embeddings = self.acmodel(sb.obs, memory * sb.mask, return_embeddings=True) else: dist, value, policy_embeddings = self.acmodel(sb.obs, return_embeddings=True) losses = [] for head_i, d in enumerate(dist): action_masks = self.acmodel.calculate_action_gradient_masks(sb.action).type(sb.log_prob.type()) entropy = (d.entropy() * action_masks[:, head_i]).mean() ratio = torch.exp(d.log_prob(sb.action[:, head_i]) - sb.log_prob[:, head_i]) surr1 = ratio * sb.advantage surr2 = torch.clamp(ratio, 1.0 - self.clip_eps, 1.0 + self.clip_eps) * sb.advantage policy_loss = ( -torch.min(surr1, surr2) * action_masks[:, head_i] ).mean() value_clipped = sb.value + torch.clamp(value - sb.value, -self.clip_eps, self.clip_eps) surr1 = (value - sb.returnn).pow(2) surr2 = (value_clipped - sb.returnn).pow(2) value_loss = ( torch.max(surr1, surr2) * action_masks[:, head_i] ).mean() head_loss = policy_loss - self.entropy_coef * entropy + self.value_loss_coef * value_loss losses.append(head_loss) if social_influence: # moa loss imgs = sb.obs.image mask = sb.mask.to(torch.int64) # we only take the first (primitive) action agent_action = sb.action.to(torch.int64) infos = numpy.array(sb.infos) npc_prim_action = torch.tensor( numpy.array([self.fn_name_to_npc_prim_act[info["NPC_prim_action"]] for info in infos])) npc_utt_action = torch.tensor( numpy.array([self.npc_utterance_to_id[info["NPC_utterance"]] for info in infos])) assert infos.shape == imgs.shape[:1] == agent_action.shape[:1] # [bs] if i == 0: prev_npc_prim_action = npc_prim_action prev_npc_utt_action = npc_utt_action else: # compute loss and train moa net if self.utterance_moa_net: # transform to long logits target = npc_prim_action.detach().to(self.device) * self.num_npc_utterance_actions + npc_utt_action.detach().to(self.device) else: target = npc_prim_action.detach().to(self.device) if self.balance_moa_training: balance_mask = compute_balance_mask(target, n_classes=self.num_npc_all_actions) else: balance_mask = torch.ones_like(target) moa_predictions_logs, moa_memory = self.moa_net( embeddings=policy_embeddings, npc_previous_prim_actions=prev_npc_prim_action.detach().to(self.device), npc_previous_utterance_actions=prev_npc_utt_action.detach().to(self.device) if self.utterance_moa_net else None, agent_actions=agent_action.detach().to(self.device), memory=moa_memory * mask, ) # moa_predictions_logs = moa_predictions_logs.reshape([*prev_shape, -1]) # is this needed # loss ce_loss = torch.nn.CrossEntropyLoss(reduction='none') intr_rew_loss = ( balance_mask * mask * ce_loss( input=moa_predictions_logs, target=target, )).mean() * self.intrinsic_reward_loss_coef preds = moa_predictions_logs.detach().argmax(dim=-1) intr_rew_f1 = f1_score( y_pred=preds.detach().cpu().numpy(), y_true=target.detach().cpu().numpy(), average="macro" ) intr_rew_acc = ( torch.argmax(moa_predictions_logs.to(self.device), dim=-1) == target ).to(float).mean() batch_intr_rew_loss += intr_rew_loss batch_intr_rew_acc += intr_rew_acc.detach().cpu().numpy().mean() batch_intr_rew_f1 += intr_rew_f1 losses.append(intr_rew_loss) # trained with the policy optimizer loss = torch.stack(losses).mean() # Update batch values batch_entropy += entropy.item() batch_value += value.mean().item() batch_policy_loss += policy_loss.item() batch_value_loss += value_loss.item() batch_loss += loss # Update memories for next epoch # assert self.acmodel.recurrent == (self.recurrence > 1) if self.acmodel.recurrent and i < self.recurrence - 1: exps.memory[inds + i + 1] = memory.detach() if social_influence and i < self.recurrence - 1: exps.moa_memory[inds + i + 1] = moa_memory.detach() # Update batch values batch_entropy /= self.recurrence batch_value /= self.recurrence batch_policy_loss /= self.recurrence batch_value_loss /= self.recurrence batch_loss /= self.recurrence # Update actor-critic self.optimizer.zero_grad() batch_loss.backward() grad_norm = sum(p.grad.data.norm(2).item() ** 2 for p in self.acmodel.parameters()) ** 0.5 torch.nn.utils.clip_grad_norm_(self.acmodel.parameters(), self.max_grad_norm) self.optimizer.step() self.lr_scheduler.step() if social_influence: # recurrence-1 because we skipped the first step batch_intr_rew_loss /= self.recurrence - 1 batch_intr_rew_acc /= self.recurrence - 1 batch_intr_rew_f1 /= self.recurrence - 1 intr_rew_perf_ = batch_intr_rew_f1 intr_rew_perf = batch_intr_rew_acc # Update log values log_entropies.append(batch_entropy) log_values.append(batch_value) log_policy_losses.append(batch_policy_loss) log_value_losses.append(batch_value_loss) log_grad_norms.append(grad_norm) log_lrs.append(self.optimizer.param_groups[0]['lr']) # Log some values logs = { "entropy": numpy.mean(log_entropies), "value": numpy.mean(log_values), "policy_loss": numpy.mean(log_policy_losses), "value_loss": numpy.mean(log_value_losses), "grad_norm": numpy.mean(log_grad_norms), "intr_reward_perf": intr_rew_perf, "intr_reward_perf_": intr_rew_perf_, "lr": numpy.mean(log_lrs), } return logs def _get_batches_starting_indexes(self): """Gives, for each batch, the indexes of the observations given to the model and the experiences used to compute the loss at first. First, the indexes are the integers from 0 to `self.num_frames` with a step of `self.recurrence`, shifted by `self.recurrence//2` one time in two for having more diverse batches. Then, the indexes are splited into the different batches. Returns ------- batches_starting_indexes : list of list of int the indexes of the experiences to be used at first for each batch """ indexes = numpy.arange(0, self.num_frames, self.recurrence) indexes = numpy.random.permutation(indexes) # Shift starting indexes by self.recurrence//2 half the time if self.batch_num % 2 == 1: indexes = indexes[(indexes + self.recurrence) % self.num_frames_per_proc != 0] indexes += self.recurrence // 2 self.batch_num += 1 num_indexes = self.batch_size // self.recurrence batches_starting_indexes = [indexes[i:i+num_indexes] for i in range(0, len(indexes), num_indexes)] return batches_starting_indexes def get_config_dict(self): del self.config['envs'] del self.config['acmodel'] del self.config['__class__'] del self.config['self'] del self.config['preprocess_obss'] del self.config['device'] return self.config