package it.uniroma2.sag.kelp.examples.demo.mutag;

import it.uniroma2.sag.kelp.data.dataset.SimpleDataset;
import it.uniroma2.sag.kelp.data.label.StringLabel;
import it.uniroma2.sag.kelp.data.manipulator.WLSubtreeMapper;
import it.uniroma2.sag.kelp.kernel.cache.FixSizeKernelCache;
import it.uniroma2.sag.kelp.kernel.graph.ShortestPathKernel;
import it.uniroma2.sag.kelp.kernel.standard.LinearKernelCombination;
import it.uniroma2.sag.kelp.kernel.vector.LinearKernel;
import it.uniroma2.sag.kelp.learningalgorithm.classification.libsvm.BinaryCSvmClassification;
import it.uniroma2.sag.kelp.utils.ExperimentUtils;
import it.uniroma2.sag.kelp.utils.evaluation.BinaryClassificationEvaluator;
import java.util.List;

/* loaded from: input_file:it/uniroma2/sag/kelp/examples/demo/mutag/MutagClassification.class */
public class MutagClassification {
    private static final String GRAPH_REPRESENTATION_NAME = "inline";
    private static final String VECTORIAL_LINEARIZATION_NAME = "wl";

    public static void main(String[] strArr) throws Exception {
        SimpleDataset simpleDataset = new SimpleDataset();
        simpleDataset.populate("src/main/resources/mutag/mutag.txt");
        StringLabel stringLabel = new StringLabel("1");
        System.out.println("Training set statistics");
        System.out.print("Examples number ");
        System.out.println(simpleDataset.getNumberOfExamples());
        System.out.print("Positive examples ");
        System.out.println(simpleDataset.getNumberOfPositiveExamples(stringLabel));
        System.out.print("Negative examples ");
        System.out.println(simpleDataset.getNumberOfNegativeExamples(stringLabel));
        simpleDataset.manipulate(new WLSubtreeMapper(GRAPH_REPRESENTATION_NAME, VECTORIAL_LINEARIZATION_NAME, 4));
        StringLabel stringLabel2 = new StringLabel("1");
        BinaryClassificationEvaluator binaryClassificationEvaluator = new BinaryClassificationEvaluator(stringLabel2);
        LinearKernelCombination linearKernelCombination = new LinearKernelCombination();
        linearKernelCombination.addKernel(1.0f, new LinearKernel(VECTORIAL_LINEARIZATION_NAME));
        linearKernelCombination.addKernel(1.0f, new ShortestPathKernel(GRAPH_REPRESENTATION_NAME));
        linearKernelCombination.setKernelCache(new FixSizeKernelCache(simpleDataset.getNumberOfExamples()));
        float f = 0.0f;
        List nFoldCrossValidation = ExperimentUtils.nFoldCrossValidation(10, new BinaryCSvmClassification(linearKernelCombination, stringLabel2, 1.0f, 1.0f), simpleDataset, binaryClassificationEvaluator);
        for (int i = 0; i < 10; i++) {
            float performanceMeasure = ((BinaryClassificationEvaluator) nFoldCrossValidation.get(i)).getPerformanceMeasure("accuracy", new Object[0]);
            System.out.println("fold " + (i + 1) + " accuracy: " + performanceMeasure);
            f += performanceMeasure;
        }
        System.out.println("MEAN ACC: " + (f / 10.0f));
    }
}
