Poprawa klasyfikacji cukrzycy SVM


10

Używam SVM do przewidywania cukrzycy. Używam do tego zestawu danych BRFSS . Zestaw danych ma wymiary i jest przekrzywiony. Procent s w zmiennej docelowej wynosi 11 %, podczas gdy s stanowią pozostałe 89 % .432607×136Y11%N89%

Korzystam tylko 15z 136niezależnych zmiennych z zestawu danych. Jednym z powodów zmniejszenia zbioru danych było zwiększenie liczby próbek treningowych, gdy NApominięto wiersze zawierające s.

Te 15zmienne zostały wybrane po uruchomieniu metody statystyczne, takie jak przypadkowych drzew, regresji logistycznej i dowiedzieć się, które zmienne są istotne z otrzymanych modeli. Na przykład po przeprowadzeniu regresji logistycznej posortowaliśmy p-valuenajbardziej znaczące zmienne.

Czy moja metoda dokonywania wyboru zmiennych jest poprawna? Wszelkie sugestie są mile widziane.

Oto moja Rimplementacja.

library(e1071) # Support Vector Machines

#--------------------------------------------------------------------
# read brfss file (huge 135 MB file)
#--------------------------------------------------------------------
y <- read.csv("http://www.hofroe.net/stat579/brfss%2009/brfss-2009-clean.csv")
indicator <- c("DIABETE2", "GENHLTH", "PERSDOC2", "SEX", "FLUSHOT3", "PNEUVAC3", 
    "X_RFHYPE5", "X_RFCHOL", "RACE2", "X_SMOKER3", "X_AGE_G", "X_BMI4CAT", 
    "X_INCOMG", "X_RFDRHV3", "X_RFDRHV3", "X_STATE");
target <- "DIABETE2";
diabetes <- y[, indicator];

#--------------------------------------------------------------------
# recode DIABETE2
#--------------------------------------------------------------------
x <- diabetes$DIABETE2;
x[x > 1]  <- 'N';
x[x != 'N']  <- 'Y';
diabetes$DIABETE2 <- x; 
rm(x);

#--------------------------------------------------------------------
# remove NA
#--------------------------------------------------------------------
x <- na.omit(diabetes);
diabetes <- x;
rm(x);

#--------------------------------------------------------------------
# reproducible research 
#--------------------------------------------------------------------
set.seed(1612);
nsamples <- 1000; 
sample.diabetes <- diabetes[sample(nrow(diabetes), nsamples), ]; 

#--------------------------------------------------------------------
# split the dataset into training and test
#--------------------------------------------------------------------
ratio <- 0.7;
train.samples <- ratio*nsamples;
train.rows <- c(sample(nrow(sample.diabetes), trunc(train.samples)));

train.set  <- sample.diabetes[train.rows, ];
test.set   <- sample.diabetes[-train.rows, ];

train.result <- train.set[ , which(names(train.set) == target)];
test.result  <- test.set[ , which(names(test.set) == target)];

#--------------------------------------------------------------------
# SVM 
#--------------------------------------------------------------------
formula <- as.formula(factor(DIABETE2) ~ . );
svm.tune <- tune.svm(formula, data = train.set, 
    gamma = 10^(-3:0), cost = 10^(-1:1));
svm.model <- svm(formula, data = train.set, 
    kernel = "linear", 
    gamma = svm.tune$best.parameters$gamma, 
    cost  = svm.tune$best.parameters$cost);

#--------------------------------------------------------------------
# Confusion matrix
#--------------------------------------------------------------------
train.pred <- predict(svm.model, train.set);
test.pred  <- predict(svm.model, test.set);
svm.table <- table(pred = test.pred, true = test.result);
print(svm.table);

Uruchomiłem z (trening = 700 i test = 300 ) próbek, ponieważ w moim laptopie jest szybszy. Macierz pomyłek dla danych testowych ( 300 próbek), które otrzymuję, jest dość zła.1000700300300

    true
pred   N   Y
   N 262  38
   Y   0   0

Muszę poprawić swoje przewidywania dla Yklasy. W rzeczywistości muszę być tak dokładny, jak to możliwe, Ynawet jeśli źle sobie radzę N. Wszelkie sugestie dotyczące poprawy dokładności klasyfikacji byłyby bardzo mile widziane.


Myślę, że twój SVM w ogóle nie działa, ale nie wiem dlaczego! może być też lepiej znormalizować dane ...
user4581,

Y 90%

Rozpocznij od normalizacji danych. Zacznij od tego. Możesz także spróbować przeszukać także nieliniowe jądro, które może dać lepszy wynik. (To zależy od twojego przewidywania granic, być może normalizacja powinna wystarczyć)
404Dreamer_ML

