#include "Spectrum.hpp"
#include <cassert>
#include <fstream>
#include <numeric>
#include "schneide_base/CSVReader.hpp"
#include "schneide_base/Logger.hpp"
#include "schneide_base/Numerical.hpp"

using namespace vera;
using namespace schneide;

Spectrum::Spectrum() = default;
Spectrum::~Spectrum() = default;

Interval Spectrum::wavenumbers() const
{
  return mWavenumbers;
}

Interval& Spectrum::wavenumbers()
{
  return mWavenumbers;
}

std::vector<double> const& Spectrum::intensities() const
{
  return mIntensities;
}

std::vector<double>& Spectrum::intensities()
{
  return mIntensities;
}

Spectrum::Spectrum(Interval wavenumbers, std::vector<double> intensities)
: mWavenumbers(wavenumbers), mIntensities(std::move(intensities))
{
}

Spectrum::Spectrum(Line line, std::size_t N)
: mWavenumbers(line.x()), mIntensities(N)
{
  auto range = line.y().over(N);
  for (std::size_t i = 0; i < N; ++i) mIntensities[i] = range.fromRelative(i);
}

void Spectrum::saveCSV(std::string const& filename) const
{
  Logger::get()->info("Storing scan as {0}", filename);

  std::ofstream file(filename, std::ios::trunc);
  if (!file.good())
    throw std::invalid_argument(
      "Unable to open file " + filename + " for writing");

  auto range = mWavenumbers.over(mIntensities.size());
  for (std::size_t i = 0; i < mIntensities.size(); ++i)
  {
    file << range.fromRelative(i) << "," << mIntensities[i] << std::endl;
  }
}

void Spectrum::saveTaggedCSV(
  std::string const& experimentTag, std::string const& typeTag) const
{
  std::string filename(fmt::format("{0}-{1}.csv", experimentTag, typeTag));
  saveCSV(filename);
}

double Spectrum::at(double wavenumber) const
{
  return interpolate(mIntensities, index(wavenumber));
}

double Spectrum::index(double wavenumber) const
{
  return mWavenumbers.over(mIntensities.size()).toRelative(wavenumber);
}

std::size_t Spectrum::clampedIndex(long unclamped) const
{
  return static_cast<std::size_t>(
    std::min(std::max(unclamped, 0L), static_cast<long>(mIntensities.size())));
}

std::size_t Spectrum::indexCeil(double wavenumber) const
{
  return clampedIndex(static_cast<long>(std::ceil(index(wavenumber))));
}

std::size_t Spectrum::indexFloor(double wavenumber) const
{
  return clampedIndex(static_cast<long>(std::floor(index(wavenumber))));
}

double Spectrum::binWidth() const
{
  return mWavenumbers.over(binCount()).length();
}

std::size_t Spectrum::binCount() const
{
  return mIntensities.size();
}

Spectrum vera::logOrZero(Spectrum spectrum)
{
  for (auto& each : spectrum.intensities())
  {
    if (each > 0.0)
      each = std::log(each);
    else
      each = 0.0;
  }
  return spectrum;
}

Spectrum vera::operator-(Spectrum lhs)
{
  auto const N = lhs.intensities().size();
  auto& lhsValues = lhs.intensities();
  for (std::size_t i = 0; i < N; ++i)
  {
    lhsValues[i] = -lhsValues[i];
  }
  return lhs;
}

Spectrum vera::operator-(Spectrum lhs, Spectrum const& rhs)
{
  if (lhs.wavenumbers() != rhs.wavenumbers() ||
      lhs.intensities().size() != rhs.intensities().size())
  {
    throw std::invalid_argument("Spectrum structures do not match.");
  }

  auto const N = lhs.intensities().size();
  auto& lhsValues = lhs.intensities();
  auto const& rhsValues = rhs.intensities();
  for (std::size_t i = 0; i < N; ++i)
  {
    lhsValues[i] -= rhsValues[i];
  }

  return lhs;
}

Spectrum vera::operator/(Spectrum lhs, Spectrum const& rhs)
{
  if (lhs.wavenumbers() != rhs.wavenumbers() ||
      lhs.intensities().size() != rhs.intensities().size())
  {
    throw std::invalid_argument("Spectrum structures do not match.");
  }

  auto const N = lhs.intensities().size();
  auto& lhsValues = lhs.intensities();
  auto const& rhsValues = rhs.intensities();
  for (std::size_t i = 0; i < N; ++i)
  {
    lhsValues[i] /= rhsValues[i];
  }

  return lhs;
}

