본문 바로가기
Artificial intelligence

[AI] torchvision AlexNet, VGG16, ResNet18, DenseNet121 에 대한 성능 비교하기

by @__100.s 2021. 12. 16.
반응형
  • PyTorch의 torchvision을 활용하여 ImageNet으로 사전학습된 AlexNet, VGG, ResNet, DenseNet의 성능을 ImageNet Validation dataset (2012) 기준으로 비교하기 
  • ImageNet Validation dataset(2012)
  • colab gdrive에 다운 받은 ImageNet Validation dataset(2012)를 업로드 한다.

from google.colab import drive

drive.mount('/content/gdrive')
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.models import alexnet, vgg16, resnet18, densenet121
model_names = ["alexnet", "vgg16", "resnet18", "densenet121"]

models = {
    "alexnet":alexnet,
    "vgg16":vgg16,
    "resnet18": resnet18,
    "densenet121": densenet121,
}

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
transform = transforms.Compose(
        [transforms.Resize(256),
         transforms.CenterCrop(224),
         transforms.ToTensor(),
         normalize,
         ])
test_set = torchvision.datasets.ImageNet(root='/content/gdrive/MyDrive/data', split='val', transform = transform);
test_loader = torch.utils.data.DataLoader(test_set, batch_size=100, shuffle=True, num_workers=4)
for name in model_names:

  model = models[name](pretrained=True).to(device)
  model.eval()

  top1_accuracy = 0
  top5_accuracy = 0
  total = 0

  with torch.no_grad():
      for i, (data, target) in enumerate(test_loader):
          data = data.to(device)
          target = target.to(device)
          outputs = model(data)

          # top1
          _, top1 = torch.max(outputs, 1)
          total += target.size(0)

          top1_accuracy += (top1 == target).sum().item()

          # top5
          _, top5 = outputs.topk(5, 1, True, True)
          top5 = top5.t()
          correct = top5.eq(target.view(1, -1).expand_as(top5))
          for j in range(6):
              correct_j = correct[:j].reshape(-1).float().sum(0, keepdim=True)

          top5_accuracy += correct_j.item()

          step_length = len(test_set)/int(target.size(0))
          tp1_accuracy = (top1_accuracy/total)*100
          tp5_accuracy = (top5_accuracy/total)*100

          print("(step) {} / {}".format(i + 1, step_length))
          print("(top1_accuracy) {0:0.2f}%".format(tp1_accuracy))
          print("(top5_accuracy) {0:0.2f}%".format(tp5_accuracy))

성능 확인

  • Top 1 accuracy : 상위 1개가 정답일 확률
  • Top 5 accuracy : 상위 5개 중 정답이 있을 확률
  Top 1 accuracy Top 5 accuracy
AlexNet 56% 79%
VGG16 72% 91%
ResNet18 69% 88%
DenseNet121 75% 91%

 

참고

반응형