import React, { Component } from "react";
import PropTypes from "prop-types";
import Button from "@mui/material/Button";
import Dialog from "@mui/material/Dialog";
import DialogContent from "@mui/material/DialogContent";
import DialogTitle from "@mui/material/DialogTitle";
import Slide from "@mui/material/Slide";
import ReplayRoundedIcon from "@mui/icons-material/ReplayRounded";
import {
  Stepper,
  Step,
  StepButton,
  MobileStepper,
  TextField,
  LinearProgress,
} from "@mui/material";
import { ValidatorForm } from "react-material-ui-form-validator";
import AddCircleOutlineIcon from "@mui/icons-material/AddCircleOutline";
import Backend from "../../common/utils/Backend";
import Draggable from "react-draggable";
import Paper from "@mui/material/Paper";
import withStyles from "@mui/styles/withStyles";
import { KeyboardArrowLeft, KeyboardArrowRight } from "@mui/icons-material";

import TrainModelStep0 from "./TrainAIModelComponents/TrainModelStep0";
import TrainModelStep1 from "./TrainAIModelComponents/TrainModelStep1";
import TrainModelStep2 from "./TrainAIModelComponents/TrainModelStep2";
import TrainModelStep3 from "./TrainAIModelComponents/TrainModelStep3";
import TrainModelStep4 from "./TrainAIModelComponents/TrainModelStep4";

// define the component's styling
const styles = () => ({
  dialogContent: {
    overflowY: "hidden",
    minHeight: "100%",
  },
  tableDialog: {
    "& .MuiPaper-root": {
      maxWidth: "none",
    },
    "& .MuiDialogContent-root": {
      padding: 0,
    },
  },
  dialogRowContainer: {
    width: "100%",
    height: "100%",
    display: "grid",
    gridTemplateRows: "auto 1fr auto",
    overflow: "hidden",
  },
  content: {
    padding: "0 24px",
    overflow: "auto",
  },
  trainModelInfo: { margin: "5px", width: "calc(50% - 10px)" },
});

// finds string between stringBefore and stringAfter
function getStringBetween(dataString, stringBefore, stringAfter) {
  const regexString = "(?<=" + stringBefore + ").*(?=" + stringAfter + ")";
  return dataString.match(regexString);
}

function PaperComponent(props) {
  return (
    <Draggable
      handle="#draggable-dialog-title"
      cancel={'[class*="MuiDialogContent-root"]'}
    >
      <Paper {...props} />
    </Draggable>
  );
}

const Transition = React.forwardRef(function Transition(props, ref) {
  return <Slide direction="up" ref={ref} {...props} />;
});

class TrainAIModelDialog extends Component {
  constructor(props) {
    super(props);
    this.selectedModel = { name: "newModel" };
    this.tileProgress = 0;
    this.buildingModelProgress = 0;
    this.processingFilesCount = 0;
    this.multiplier = 0.0;
  }

