makeROCData.py

#!/usr/bin/python
import os
import re
from scipy import ndimage, misc

import torch
import torch.nn as nn
#from __future__ import print_function
import argparse
from PIL import Image
import torchvision.models as models
import skimage.io
from torch.autograd import Variable as V
from torch.nn import functional as f
from torchvision import transforms as trn

# define image transformation
centre_crop = trn.Compose([
        trn.ToPILImage(),
        trn.Scale(256),
        trn.CenterCrop(224),
        trn.ToTensor(),
        trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

images = []
for root, dirnames, filenames in os.walk("/home/kerb/Documents/data_bm_0913/test/benign/"):
    for filename in filenames:      # for all files
        if re.search("\.(jpg|jpeg|png)$", filename):
            filepath = os.path.join(root, filename)     # path + filename
            print(filepath)
           
            #image = ndimage.imread(filepath, mode="L")  # image read
            img = skimage.io.imread(filepath)
            #image_resized = misc.imresize(image, (256, 256))    #resize
            x = V(centre_crop(img).unsqueeze(0), volatile=True).cuda()
            model = models.__dict__['resnet34']()
            model = torch.nn.DataParallel(model).cuda()

            model = torch.load('modelFT_BM60.pth')

            logit = model(x)
            #print(logit)
            h_x = f.softmax(logit).data.squeeze()
            f1= open('rocdata.csv', 'a')
            f1.write("1" + "," + str(h_x[0]) + "\n")
            f1.close()
       
images2 = []
for root, dirnames, filenames in os.walk("/home/kerb/Documents/data_bm_0913/test/malware"):
    for filename in filenames:      # for all files
        if re.search("\.(jpg|jpeg|png)$", filename):
            filepath = os.path.join(root, filename)     # path + filename
            print(filepath)
           
            #image = ndimage.imread(filepath, mode="L")  # image read
            img = skimage.io.imread(filepath)
            #image_resized = misc.imresize(image, (256, 256))   #resize
            x = V(centre_crop(img).unsqueeze(0), volatile=True).cuda()
            model = models.__dict__['resnet34']()
            model = torch.nn.DataParallel(model).cuda()

            model = torch.load('modelFT_BM60.pth')

            logit = model(x)
            #print(logit)
            h_x = f.softmax(logit).data.squeeze()
            f1= open('rocdata.csv', 'a')
            f1.write("0" + "," + str(h_x[1]) + "\n")
            f1.close()

댓글

이 블로그의 인기 게시물

파이썬으로 Homomorphic Filtering 하기

파이썬으로 2D FFT/iFFT 하기: numpy 버전