#include #include #include #include #include #include "../linear.h" #include "mex.h" #include "linear_model_matlab.h" #ifdef MX_API_VER #if MX_API_VER < 0x07030000 typedef int mwIndex; #endif #endif #define CMD_LEN 2048 #define Malloc(type,n) (type *)malloc((n)*sizeof(type)) #define INF HUGE_VAL void print_null(const char *s) {} void print_string_matlab(const char *s) {mexPrintf(s);} void exit_with_help() { mexPrintf( "Usage: model = train(training_label_vector, training_instance_matrix, 'liblinear_options', 'col');\n" "liblinear_options:\n" "-s type : set type of solver (default 1)\n" " for multi-class classification\n" " 0 -- L2-regularized logistic regression (primal)\n" " 1 -- L2-regularized L2-loss support vector classification (dual)\n" " 2 -- L2-regularized L2-loss support vector classification (primal)\n" " 3 -- L2-regularized L1-loss support vector classification (dual)\n" " 4 -- support vector classification by Crammer and Singer\n" " 5 -- L1-regularized L2-loss support vector classification\n" " 6 -- L1-regularized logistic regression\n" " 7 -- L2-regularized logistic regression (dual)\n" " for regression\n" " 11 -- L2-regularized L2-loss support vector regression (primal)\n" " 12 -- L2-regularized L2-loss support vector regression (dual)\n" " 13 -- L2-regularized L1-loss support vector regression (dual)\n" "-c cost : set the parameter C (default 1)\n" "-p epsilon : set the epsilon in loss function of SVR (default 0.1)\n" "-e epsilon : set tolerance of termination criterion\n" " -s 0 and 2\n" " |f'(w)|_2 <= eps*min(pos,neg)/l*|f'(w0)|_2,\n" " where f is the primal function and pos/neg are # of\n" " positive/negative data (default 0.01)\n" " -s 11\n" " |f'(w)|_2 <= eps*|f'(w0)|_2 (default 0.001)\n" " -s 1, 3, 4 and 7\n" " Dual maximal violation <= eps; similar to libsvm (default 0.1)\n" " -s 5 and 6\n" " |f'(w)|_1 <= eps*min(pos,neg)/l*|f'(w0)|_1,\n" " where f is the primal function (default 0.01)\n" " -s 12 and 13\n" " |f'(alpha)|_1 <= eps |f'(alpha0)|,\n" " where f is the dual function (default 0.1)\n" "-B bias : if bias >= 0, instance x becomes [x; bias]; if < 0, no bias term added (default -1)\n" "-wi weight: weights adjust the parameter C of different classes (see README for details)\n" "-v n: n-fold cross validation mode\n" "-q : quiet mode (no outputs)\n" "col:\n" " if 'col' is setted, training_instance_matrix is parsed in column format, otherwise is in row format\n" ); } // liblinear arguments struct parameter param; // set by parse_command_line struct problem prob; // set by read_problem struct model *model_; struct feature_node *x_space; int cross_validation_flag; int col_format_flag; int nr_fold; double bias; double do_cross_validation() { int i; int total_correct = 0; double total_error = 0; double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0; double *target = Malloc(double, prob.l); double retval = 0.0; cross_validation(&prob,¶m,nr_fold,target); if(param.solver_type == L2R_L2LOSS_SVR || param.solver_type == L2R_L1LOSS_SVR_DUAL || param.solver_type == L2R_L2LOSS_SVR_DUAL) { for(i=0;i 2) { mxGetString(prhs[2], cmd, mxGetN(prhs[2]) + 1); if((argv[argc] = strtok(cmd, " ")) != NULL) while((argv[++argc] = strtok(NULL, " ")) != NULL) ; } // parse options for(i=1;i=argc && argv[i-1][1] != 'q') // since option -q has no parameter return 1; switch(argv[i-1][1]) { case 's': param.solver_type = atoi(argv[i]); break; case 'c': param.C = atof(argv[i]); break; case 'p': param.p = atof(argv[i]); break; case 'e': param.eps = atof(argv[i]); break; case 'B': bias = atof(argv[i]); break; case 'v': cross_validation_flag = 1; nr_fold = atoi(argv[i]); if(nr_fold < 2) { mexPrintf("n-fold cross validation: n must >= 2\n"); return 1; } break; case 'w': ++param.nr_weight; param.weight_label = (int *) realloc(param.weight_label,sizeof(int)*param.nr_weight); param.weight = (double *) realloc(param.weight,sizeof(double)*param.nr_weight); param.weight_label[param.nr_weight-1] = atoi(&argv[i-1][2]); param.weight[param.nr_weight-1] = atof(argv[i]); break; case 'q': print_func = &print_null; i--; break; default: mexPrintf("unknown option\n"); return 1; } } set_print_string_function(print_func); if(param.eps == INF) { switch(param.solver_type) { case L2R_LR: case L2R_L2LOSS_SVC: param.eps = 0.01; break; case L2R_L2LOSS_SVR: param.eps = 0.001; break; case L2R_L2LOSS_SVC_DUAL: case L2R_L1LOSS_SVC_DUAL: case MCSVM_CS: case L2R_LR_DUAL: param.eps = 0.1; break; case L1R_L2LOSS_SVC: case L1R_LR: param.eps = 0.01; break; case L2R_L1LOSS_SVR_DUAL: case L2R_L2LOSS_SVR_DUAL: param.eps = 0.1; break; } } return 0; } static void fake_answer(int nlhs, mxArray *plhs[]) { int i; for(i=0;i=0) { x_space[j].index = max_index+1; x_space[j].value = prob.bias; j++; } x_space[j++].index = -1; } if(prob.bias>=0) prob.n = max_index+1; else prob.n = max_index; return 0; } // Interface function of matlab // now assume prhs[0]: label prhs[1]: features void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[] ) { const char *error_msg; // fix random seed to have same results for each run // (for cross validation) srand(1); if(nlhs > 1) { exit_with_help(); fake_answer(nlhs, plhs); return; } // Transform the input Matrix to libsvm format if(nrhs > 1 && nrhs < 5) { int err=0; if(!mxIsDouble(prhs[0]) || !mxIsDouble(prhs[1])) { mexPrintf("Error: label vector and instance matrix must be double\n"); fake_answer(nlhs, plhs); return; } if(parse_command_line(nrhs, prhs, NULL)) { exit_with_help(); destroy_param(¶m); fake_answer(nlhs, plhs); return; } if(mxIsSparse(prhs[1])) err = read_problem_sparse(prhs[0], prhs[1]); else { mexPrintf("Training_instance_matrix must be sparse; " "use sparse(Training_instance_matrix) first\n"); destroy_param(¶m); fake_answer(nlhs, plhs); return; } // train's original code error_msg = check_parameter(&prob, ¶m); if(err || error_msg) { if (error_msg != NULL) mexPrintf("Error: %s\n", error_msg); destroy_param(¶m); free(prob.y); free(prob.x); free(x_space); fake_answer(nlhs, plhs); return; } if(cross_validation_flag) { double *ptr; plhs[0] = mxCreateDoubleMatrix(1, 1, mxREAL); ptr = mxGetPr(plhs[0]); ptr[0] = do_cross_validation(); } else { const char *error_msg; model_ = train(&prob, ¶m); error_msg = model_to_matlab_structure(plhs, model_); if(error_msg) mexPrintf("Error: can't convert libsvm model to matrix structure: %s\n", error_msg); free_and_destroy_model(&model_); } destroy_param(¶m); free(prob.y); free(prob.x); free(x_space); } else { exit_with_help(); fake_answer(nlhs, plhs); return; } }