any_hit.wgsl

webgpu-based path tracer

src/shaders/any_hit.wgsl

4.57 KB
fn is_occluded(pos: vec3f, normal: vec3f, light_dir: vec3f, light_distance: f32) -> bool {
    var shadow_ray: Ray;
    shadow_ray.origin = offset_ray(pos, normal);
    shadow_ray.direction = light_dir;
    var shadow_hit = trace_any(shadow_ray, light_distance);
    return shadow_hit;
}

fn trace_any(ray: Ray, t_max: f32) -> bool {
    var node_idx_stack: array<u32, 64>;
    var stack_ptr: i32 = 0;
    var current_node_idx: u32 = 0;

    while (true) {
        let node = node_tree.nodes[current_node_idx];

        let primitive_count = u32(node.primitive_count);
        let child_or_prim_idx = u32(node.left_child);
        if (primitive_count == 0u) {
            // internal node
            let left_child_idx = child_or_prim_idx;
            let right_child_idx = child_or_prim_idx + 1u;

            // use t_max for pruning
            let hit1 = any_hit_aabb(ray, node_tree.nodes[left_child_idx].min_corner, node_tree.nodes[left_child_idx].max_corner, t_max);
            let hit2 = any_hit_aabb(ray, node_tree.nodes[right_child_idx].min_corner, node_tree.nodes[right_child_idx].max_corner, t_max);

            var near_child_idx = left_child_idx;
            var far_child_idx = right_child_idx;
            var dist1_hit = hit1;
            var dist2_hit = hit2;

            if (!hit1 && hit2) {
                near_child_idx = right_child_idx;
                far_child_idx = left_child_idx;
                dist1_hit = hit2;
                dist2_hit = hit1;
            }

            if (dist1_hit) {
                current_node_idx = near_child_idx;
                if (dist2_hit) {
                    if (stack_ptr >= 64) {
                        break;
                    }
                    // overflow
                    node_idx_stack[stack_ptr] = far_child_idx;
                    stack_ptr += 1;
                }
                continue;
                // descend into near child
            }
            // neither child is relevant, fall through to pop

        }
        else {
            // leaf node
            for (var i = 0u; i < primitive_count; i += 1u) {
                let prim_index = tri_lut.primitive_indices[child_or_prim_idx + i];
                let triangle = objects.triangles[i32(prim_index)];

                // any_hit_triangle returns true if hit within range
                if (any_hit_triangle(ray, triangle, 0.001, t_max)) {
                    return true;
                    // found an occlusion, exit immediately
                }
            }
            // finished leaf without finding occlusion, fall through to pop
        }

        // pop from stack or break if empty
        if (stack_ptr == 0) {
            break;
            // traversal finished without finding occlusion
        }
        else {
            stack_ptr -= 1;
            current_node_idx = node_idx_stack[stack_ptr];
        }
    }
    // kill sunlight 0.0
    let floor_z = 0.0;
    let denom = ray.direction.z;
    if (abs(denom) > 1e-6) {
        let t = (floor_z - ray.origin.z) / denom;
        if (t > 0.001 && t < t_max) {
            return true;
            // hit floor within range
        }
    }
    return false;
    // no occlusion found
}

fn any_hit_aabb(ray: Ray, aabb_min: vec3f, aabb_max: vec3f, t_max: f32) -> bool {
    var inverse_dir: vec3<f32> = vec3(1.0) / ray.direction;
    var tmin = (aabb_min - ray.origin) * inverse_dir;
    var tmax = (aabb_max - ray.origin) * inverse_dir;
    var t1 = min(tmin, tmax);
    var t2 = max(tmin, tmax);
    var t_near = max(max(t1.x, t1.y), t1.z);
    var t_far = min(min(t2.x, t2.y), t2.z);
    return t_near <= t_far && t_far >= 0.001 && t_near <= t_max;
}

// lazy, just clean this up to only use hit_triangle and move all shading data post hit out
fn any_hit_triangle(ray: Ray, tri: Triangle, dist_min: f32, dist_max: f32) -> bool {
    let edge1 = tri.corner_b - tri.corner_a;
    let edge2 = tri.corner_c - tri.corner_a;

    let pvec = cross(ray.direction, edge2);
    let determinant = dot(edge1, pvec);

    // reject nearly parallel rays.
    if abs(determinant) < EPSILON {
        return false;
    }

    let inv_det = 1.0 / determinant;
    let tvec = ray.origin - tri.corner_a;

    // compute barycentric coordinate u.
    let u = dot(tvec, pvec) * inv_det;
    if (u < 0.0 || u > 1.0) {
        return false;
    }

    // compute barycentric coordinate v.
    let qvec = cross(tvec, edge1);
    let v = dot(ray.direction, qvec) * inv_det;
    if (v < 0.0 || (u + v) > 1.0) {
        return false;
    }

    // calculate ray parameter (distance).
    let dist = dot(edge2, qvec) * inv_det;
    return dist > dist_min && dist < dist_max;
}