  updateTrainingProgressObject = (dataString, modelType, isVdlModel) => {
    const { projectContext } = this.props;
    projectContext.aiStateObject.modelType = modelType;
    if (dataString.includes("Traceback (most recent call last):")) {
      // further error checking
      if (
        modelType === "instance segmentation" &&
        (dataString.includes("FileNotFoundError: [Errno 2]") ||
          dataString.includes("CUDA out of memory"))
      ) {
        projectContext.aiStateObject.errorLabel =
          "Failed: Try a smaller tile size!";
      } else {
        projectContext.aiStateObject.errorLabel = "Failed";
        window.openErrorDialog(dataString);
      }

      projectContext.aiStateObject.trainingFinished = true;
      projectContext.aiStateObject.trainingSuccessful = false;
      projectContext.aiStateObject.showTrainingButton = true;
      projectContext.aiStateObject.buildingModel = false;
      this.forceUpdate();
      return;
    } else if (dataString.includes("CUDA out of memory")) {
      projectContext.aiStateObject.errorLabel =
        "Failed: Try a smaller tile size / batch size!";

      projectContext.aiStateObject.trainingFinished = true;
      projectContext.aiStateObject.trainingSuccessful = false;
      projectContext.aiStateObject.showTrainingButton = true;
      projectContext.aiStateObject.buildingModel = false;
    }
    if (dataString.includes("training done")) {
      projectContext.aiStateObject.trainingFinished = true;
      projectContext.aiStateObject.showTrainingButton = true;
      this.forceUpdate();
      return;
    }

    if (dataString.includes("Train on")) {
      projectContext.aiStateObject.maxEpochs = dataString.match("[0-9]+");
    }

    if (dataString.includes("processing dataset: ")) {
      let tempProgress = parseFloat(
        dataString.split("processing dataset:")[1]
      ).toFixed(2);
      if (tempProgress) {
        projectContext.aiStateObject.overallProgress = tempProgress;
      }
    }

    let stepsProgress = dataString.match("[0-9]+/[0-9]+");
    let epochsProgress = projectContext.aiStateObject.epochsProgress;
    let vallossProgress = null;
    let valmeanIoUProgress = null;
    let lossProgress = null;
    let meanIoUProgress = null;

    if (
      modelType === "object detection" ||
      (modelType === "instance segmentation" && !isVdlModel)
    ) {
      // instance segmentation and object detection data (comdl models)
      // get training progress
      epochsProgress = getStringBetween(dataString, " \\[", "]\\[");
      if (stepsProgress)
        projectContext.aiStateObject.stepsProgress = stepsProgress;

      lossProgress = dataString.match("(?<=loss:).*");
      if (lossProgress)
        projectContext.aiStateObject.metricsDict["Training Loss"] =
          lossProgress;

      // get validation metrics
      if (dataString.includes("Epoch(val)")) {
        let metrics = dataString.substring(
          dataString.indexOf("]	") + 2,
          dataString.length
        );
        let splitMetrics = metrics.split(", ");
        for (let i = 0; i < splitMetrics.length; i++) {
          let keyValue = splitMetrics[i].split(": ");
          projectContext.aiStateObject.metricsDict[keyValue[0]] = keyValue[1];
        }
      }

      // only update if Epochs are included
      if (epochsProgress) {
        projectContext.aiStateObject.buildingModel = false;
        projectContext.aiStateObject.showTrainingProgress = true;
        projectContext.aiStateObject.epochsProgress =
          epochsProgress + "/" + projectContext.aiStateObject.maxEpochs;
      }
    } else if (modelType === "instance segmentation" && isVdlModel) {
      // instance segmentation data (vdl models)
      if (stepsProgress)
        projectContext.aiStateObject.stepsProgress = stepsProgress;
      epochsProgress = getStringBetween(dataString, "Epoch:", " Loss:");
      if (epochsProgress) epochsProgress = String(parseInt(epochsProgress) + 1);

      if (dataString.includes("Training:")) {
        // training metrics
        lossProgress = getStringBetween(dataString, "Loss: ", " Best:");
        if (lossProgress)
          projectContext.aiStateObject.metricsDict["Training Loss"] =
            lossProgress;
      } else if (dataString.includes("Evaluating:")) {
        // validation metrics
        vallossProgress = getStringBetween(dataString, "Loss: ", ": ");
        if (vallossProgress)
          projectContext.aiStateObject.metricsDict["Validation Loss"] =
            vallossProgress;
      }

      // only update if Epochs are included
      if (epochsProgress) {
        projectContext.aiStateObject.buildingModel = false;
        projectContext.aiStateObject.showTrainingProgress = true;
        projectContext.aiStateObject.epochsProgress =
          epochsProgress + "/" + projectContext.aiStateObject.maxEpochs;
      }
    } else if (modelType === "classification") {
      // classification data
      epochsProgress = dataString.match("(?<=Epoch:).*(?=, Loss)");
      if (stepsProgress)
        projectContext.aiStateObject.stepsProgress = stepsProgress;

      if (dataString.includes("Valid")) {
        vallossProgress = dataString.match("(?<=Loss:).*(?=, F1)");
      } else {
        lossProgress = dataString.match("(?<=Loss:).*(?=, Best)");
        if (lossProgress)
          projectContext.aiStateObject.metricsDict["Training Loss"] =
            lossProgress;
      }

      if (vallossProgress)
        projectContext.aiStateObject.metricsDict["Validation Loss"] =
          vallossProgress;

      // only update if Epochs are included
      if (epochsProgress) {
        projectContext.aiStateObject.buildingModel = false;
        projectContext.aiStateObject.showTrainingProgress = true;
        projectContext.aiStateObject.epochsProgress =
          epochsProgress + "/" + projectContext.aiStateObject.maxEpochs;
      }
    } else {
      // segmentation data
      epochsProgress = dataString.match("(?<=Epoch:).*(?=, Loss)");
      if (stepsProgress)
        projectContext.aiStateObject.stepsProgress = stepsProgress;

      if (dataString.includes("Valid")) {
        vallossProgress = dataString.match("(?<=Loss:).*(?=, Accuracy)");
        valmeanIoUProgress = dataString.match("(?<=IoU:).*(?=, : )");
      } else {
        lossProgress = dataString.match("(?<=Loss:).*(?=, Best)");
        if (lossProgress)
          projectContext.aiStateObject.metricsDict["Training Loss"] =
            lossProgress;
        meanIoUProgress = dataString.match("(?<=IoU:).*(?=, : )");
      }

      if (meanIoUProgress)
        projectContext.aiStateObject.metricsDict["Training Loss"] =
          lossProgress;

      if (meanIoUProgress)
        projectContext.aiStateObject.metricsDict["Training mean IoU"] =
          meanIoUProgress;

      if (vallossProgress)
        projectContext.aiStateObject.metricsDict["Validation Loss"] =
          vallossProgress;

      if (valmeanIoUProgress)
        projectContext.aiStateObject.metricsDict["Validation mean IoU"] =
          valmeanIoUProgress;

      // only update if Epochs are included
      if (epochsProgress) {
        projectContext.aiStateObject.buildingModel = false;
        projectContext.aiStateObject.showTrainingProgress = true;
        projectContext.aiStateObject.epochsProgress =
          epochsProgress + "/" + projectContext.aiStateObject.maxEpochs;
      }
    }
    this.forceUpdate();
  };