Możesz także spróbować kernlabzamiast e1071- automatycznie normalizuje się i ma pewne cechy heurystyczne, które ułatwiają uruchomienie pierwszego modelu.

Odpowiedzi:


9

Mam 4 sugestie:

  1. Jak wybierasz zmienne, które chcesz uwzględnić w swoim modelu? Być może brakuje niektórych kluczowych wskaźników z większego zestawu danych.
  2. Prawie wszystkie używane wskaźniki (takie jak płeć, palacz itp.) Należy traktować jako czynniki. Traktowanie tych zmiennych jako liczbowych jest błędne i prawdopodobnie przyczynia się do błędu w twoim modelu.
  3. Dlaczego używasz SVM? Czy wypróbowałeś jakieś prostsze metody, takie jak liniowa analiza dyskryminacyjna lub nawet regresja liniowa? Być może proste podejście do większego zestawu danych da lepszy wynik.
  4. Wypróbuj pakiet Caret . Pomoże Ci w sprawdzeniu dokładności modelu, jest równoległy, co pozwoli Ci pracować szybciej i ułatwi przeglądanie różnych typów modeli.

Oto przykładowy kod dla karetki:

library(caret)

#Parallize
library(doSMP)
w <- startWorkers()
registerDoSMP(w)

#Build model
X <- train.set[,-1]
Y <- factor(train.set[,1],levels=c('N','Y'))
model <- train(X,Y,method='lda')

#Evaluate model on test set
print(model)
predY <- predict(model,test.set[,-1])
confusionMatrix(predY,test.set[,1])
stopWorkers(w)

Ten model LDA bije SVM, a ja nawet nie naprawiłem twoich czynników. Jestem pewien, że jeśli przekodujesz płeć, palacza itp. Jako czynniki, uzyskasz lepsze wyniki.


Pojawia się następujący błąd task 1 failed - "could not find function "predictionFunction"". Wiem, że to nie jest forum, ale jeśli masz jakieś uwagi, daj mi znać.
Anand

1
@Anand: Otwórz nową sesję R jako administrator (lub uruchom sudo R na Mac / Linux). Uruchom update.packages.Po zakończeniu zamknij R i ponownie otwórz normalną (nieadministracyjną) sesję. Uruchom kod, z wyjątkiem sekcji „SVM” i „Macierz konfuzji”. Następnie uruchom mój kod. Jeśli nadal występuje błąd, opublikuj wiersz, który zwrócił błąd, wraz z dokładnym błędem.
Zach.

1
@Anand: Upewnij się również, że korzystasz z najnowszej wersji R (2.14) i używasz najnowszej wersji caret. Możesz zaktualizować karetkę, uruchamiając install.packages('caret')ponownie.
Zach.

@Anand: Świetnie! Możesz wprowadzić różne metody dla trainfunkcji, takie jak nb(naiwne bayes), glm(regresja logistyczna) svmLineari svmRadial. Dopasowywanie svm zajmie dużo czasu.
Zach.

3

Jeśli używasz jądra liniowego, możliwe jest, że wybór funkcji jest złym pomysłem, a regularyzacja może zapobiec nadmiernemu dopasowaniu bardziej skutecznie niż wybór funkcji. Należy pamiętać, że granice wydajności, które w przybliżeniu implementuje SVM, są niezależne od wymiaru przestrzeni funkcji, który był jednym z punktów sprzedaży SVM.


2

Ostatnio miałem ten problem i znalazłem kilka rzeczy, które pomagają. Najpierw wypróbuj model Naive Bayesa (pakiet klaR), który czasem daje lepsze wyniki, gdy klasa mniejszościowa w problemie klasyfikacji jest niewielka. Ponadto, jeśli zdecydujesz się pozostać przy SVM, możesz spróbować nadpróbkowania klasy mniejszości. Zasadniczo będziesz chciał podać więcej przykładów klasy mniejszości lub syntetycznie utworzyć przypadki dla klasy mniejszości

Ten dokument: http: //www.it.iitb.ac.in/~kamlesh/Page/Reports/highlySkewed.pdf

Przeprowadzono dyskusję i przykłady tych technik zaimplementowanych w Weka, ale możliwe jest także ich wszczepienie w R.


Dziękuję za pomocne komentarze. Pozwól mi wypróbować twoje sugestie.
Anand

1

Oprócz tego, co już wspomniano, naprawiasz swój najlepszy model, aby używać jądra liniowego. Powinieneś przewidzieć użycie najlepszego modelu, który został dostrojony, w tym tego samego jądra, które zostało użyte / znalezione na etapie dostrajania (zakładam, że jest to RBF, ponieważ tunujesz gamma).

Korzystając z naszej strony potwierdzasz, że przeczytałeś(-aś) i rozumiesz nasze zasady używania plików cookie i zasady ochrony prywatności.
Licensed under cc by-sa 3.0 with attribution required.