//  ************************************************************************************************
//
//  BornAgain: simulate and fit reflection and scattering
//
//! @file      Sim/Contrib/RoughMultiLayerContribution.cpp
//! @brief     Implements class RoughMultiLayerContribution.
//!
//! @homepage  http://www.bornagainproject.org
//! @license   GNU General Public License v3 or higher (see COPYING)
//! @copyright Forschungszentrum Jülich GmbH 2018
//! @authors   Scientific Computing Group at MLZ (see CITATION, AUTHORS)
//
//  ************************************************************************************************

#include "Sim/Contrib/RoughMultiLayerContribution.h"
#include <numbers>
using std::numbers::pi;
#include "Base/Util/Assert.h"
#include "Resample/Element/DiffuseElement.h"
#include "Resample/Flux/ScalarFlux.h"
#include "Resample/Processed/ReSample.h"
#include "Sample/Interface/LayerInterface.h"
#include "Sample/Interface/LayerRoughness.h"
#include "Sample/Multilayer/Layer.h"
#include "Sample/Multilayer/MultiLayer.h"
#include "cerfcpp.h"

// As we say in our 2020 paper (Sect 5.6), diffuse scattering from rough interfaces
// is modelled after Schlomka et al, Phys Rev B, 51, 2311 (1995). They give credit
// for the basic modelling idea to Sinha et al (1988), and for the matrix elements
// to Holy et al (1993), Holy and Baumbach (1994), Sinha (1994) and de Boer (unpublished).
//
// Specifically, we implemented the differential cross section according to lines 2-3
// in Eq 3 of Schlomka et al. Unfortunately, this equation has an incorrect prefactor k^2,
// and so had our implementation up to release 20.0. The correct prefactor is k^4, as in
// Eq 17 of Holy et al (1993). This was corrected in release 20.1 (issue #553).

namespace {

complex_t h_above(complex_t z)
{
    return 0.5 * cerfcx(-mul_I(z) / std::sqrt(2.0));
}
complex_t h_below(complex_t z)
{
    return 0.5 * cerfcx(mul_I(z) / std::sqrt(2.0));
}

complex_t get_refractive_term(const ReSample& re_sample, size_t i_layer, double wavelength)
{
    return re_sample.avgeSlice(i_layer).material().refractiveIndex2(wavelength)
           - re_sample.avgeSlice(i_layer + 1).material().refractiveIndex2(wavelength);
}

complex_t get_sum8terms(const ReSample& re_sample, size_t i_layer, const DiffuseElement& ele)
{
    // Abbreviations:
    //   i/f : initial/final beam
    //   A/B : above/below the interface
    const auto* i_A = dynamic_cast<const ScalarFlux*>(ele.fluxIn(i_layer));
    const auto* f_A = dynamic_cast<const ScalarFlux*>(ele.fluxOut(i_layer));
    const auto* i_B = dynamic_cast<const ScalarFlux*>(ele.fluxIn(i_layer + 1));
    const auto* f_B = dynamic_cast<const ScalarFlux*>(ele.fluxOut(i_layer + 1));
    if (!(i_A && f_A && i_B && f_B))
        throw std::runtime_error(
            "Rough interfaces not yet supported for polarized simulation (issue #564)");

    const complex_t kiz_A = i_A->getScalarKz();
    const complex_t kfz_A = f_A->getScalarKz();
    const complex_t qz1_A = -kiz_A - kfz_A;
    const complex_t qz2_A = -kiz_A + kfz_A;
    const complex_t qz3_A = -qz2_A;
    const complex_t qz4_A = -qz1_A;

    const double thickness = re_sample.avgeSlice(i_layer).thicknessOr0();
    const complex_t T_i_A = i_A->getScalarT() * exp_I(kiz_A * thickness);
    const complex_t R_i_A = i_A->getScalarR() * exp_I(-kiz_A * thickness);
    const complex_t T_f_A = f_A->getScalarT() * exp_I(kfz_A * thickness);
    const complex_t R_f_A = f_A->getScalarR() * exp_I(-kfz_A * thickness);

    const complex_t kiz_B = i_B->getScalarKz();
    const complex_t kfz_B = f_B->getScalarKz();
    const complex_t qz1_B = -kiz_B - kfz_B;
    const complex_t qz2_B = -kiz_B + kfz_B;
    const complex_t qz3_B = -qz2_B;
    const complex_t qz4_B = -qz1_B;

    const LayerRoughness* roughness = re_sample.averageSlices().bottomRoughness(i_layer);
    const double sigma = roughness ? roughness->sigma() : 0.;
    const complex_t term1 = T_i_A * T_f_A * ::h_above(qz1_A * sigma);
    const complex_t term2 = T_i_A * R_f_A * ::h_above(qz2_A * sigma);
    const complex_t term3 = R_i_A * T_f_A * ::h_above(qz3_A * sigma);
    const complex_t term4 = R_i_A * R_f_A * ::h_above(qz4_A * sigma);
    const complex_t term5 = i_B->getScalarT() * f_B->getScalarT() * ::h_below(qz1_B * sigma);
    const complex_t term6 = i_B->getScalarT() * f_B->getScalarR() * ::h_below(qz2_B * sigma);
    const complex_t term7 = i_B->getScalarR() * f_B->getScalarT() * ::h_below(qz3_B * sigma);
    const complex_t term8 = i_B->getScalarR() * f_B->getScalarR() * ::h_below(qz4_B * sigma);

    return term1 + term2 + term3 + term4 + term5 + term6 + term7 + term8;
}

} // namespace


double Compute::roughMultiLayerContribution(const ReSample& re_sample, const DiffuseElement& ele)
{
    if (ele.alphaMean() < 0.0)
        return 0;

    const size_t n_slices = re_sample.numberOfSlices();
    R3 q = ele.meanQ();
    double wavelength = ele.wavelength();
    double autocorr(0.0);
    complex_t crosscorr(0.0, 0.0);

    std::vector<complex_t> rterm(n_slices - 1);
    std::vector<complex_t> sterm(n_slices - 1);

    for (size_t i = 0; i + 1 < n_slices; i++) {
        rterm[i] = ::get_refractive_term(re_sample, i, wavelength);
        sterm[i] = ::get_sum8terms(re_sample, i, ele);
    }

    // auto correlation in each layer (first term in final expression in Eq (3) of Schlomka et al)
    for (size_t i = 0; i + 1 < n_slices; i++)
        if (const LayerRoughness* rough = re_sample.avgeSlice(i + 1).topRoughness())
            autocorr += std::norm(rterm[i]) * std::norm(sterm[i]) * rough->spectralFunction(q);

    // cross correlation between layers (second term in loc. cit.)
    if (re_sample.sample().crossCorrLength() != 0.0) {
        for (size_t j = 0; j < n_slices - 1; j++) {
            for (size_t k = 0; k < n_slices - 1; k++) {
                if (j == k)
                    continue;
                crosscorr += rterm[j] * sterm[j] * re_sample.crossCorrSpectralFun(q, j, k)
                             * std::conj(rterm[k]) * std::conj(sterm[k]);
            }
        }
    }
    // TODO clarify complex vs double
    const double k0 = (2 * pi) / wavelength;
    return pow(k0, 4) / 16 / pi / pi * (autocorr + crosscorr.real());
}
