SocialAISchool / torch-ac /torch_ac /intrinsic_reward_models.py
grg's picture
Cleaned old git history
be5548b
raw
history blame
8.26 kB
from torch import nn
import torch
from torch.nn import functional as F
def init(module, weight_init, bias_init, gain=1):
weight_init(module.weight.data, gain=gain)
bias_init(module.bias.data)
return module
class MinigridInverseDynamicsNet(nn.Module):
def __init__(self, num_actions):
super(MinigridInverseDynamicsNet, self).__init__()
self.num_actions = num_actions
init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
constant_(x, 0), nn.init.calculate_gain('relu'))
self.inverse_dynamics = nn.Sequential(
init_(nn.Linear(2 * 128, 256)),
nn.ReLU(),
)
init_ = lambda m: init(m, nn.init.orthogonal_,
lambda x: nn.init.constant_(x, 0))
self.id_out = init_(nn.Linear(256, self.num_actions))
def forward(self, state_embedding, next_state_embedding):
inputs = torch.cat((state_embedding, next_state_embedding), dim=2)
action_logits = self.id_out(self.inverse_dynamics(inputs))
return action_logits
class MinigridForwardDynamicsNet(nn.Module):
def __init__(self, num_actions):
super(MinigridForwardDynamicsNet, self).__init__()
self.num_actions = num_actions
init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
constant_(x, 0), nn.init.calculate_gain('relu'))
self.forward_dynamics = nn.Sequential(
init_(nn.Linear(128 + self.num_actions, 256)),
nn.ReLU(),
)
init_ = lambda m: init(m, nn.init.orthogonal_,
lambda x: nn.init.constant_(x, 0))
self.fd_out = init_(nn.Linear(256, 128))
def forward(self, state_embedding, action):
action_one_hot = F.one_hot(action, num_classes=self.num_actions).float()
inputs = torch.cat((state_embedding, action_one_hot), dim=2)
next_state_emb = self.fd_out(self.forward_dynamics(inputs))
return next_state_emb
class MinigridStateEmbeddingNet(nn.Module):
def __init__(self, observation_shape):
super(MinigridStateEmbeddingNet, self).__init__()
self.observation_shape = observation_shape
init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
constant_(x, 0), nn.init.calculate_gain('relu'))
self.feat_extract = nn.Sequential(
init_(nn.Conv2d(in_channels=self.observation_shape[2], out_channels=32, kernel_size=(3, 3),
stride=2, padding=1)),
nn.ELU(),
init_(nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), stride=2, padding=1)),
nn.ELU(),
init_(nn.Conv2d(in_channels=32, out_channels=128, kernel_size=(3, 3), stride=2, padding=1)),
nn.ELU(),
)
def forward(self, inputs):
# -- [unroll_length x batch_size x height x width x channels]
x = inputs
T, B, *_ = x.shape
# -- [unroll_length*batch_size x height x width x channels]
x = torch.flatten(x, 0, 1) # Merge time and batch.
x = x.float() / 255.0
# -- [unroll_length*batch_size x channels x width x height]
x = x.transpose(1, 3)
x = self.feat_extract(x)
state_embedding = x.view(T, B, -1)
return state_embedding
def compute_forward_dynamics_loss(pred_next_emb, next_emb):
forward_dynamics_loss = torch.norm(pred_next_emb - next_emb, dim=2, p=2)
return torch.sum(torch.mean(forward_dynamics_loss, dim=1))
def compute_inverse_dynamics_loss(pred_actions, true_actions):
inverse_dynamics_loss = F.nll_loss(
F.log_softmax(torch.flatten(pred_actions, 0, 1), dim=-1),
target=torch.flatten(true_actions, 0, 1),
reduction='none')
inverse_dynamics_loss = inverse_dynamics_loss.view_as(true_actions)
return torch.sum(torch.mean(inverse_dynamics_loss, dim=1))
class LSTMMoaNet(nn.Module):
def __init__(self, input_size, num_npc_prim_actions, acmodel, num_npc_utterance_actions=None, memory_dim=128):
super(LSTMMoaNet, self).__init__()
self.num_npc_prim_actions = num_npc_prim_actions
self.num_npc_utterance_actions = num_npc_utterance_actions
self.utterance_moa = num_npc_utterance_actions is not None
self.input_size = input_size
self.acmodel = acmodel
init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
constant_(x, 0), nn.init.calculate_gain('relu'))
self.hidden_size = 128 # 256 in the original paper
self.forward_dynamics = nn.Sequential(
init_(nn.Linear(self.input_size, self.hidden_size)),
nn.ReLU(),
)
self.memory_dim = memory_dim
self.memory_rnn = nn.LSTMCell(self.hidden_size, self.memory_dim)
self.embedding_size = self.semi_memory_size
init_ = lambda m: init(m, nn.init.orthogonal_,
lambda x: nn.init.constant_(x, 0))
self.fd_out_prim = init_(nn.Linear(self.embedding_size, self.num_npc_prim_actions))
if self.utterance_moa:
self.fd_out_utt = init_(nn.Linear(self.embedding_size, self.num_npc_utterance_actions))
@property
def memory_size(self):
return 2 * self.semi_memory_size
@property
def semi_memory_size(self):
return self.memory_dim
def forward(self, embeddings, npc_previous_prim_actions, agent_actions, memory, npc_previous_utterance_actions=None):
npc_previous_prim_actions_OH = F.one_hot(npc_previous_prim_actions, self.num_npc_prim_actions)
if self.utterance_moa:
npc_previous_utterance_actions_OH = F.one_hot(
npc_previous_utterance_actions,
self.num_npc_utterance_actions
)
# is_agent_speaking = self.acmodel.is_raw_action_speaking(agent_action[None, :])
# assert len(is_agent_speaking) == 1
# is_agent_speaking = is_agent_speaking[0]
# enocde agents' action
is_agent_speaking = self.acmodel.is_raw_action_speaking(agent_actions)
# prim_action_OH_ = prim_action_OH[None, :].repeat([len(npc_previous_actions_OH), 1])
# template_OH_ = template_OH[None, :].repeat([len(npc_previous_actions_OH), 1])
# word_OH_ = word_OH[None, :].repeat([len(npc_previous_actions_OH), 1])
prim_action_OH = F.one_hot(agent_actions[:, 0], self.acmodel.model_raw_action_space.nvec[0])
template_OH = F.one_hot(agent_actions[:, 2], self.acmodel.model_raw_action_space.nvec[2])
word_OH = F.one_hot(agent_actions[:, 3], self.acmodel.model_raw_action_space.nvec[3])
# if not speaking make the templates 0
template_OH = template_OH * is_agent_speaking[:, None]
word_OH = word_OH * is_agent_speaking[:, None]
if self.utterance_moa:
inputs = torch.cat((
embeddings, # obs
npc_previous_prim_actions_OH, # npc
npc_previous_utterance_actions_OH,
prim_action_OH, template_OH, word_OH # agent
), dim=1).float()
else:
inputs = torch.cat((
embeddings, # obs
npc_previous_prim_actions_OH, # npc
prim_action_OH, template_OH, word_OH # agent
), dim=1).float()
outs_1 = self.forward_dynamics(inputs)
# LSTM
hidden = (memory[:, :self.semi_memory_size], memory[:, self.semi_memory_size:])
hidden = self.memory_rnn(outs_1, hidden)
embedding = hidden[0]
memory = torch.cat(hidden, dim=1)
outs_prim = self.fd_out_prim(embedding)
if self.num_npc_utterance_actions:
outs_utt = self.fd_out_utt(embedding)
# cartesian product
# outs = torch.bmm(outs_prim.unsqueeze(2), outs_utt.unsqueeze(1)).reshape(-1, self.num_npc_prim_actions*self.num_npc_utterance_actions)
# outer sum
outs = (outs_prim[..., None] + outs_utt[..., None, :]).reshape(-1, self.num_npc_prim_actions*self.num_npc_utterance_actions)
else:
outs = outs_prim
return outs, memory