[Processing] SVM으로 스팸문자 판별하기

import psvm.*;

SVM model;

// data for trainning
String trainingDocuments[] = {
    "FREE NATIONAL TREASURE",     // 1.Spam
    "FREE TV for EVERY visitor",   // 2. Spam
    "Peter and Stewie are hilarious", // 3. OK
    "AS SEEN ON NATIONAL TV",      // 4. SPAM
    "Buy Viagra Now!",          // 5.SPAM
    "New episode rocks, Peter and Stewie are hilarious", // 6.OK
    "Peter is my fav!",        // 7.OK
    "Buy Viagra",              // 8.Spam
    "Free viagra for you.",      // 9. Spam
    "Hello, Doctor",             // 10.OK
    "If she follow me",          // 11. OK
    "delete this email, please!",  // 12. OK
    "50% sale, Buy Now~",          // 13. spam
    "Buy this now",                // 14. spam
    "Buy this cool phone"          // 15. Spam
};

// spam class labels
int labels[] = {
  1, 1, 0, 1, 1,
  0, 0, 1, 1, 0,
  0, 0, 1, 1, 1
};

// data for Test
String testDocuments[] = {
    "FREE lotterry for the NATIONAL TREASURE!!!", // Spam
    "Stewie is hilarious",     // OK
    "Poor Peter...nhilarious",    // OK
    "I love this show",
    "Free gold just click HERE.",
    "Best episode ever",
    "Buy Viagra",
    "I love golden cats",
    "FREE GOLD"
};

ArrayList<String> globalDictionary;

// Make Dict from train & test data func
void buildGlobalDictionary() {

  globalDictionary = new ArrayList<String>();
  // from train data
  for (int i = 0; i < trainingDocuments.length; i++) {
    String doc = trainingDocuments[i];
    String words[] = split(doc, ' ');
    for (int w = 0; w < words.length; w++) {
      String word = words[w].toLowerCase();
      word = word.replaceAll("\\W", "");
      if (!globalDictionary.contains(word)) {
        globalDictionary.add(word);
      }
    }
  }
  // from test data
  for (int i = 0; i < testDocuments.length; i++) {
    String doc = testDocuments[i];
    String words[] = split(doc, ' ');
    for (int w = 0; w < words.length; w++) {
      String word = words[w].toLowerCase();
      word = word.replaceAll("\\W", "");
      if (!globalDictionary.contains(word)) {
        globalDictionary.add(word);
      }
    }
  }

  println(globalDictionary);
}

// Make Feature Vector func
int[] buildVector(String input) {
  String[] words = split(input, ' ');
  ArrayList<String> normalizedWords = new ArrayList();
  for (int w = 0; w < words.length; w++) {
    words[w] = words[w].replaceAll("\\W", "");
    normalizedWords.add(words[w].toLowerCase());
  }

// dict final ready
  int[] result = new int[globalDictionary.size()];
  for (int i = 0; i < globalDictionary.size(); i++) {
    String word = globalDictionary.get(i);
    if (normalizedWords.contains(word)) {
      result[i] = 1;
    }
    else {
      result[i] = 0;
    }
  }
  return result;
}

// -------------------------------------------------------
void setup() {
  // window
  size(500,500);
  buildGlobalDictionary();

  // Make Train FVector
  int[][] trainingVectors = new int[trainingDocuments.length][globalDictionary.size()];
  for(int i = 0; i < trainingDocuments.length; i++){
    trainingVectors[i] = buildVector(trainingDocuments[i]);
  }

  // Model init
  model = new SVM(this);
  // Make Problem & Param
  SVMProblem problem = new SVMProblem();
  problem.setNumFeatures(2);
  problem.setSampleData(labels, trainingVectors);
  // Make Training ************
  model.train(problem);

  // Make Test FVector
  int[][] testVectors = new int[testDocuments.length][globalDictionary.size()];
  for(int i = 0; i < testDocuments.length; i++){
    testVectors[i] = buildVector(testDocuments[i]);
  }
  // test
  for(int i = 0; i < testDocuments.length; i++){
   // println("testing: " + testDocuments[i] );
    double score = model.test(testVectors[i]);
    //println("result: " + score);
  }
}

//--------------------------------------------------
void draw() {
// black back
background(0);
// test from trained SVM model.
for (int i =0; i < testDocuments.length; i++) {
    double score = model.test(buildVector(testDocuments[i]));
    if(score == 0){
      // No Spam, Green text
      fill(0,255,0);
    } else {
      // Spam, Red Text
      fill(255,0,0);
    }
    text(testDocuments[i] + " [score: "+score+"]", 10, 20*i + 20);
  }
}

// save model & dic when keypress ---------------------------------
void saveGlobalDictionary(){
  String[] strings = new String[globalDictionary.size()];
  strings = globalDictionary.toArray(strings);
  saveStrings(dataPath("dictionary.txt"), strings);
}

void keyPressed(){
  saveGlobalDictionary();
  model.saveModel("model.txt");
}

댓글

이 블로그의 인기 게시물

파이썬으로 Homomorphic Filtering 하기

파이썬으로 2D FFT/iFFT 하기: numpy 버전