resnet50 = models.resnet50(pretrained=False)
fc_in_features = resnet50.fc.in_features
resnet50.fc = nn.Linear(
in_features=fc_in_features,
out_features=7)
resnet50.load_state_dict(torch.load("./Models/resnet50.pt", map_location=torch.device('cpu')))
torch_model = torch.nn.Sequential(
resnet50,
nn.Sigmoid())
torch_model.eval()
input_shape = (3, 224, 224)
example_input = torch.rand(1, *input_shape)
traced_model = torch.jit.trace(torch_model, example_input)
scripted_model = torch.jit.script(torch_model)
input = ct.ImageType(
color_layout='RGB',
scale=1.0/255.0/0.226,
bias=(-0.485/0.229, -0.456/0.224, -0.406/0.225),
shape=example_input.shape)
cml_model = ct.convert(
traced_model,
inputs=[input],
classifier_config=ct.ClassifierConfig(class_labels)
cml_model.save('Models/Resnet50.mlmodel')
model_8bit = quantization_utils.quantize_weights(
cml_model,
nbits=8,
quantization_mode="linear")
model_8bit.save('Models/Resnet50_8bit.mlmodel')
model = ct.models.MLModel('Models/Resnet50.mlmodel')
prediction = model.predict({'input.1': img})