File size: 170 Bytes
3f96a16
881b143
 
3f96a16
 
881b143
 
3f96a16
881b143
1
2
3
4
5
6
7
8
9
10
from torch import nn

FC_CLASS_REGISTRY = {"torch": nn.Linear}
try:
    import transformer_engine.pytorch as te

    FC_CLASS_REGISTRY["te"] = te.Linear
except:
    pass