GenSim / cliport /models /streams /two_stream_attention.py
LeroyWaa's picture
add gensim code
8fc2b4e
raw
history blame
No virus
1.66 kB
import cliport.models as models
import cliport.models.core.fusion as fusion
from cliport.models.core.attention import Attention
class TwoStreamAttention(Attention):
"""Two Stream Attention (a.k.a Pick) module"""
def __init__(self, stream_fcn, in_shape, n_rotations, preprocess, cfg, device):
self.fusion_type = cfg['train']['attn_stream_fusion_type']
super().__init__(stream_fcn, in_shape, n_rotations, preprocess, cfg, device)
def _build_nets(self):
stream_one_fcn, stream_two_fcn = self.stream_fcn
stream_one_model = models.names[stream_one_fcn]
stream_two_model = models.names[stream_two_fcn]
self.attn_stream_one = stream_one_model(self.in_shape, 1, self.cfg, self.device, self.preprocess)
self.attn_stream_two = stream_two_model(self.in_shape, 1, self.cfg, self.device, self.preprocess)
self.fusion = fusion.names[self.fusion_type](input_dim=1)
print(f"Attn FCN - Stream One: {stream_one_fcn}, Stream Two: {stream_two_fcn}, Stream Fusion: {self.fusion_type}")
def attend(self, x):
x1 = self.attn_stream_one(x)
x2 = self.attn_stream_two(x)
x = self.fusion(x1, x2)
return x
class TwoStreamAttentionLat(TwoStreamAttention):
"""Two Stream Attention (a.k.a Pick) module with lateral connections"""
def __init__(self, stream_fcn, in_shape, n_rotations, preprocess, cfg, device):
super().__init__(stream_fcn, in_shape, n_rotations, preprocess, cfg, device)
def attend(self, x):
x1, lat = self.attn_stream_one(x)
x2 = self.attn_stream_two(x, lat)
x = self.fusion(x1, x2)
return x