;;; LIBSVM interface
;;; <http://www.csie.ntu.edu.tw/~cjlin/libsvm/>
;;;
;;; Copyright (C) 2006 by Sam Steingold
;;; This is Free Software, covered by the GNU GPL (v2)
;;; See http://www.gnu.org/copyleft/gpl.html

(defpackage "LIBSVM"
  (:modern t) (:use "CL" "FFI")
  (:shadowing-import-from "EXPORTING" #:def-c-enum #:def-c-struct
                          #:def-call-out #:def-c-type #:defun))
(in-package "LIBSVM")
(setf (documentation (find-package "LIBSVM") 'sys::impnotes) "libsvm")

(default-foreign-language :stdc)
(defconstant svm-so
  (namestring (merge-pathnames "svm.so" *load-pathname*)))
;;; types and constants

(def-c-type node (c-struct list (index int) (value double-float)))
(def-c-type problem (c-struct list
  (l int)                               ; number of records
  (y (c-array-ptr double-float))        ; of length l (targets)
  (x (c-array-ptr (c-array-ptr node))))); of length l (predictors)
(def-c-enum svm_type C_SVC NU_SVC ONE_CLASS EPSILON_SVR NU_SVR)
(def-c-enum kernel_type LINEAR POLY RBF SIGMOID PRECOMPUTED)
(def-c-type parameter (c-struct list
  (svm_type int)
  (kernel_type int)
  (degree int)                  ; for poly
  (gamma double-float)          ; for poly/rbf/sigmoid
  (coef0 double-float)          ; for poly/sigmoid
  ;; these are for training only
  (cache_size double-float)     ; in MB
  (eps double-float)            ; stopping criteria
  (C double-float)              ; for C_SVC, EPSILON_SVR and NU_SVR
  (nr_weight int)               ; for C_SVC
  (weight_label (c-array-ptr int)) ; for C_SVC
  (weight (c-array-ptr double-float)) ; for C_SVC
  (nu double-float)                ; for NU_SVC, ONE_CLASS, and NU_SVR
  (p double-float)                 ; for EPSILON_SVR
  (shrinking int)                  ; use the shrinking heuristics
  (probability int)))              ; do probability estimates

;; the defaults are the same as those in svm-train (see README)
(defun make-parameter (&key (list () list-p)
                       ((svm_type svm_type) (if list-p (nth 0 list) C_SVC))
                       ((kernel_type kernel_type) (if list-p (nth 1 list) RBF))
                       ((degree degree) (if list-p (nth 2 list) 3))
                       ((gamma gamma) (if list-p (nth 3 list) 0d0)) ; 1/maxindex
                       ((coef0 coef0) (if list-p (nth 4 list) 0d0))
                       ((cache_size cache_size) (if list-p (nth 5 list) 1d2))
                       ((eps eps) (if list-p (nth 6 list) 1d-3))
                       ((C C) (if list-p (nth 7 list) 1d0))
                       ((nr_weight nr_weight) (if list-p (nth 8 list) 0))
                       ((weight_label weight_label) (if list-p (nth 9 list) ()))
                       ((weight weight) (if list-p (nth 10 list) ()))
                       ((nu nu) (if list-p (nth 11 list) 5d-1))
                       ((p p) (if list-p (nth 12 list) 1d-1))
                       ((shrinking shrinking) (if list-p (nth 13 list) 1))
                       ((probability probability) (if list-p (nth 14 list) 0)))
  (assert (= nr_weight (length weight_label)) (nr_weight weight_label)
          "~S: nr_weight=~:D /= ~:D=(length weight_label)"
          'make-parameter nr_weight (length weight_label))
  (assert (= nr_weight (length weight)) (nr_weight weight)
          "~S: nr_weight=~:D /= ~:D=(length weight)"
          'make-parameter nr_weight (length weight))
  (let ((ret (allocate-shallow 'parameter)))
    (setf (slot (foreign-value ret) 'svm_type) svm_type
          (slot (foreign-value ret) 'kernel_type) kernel_type
          (slot (foreign-value ret) 'degree) degree
          (slot (foreign-value ret) 'gamma) gamma
          (slot (foreign-value ret) 'coef0) coef0
          (slot (foreign-value ret) 'cache_size) cache_size
          (slot (foreign-value ret) 'eps) eps
          (slot (foreign-value ret) 'C) C
          (slot (foreign-value ret) 'nr_weight) nr_weight
          (slot (foreign-value ret) 'weight_label)
          (and (plusp nr_weight)
               (offset (foreign-value (list-to-vector weight_label)) 0 ; FIXME
                       `(c-array double-float ,nr_weight)))
          (slot (foreign-value ret) 'weight)
          (and (plusp nr_weight)
               (offset (foreign-value (list-to-vector weight)) 0 ; FIXME
                       `(c-array double-float ,nr_weight)))
          (slot (foreign-value ret) 'nu) nu
          (slot (foreign-value ret) 'p) p
          (slot (foreign-value ret) 'shrinking) shrinking
          (slot (foreign-value ret) 'probability) probability)
    ;; you must call (destroy-parameter ret) yourself!
    ret))

(defun destroy-parameter (parameter)
  (destroy-param parameter)
  (foreign-free parameter)
  (setf (ffi:validp parameter) nil))

(def-c-type model c-pointer)

(def-call-out train (:library svm-so) (:name "svm_train")
  (:arguments (problem (c-ptr problem)) (param (c-ptr parameter)))
  (:return-type model))

(ffi:def-call-out svm_cross_validation (:library svm-so)
  (:arguments (problem (c-ptr problem)) (param (c-ptr parameter))
              (nr_fold int) (target c-pointer))
  (:return-type nil))
(defun cross-validation (problem param nr_fold)
  (with-foreign-object (target `(c-array double-float
                                         ,(slot (foreign-value problem) 'l)))
    (svm_cross_validation problem param nr_fold target)
    (foreign-value target)))

(def-call-out save-model (:library svm-so) (:name "svm_save_model")
  (:arguments (model_file_name c-string) (model model))
  (:return-type int))
(def-call-out load-model (:library svm-so) (:name "svm_load_model")
  (:arguments (model_file_name c-string))
  (:return-type model))

(def-call-out get-svm-type (:library svm-so) (:name "svm_get_svm_type")
  (:arguments (model model))
  (:return-type int))
(def-call-out get-nr-class (:library svm-so) (:name "svm_get_nr_class")
  (:arguments (model model))
  (:return-type int))
(ffi:def-call-out svm_get_labels (:library svm-so)
  (:arguments (model model) (label c-pointer))
  (:return-type nil))
(defun get-labels (model)
  (with-foreign-object (label `(c-array int ,(get-nr-class model)))
    (svm_get_labels model label)
    (foreign-value label)))
(def-call-out get-svr-probability (:library svm-so)
  (:name "svm_get_svr_probability")
  (:arguments (model model))
  (:return-type double-float))

(ffi:def-call-out svm_predict_values1 (:name "svm_predict_values")
  (:arguments (model model) (x (c-array-ptr node))
              (dec_values (c-ptr double-float) :out))
  (:return-type nil) (:library svm-so))
(ffi:def-call-out svm_predict_values2 (:name "svm_predict_values")
  (:arguments (model model) (x (c-array-ptr node))
              (dec_values c-pointer))
  (:return-type nil) (:library svm-so))
(defun predict-values (model x)
  (case (get-svm-type model)
    ((ONE_CLASS EPSILON_SVR NU_SVR)
     (svm_predict_values1 model x))
    (t (let* ((nr-class (get-nr-class model))
              (len (/ (* nr-class (1- nr-class)) 2)))
         (with-foreign-object (dec-values `(c-array double-float ,len))
           (svm_predict_values2 model x dec-values)
           (foreign-value dec-values))))))

(def-call-out predict (:library svm-so) (:name "svm_predict")
  (:arguments (model model) (x (c-ptr node)))
  (:return-type double-float))
(ffi:def-call-out svm_predict_probability (:library svm-so)
  (:arguments (model model) (x (c-array-ptr node))
              (prob_estimates c-pointer))
  (:return-type double-float))
(defun predict-probability (model x)
  (with-foreign-object (prob_estimates `(c-array int ,(get-nr-class model)))
    (svm_predict_probability model x prob_estimates)
    (foreign-value prob_estimates)))

(ffi:def-call-out svm_destroy_model (:library svm-so)
  (:arguments (model model)) (:return-type nil))
(defun destroy-model (model)
  (svm_destroy_model model)
  (setf (ffi:validp model) nil))
(def-call-out destroy-param (:library svm-so) (:name "svm_destroy_param")
  (:arguments (param c-pointer #|(c-ptr parameter)|#)) (:return-type nil))

(def-call-out check-parameter (:library svm-so) (:name "svm_check_parameter")
  (:arguments (problem (c-ptr problem)) (param (c-ptr parameter)))
  (:return-type c-string))
(def-call-out check-probability-model (:library svm-so)
  (:name "svm_check_probability_model")
  (:arguments (model model))
  (:return-type int))

;;; high-level helpers
(defun alist-to-nodes (alist)
  ;; Lisp ((index . value) ...) --> C [index;value]...[-1;*]
  (let* ((len (length alist))
         (ret (allocate-shallow 'node :count (1+ len))))
    (with-c-place (v ret)
      (loop :for i :upfrom 0 :for (index . value) :in alist :do
        (setf (slot (element v i) 'index) index
              (slot (element v i) 'value) value))
      (setf (slot (element v len) 'index) -1))
    ;; you must call (foreign-free ret) yourself!
    ret))

(defun list-to-vector (list)
  ;; Lisp list --> C array
  (let* ((len (length list)) (ret (allocate-shallow 'double-float :count len)))
    (with-c-place (v ret)
      (loop :for i :upfrom 0 :for value :in list :do
        (setf (element v i) value)))
    ;; you must call (foreign-free ret) yourself!
    ret))

(defun make-problem (&key (l 0) y x)
  (assert (= l (length y)) (l y)
          "~S: l=~:D /= ~:D=(length y)" 'make-problem l (length y))
  (assert (= l (length x)) (l x)
          "~S: l=~:D /= ~:D=(length x)" 'make-problem l (length x))
  (let ((ret (allocate-shallow 'problem)))
    (setf (slot (foreign-value ret) 'l) l
          (slot (foreign-value ret) 'y)
          (offset (foreign-value (list-to-vector y)) 0 ; FIXME
                  `(c-array double-float ,l))
          (slot (foreign-value ret) 'x)
          (offset (foreign-value (alist-to-nodes x)) 0 ; FIXME
                  `(c-array (c-ptr node) ,l)))
    ;; you must call (destroy-problem ret) yourself!
    ;; -- but remember that `model' returned by `train' uses `problem',
    ;;    so you cannot free `problem' until you free all `model's
    ret))

(defun destroy-problem (problem)
  (foreign-free (slot (foreign-value problem) 'y))
  (foreign-free (slot (foreign-value problem) 'x))
  (foreign-free problem)
  (setf (ffi:validp problem) nil))

(defun load-problem (file &key (log *standard-output*))
  ;; load the `problem' object from a standard libsvm/svmlight problem file:
  ;; target index1:value1 index2:value2 ...
  (let ((len 0) y x (maxindex 0))
    (with-open-file (in file)
      (when log
        (format log "~&;; ~S(~S): ~:D byte~:P..."
                'load-problem file (file-length in))
        (force-output log))
      (loop :for line = (read-line in nil nil) :while line
        :unless (or (zerop (length line)) (char= #\# (aref line 0))) :do
        (incf len)
        (multiple-value-bind (target pos) (read-from-string line)
          (push (coerce target 'double-float) y)
          (push
           (loop :with index :and value
             :for colon = (position #\: line :start pos) :while colon :do
             (multiple-value-setq (index pos)
               (parse-integer line :start pos :end colon))
             (multiple-value-setq (value pos)
               (read-from-string line t nil :start (1+ colon)))
             (setq maxindex (max maxindex index))
             :collect (cons (coerce index 'integer)
                            (coerce value 'double-float)))
           x)))
      (when log (format log "~:D record~:P~%" len)))
    (values (make-problem :l len :y (nreverse y) :x (nreverse x))
            maxindex)))

(defun save-problem (file problem &key (log *standard-output*)
                     &aux (size (slot (foreign-value problem) 'l)))
  (with-open-file (out file :direction :output)
    (when log
      (format log "~&;; ~S(~S): ~:D record~:P..." 'save-problem file size)
      (force-output log))
    (with-c-var (y (deref (slot (foreign-value problem) 'y)))
      (with-c-var (x (deref (slot (foreign-value problem) 'x)))
        (dotimes (i size)
          (format out "~G" (element y i))
          (with-c-var (nodes (element x i))
            (loop :for j :upfrom 0 :for node = (element nodes j)
              :for index = (slot (foreign-value node) 'index)
              :for value = (slot (foreign-value node) 'value)
              :while (/= index -1) :do
              (format out " ~D:~G" index value)))
          (terpri out))))
    (when log (format log "~:D byte~:P~%" (file-length out)))))

(pushnew :libsvm *features*)
(provide "libsvm")
(pushnew "LIBSVM" custom:*system-package-list* :test #'string=)
(setf (ext:package-lock "LIBSVM") t)