  componentDidMount = () => {
    window.updateTrainingData = this.updateTrainingData;
  };

  componentWillUnmount = () => {
    window.updateTrainingData = this.updateTrainingProgressObject;
  };

  setSelectedModel = (model) => {
    this.selectedModel = model;
  };

  updateTrainingData = (dataString, modelType, isVdlModel) => {
    this.updateTrainingProgressObject(dataString, modelType, isVdlModel);
  };

  updateFormDataGeneral = (key, value) => {
    const { projectContext } = this.props;
    projectContext.aiStateObject.formData[key] = value;
    this.forceUpdate();
  };

  updateFormDataStructures = (structure, checked) => {
    const { projectContext } = this.props;
    if (checked) {
      if (
        !projectContext.aiStateObject.formData["selStructures"].includes(
          structure
        )
      ) {
        projectContext.aiStateObject.formData["selStructures"].push(structure);
      }
    } else {
      let index =
        projectContext.aiStateObject.formData["selStructures"].indexOf(
          structure
        );
      if (index > -1) {
        projectContext.aiStateObject.formData["selStructures"].splice(index, 1);
      }
    }
    this.forceUpdate();
  };

  updateFormDataAdvSettings = (key, value) => {
    const { projectContext } = this.props;
    projectContext.aiStateObject.formData["advancedSettings"][key] = value;
    this.forceUpdate();
  };

  updateFormDataAdvSettingsEncoders = (key, value, architecture) => {
    const { projectContext } = this.props;
    projectContext.aiStateObject.formData["advancedSettings"][key][
      architecture
    ] = value;
    this.forceUpdate();
  };