Spectrum vera::gauss(Spectrum spectrum, double sigmaWaveNumber)
{
  if (sigmaWaveNumber == 0.0)
    return spectrum;

  return Spectrum(
    spectrum.wavenumbers(), schneide::gauss(std::move(spectrum.intensities()),
                              std::abs(sigmaWaveNumber / spectrum.binWidth())));
}

void vera::saveTaggedCSVFor(
  std::vector<std::pair<std::string, Ptr<Spectrum>>> spectrums,
  std::string const& experimentTag, std::string const& typeTag)
{
  if (spectrums.empty())
    return;

  // Make sure all spectrums have the same format
  auto front = spectrums.front().second;
  auto N = front->intensities().size();
  for (auto const& each : spectrums)
  {
    auto const& spectrum = each.second;
    if (spectrum->wavenumbers() != front->wavenumbers() ||
        spectrum->intensities().size() != N)
      throw std::runtime_error("Incompatible spectrums");
  }

  constexpr auto separator = ";";
  std::string filename(fmt::format("{0}-{1}.csv", experimentTag, typeTag));
  Logger::get()->info("Storing scan as {0}", filename);

  std::ofstream file(filename, std::ios::trunc);
  if (!file.good())
    throw std::invalid_argument(
      "Unable to open file " + filename + " for writing");

  // Write the header line
  file << "Wavenumber";
  for (auto const& source : spectrums)
  {
    file << separator << source.first;
  }
  file << std::endl;

  auto range = front->wavenumbers().over(N);
  for (std::size_t i = 0; i < N; ++i)
  {
    file << fmt::format("{0:.2f}", range.fromRelative(i));
    for (auto const& source : spectrums)
    {
      file << separator << source.second->intensities()[i];
    }
    file << std::endl;
  }
  file.close();
}

double Spectrum::wavenumberAt(std::size_t index) const
{
  auto range = wavenumbers().over(mIntensities.size());
  return range.fromRelative(static_cast<double>(index));
}

std::unordered_map<std::string, Ptr<Spectrum>> vera::loadSpectrums(
  std::string const& filename)
{
  auto parsed = CSVReader::parseFile(filename, ';');
  auto header = parsed.at(0);
  std::vector<std::string> names;
  for (std::size_t i = 1; i < header.size(); ++i)
    names.push_back(header.at(i).asString());

  std::vector<std::vector<double>> intensities;
  intensities.resize(names.size());

  for (std::size_t row = 1; row < parsed.size(); ++row)
  {
    for (std::size_t i = 0; i < names.size(); ++i)
      intensities[i].push_back(parsed.at(row).at(i + 1).asDouble());
  }

  auto from = parsed.at(1).at(0).asDouble();
  auto to = parsed.at(parsed.size() - 1).at(0).asDouble();

  Interval wavenumbers(from, to);
  std::unordered_map<std::string, Ptr<Spectrum>> result;
  for (std::size_t i = 0; i < names.size(); ++i)
  {
    result[names[i]] =
      std::make_shared<Spectrum>(wavenumbers, std::move(intensities[i]));
  }
  return result;
}

PeakInfo vera::PeakInfo::computeFrom(Spectrum const& spectrum,
  double peakPosition, double searchWidth, double isolation)
{
  auto const& intensities = spectrum.intensities();
  auto const N = intensities.size();
  auto begin = intensities.begin();
  auto peakBegin = spectrum.indexCeil(peakPosition - searchWidth) + begin;
  auto peakEnd = spectrum.indexFloor(peakPosition + searchWidth) + begin;

  // Support high -> low spectrums
  if (peakBegin > peakEnd)
  {
    std::swap(peakBegin, peakEnd);
  }

  auto peakIndex = std::max_element(peakBegin, peakEnd) - begin;
  auto peakWavenumber = spectrum.wavenumberAt(peakIndex);
  auto peakIntensity = intensities[peakIndex];

  auto leftMin = spectrum.indexFloor(peakWavenumber - isolation);
  auto rightMin = spectrum.indexCeil(peakWavenumber + isolation);

  // Support high -> low spectrums
  if (leftMin > rightMin)
    std::swap(leftMin, rightMin);

  // Move to the left into a local minimum
  for (; leftMin > 0; --leftMin)
  {
    if (intensities[leftMin - 1] > intensities[leftMin])
      break;
  }

  // Move to the right into a local minimum
  for (; rightMin + 1 < N; ++rightMin)
  {
    if (intensities[rightMin + 1] > intensities[rightMin])
      break;
  }

  Line base{{spectrum.wavenumberAt(leftMin), spectrum.wavenumberAt(rightMin)},
    {intensities[leftMin], intensities[rightMin]}};

  return {peakWavenumber, peakIntensity, base};
}

double vera::PeakInfo::relativePeak() const
{
  return peakIntensity - baseLine(peakWavenumber);
}
