package it.uniroma2.sag.kelp.learningalgorithm.classification.passiveaggressive;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonProperty;
import it.uniroma2.sag.kelp.data.dataset.Dataset;
import it.uniroma2.sag.kelp.data.example.Example;
import it.uniroma2.sag.kelp.learningalgorithm.PassiveAggressive;
import it.uniroma2.sag.kelp.learningalgorithm.classification.ClassificationLearningAlgorithm;
import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryClassifier;
import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryMarginClassifierOutput;

/* loaded from: input_file:it/uniroma2/sag/kelp/learningalgorithm/classification/passiveaggressive/PassiveAggressiveClassification.class */
public abstract class PassiveAggressiveClassification extends PassiveAggressive implements ClassificationLearningAlgorithm {
    protected Loss loss = Loss.HINGE;
    protected float cp = this.c;
    protected boolean fairness = false;

    @JsonIgnore
    protected BinaryClassifier classifier;

    /* loaded from: input_file:it/uniroma2/sag/kelp/learningalgorithm/classification/passiveaggressive/PassiveAggressiveClassification$Loss.class */
    public enum Loss {
        HINGE,
        RAMP
    }

    public boolean isFairness() {
        return this.fairness;
    }

    public void setFairness(boolean z) {
        this.fairness = z;
    }

    public float getCp() {
        return this.cp;
    }

    public void setCp(float f) {
        this.cp = f;
    }

    public float getCn() {
        return this.c;
    }

    public void setCn(float f) {
        this.c = f;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.PassiveAggressive
    @JsonIgnore
    public float getC() {
        return this.c;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.PassiveAggressive
    @JsonProperty
    public void setC(float f) {
        super.setC(f);
        this.cp = f;
    }

    public Loss getLoss() {
        return this.loss;
    }

    public void setLoss(Loss loss) {
        this.loss = loss;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public BinaryClassifier getPredictionFunction() {
        return this.classifier;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.OnlineLearningAlgorithm
    public BinaryMarginClassifierOutput learn(Example example) {
        BinaryMarginClassifierOutput predict = this.classifier.predict(example);
        float f = 0.0f;
        if (predict.isClassPredicted(this.label) != example.isExampleOf(this.label)) {
            f = 1.0f + Math.abs(predict.getScore(this.label).floatValue());
        } else if (Math.abs(predict.getScore(this.label).floatValue()) < 1.0f) {
            f = 1.0f - Math.abs(predict.getScore(this.label).floatValue());
        }
        if (f > 0.0f && (f < 2.0f || this.loss != Loss.RAMP)) {
            float f2 = this.c;
            if (example.isExampleOf(this.label)) {
                f2 = this.cp;
            }
            float computeWeight = computeWeight(example, f, this.classifier.getModel().getSquaredNorm(example), f2);
            if (!example.isExampleOf(this.label)) {
                computeWeight *= -1.0f;
            }
            getPredictionFunction().getModel().addExample(computeWeight, example);
        }
        return predict;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.PassiveAggressive, it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public void learn(Dataset dataset) {
        if (this.fairness) {
            this.cp = (this.c * dataset.getNumberOfNegativeExamples(this.label)) / dataset.getNumberOfPositiveExamples(this.label);
        }
        super.learn(dataset);
    }
}