  updateFormDataAdvSettingsFlChannels = (channel, checked) => {
    const { projectContext } = this.props;
    if (checked) {
      if (
        !projectContext.aiStateObject.formData["advancedSettings"][
          "flChannels"
        ].includes(channel.name)
      ) {
        projectContext.aiStateObject.formData["advancedSettings"][
          "flChannels"
        ].push(channel.name);
      }
    } else {
      let index = projectContext.aiStateObject.formData["advancedSettings"][
        "flChannels"
      ].indexOf(channel.name);
      if (index > -1) {
        projectContext.aiStateObject.formData["advancedSettings"][
          "flChannels"
        ].splice(index, 1);
      }
    }
    projectContext.aiStateObject.formData["advancedSettings"]["in_channels"] =
      projectContext.aiStateObject.formData["advancedSettings"][
        "flChannels"
      ].length;
    this.forceUpdate();
  };

  updateFormDataMetaData = (key, value) => {
    const { projectContext } = this.props;
    projectContext.aiStateObject.formData["metaData"][key] = value;
    this.forceUpdate();
  };

  updateFormDataAugmentations = (e, augmentation) => {
    const { projectContext } = this.props;
    if (e.currentTarget.checked) {
      projectContext.aiStateObject.formData["augmentations"].push(augmentation);
    } else {
      let index =
        projectContext.aiStateObject.formData["augmentations"].indexOf(
          augmentation
        );
      if (index > -1) {
        projectContext.aiStateObject.formData["augmentations"].splice(index, 1);
      }
    }
    this.forceUpdate();
  };

