import numpy import torch import torch.nn.functional as F from torch_ac.algos.base import BaseAlgo class A2CAlgo(BaseAlgo): """The Advantage Actor-Critic algorithm.""" def __init__(self, envs, acmodel, device=None, num_frames_per_proc=None, discount=0.99, lr=0.01, gae_lambda=0.95, entropy_coef=0.01, value_loss_coef=0.5, max_grad_norm=0.5, recurrence=4, rmsprop_alpha=0.99, rmsprop_eps=1e-8, preprocess_obss=None, reshape_reward=None): num_frames_per_proc = num_frames_per_proc or 8 super().__init__(envs, acmodel, device, num_frames_per_proc, discount, lr, gae_lambda, entropy_coef, value_loss_coef, max_grad_norm, recurrence, preprocess_obss, reshape_reward) self.optimizer = torch.optim.RMSprop(self.acmodel.parameters(), lr, alpha=rmsprop_alpha, eps=rmsprop_eps) raise NotImplementedError("This needs to be refactored to work with mm actions") def update_parameters(self, exps): # Compute starting indexes inds = self._get_starting_indexes() # Initialize update values update_entropy = 0 update_value = 0 update_policy_loss = 0 update_value_loss = 0 update_loss = 0 # Initialize memory if self.acmodel.recurrent: memory = exps.memory[inds] for i in range(self.recurrence): # Create a sub-batch of experience sb = exps[inds + i] # Compute loss if self.acmodel.recurrent: dist, value, memory = self.acmodel(sb.obs, memory * sb.mask) else: dist, value = self.acmodel(sb.obs) entropy = dist.entropy().mean() policy_loss = -(dist.log_prob(sb.action) * sb.advantage).mean() value_loss = (value - sb.returnn).pow(2).mean() loss = policy_loss - self.entropy_coef * entropy + self.value_loss_coef * value_loss # Update batch values update_entropy += entropy.item() update_value += value.mean().item() update_policy_loss += policy_loss.item() update_value_loss += value_loss.item() update_loss += loss # Update update values update_entropy /= self.recurrence update_value /= self.recurrence update_policy_loss /= self.recurrence update_value_loss /= self.recurrence update_loss /= self.recurrence # Update actor-critic self.optimizer.zero_grad() update_loss.backward() update_grad_norm = sum(p.grad.data.norm(2) ** 2 for p in self.acmodel.parameters()) ** 0.5 torch.nn.utils.clip_grad_norm_(self.acmodel.parameters(), self.max_grad_norm) self.optimizer.step() # Log some values logs = { "entropy": update_entropy, "value": update_value, "policy_loss": update_policy_loss, "value_loss": update_value_loss, "grad_norm": update_grad_norm } return logs def _get_starting_indexes(self): """Gives the indexes of the observations given to the model and the experiences used to compute the loss at first. The indexes are the integers from 0 to `self.num_frames` with a step of `self.recurrence`. If the model is not recurrent, they are all the integers from 0 to `self.num_frames`. Returns ------- starting_indexes : list of int the indexes of the experiences to be used at first """ starting_indexes = numpy.arange(0, self.num_frames, self.recurrence) return starting_indexes