package mltk.core;

import java.io.File;
import mltk.cmdline.Argument;
import mltk.cmdline.CmdLineParser;
import mltk.core.io.InstancesReader;
import mltk.core.io.InstancesWriter;
import mltk.util.Random;
import weka.core.json.JSONInstances;

/* loaded from: input_file:mltk/core/InstancesSplitter.class */
public class InstancesSplitter {

    /* loaded from: input_file:mltk/core/InstancesSplitter$Options.class */
    static class Options {

        @Argument(name = "-r", description = "attribute file path")
        String attPath = null;

        @Argument(name = "-i", description = "input dataset path", required = true)
        String inputPath = null;

        @Argument(name = "-o", description = "output directory path", required = true)
        String outputDirPath = null;

        @Argument(name = "-m", description = "splitting mode:parameter. Splitting mode can be split (s) and cross validation (c) (default: c:5)")
        String crossValidationMode = "c:5";

        @Argument(name = "-s", description = "seed of the random number generator (default: 0)")
        long seed = 0;

        Options() {
        }
    }

    public static Instances[] split(Instances instances, double d) {
        Instances instances2 = new Instances(instances);
        instances2.shuffle();
        Instances instances3 = new Instances(instances2.attributes, instances2.targetAtt);
        Instances instances4 = new Instances(instances2.attributes, instances2.targetAtt);
        int size = (int) (instances2.size() * d);
        for (int i = 0; i < size; i++) {
            instances3.add(instances2.get(i));
        }
        for (int i2 = size; i2 < instances2.size(); i2++) {
            instances4.add(instances2.get(i2));
        }
        return new Instances[]{instances3, instances4};
    }

    public static Instances[] split(Instances instances, int i) {
        Instances instances2 = new Instances(instances);
        instances2.shuffle();
        Instances[] instancesArr = new Instances[i];
        for (int i2 = 0; i2 < instancesArr.length; i2++) {
            instancesArr[i2] = new Instances(instances2.attributes, instances2.targetAtt);
        }
        for (int i3 = 0; i3 < instances2.size(); i3++) {
            instancesArr[i3 % instancesArr.length].add(instances2.get(i3));
        }
        return instancesArr;
    }

    public static Instances[][] createCrossValidationFolds(Instances instances, int i) {
        Instances[] split = split(instances, i);
        Instances[][] instancesArr = new Instances[i][2];
        for (int i2 = 0; i2 < i; i2++) {
            instancesArr[i2][1] = split[i2];
            instancesArr[i2][0] = new Instances(instances.attributes, instances.targetAtt);
            for (int i3 = 0; i3 < i; i3++) {
                if (i2 != i3) {
                    instancesArr[i2][0].instances.addAll(split[i3].instances);
                }
            }
        }
        return instancesArr;
    }

    public static Instances[][] createCrossValidationFolds(Instances instances, int i, double d) {
        Instances[] split = split(instances, i);
        Instances[][] instancesArr = new Instances[i][3];
        for (int i2 = 0; i2 < i; i2++) {
            instancesArr[i2][2] = split[i2];
            Instances instances2 = new Instances(instances.attributes, instances.targetAtt);
            for (int i3 = 0; i3 < i; i3++) {
                if (i2 != i3) {
                    instances2.instances.addAll(split[i3].instances);
                }
            }
            Instances[] split2 = split(instances2, d);
            instancesArr[i2][0] = split2[0];
            instancesArr[i2][1] = split2[1];
        }
        return instancesArr;
    }

    public static void main(String[] strArr) throws Exception {
        Options options = new Options();
        CmdLineParser cmdLineParser = new CmdLineParser(InstancesSplitter.class, options);
        String[] strArr2 = null;
        try {
            cmdLineParser.parse(strArr);
            strArr2 = options.crossValidationMode.split(JSONInstances.SPARSE_SEPARATOR);
            if (strArr2.length != 2) {
                throw new IllegalArgumentException();
            }
        } catch (IllegalArgumentException e) {
            cmdLineParser.printUsage();
            System.exit(1);
        }
        Random.getInstance().setSeed(options.seed);
        Instances read = InstancesReader.read(options.attPath, options.inputPath);
        String str = new File(options.attPath).getName().split("\\.")[0];
        File file = new File(options.outputDirPath);
        if (!file.exists()) {
            file.mkdir();
        }
        String str2 = strArr2[0];
        switch (str2.hashCode()) {
            case 99:
                if (str2.equals("c")) {
                    int parseInt = Integer.parseInt(strArr2[1]);
                    if (strArr2.length == 2) {
                        Instances[][] createCrossValidationFolds = createCrossValidationFolds(read, parseInt);
                        for (int i = 0; i < createCrossValidationFolds.length; i++) {
                            String str3 = String.valueOf(options.outputDirPath) + File.separator + "cv." + i;
                            File file2 = new File(str3);
                            if (!file2.exists()) {
                                file2.mkdir();
                            }
                            InstancesWriter.write(createCrossValidationFolds[i][0], String.valueOf(str3) + File.separator + str + ".attr", String.valueOf(str3) + File.separator + str + ".train.all");
                            InstancesWriter.write(createCrossValidationFolds[i][1], String.valueOf(str3) + File.separator + str + ".test");
                        }
                        return;
                    }
                    Instances[][] createCrossValidationFolds2 = createCrossValidationFolds(read, parseInt, Double.parseDouble(strArr2[2]));
                    for (int i2 = 0; i2 < createCrossValidationFolds2.length; i2++) {
                        String str4 = String.valueOf(options.outputDirPath) + File.separator + "cv." + i2;
                        File file3 = new File(str4);
                        if (!file3.exists()) {
                            file3.mkdir();
                        }
                        InstancesWriter.write(createCrossValidationFolds2[i2][0], String.valueOf(str4) + File.separator + str + ".attr", String.valueOf(str4) + File.separator + str + ".train.all");
                        InstancesWriter.write(createCrossValidationFolds2[i2][1], String.valueOf(str4) + File.separator + str + ".test");
                    }
                    return;
                }
                return;
            case 115:
                if (str2.equals("s")) {
                    Instances[] split = split(read, Double.parseDouble(strArr2[1]));
                    InstancesWriter.write(split[0], String.valueOf(options.outputDirPath) + File.separator + str + ".attr", String.valueOf(options.outputDirPath) + File.separator + str + ".train");
                    InstancesWriter.write(split[1], String.valueOf(options.outputDirPath) + File.separator + str + ".valid");
                    return;
                }
                return;
            default:
                return;
        }
    }
}
