cbuffer LightBuffer
{
    float4 diffuseColor;
    float3 lightDirection;
    float padding; // Padding to ensure the structure is a multiple of 16 bytes.
    float3 lightPosition; // Add light position
    float padding2; // Padding to ensure the structure is a multiple of 16 bytes.
    float constantAttenuation;
    float linearAttenuation;
    float quadraticAttenuation;
    float padding3; // Padding to ensure the structure is a multiple of 16 bytes.
};

Texture2D shaderTexture;
SamplerState SampleType;

struct PixelInputType
{
    float4 position : SV_POSITION;
    float3 normal : NORMAL;
    float2 tex : TEXCOORD0;
    float3 worldPos : TEXCOORD1; // Add world position
};

float4 CelShadingPixelShader(PixelInputType input) : SV_TARGET
{
    float4 textureColor;
    float lightIntensity;
    float4 finalColor;

    // Sample the pixel color from the texture.
    textureColor = shaderTexture.Sample(SampleType, input.tex);

    float3 normal = normalize(input.normal);

    // Calculate the light vector from the light position to the world position
    float3 lightVector = normalize(lightPosition - input.worldPos);

    // Calculate the light intensity based on the light direction.
    float directionalLightIntensity = saturate(dot(normal, normalize(lightDirection)));

    // Calculate the light intensity based on the light position.
    float positionalLightIntensity = saturate(dot(normal, lightVector));

    // Combine the directional and positional light intensities.
    lightIntensity = max(directionalLightIntensity, positionalLightIntensity);

    // Calculate the distance from the light to the fragment.
    float distance = length(lightPosition - input.worldPos);

    // Apply an attenuation factor based on the distance.
    float attenuation = 1.0f / (constantAttenuation + linearAttenuation * distance + quadraticAttenuation * distance * distance);

    // Combine the light intensity with the attenuation factor.
    lightIntensity *= attenuation;

    // Apply a step function to create the cel shading effect.
    if (lightIntensity > 0.75f)
    {
        lightIntensity = 1.0f; // Brightest level
    }
    else if (lightIntensity > 0.5f)
    {
        lightIntensity = 0.7f; // Mid-bright level
    }
    else if (lightIntensity > 0.25f)
    {
        lightIntensity = 0.4f; // Mid-dark level
    }
    else
    {
        lightIntensity = 0.1f; // Darkest level
    }

    // Simple shadow calculation: if the fragment is behind the light source, it is in shadow.
    float3 toLight = normalize(lightPosition - input.worldPos);
    float shadow = saturate(dot(normal, toLight));
    if (shadow < 0.1f)
    {
        lightIntensity *= 0.5f; // Darken the fragment if it is in shadow
    }

    // Calculate the final color by combining the texture color with the light intensity and diffuse color.
    finalColor = textureColor * diffuseColor * lightIntensity;

    return finalColor;
}