#include <catch2/catch.hpp>
#include "schneide_base/Numerical.hpp"

using namespace Catch::literals;
using namespace schneide;

TEST_CASE("Gauss smoothes a peak")
{
  auto v = std::vector<double>{1.0, 1.0, 10.0, 1.0, 1.0};
  v = gauss(v, 1.0);
  REQUIRE(v.size() == 5);
  for (auto const& each : v)
  {
    REQUIRE(each < 10.0);
    REQUIRE(each > 1.0);
  }
}

TEST_CASE("Gauss leaves flat functions alone")
{
  auto v = gauss(std::vector<double>{1.0, 1.0, 1.0, 1.0, 1.0}, 1.0);
  REQUIRE(v.size() == 5);
  for (auto const& each : v)
    REQUIRE(each == 1.0_a);
}

TEST_CASE("Can move values by convolution")
{
  auto v = convolve({0.0, 3.0, 0.0, 0.0}, {1.0, 0.0, 0.0});
  REQUIRE(v == std::vector<double>{0.0, 0.0, 3.0, 0.0});
}

TEST_CASE("Can not interpolate empty values")
{
  REQUIRE_THROWS(interpolate({}, 42.0));
}

TEST_CASE("Interpolate values")
{
  auto v = std::vector<double>{1.0, 2.0, 6.0};
  SECTION("Below zero returns first")
  {
    REQUIRE(interpolate(v, -7.0) == 1.0);
  }
  SECTION("Past max returns last")
  {
    REQUIRE(interpolate(v, 2.1) == 6.0);
  }
  SECTION("Integer return exact value")
  {
    REQUIRE(interpolate(v, 0.0) == 1.0);
    REQUIRE(interpolate(v, 1.0) == 2.0);
    REQUIRE(interpolate(v, 2.0) == 6.0);
  }
  SECTION("Interpolate in the middle")
  {
    REQUIRE(interpolate(v, 0.4) == 1.4_a);
    REQUIRE(interpolate(v, 0.7) == 1.7_a);
    REQUIRE(interpolate(v, 1.25) == 3.0_a);
    REQUIRE(interpolate(v, 1.5) == 4.0_a);
  }
}

TEST_CASE("findRelative")
{
  auto v = std::vector<double>{3.0, 7.0, 14.0, 38.0};
  SECTION("below returns zero")
  {
    REQUIRE(findRelative(v, 2.12) == 0.0);
  }
  SECTION("above returns N-1")
  {
    REQUIRE(findRelative(v, 40.0) == 3.0);
  }
  SECTION("in the middle")
  {
    REQUIRE(findRelative(v, 5.0) == 0.5_a);
    REQUIRE(findRelative(v, 10.5) == 1.5_a);
    REQUIRE(findRelative(v, 20.0) == 2.25_a);
  }
}