  render() {
    const { classes, dialog, projectContext, ...propsWithoutClasses } =
      this.props;

    const handleClickOpen = () => {
      const { projectContext } = this.props;
      projectContext.aiStateObject.open = true;
      this.forceUpdate();
    };

    const handleClose = (e) => {
      const { projectContext } = this.props;
      projectContext.aiStateObject.open = false;
      e.preventDefault();
      this.forceUpdate();
    };

    const initTraining = () => {
      projectContext.aiStateObject.overallProgress = 0.0;
      projectContext.aiStateObject.showTrainingButton = false;
      projectContext.aiStateObject.showTrainingProgress = false;
      projectContext.aiStateObject.trainingSuccessful = true;
      projectContext.aiStateObject.buildingModel = true;
      projectContext.aiStateObject.trainingFinished = false;
      this.tileProgress = 0;
      this.buildingModelProgress = 0;
      this.processingFilesCount = 0;
      this.forceUpdate();
    };

    const handleNext = (datasetOnly = false) => {
      const { projectContext } = this.props;
      const uniqueName =
        projectContext.aiStateObject.formData.metaData.uniqueName;
      const validChar =
        projectContext.aiStateObject.formData.metaData.validChar;
      const newModel = projectContext.aiStateObject.formData.metaData.newModel;
      if (typeof datasetOnly !== "boolean") {
        datasetOnly = false;
      }
      if (projectContext.aiStateObject.activeStep === 4) {
        if (projectContext.aiStateObject.formData.selStructures.length === 0) {
          projectContext.aiStateObject.activeStep = 1;
          this.forceUpdate();
          return;
        }

        // skip redirect if dataset only or existing model,
        if (!datasetOnly && newModel) {
          // else check for valid name
          if (!(uniqueName && validChar)) {
            projectContext.aiStateObject.activeStep = 3;
            this.forceUpdate();
            return;
          }
        }

        initTraining();

        projectContext.aiStateObject.formData["selStructures"].forEach(
          (selStructure) => {
            let idx = this.props.structures.findIndex(
              (element) => element === selStructure
            );

            if (idx > -1) {
              projectContext.aiStateObject.formData["structureIndices"].push(
                idx
              );
            }
          }
        );

        projectContext.aiStateObject.open = false;
        let data = {
          parameters: projectContext.aiStateObject.formData,
          projectId: this.props.projectId,
        };
        data.parameters.datasetOnly = datasetOnly;

        this.props.onSave(() => {
          // send parameters to backend for ai training
          projectContext.aiStateObject.startTrainingTime = new Date();
          Backend.aiTrainingSignalR(
            data,
            (progress) => {
              console.log("Training Result");
              console.log(progress);
            },
            () => {
              console.log("finished!");
            },
            (error) => {
              window.openErrorDialog(error);
            }
          );
        });
        this.forceUpdate();
        return;
      }
      projectContext.aiStateObject.activeStep += 1;
      this.forceUpdate();
    };

    const handleBack = () => {
      projectContext.aiStateObject.activeStep -= 1;
      this.forceUpdate();
    };

    const onStopTraining = () => {
      Backend.stopAITraining((result) => {
        console.log("stopping ai training", result);
      });
    };

    const steps = [
      "Select Model Type",
      "Select Structures",
      "Augmentations",
      "Meta Data",
      "Settings",
    ];

    return (
      <div>
        {dialog === "TrainModelPages" && (
          <div>
            {projectContext.aiStateObject.showTrainingButton && (
              <div
                onClick={handleClickOpen}
                style={{ textAlign: "center", cursor: "pointer", margin: 15 }}
              >
                <AddCircleOutlineIcon fontSize="default" />
                <span
                  style={{
                    fontWeight: "bold",
                    marginLeft: 10,
                    position: "relative",
                    top: 2,
                  }}
                >
                  Train Custom AI Model
                </span>
              </div>
            )}

            {projectContext.aiStateObject.buildingModel && (
              <div
                style={{
                  margin: 20,
                }}
              >
                <TextField
                  variant="standard"
                  style={{
                    marginTop: "10px",
                    textAlign: "center",
                    width: "100%",
                  }}
                  label="Creating dataset:"
                  fullWidth
                  value={
                    projectContext.aiStateObject.overallProgress > 0
                      ? projectContext.aiStateObject.overallProgress + "%"
                      : ""
                  }
                  InputProps={{
                    readOnly: true,
                    disableUnderline: true,
                  }}
                />
                <LinearProgress
                  variant={
                    projectContext.aiStateObject.overallProgress > 0
                      ? "determinate"
                      : "indeterminate"
                  }
                  value={Number(projectContext.aiStateObject.overallProgress)}
                />
              </div>
            )}
            {projectContext.aiStateObject.showTrainingProgress && (
              <div
                style={{
                  marginTop: 20,
                }}
              >
                {(projectContext.aiStateObject.modelType ===
                  "object detection" ||
                  projectContext.aiStateObject.modelType ===
                    "instance segmentation" ||
                  projectContext.aiStateObject.modelType === "segmentation" ||
                  projectContext.aiStateObject.modelType ===
                    "classification") && (
                  <div>
                    <TextField
                      className={classes.trainModelInfo}
                      label="Epochs:"
                      value={projectContext.aiStateObject.epochsProgress}
                      InputProps={{
                        readOnly: true,
                      }}
                    />
                    <TextField
                      className={classes.trainModelInfo}
                      label="Steps:"
                      value={projectContext.aiStateObject.stepsProgress}
                      InputProps={{
                        readOnly: true,
                      }}
                    />
                    <br />
                    {Object.entries(
                      projectContext.aiStateObject.metricsDict
                    ).map(([key, value]) => {
                      return (
                        <TextField
                          key={key}
                          className={classes.trainModelInfo}
                          label={key}
                          value={value}
                          InputProps={{
                            readOnly: true,
                          }}
                        />
                      );
                    })}
                  </div>
                )}

                {!projectContext.aiStateObject.showTrainingButton && (
                  <div>
                    <Button
                      disabled={true}
                      className={classes.trainModelInfo}
                      variant="contained"
                      color="primary"
                      onClick={onStopTraining}
                    >
                      Interrupt Training
                    </Button>
                  </div>
                )}

                {projectContext.aiStateObject.showOptimizationProgress && (
                  <div>
                    <TextField
                      style={{
                        marginTop: "10px",
                        textAlign: "center",
                        width: "100%",
                      }}
                      label="Optimize Deep Learning Model:"
                      fullWidth
                      value={
                        projectContext.aiStateObject.modelOptimizeProgress + "%"
                      }
                      InputProps={{
                        readOnly: true,
                      }}
                    />
                    <LinearProgress
                      variant="determinate"
                      value={projectContext.aiStateObject.modelOptimizeProgress}
                    />
                  </div>
                )}

                {projectContext.aiStateObject.trainingFailed && (
                  <div>
                    <h4
                      style={{
                        textAlign: "center",
                        marginTop: "20px",
                        color: "red",
                      }}
                    >
                      {projectContext.aiStateObject.errorLabel}
                    </h4>
                    <Button
                      style={{ marginTop: "5px" }}
                      fullWidth
                      tooltip="Go back to the start page"
                      variant="contained"
                      color="primary"
                      startIcon={<ReplayRoundedIcon />}
                      onClick={() => {
                        projectContext.aiStateObject.buildingModel = false;
                        projectContext.aiStateObject.showTrainingProgress = false;
                        projectContext.aiStateObject.hideStartTraining = false;
                        projectContext.aiStateObject.trainingFinished = false;
                        projectContext.aiStateObject.epochsProgress = "";
                        projectContext.aiStateObject.stepsProgress = "";
                        projectContext.aiStateObject.meanIoUProgress = "";
                        projectContext.aiStateObject.lossProgress = "";
                        projectContext.aiStateObject.valmeanIoUProgress = "";
                        projectContext.aiStateObject.vallossProgress = "";
                        this.forceUpdate();
                      }}
                    >
                      Restart Training
                    </Button>
                  </div>
                )}
              </div>
            )}
            {projectContext.aiStateObject.trainingFinished && (
              <div>
                <h4
                  style={{
                    textAlign: "center",
                    marginTop: "20px",
                    color: projectContext.aiStateObject.trainingSuccessful
                      ? "green"
                      : "red",
                  }}
                >
                  {projectContext.aiStateObject.trainingSuccessful
                    ? "Successful"
                    : projectContext.aiStateObject.errorLabel}
                </h4>
              </div>
            )}
            <Dialog
              open={projectContext.aiStateObject.open}
              TransitionComponent={Transition}
              keepMounted
              onClose={handleClose}
              PaperComponent={PaperComponent}
              className={classes.tableDialog}
            >
              <DialogContent className={classes.dialogContent}>
                <ValidatorForm
                  onSubmit={handleNext}
                  style={{ width: "900px", height: "600px" }}
                >
                  <div className={classes.dialogRowContainer}>
                    <DialogTitle
                      style={{ cursor: "move" }}
                      id="draggable-dialog-title"
                    >
                      {projectContext.aiStateObject.activeStep == 0
                        ? "Select Model Type"
                        : projectContext.aiStateObject.activeStep == 1
                        ? "Select Structures"
                        : projectContext.aiStateObject.activeStep == 2
                        ? "Augmentations"
                        : projectContext.aiStateObject.activeStep == 3
                        ? "Meta Data"
                        : "Settings"}
                    </DialogTitle>
                    <div className={classes.content}>
                      {projectContext.aiStateObject.activeStep === 0 && (
                        <TrainModelStep0
                          key="0"
                          formData={projectContext.aiStateObject.formData}
                          updateFormDataGeneral={this.updateFormDataGeneral}
                          {...propsWithoutClasses}
                        />
                      )}
                      {projectContext.aiStateObject.activeStep === 1 && (
                        <TrainModelStep1
                          key="1"
                          formData={projectContext.aiStateObject.formData}
                          updateFormDataGeneral={this.updateFormDataGeneral}
                          updateFormDataStructures={
                            this.updateFormDataStructures
                          }
                          structures={this.props.structures}
                          roiLayers={this.props.roiLayers}
                          {...propsWithoutClasses}
                        />
                      )}
                      {projectContext.aiStateObject.activeStep === 2 && (
                        <TrainModelStep2
                          key="2"
                          formData={projectContext.aiStateObject.formData}
                          updateFormDataGeneral={this.updateFormDataGeneral}
                          updateFormDataAugmentations={
                            this.updateFormDataAugmentations
                          }
                          {...propsWithoutClasses}
                        />
                      )}
                      {projectContext.aiStateObject.activeStep === 3 && (
                        <TrainModelStep3
                          key="3"
                          formData={projectContext.aiStateObject.formData}
                          updateFormDataGeneral={this.updateFormDataGeneral}
                          updateFormDataMetaData={this.updateFormDataMetaData}
                          updateFormDataAdvSettings={
                            this.updateFormDataAdvSettings
                          }
                          newModelName={
                            projectContext.aiStateObject.newModelName
                          }
                          selectedModel={this.selectedModel}
                          setSelectedModel={this.setSelectedModel}
                        />
                      )}
                      {projectContext.aiStateObject.activeStep === 4 && (
                        <TrainModelStep4
                          key="4"
                          formData={projectContext.aiStateObject.formData}
                          ome={this.props.ome}
                          histogramConfig={
                            this.props.histogramConfig[this.props.fileId]
                          }
                          selectedModel={this.selectedModel}
                          structures={this.props.structures}
                          roiLayers={this.props.roiLayers}
                          updateFormDataAdvSettings={
                            this.updateFormDataAdvSettings
                          }
                          updateFormDataAdvSettingsEncoders={
                            this.updateFormDataAdvSettingsEncoders
                          }
                          updateFormDataAdvSettingsFlChannels={
                            this.updateFormDataAdvSettingsFlChannels
                          }
                          createDataSet={() => {
                            handleNext(true);
                          }}
                        />
                      )}
                    </div>

                    <div style={{ background: "#fafafa" }}>
                      <Stepper
                        alternativeLabel
                        nonLinear
                        activeStep={projectContext.aiStateObject.activeStep}
                        style={{
                          width: "80%",
                          left: "10%",
                          position: "relative",
                          marginBottom: "-50px",
                          paddingBottom: "10px",
                          marginTop: "25px",
                          background: "#fafafa",
                        }}
                      >
                        {steps.map((label, index) => {
                          const stepProps = {};
                          const buttonProps = {};
                          return (
                            <Step key={label} {...stepProps}>
                              <StepButton
                                onClick={() => {
                                  projectContext.aiStateObject.activeStep =
                                    index;
                                  this.forceUpdate();
                                }}
                                {...buttonProps}
                              >
                                {label}
                              </StepButton>
                            </Step>
                          );
                        })}
                      </Stepper>

                      <MobileStepper
                        variant="dots"
                        steps={5}
                        position="static"
                        activeStep={projectContext.aiStateObject.activeStep}
                        nextButton={
                          <Button size="small" type="submit">
                            {projectContext.aiStateObject.activeStep === 4
                              ? "Train"
                              : "Next"}
                            <KeyboardArrowRight />
                          </Button>
                        }
                        backButton={
                          <Button
                            size="small"
                            disabled={
                              projectContext.aiStateObject.activeStep < 1
                            }
                            onClick={handleBack}
                          >
                            <KeyboardArrowLeft />
                            {"Back"}
                          </Button>
                        }
                      />
                    </div>
                  </div>
                </ValidatorForm>
              </DialogContent>
            </Dialog>
          </div>
        )}
      </div>
    );
  }
}

// define the component's interface
TrainAIModelDialog.propTypes = {
  classes: PropTypes.object.isRequired,
  handleOptionsClose: PropTypes.func,
  structures: PropTypes.array,
  roiLayers: PropTypes.array,
  formDataAICockpit: PropTypes.object,
  setAvailableModels: PropTypes.func,
  setFormDataAICockpit: PropTypes.func,
  projectId: PropTypes.string,
  onSave: PropTypes.func,
  ome: PropTypes.object,
  dialog: PropTypes.string,
  projectStringProperties: PropTypes.object,
  histogramConfig: PropTypes.object,
  projectContext: PropTypes.object,
  fileId: PropTypes.string,
};

export default withStyles(styles)(TrainAIModelDialog);
