-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Description
I try define a model based on pretrained EfficientNet as below. But I get a NotImplementedError: when use 'forward' function. However, when I use other pretrained CNN e.g., resnet18 from torchvision, there is no such problem. Can anyone help me? Thanks a lot
'Model definition'
class EfficientNet_scene(nn.Module):
def __init__(self,model_name='efficientnet-b0',class_num=45,initfc_type='normal',gain=0.2):
super(EfficientNet_scene, self).__init__()
model = EfficientNet.from_pretrained(model_name)
aul = [*model.children()]
self.features = nn.Sequential(*aul[:-1])
self.fc = nn.Linear(aul[-1].in_features,class_num)
if hasattr(self.fc, 'bias') and self.fc.bias is not None:
nn.init.constant_(self.fc.bias.data, 0.0)
if initfc_type == 'normal':
nn.init.normal_(self.fc.weight.data, 0.0, gain)
elif initfc_type == 'xavier':
nn.init.xavier_normal_(self.fc.weight.data, gain=gain)
elif initfc_type == 'kaiming':
nn.init.kaiming_normal_(self.fc.weight.data, a=0, mode='fan_in')
elif initfc_type == 'orthogonal':
nn.init.orthogonal_(self.fc.weight.data, gain=gain)
def forward(self,x):
x = self.features(x)
x = self.fc(x)
return x
net = EfficientNet_scene()
image = torch.randn(1,3,224,224)
b = net(image)
Error information:
NotImplementedError Traceback (most recent call last)
in ()
37 print(net)
38 image = torch.randn(1,3,224,224)
---> 39 b = net(image)
40 print(b)
~/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py in call(self, *input, **kwargs)
491 result = self._slow_forward(*input, **kwargs)
492 else:
--> 493 result = self.forward(*input, **kwargs)
494 for hook in self._forward_hooks.values():
495 hook_result = hook(self, input, result)
in forward(self, x)
30
31 def forward(self,x):
---> 32 x = self.features(x)
33 x = x.reshape(x.size(0), -1)
34 x = self.fc(x)
~/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py in call(self, *input, **kwargs)
491 result = self._slow_forward(*input, **kwargs)
492 else:
--> 493 result = self.forward(*input, **kwargs)
494 for hook in self._forward_hooks.values():
495 hook_result = hook(self, input, result)
~/anaconda3/lib/python3.6/site-packages/torch/nn/modules/container.py in forward(self, input)
90 def forward(self, input):
91 for module in self._modules.values():
---> 92 input = module(input)
93 return input
94
~/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py in call(self, *input, **kwargs)
491 result = self._slow_forward(*input, **kwargs)
492 else:
--> 493 result = self.forward(*input, **kwargs)
494 for hook in self._forward_hooks.values():
495 hook_result = hook(self, input, result)
~/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py in forward(self, *input)
86 registered hooks while the latter silently ignores them.
87 """
---> 88 raise NotImplementedError
89
90 def register_buffer(self, name, tensor):
NotImplementedError: