package it.uniroma2.sag.kelp.learningalgorithm.classification.libsvm;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonTypeName;
import it.uniroma2.sag.kelp.data.dataset.Dataset;
import it.uniroma2.sag.kelp.data.label.Label;
import it.uniroma2.sag.kelp.kernel.Kernel;
import it.uniroma2.sag.kelp.learningalgorithm.KernelMethod;
import it.uniroma2.sag.kelp.learningalgorithm.classification.ClassificationLearningAlgorithm;
import it.uniroma2.sag.kelp.learningalgorithm.classification.libsvm.solver.LibCSvmSolver;
import it.uniroma2.sag.kelp.learningalgorithm.classification.libsvm.solver.SvmSolution;
import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction;
import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryKernelMachineClassifier;
import it.uniroma2.sag.kelp.predictionfunction.classifier.Classifier;
import it.uniroma2.sag.kelp.predictionfunction.model.BinaryKernelMachineModel;

@JsonTypeName("binaryCSvmClassification")
/* loaded from: input_file:it/uniroma2/sag/kelp/learningalgorithm/classification/libsvm/BinaryCSvmClassification.class */
public class BinaryCSvmClassification extends LibCSvmSolver implements ClassificationLearningAlgorithm, KernelMethod {
    private boolean fairness;

    @JsonIgnore
    protected BinaryKernelMachineClassifier classifier;

    public BinaryCSvmClassification() {
        this.fairness = false;
        initializeClassifier();
    }

    public BinaryCSvmClassification(Kernel kernel, Label label, float f, float f2) {
        super(kernel, f, f2);
        this.fairness = false;
        initializeClassifier();
        setLabel(label);
    }

    public BinaryCSvmClassification(Kernel kernel, Label label, float f, float f2, boolean z) {
        super(kernel, f, f2);
        this.fairness = false;
        initializeClassifier();
        setLabel(label);
        this.fairness = z;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public BinaryCSvmClassification duplicate() {
        BinaryCSvmClassification binaryCSvmClassification = new BinaryCSvmClassification(this.kernel, this.label, this.cp, this.cn, this.fairness);
        binaryCSvmClassification.setEps(this.eps);
        return binaryCSvmClassification;
    }

    private float[] getCSvmP(Dataset dataset) {
        float[] fArr = new float[dataset.getNumberOfExamples()];
        for (int i = 0; i < fArr.length; i++) {
            fArr[i] = -1.0f;
        }
        return fArr;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public BinaryKernelMachineClassifier getPredictionFunction() {
        return this.classifier;
    }

    private void initializeClassifier() {
        BinaryKernelMachineModel binaryKernelMachineModel = new BinaryKernelMachineModel();
        this.classifier = new BinaryKernelMachineClassifier();
        this.classifier.setModel(binaryKernelMachineModel);
    }

    public boolean isFairness() {
        return this.fairness;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public void learn(Dataset dataset) {
        if (isFairness()) {
            this.cp = (this.cn * dataset.getNumberOfNegativeExamples(this.label)) / dataset.getNumberOfPositiveExamples(this.label);
            info("cn=" + this.cn + " cp=" + this.cp);
        }
        learn(dataset, getCSvmAlpha(dataset));
    }

    private Classifier learn(Dataset dataset, float[] fArr) {
        float[] cSvmP = getCSvmP(dataset);
        int[] iArr = new int[dataset.getNumberOfExamples()];
        for (int i = 0; i < iArr.length; i++) {
            if (dataset.getExamples().get(i).isExampleOf(this.label)) {
                iArr[i] = 1;
            } else {
                iArr[i] = -1;
            }
        }
        SvmSolution solve = solve(dataset.getNumberOfExamples(), dataset, cSvmP, iArr, fArr);
        this.classifier.getModel().setBias(-solve.getRho());
        float[] alphas = solve.getAlphas();
        for (int i2 = 0; i2 < dataset.getNumberOfExamples(); i2++) {
            if (alphas[i2] != 0.0f) {
                this.classifier.getModel().addExample(iArr[i2] * alphas[i2], dataset.getExamples().get(i2));
            }
        }
        this.classifier.getModel().setKernel(this.kernel);
        return this.classifier;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public void reset() {
        this.classifier.reset();
    }

    public void setFairness(boolean z) {
        this.fairness = z;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.KernelMethod
    public void setKernel(Kernel kernel) {
        this.kernel = kernel;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public void setPredictionFunction(PredictionFunction predictionFunction) {
        this.classifier = (BinaryKernelMachineClassifier) predictionFunction;
    }
}
