#ifndef SPLAT_BRDF_DEFINED
#define SPLAT_BRDF_DEFINED

float ComputeDistanceBaseRoughness(float distanceIntersectionToShadedPoint, float distanceIntersectionToProbeCenter, float perceptualRoughness){
    float newPerceptualRoughness = clamp(distanceIntersectionToShadedPoint / distanceIntersectionToProbeCenter * perceptualRoughness, 0, perceptualRoughness);
    return lerp(newPerceptualRoughness, perceptualRoughness, perceptualRoughness);
}

//SOURCE - https://github.com/Unity-Technologies/Graphics/blob/504e639c4e07492f74716f36acf7aad0294af16e/Packages/com.unity.render-pipelines.core/ShaderLibrary/GeometricTools.hlsl#L78
//This simplified version assume that we care about the result only when we are inside the box
//NOTE: Untouched from HDRP
float IntersectRayAABBSimple(float3 start, float3 dir, float3 boxMin, float3 boxMax){
    float3 invDir = rcp(dir);

    // Find the ray intersection with box plane
    float3 rbmin = (boxMin - start) * invDir;
    float3 rbmax = (boxMax - start) * invDir;

    float3 rbminmax = float3((dir.x > 0.0) ? rbmax.x : rbmin.x, (dir.y > 0.0) ? rbmax.y : rbmin.y, (dir.z > 0.0) ? rbmax.z : rbmin.z);

    return min(min(rbminmax.x, rbminmax.y), rbminmax.z);
}

//SOURCE - https://github.com/Unity-Technologies/Graphics/blob/504e639c4e07492f74716f36acf7aad0294af16e/Packages/com.unity.render-pipelines.high-definition/Runtime/Lighting/LightEvaluation.hlsl  
//return projectionDistance, can be used in ComputeDistanceBaseRoughness formula
//return in R the unormalized corrected direction which is used to fetch cubemap but also its length represent the distance of the capture point to the intersection
//Length R can be reuse as a parameter of ComputeDistanceBaseRoughness for distIntersectionToProbeCenter
//NOTE: Modified to be much simpler, and to work with the Built-In Render Pipeline (BIRP)
float EvaluateLight_EnvIntersection(float3 worldSpacePosition, inout float3 R, float3 boxMin, float3 boxMax, float3 probePos){
    float projectionDistance = IntersectRayAABBSimple(worldSpacePosition, R, boxMin, boxMax);

    R = (worldSpacePosition + projectionDistance * R) - probePos;

    return projectionDistance;
}


float3 GetEnvironmentReflections(float3 reflDir, float3 worldPos, float roughness){
    float3 baseReflDir = reflDir;
    float roughness0 = roughness;
    #ifdef UNITY_SPECCUBE_BOX_PROJECTION
        if (unity_SpecCube0_ProbePosition.w > 0){
            float projectionDistance0 = EvaluateLight_EnvIntersection(worldPos, baseReflDir, unity_SpecCube0_BoxMin.xyz, unity_SpecCube0_BoxMax.xyz, unity_SpecCube0_ProbePosition.xyz);
            float distanceBasedRoughness0 = ComputeDistanceBaseRoughness(projectionDistance0, length(baseReflDir), roughness0);
            roughness0 = lerp(roughness, distanceBasedRoughness0, _ContactHardening);
        }
    #endif
    roughness0 *= 1.7-0.7*roughness0;
    float4 envSample0 = UNITY_SAMPLE_TEXCUBE_LOD(unity_SpecCube0, baseReflDir, roughness0 * UNITY_SPECCUBE_LOD_STEPS);
    float3 p0 = DecodeHDR(envSample0, unity_SpecCube0_HDR);
    [branch]
    if (unity_SpecCube0_BoxMin.w < 0.99999){
        float3 blendReflDir = reflDir;
        float roughness1 = roughness;
        #ifdef UNITY_SPECCUBE_BOX_PROJECTION
            if (unity_SpecCube1_ProbePosition.w > 0){
                float projectionDistance1 = EvaluateLight_EnvIntersection(worldPos, blendReflDir, unity_SpecCube1_BoxMin.xyz, unity_SpecCube1_BoxMax.xyz, unity_SpecCube1_ProbePosition.xyz);
                float distanceBasedRoughness1 = ComputeDistanceBaseRoughness(projectionDistance1, length(blendReflDir), roughness1);
                roughness1 = lerp(roughness, distanceBasedRoughness1, _ContactHardening);
            }
        #endif
        roughness1 *= 1.7-0.7*roughness1;
        float4 envSample1 = UNITY_SAMPLE_TEXCUBE_SAMPLER_LOD(unity_SpecCube1, unity_SpecCube0, blendReflDir, roughness1 * UNITY_SPECCUBE_LOD_STEPS);
        float3 p1 = DecodeHDR(envSample1, unity_SpecCube1_HDR);
        p0 = lerp(p1, p0, unity_SpecCube0_BoxMin.w);
    }
    return p0;
}

