package it.uniroma2.sag.kelp.examples.main;

import it.uniroma2.sag.kelp.data.dataset.SimpleDataset;
import it.uniroma2.sag.kelp.data.example.Example;
import it.uniroma2.sag.kelp.data.label.Label;
import it.uniroma2.sag.kelp.learningalgorithm.classification.ClassificationLearningAlgorithm;
import it.uniroma2.sag.kelp.predictionfunction.classifier.Classifier;
import it.uniroma2.sag.kelp.utils.JacksonSerializerWrapper;
import it.uniroma2.sag.kelp.utils.evaluation.BinaryClassificationEvaluator;
import it.uniroma2.sag.kelp.utils.evaluation.Evaluator;
import it.uniroma2.sag.kelp.utils.evaluation.MulticlassClassificationEvaluator;
import java.io.File;
import java.util.List;
import org.slf4j.impl.SimpleLogger;

/* loaded from: input_file:it/uniroma2/sag/kelp/examples/main/ClassificationDemo.class */
public class ClassificationDemo {
    public static void main(String[] strArr) throws Exception {
        if (strArr.length != 3) {
            System.err.println("Usage: trainFilePath testFilePath learningAlgorithmInputPath");
            System.exit(0);
        }
        System.setProperty(SimpleLogger.DEFAULT_LOG_LEVEL_KEY, "WARN");
        String str = strArr[0];
        String str2 = strArr[1];
        String str3 = strArr[2];
        JacksonSerializerWrapper jacksonSerializerWrapper = new JacksonSerializerWrapper();
        ClassificationLearningAlgorithm classificationLearningAlgorithm = (ClassificationLearningAlgorithm) jacksonSerializerWrapper.readValue(new File(str3), ClassificationLearningAlgorithm.class);
        System.out.println(jacksonSerializerWrapper.writeValueAsString(classificationLearningAlgorithm));
        SimpleDataset simpleDataset = new SimpleDataset();
        simpleDataset.populate(str);
        SimpleDataset simpleDataset2 = new SimpleDataset();
        simpleDataset2.populate(str2);
        System.out.println("Dataset statistics");
        System.out.print("Training Example number ");
        System.out.println(simpleDataset.getNumberOfExamples());
        System.out.print("Testing Example number ");
        System.out.println(simpleDataset2.getNumberOfExamples());
        List<Label> classificationLabels = simpleDataset.getClassificationLabels();
        for (Label label : classificationLabels) {
            System.out.println("Training Label " + label.toString() + ": " + simpleDataset.getNumberOfPositiveExamples(label));
            System.out.println("Test Label " + label.toString() + ": " + simpleDataset2.getNumberOfPositiveExamples(label));
        }
        boolean z = false;
        if (classificationLabels.size() == 2) {
            z = true;
            classificationLearningAlgorithm.setLabels(classificationLabels.subList(0, 1));
        } else {
            classificationLearningAlgorithm.setLabels(classificationLabels);
        }
        classificationLearningAlgorithm.learn(simpleDataset);
        Classifier predictionFunction = classificationLearningAlgorithm.getPredictionFunction();
        Evaluator binaryClassificationEvaluator = z ? new BinaryClassificationEvaluator(classificationLearningAlgorithm.getLabels().get(0)) : new MulticlassClassificationEvaluator(simpleDataset.getClassificationLabels());
        for (Example example : simpleDataset2.getExamples()) {
            binaryClassificationEvaluator.addCount(example, predictionFunction.predict(example));
        }
        System.out.println("ACC: " + binaryClassificationEvaluator.getPerformanceMeasure("accuracy", new Object[0]));
        if (z) {
            System.out.println("F1: " + binaryClassificationEvaluator.getPerformanceMeasure("F1", new Object[0]));
        }
    }
}
