package it.uniroma2.sag.kelp.learningalgorithm.budgetedAlgorithm;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonTypeName;
import it.uniroma2.sag.kelp.data.example.Example;
import it.uniroma2.sag.kelp.data.label.Label;
import it.uniroma2.sag.kelp.kernel.Kernel;
import it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm;
import it.uniroma2.sag.kelp.learningalgorithm.KernelMethod;
import it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm;
import it.uniroma2.sag.kelp.learningalgorithm.MetaLearningAlgorithm;
import it.uniroma2.sag.kelp.learningalgorithm.OnlineLearningAlgorithm;
import it.uniroma2.sag.kelp.predictionfunction.Prediction;
import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction;
import it.uniroma2.sag.kelp.predictionfunction.model.BinaryKernelMachineModel;
import it.uniroma2.sag.kelp.predictionfunction.model.SupportVector;
import java.util.Random;

@JsonTypeName("randomizedPerceptron")
/* loaded from: input_file:it/uniroma2/sag/kelp/learningalgorithm/budgetedAlgorithm/RandomizedBudgetPerceptron.class */
public class RandomizedBudgetPerceptron extends BudgetedLearningAlgorithm implements MetaLearningAlgorithm {
    private static final long DEFAULT_SEED = 1;
    private long initialSeed = 1;

    @JsonIgnore
    private Random randomGenerator = new Random(this.initialSeed);
    private OnlineLearningAlgorithm baseAlgorithm;

    public RandomizedBudgetPerceptron() {
    }

    public RandomizedBudgetPerceptron(int i, OnlineLearningAlgorithm onlineLearningAlgorithm, long j, Label label) {
        setBudget(i);
        setBaseAlgorithm(onlineLearningAlgorithm);
        setSeed(j);
        setLabel(label);
    }

    public void setSeed(long j) {
        this.initialSeed = j;
        this.randomGenerator.setSeed(j);
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public RandomizedBudgetPerceptron duplicate() {
        RandomizedBudgetPerceptron randomizedBudgetPerceptron = new RandomizedBudgetPerceptron();
        randomizedBudgetPerceptron.setBudget(this.budget);
        randomizedBudgetPerceptron.setBaseAlgorithm(this.baseAlgorithm.duplicate());
        randomizedBudgetPerceptron.setSeed(this.initialSeed);
        return randomizedBudgetPerceptron;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public void reset() {
        this.baseAlgorithm.reset();
        this.randomGenerator.setSeed(this.initialSeed);
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.budgetedAlgorithm.BudgetedLearningAlgorithm
    protected Prediction predictAndLearnWithFullBudget(Example example) {
        Prediction predict = this.baseAlgorithm.getPredictionFunction().predict(example);
        if ((predict.getScore(getLabel()).floatValue() > 0.0f) != example.isExampleOf(getLabel())) {
            int nextInt = this.randomGenerator.nextInt(this.budget);
            float f = 1.0f;
            if (!example.isExampleOf(getLabels().get(0))) {
                f = -1.0f;
            }
            ((BinaryKernelMachineModel) this.baseAlgorithm.getPredictionFunction().getModel()).setSupportVector(new SupportVector(f, example), nextInt);
        }
        return predict;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.MetaLearningAlgorithm
    public void setBaseAlgorithm(LearningAlgorithm learningAlgorithm) {
        if (!(learningAlgorithm instanceof OnlineLearningAlgorithm) || !(learningAlgorithm instanceof KernelMethod) || !(learningAlgorithm instanceof BinaryLearningAlgorithm)) {
            throw new IllegalArgumentException("a valid baseAlgorithm for the Randomized Budget Perceptron must implement OnlineLearningAlgorithm, BinaryLeaningAlgorithm and KernelMethod");
        }
        this.baseAlgorithm = (OnlineLearningAlgorithm) learningAlgorithm;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.MetaLearningAlgorithm
    public OnlineLearningAlgorithm getBaseAlgorithm() {
        return this.baseAlgorithm;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public PredictionFunction getPredictionFunction() {
        return this.baseAlgorithm.getPredictionFunction();
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.KernelMethod
    public Kernel getKernel() {
        return ((KernelMethod) this.baseAlgorithm).getKernel();
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.KernelMethod
    public void setKernel(Kernel kernel) {
        ((KernelMethod) this.baseAlgorithm).setKernel(kernel);
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.budgetedAlgorithm.BudgetedLearningAlgorithm
    protected Prediction predictAndLearnWithAvailableBudget(Example example) {
        return this.baseAlgorithm.learn(example);
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public void setPredictionFunction(PredictionFunction predictionFunction) {
        this.baseAlgorithm.setPredictionFunction(predictionFunction);
    }
}