void CalculateBRDF(inout LightingData ld, InputData id, v2f i){

    float roughSq = max(id.roughness * id.roughness, 0.003);
    float NdotV = abs(dot(id.normal, ld.viewDir));
    specularTint = lerp(unity_ColorSpaceDielectricSpec.rgb, id.albedo, id.metallic);

    float3 reflDir = reflect(-ld.viewDir, id.normal);
    float surfaceReduction = 1.0 / (roughSq*roughSq + 1.0);
    float grazingTerm = saturate((1-id.roughness) + (1-ld.omr));
    float3 fresnel = FresnelLerp(specularTint, grazingTerm, lerp(1, NdotV, _FresnelStrength*_FresnelToggle));
    float horizon = min(1 + dot(reflDir, id.normal), 1);
    float3 reflAdjust = fresnel * surfaceReduction * horizon * horizon;

    #if defined(_REFLECTIONS_ON)
        float3 reflCol = GetEnvironmentReflections(reflDir, i.worldPos, id.roughness) * _ReflectionStrength;
        ld.reflectionCol += (reflCol * reflAdjust * id.occlusion);
    #endif

    #if LTCGI_ENABLED
        ld.reflectionCol += ld.ltcgiSpecularity * reflAdjust * id.occlusion;
    #endif
    
    #if defined(_SSR_ON)
        float4 ssr = 0;
        [branch]
        if (((_VRSSR == 0 && IsNotVR()) || _VRSSR == 1) && _SSRStrength > 0){
            ssr = GetSSR(i.worldPos, ld.viewDir, reflDir, id.normal, 1-id.roughness, id.albedo, id.metallic, i.grabUV);
            if (_SSREdgeFade == 0)
                ssr.a = ssr.a > 0 ? 1 : 0;
            ssr.rgb *= reflAdjust * id.occlusion;
            ld.reflectionCol = lerp(ld.reflectionCol, ssr.rgb, ssr.a * saturate(_SSRStrength));
        }
    #endif

    #if defined(_SPECULAR_HIGHLIGHTS_ON)
        float3 halfVector = Unity_SafeNormalize(ld.lightDir + ld.viewDir);
        float LdotH = saturate(dot(ld.lightDir, halfVector));
        float NdotH = saturate(dot(id.normal, halfVector));
        float3 fresnelTerm = FresnelTerm(specularTint, LdotH);
        float V = SmithJointGGXVisibilityTerm(ld.NdotL, NdotV, roughSq);
        float D = GGXTerm(NdotH, roughSq);
        float specularTerm = V * D * UNITY_PI;
        ld.specHighlightCol = ld.directCol * fresnelTerm * specularTerm * _SpecularHighlightStrength;
        #if defined(_SSR_ON)
            ld.specHighlightCol *= (1-ssr.a);
        #endif
    #endif

    ld.reflectionCol += (ld.lmSpec * reflAdjust * UNITY_PI * _BakeryLMSpecStrength * 0.1);

    #if defined(UNITY_PASS_FORWARDBASE)
        #if defined(LIGHTMAP_ON) || defined(DYNAMICLIGHTMAP_ON)
            if (_SpecularOcclusionToggle == 1){
                float3 lightmap = Desaturate(ld.indirectCol);
                lightmap = GetContrast(lightmap, _SpecularOcclusionContrast);
                lightmap = lerp(lightmap, GetHDR(lightmap), _SpecularOcclusionHDR);
                lightmap *= _SpecularOcclusionBrightness;
                lightmap *= _SpecularOcclusionTint;
                #if LTCGI_ENABLED
                    lightmap += ld.ltcgiDiffuse;
                #endif
                specularOcclusion = saturate(lerp(1, lightmap, _SpecularOcclusionStrength));
                ld.reflectionCol *= specularOcclusion;
                ld.specHighlightCol *= specularOcclusion;
            }
        #else
            if (_SpecularOcclusionToggle == 1){
                specularOcclusion = lerp(1, ld.atten * saturate((ld.VNdotL * ld.VNdotL) + ld.VNdotL), 0.9);
                specularOcclusion = saturate(lerp(1, specularOcclusion, _SpecularOcclusionToggle*_SpecularOcclusionStrength));
                ld.reflectionCol *= specularOcclusion;
            }
        #endif
    #endif
}

#endif