grg's picture
Cleaned old git history
be5548b
raw
history blame
3.75 kB
import numpy
import torch
import torch.nn.functional as F
from torch_ac.algos.base import BaseAlgo
class A2CAlgo(BaseAlgo):
"""The Advantage Actor-Critic algorithm."""
def __init__(self, envs, acmodel, device=None, num_frames_per_proc=None, discount=0.99, lr=0.01, gae_lambda=0.95,
entropy_coef=0.01, value_loss_coef=0.5, max_grad_norm=0.5, recurrence=4,
rmsprop_alpha=0.99, rmsprop_eps=1e-8, preprocess_obss=None, reshape_reward=None):
num_frames_per_proc = num_frames_per_proc or 8
super().__init__(envs, acmodel, device, num_frames_per_proc, discount, lr, gae_lambda, entropy_coef,
value_loss_coef, max_grad_norm, recurrence, preprocess_obss, reshape_reward)
self.optimizer = torch.optim.RMSprop(self.acmodel.parameters(), lr,
alpha=rmsprop_alpha, eps=rmsprop_eps)
raise NotImplementedError("This needs to be refactored to work with mm actions")
def update_parameters(self, exps):
# Compute starting indexes
inds = self._get_starting_indexes()
# Initialize update values
update_entropy = 0
update_value = 0
update_policy_loss = 0
update_value_loss = 0
update_loss = 0
# Initialize memory
if self.acmodel.recurrent:
memory = exps.memory[inds]
for i in range(self.recurrence):
# Create a sub-batch of experience
sb = exps[inds + i]
# Compute loss
if self.acmodel.recurrent:
dist, value, memory = self.acmodel(sb.obs, memory * sb.mask)
else:
dist, value = self.acmodel(sb.obs)
entropy = dist.entropy().mean()
policy_loss = -(dist.log_prob(sb.action) * sb.advantage).mean()
value_loss = (value - sb.returnn).pow(2).mean()
loss = policy_loss - self.entropy_coef * entropy + self.value_loss_coef * value_loss
# Update batch values
update_entropy += entropy.item()
update_value += value.mean().item()
update_policy_loss += policy_loss.item()
update_value_loss += value_loss.item()
update_loss += loss
# Update update values
update_entropy /= self.recurrence
update_value /= self.recurrence
update_policy_loss /= self.recurrence
update_value_loss /= self.recurrence
update_loss /= self.recurrence
# Update actor-critic
self.optimizer.zero_grad()
update_loss.backward()
update_grad_norm = sum(p.grad.data.norm(2) ** 2 for p in self.acmodel.parameters()) ** 0.5
torch.nn.utils.clip_grad_norm_(self.acmodel.parameters(), self.max_grad_norm)
self.optimizer.step()
# Log some values
logs = {
"entropy": update_entropy,
"value": update_value,
"policy_loss": update_policy_loss,
"value_loss": update_value_loss,
"grad_norm": update_grad_norm
}
return logs
def _get_starting_indexes(self):
"""Gives the indexes of the observations given to the model and the
experiences used to compute the loss at first.
The indexes are the integers from 0 to `self.num_frames` with a step of
`self.recurrence`. If the model is not recurrent, they are all the
integers from 0 to `self.num_frames`.
Returns
-------
starting_indexes : list of int
the indexes of the experiences to be used at first
"""
starting_indexes = numpy.arange(0, self.num_frames, self.recurrence)
return starting_indexes