grg's picture
Cleaned old git history
be5548b
raw
history blame
21.6 kB
import numpy
import torch
import torch.nn.functional as F
from torch_ac.intrinsic_reward_models import compute_forward_dynamics_loss, compute_inverse_dynamics_loss
from sklearn.metrics import f1_score
from torch_ac.algos.base import BaseAlgo
def compute_balance_mask(target, n_classes):
if target.float().var() == 0:
# all the same class, don't train at all
return torch.zeros_like(target).detach()
# compute the balance mask
per_class_n = torch.bincount(target, minlength=n_classes)
# number of times the least common class (that appeared) appeared
n_for_each_class = per_class_n[torch.nonzero(per_class_n)].min()
# undersample other classes
per_class_n = n_for_each_class # sample each class that many times
balanced_indexes_ = []
for c in range(n_classes):
c_indexes = torch.where(target == c)[0]
if len(c_indexes) == 0:
continue
# c_sampled_indexes = c_indexes[torch.randint(len(c_indexes), (per_class_n,))]
c_sampled_indexes = c_indexes[torch.randperm(len(c_indexes))[:per_class_n]]
balanced_indexes_.append(c_sampled_indexes)
balanced_indexes = torch.concat(balanced_indexes_)
balance_mask = torch.zeros_like(target)
balance_mask[balanced_indexes] = 1.0
return balance_mask.detach()
class PPOAlgo(BaseAlgo):
"""The Proximal Policy Optimization algorithm
([Schulman et al., 2015](https://arxiv.org/abs/1707.06347))."""
def __init__(self, envs, acmodel, device=None, num_frames_per_proc=None, discount=0.99, lr=0.001, gae_lambda=0.95,
entropy_coef=0.01, value_loss_coef=0.5, max_grad_norm=0.5, recurrence=4,
adam_eps=1e-5, clip_eps=0.2, epochs=4, batch_size=256, preprocess_obss=None,
reshape_reward=None, exploration_bonus=False, exploration_bonus_params=None,
expert_exploration_bonus=False, episodic_exploration_bonus=True, exploration_bonus_type="lang",
exploration_bonus_tanh=None, clipped_rewards=False, intrinsic_reward_epochs=0,
# default is set to fit RND
intrinsic_reward_coef=0.1,
intrinsic_reward_learning_rate=0.0001,
intrinsic_reward_momentum=0,
intrinsic_reward_epsilon=0.01,
intrinsic_reward_alpha=0.99,
intrinsic_reward_max_grad_norm=40,
intrinsic_reward_loss_coef=0.1,
intrinsic_reward_forward_loss_coef=10,
intrinsic_reward_inverse_loss_coef=0.1,
reset_rnd_ride_at_phase=False,
balance_moa_training=False,
moa_memory_dim=128,
schedule_lr=False,
lr_schedule_end_frames=0,
end_lr=0.0,
):
num_frames_per_proc = num_frames_per_proc or 128
# save config
self.config = locals()
super().__init__(
envs=envs,
acmodel=acmodel,
device=device,
num_frames_per_proc=num_frames_per_proc,
discount=discount,
lr=lr,
gae_lambda=gae_lambda,
entropy_coef=entropy_coef,
value_loss_coef=value_loss_coef,
max_grad_norm=max_grad_norm,
recurrence=recurrence,
preprocess_obss=preprocess_obss,
reshape_reward=reshape_reward,
exploration_bonus=exploration_bonus,
expert_exploration_bonus=expert_exploration_bonus,
episodic_exploration_bonus=episodic_exploration_bonus,
exploration_bonus_params=exploration_bonus_params,
exploration_bonus_tanh=exploration_bonus_tanh,
exploration_bonus_type=exploration_bonus_type,
clipped_rewards=clipped_rewards,
intrinsic_reward_loss_coef=intrinsic_reward_loss_coef,
intrinsic_reward_coef=intrinsic_reward_coef,
intrinsic_reward_learning_rate=intrinsic_reward_learning_rate,
intrinsic_reward_momentum=intrinsic_reward_momentum,
intrinsic_reward_epsilon=intrinsic_reward_epsilon,
intrinsic_reward_alpha=intrinsic_reward_alpha,
intrinsic_reward_max_grad_norm=intrinsic_reward_max_grad_norm,
intrinsic_reward_forward_loss_coef=intrinsic_reward_forward_loss_coef,
intrinsic_reward_inverse_loss_coef=intrinsic_reward_inverse_loss_coef,
balance_moa_training=balance_moa_training,
moa_memory_dim=moa_memory_dim,
reset_rnd_ride_at_phase=reset_rnd_ride_at_phase,
)
self.clip_eps = clip_eps
self.epochs = epochs
self.intrinsic_reward_epochs = intrinsic_reward_epochs
self.batch_size = batch_size
assert self.batch_size % self.recurrence == 0
if self.exploration_bonus and "soc_inf" in self.exploration_bonus_type:
adam_params = list(dict.fromkeys(list(self.acmodel.parameters()) + list(self.moa_net.parameters())))
self.optimizer = torch.optim.Adam(adam_params, lr, eps=adam_eps)
else:
self.optimizer = torch.optim.Adam(self.acmodel.parameters(), lr, eps=adam_eps)
self.schedule_lr = schedule_lr
self.lr_schedule_end_frames = lr_schedule_end_frames
assert end_lr <= lr
def lr_lambda(step):
if self.lr_schedule_end_frames == 0:
# no schedule
return 1
end_factor = end_lr/lr
final_diminished_factor = 1-end_factor
n_frames = self.step_to_n_frames(step)
return 1 - (min(n_frames, self.lr_schedule_end_frames) / self.lr_schedule_end_frames) * final_diminished_factor
self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda)
self.batch_num = 0
def load_status_dict(self, status):
super().load_status_dict(status)
if "optimizer_state" in status:
self.optimizer.load_state_dict(status["optimizer_state"])
if "lr_scheduler_state" in status:
self.lr_scheduler.load_state_dict(status["lr_scheduler_state"])
def get_status_dict(self):
status_dict = super().get_status_dict()
status_dict["optimizer_state"] = self.optimizer.state_dict()
status_dict["lr_scheduler_state"] = self.lr_scheduler.state_dict()
return status_dict
def update_parameters(self, exps):
# Collect experiences
self.acmodel.train()
self.update_epoch += 1
intr_rew_perf = torch.tensor(0.0)
intr_rew_perf_ = 0.0
social_influence = False
if self.exploration_bonus:
if "rnd" in self.exploration_bonus_type:
imgs = exps.obs.image.reshape(
self.num_procs, self.num_frames_per_proc, *exps.obs.image.shape[1:]
).transpose(0, 1)
mask = exps.mask.reshape(
self.num_procs, self.num_frames_per_proc, 1,
).transpose(0, 1)
self.random_target_network.train()
self.predictor_network.train()
random_embedding = self.random_target_network(imgs).reshape(self.num_frames_per_proc, self.num_procs, 128)
predicted_embedding = self.predictor_network(imgs).reshape(self.num_frames_per_proc, self.num_procs, 128)
intr_rew_loss = self.intrinsic_reward_loss_coef * compute_forward_dynamics_loss(mask*predicted_embedding, mask*random_embedding.detach())
# update the intr rew models
self.intrinsic_reward_optimizer.zero_grad()
intr_rew_loss.backward()
torch.nn.utils.clip_grad_norm_(self.predictor_network.parameters(), self.intrinsic_reward_max_grad_norm)
self.intrinsic_reward_optimizer.step()
intr_rew_perf = intr_rew_loss
elif "ride" in self.exploration_bonus_type:
imgs = exps.obs.image.reshape(
self.num_procs, self.num_frames_per_proc, *exps.obs.image.shape[1:]
).transpose(0, 1)
mask = exps.mask.reshape(
self.num_procs, self.num_frames_per_proc
).transpose(0, 1).to(torch.int64)
# we only take the first (primitive) action
action = exps.action[:, 0].reshape(
self.num_procs, self.num_frames_per_proc
).transpose(0, 1).to(torch.int64)
_mask = mask[:-1]
_obs = imgs[:-1]
_actions = action[:-1]
_next_obs = imgs[1:]
self.state_embedding_model.train()
self.forward_dynamics_model.train()
self.inverse_dynamics_model.train()
state_emb = self.state_embedding_model(_obs.to(device=self.device))
next_state_emb = self.state_embedding_model(_next_obs.to(device=self.device))
pred_next_state_emb = self.forward_dynamics_model(state_emb, _actions.to(device=self.device))
pred_actions = self.inverse_dynamics_model(state_emb, next_state_emb)
forward_dynamics_loss = self.intrinsic_reward_forward_loss_coef * \
compute_forward_dynamics_loss(_mask[:,:,None]*pred_next_state_emb, _mask[:,:,None]*next_state_emb)
inverse_dynamics_loss = self.intrinsic_reward_inverse_loss_coef * \
compute_inverse_dynamics_loss(_mask[:,:,None]*pred_actions, _mask*_actions)
# update the intr rew models
self.state_embedding_optimizer.zero_grad()
self.forward_dynamics_optimizer.zero_grad()
self.inverse_dynamics_optimizer.zero_grad()
intr_rew_loss = forward_dynamics_loss + inverse_dynamics_loss
intr_rew_loss.backward()
torch.nn.utils.clip_grad_norm_(self.state_embedding_model.parameters(), self.intrinsic_reward_max_grad_norm)
torch.nn.utils.clip_grad_norm_(self.forward_dynamics_model.parameters(), self.intrinsic_reward_max_grad_norm)
torch.nn.utils.clip_grad_norm_(self.inverse_dynamics_model.parameters(), self.intrinsic_reward_max_grad_norm)
self.state_embedding_optimizer.step()
self.forward_dynamics_optimizer.step()
self.inverse_dynamics_optimizer.step()
intr_rew_perf = intr_rew_loss
elif "soc_inf" in self.exploration_bonus_type:
# trained together with the policy
social_influence = True
self.moa_net.train()
if self.intrinsic_reward_epochs > 0:
raise DeprecationWarning(f"Moa must be trained with the agent. intrinsic_reward_epochs must be 0 but is {self.intrinsic_reward_epochs}")
for _ in range(self.epochs):
# Initialize log values
log_entropies = []
log_values = []
log_policy_losses = []
log_value_losses = []
log_grad_norms = []
log_lrs = []
for inds in self._get_batches_starting_indexes():
# Initialize batch values
batch_entropy = 0
batch_value = 0
batch_policy_loss = 0
batch_value_loss = 0
batch_loss = 0
# intr reward metrics
batch_intr_rew_loss = 0
batch_intr_rew_acc = 0
batch_intr_rew_f1 = 0
# Initialize memory
if self.acmodel.recurrent:
memory = exps.memory[inds]
if social_influence:
# Initialize moa memory
moa_memory = exps.moa_memory[inds]
prev_npc_prim_action = None
for i in range(self.recurrence):
# Create a sub-batch of experience
sb = exps[inds + i]
# Compute loss
if self.acmodel.recurrent:
dist, value, memory, policy_embeddings = self.acmodel(sb.obs, memory * sb.mask, return_embeddings=True)
else:
dist, value, policy_embeddings = self.acmodel(sb.obs, return_embeddings=True)
losses = []
for head_i, d in enumerate(dist):
action_masks = self.acmodel.calculate_action_gradient_masks(sb.action).type(sb.log_prob.type())
entropy = (d.entropy() * action_masks[:, head_i]).mean()
ratio = torch.exp(d.log_prob(sb.action[:, head_i]) - sb.log_prob[:, head_i])
surr1 = ratio * sb.advantage
surr2 = torch.clamp(ratio, 1.0 - self.clip_eps, 1.0 + self.clip_eps) * sb.advantage
policy_loss = (
-torch.min(surr1, surr2) * action_masks[:, head_i]
).mean()
value_clipped = sb.value + torch.clamp(value - sb.value, -self.clip_eps, self.clip_eps)
surr1 = (value - sb.returnn).pow(2)
surr2 = (value_clipped - sb.returnn).pow(2)
value_loss = (
torch.max(surr1, surr2) * action_masks[:, head_i]
).mean()
head_loss = policy_loss - self.entropy_coef * entropy + self.value_loss_coef * value_loss
losses.append(head_loss)
if social_influence:
# moa loss
imgs = sb.obs.image
mask = sb.mask.to(torch.int64)
# we only take the first (primitive) action
agent_action = sb.action.to(torch.int64)
infos = numpy.array(sb.infos)
npc_prim_action = torch.tensor(
numpy.array([self.fn_name_to_npc_prim_act[info["NPC_prim_action"]] for info in infos]))
npc_utt_action = torch.tensor(
numpy.array([self.npc_utterance_to_id[info["NPC_utterance"]] for info in infos]))
assert infos.shape == imgs.shape[:1] == agent_action.shape[:1] # [bs]
if i == 0:
prev_npc_prim_action = npc_prim_action
prev_npc_utt_action = npc_utt_action
else:
# compute loss and train moa net
if self.utterance_moa_net:
# transform to long logits
target = npc_prim_action.detach().to(self.device) * self.num_npc_utterance_actions + npc_utt_action.detach().to(self.device)
else:
target = npc_prim_action.detach().to(self.device)
if self.balance_moa_training:
balance_mask = compute_balance_mask(target, n_classes=self.num_npc_all_actions)
else:
balance_mask = torch.ones_like(target)
moa_predictions_logs, moa_memory = self.moa_net(
embeddings=policy_embeddings,
npc_previous_prim_actions=prev_npc_prim_action.detach().to(self.device),
npc_previous_utterance_actions=prev_npc_utt_action.detach().to(self.device) if self.utterance_moa_net else None,
agent_actions=agent_action.detach().to(self.device),
memory=moa_memory * mask,
)
# moa_predictions_logs = moa_predictions_logs.reshape([*prev_shape, -1]) # is this needed
# loss
ce_loss = torch.nn.CrossEntropyLoss(reduction='none')
intr_rew_loss = (
balance_mask * mask * ce_loss(
input=moa_predictions_logs,
target=target,
)).mean() * self.intrinsic_reward_loss_coef
preds = moa_predictions_logs.detach().argmax(dim=-1)
intr_rew_f1 = f1_score(
y_pred=preds.detach().cpu().numpy(),
y_true=target.detach().cpu().numpy(),
average="macro"
)
intr_rew_acc = (
torch.argmax(moa_predictions_logs.to(self.device), dim=-1) == target
).to(float).mean()
batch_intr_rew_loss += intr_rew_loss
batch_intr_rew_acc += intr_rew_acc.detach().cpu().numpy().mean()
batch_intr_rew_f1 += intr_rew_f1
losses.append(intr_rew_loss) # trained with the policy optimizer
loss = torch.stack(losses).mean()
# Update batch values
batch_entropy += entropy.item()
batch_value += value.mean().item()
batch_policy_loss += policy_loss.item()
batch_value_loss += value_loss.item()
batch_loss += loss
# Update memories for next epoch
# assert self.acmodel.recurrent == (self.recurrence > 1)
if self.acmodel.recurrent and i < self.recurrence - 1:
exps.memory[inds + i + 1] = memory.detach()
if social_influence and i < self.recurrence - 1:
exps.moa_memory[inds + i + 1] = moa_memory.detach()
# Update batch values
batch_entropy /= self.recurrence
batch_value /= self.recurrence
batch_policy_loss /= self.recurrence
batch_value_loss /= self.recurrence
batch_loss /= self.recurrence
# Update actor-critic
self.optimizer.zero_grad()
batch_loss.backward()
grad_norm = sum(p.grad.data.norm(2).item() ** 2 for p in self.acmodel.parameters()) ** 0.5
torch.nn.utils.clip_grad_norm_(self.acmodel.parameters(), self.max_grad_norm)
self.optimizer.step()
self.lr_scheduler.step()
if social_influence:
# recurrence-1 because we skipped the first step
batch_intr_rew_loss /= self.recurrence - 1
batch_intr_rew_acc /= self.recurrence - 1
batch_intr_rew_f1 /= self.recurrence - 1
intr_rew_perf_ = batch_intr_rew_f1
intr_rew_perf = batch_intr_rew_acc
# Update log values
log_entropies.append(batch_entropy)
log_values.append(batch_value)
log_policy_losses.append(batch_policy_loss)
log_value_losses.append(batch_value_loss)
log_grad_norms.append(grad_norm)
log_lrs.append(self.optimizer.param_groups[0]['lr'])
# Log some values
logs = {
"entropy": numpy.mean(log_entropies),
"value": numpy.mean(log_values),
"policy_loss": numpy.mean(log_policy_losses),
"value_loss": numpy.mean(log_value_losses),
"grad_norm": numpy.mean(log_grad_norms),
"intr_reward_perf": intr_rew_perf,
"intr_reward_perf_": intr_rew_perf_,
"lr": numpy.mean(log_lrs),
}
return logs
def _get_batches_starting_indexes(self):
"""Gives, for each batch, the indexes of the observations given to
the model and the experiences used to compute the loss at first.
First, the indexes are the integers from 0 to `self.num_frames` with a step of
`self.recurrence`, shifted by `self.recurrence//2` one time in two for having
more diverse batches. Then, the indexes are splited into the different batches.
Returns
-------
batches_starting_indexes : list of list of int
the indexes of the experiences to be used at first for each batch
"""
indexes = numpy.arange(0, self.num_frames, self.recurrence)
indexes = numpy.random.permutation(indexes)
# Shift starting indexes by self.recurrence//2 half the time
if self.batch_num % 2 == 1:
indexes = indexes[(indexes + self.recurrence) % self.num_frames_per_proc != 0]
indexes += self.recurrence // 2
self.batch_num += 1
num_indexes = self.batch_size // self.recurrence
batches_starting_indexes = [indexes[i:i+num_indexes] for i in range(0, len(indexes), num_indexes)]
return batches_starting_indexes
def get_config_dict(self):
del self.config['envs']
del self.config['acmodel']
del self.config['__class__']
del self.config['self']
del self.config['preprocess_obss']
del self.config['device']
return self.config