package it.uniroma2.sag.kelp.learningalgorithm.regression.liblinear;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonTypeName;
import it.uniroma2.sag.kelp.data.dataset.Dataset;
import it.uniroma2.sag.kelp.data.label.Label;
import it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm;
import it.uniroma2.sag.kelp.learningalgorithm.LinearMethod;
import it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.solver.L2R_L2_SvrFunction;
import it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.solver.Problem;
import it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.solver.Tron;
import it.uniroma2.sag.kelp.learningalgorithm.regression.RegressionLearningAlgorithm;
import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction;
import it.uniroma2.sag.kelp.predictionfunction.model.BinaryLinearModel;
import it.uniroma2.sag.kelp.predictionfunction.regressionfunction.UnivariateLinearRegressionFunction;
import java.util.Arrays;
import java.util.List;

@JsonTypeName("liblinearregression")
/* loaded from: input_file:it/uniroma2/sag/kelp/learningalgorithm/regression/liblinear/LibLinearRegression.class */
public class LibLinearRegression implements LinearMethod, RegressionLearningAlgorithm, BinaryLearningAlgorithm {
    private Label label;
    private double c;

    @JsonIgnore
    private UnivariateLinearRegressionFunction regressionFunction;
    private double p;
    private String representation;

    public LibLinearRegression(Label label, double d, double d2, String str) {
        this();
        setLabel(label);
        this.c = d;
        this.p = d2;
        setRepresentation(str);
    }

    public LibLinearRegression(double d, double d2, String str) {
        this();
        this.c = d;
        this.p = d2;
        setRepresentation(str);
    }

    public LibLinearRegression() {
        this.c = 1.0d;
        this.p = 0.10000000149011612d;
        this.regressionFunction = new UnivariateLinearRegressionFunction();
        this.regressionFunction.setModel(new BinaryLinearModel());
    }

    public double getC() {
        return this.c;
    }

    public void setC(double d) {
        this.c = d;
    }

    public double getP() {
        return this.p;
    }

    public void setP(double d) {
        this.p = d;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LinearMethod
    public String getRepresentation() {
        return this.representation;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LinearMethod
    public void setRepresentation(String str) {
        this.representation = str;
        this.regressionFunction.getModel().setRepresentation(str);
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public void setLabels(List<Label> list) {
        if (list.size() != 1) {
            throw new IllegalArgumentException("LibLinear algorithm is a binary method which can learn a single Label");
        }
        this.label = list.get(0);
        this.regressionFunction.setLabels(list);
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public List<Label> getLabels() {
        return Arrays.asList(this.label);
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm
    public Label getLabel() {
        return this.label;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm
    public void setLabel(Label label) {
        setLabels(Arrays.asList(label));
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public void learn(Dataset dataset) {
        int numberOfExamples = dataset.getNumberOfExamples();
        double[] dArr = new double[numberOfExamples];
        for (int i = 0; i < numberOfExamples; i++) {
            dArr[i] = this.c;
        }
        Problem problem = new Problem(dataset, this.representation, this.label, Problem.LibLinearSolverType.REGRESSION);
        Tron tron = new Tron(new L2R_L2_SvrFunction(problem, dArr, this.p), 0.001d);
        double[] dArr2 = new double[problem.n];
        tron.tron(dArr2);
        this.regressionFunction.getModel().setHyperplane(problem.getW(dArr2));
        this.regressionFunction.getModel().setRepresentation(this.representation);
        this.regressionFunction.getModel().setBias(0.0f);
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public LibLinearRegression duplicate() {
        LibLinearRegression libLinearRegression = new LibLinearRegression();
        libLinearRegression.setRepresentation(this.representation);
        libLinearRegression.setC(this.c);
        libLinearRegression.setP(this.p);
        return libLinearRegression;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public void reset() {
        this.regressionFunction.reset();
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public UnivariateLinearRegressionFunction getPredictionFunction() {
        return this.regressionFunction;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public void setPredictionFunction(PredictionFunction predictionFunction) {
        this.regressionFunction = (UnivariateLinearRegressionFunction) predictionFunction;
    }
}
