[PyTorch] Test Sample Accuracy 계산하기

http://andersonjo.github.io/pytorch/2017/04/01/PyTorch-Getting-Started/

댓글

  1. import torch
    import torch.nn as nn
    #from __future__ import print_function
    import argparse
    from PIL import Image
    import torchvision.models as models
    import skimage.io
    from torch.autograd import Variable as V
    from torch.nn import functional as f
    from torchvision import transforms as trn


    # define image transformation
    centre_crop = trn.Compose([
    trn.ToPILImage(),
    trn.Scale(256),
    trn.CenterCrop(224),
    trn.ToTensor(),
    trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    filename=r'/home/scrapmetal/Documents/mw_data/val/malware/VirusShare_0c8ab85240a6bfcd12bdc4fae2437ed91.png'
    img = skimage.io.imread(filename)

    x = V(centre_crop(img).unsqueeze(0), volatile=True)

    model = models.__dict__['resnet18']()
    model = torch.nn.DataParallel(model).cuda()

    model = torch.load('mw_model0831.pth')
    #model.load_state_dict(checkpoint['state_dict'])

    #best_prec1 = checkpoint['best_prec1']
    logit = model(x)

    print(logit)
    print(len(logit))
    h_x = f.softmax(logit).data.squeeze()

    답글삭제

댓글 쓰기

이 블로그의 인기 게시물

파이썬으로 Homomorphic Filtering 하기

파이썬으로 2D FFT/iFFT 하기: numpy 버전