import { Vec3 } from "gl-matrix";
const MAX_BOUND = 999999;
interface Triangle {
cornerA: Vec3;
cornerB: Vec3;
cornerC: Vec3;
centroid: Vec3;
}
interface Bin {
instanceCount: number;
bounds: AABB;
}
// just for documentation
// interface BVHNode {
// min: Vec3; // 12
// max: Vec3; // 12
// left: number; // 4
// instanceCount: number; // 4
// }
class AABB {
bmin: Vec3;
bmax: Vec3;
constructor() {
this.bmin = Vec3.fromValues(MAX_BOUND, MAX_BOUND, MAX_BOUND);
this.bmax = Vec3.fromValues(-MAX_BOUND, -MAX_BOUND, -MAX_BOUND);
}
grow(p: Vec3) {
Vec3.min(this.bmin, this.bmin, p);
Vec3.max(this.bmax, this.bmax, p);
}
growAABB(aabb: AABB) {
// Only grow if the other AABB is valid.
if (aabb.bmin[0] !== MAX_BOUND) {
this.grow(aabb.bmin);
this.grow(aabb.bmax);
}
}
area() {
const e = Vec3.create();
Vec3.subtract(e, this.bmax, this.bmin);
// standard surface area measure (omitting the factor 2 is acceptable since SAH is relative)
return e[0] * e[1] + e[1] * e[2] + e[2] * e[0];
}
}
class BVH {
nodesMin: Float32Array; // min x,y,z for each node
nodesMax: Float32Array; // max x,y,z for each node
nodesLeft: Uint32Array;
nodesInstanceCount: Uint32Array;
nodesUsed: number;
triangles: Triangle[];
triIdx: Uint32Array;
constructor(triangles: Triangle[]) {
this.triangles = triangles;
this.triIdx = new Uint32Array(triangles.length);
this.nodesUsed = 0;
const maxNodes = 2 * triangles.length - 1;
this.nodesMin = new Float32Array(maxNodes * 3);
this.nodesMax = new Float32Array(maxNodes * 3);
this.nodesLeft = new Uint32Array(maxNodes);
this.nodesInstanceCount = new Uint32Array(maxNodes);
}
construct() {
for (let i = 0; i < this.triangles.length; i++) {
this.triIdx[i] = i;
}
this.nodesInstanceCount[0] = this.triangles.length;
this.nodesLeft[0] = 0; // root node
this.nodesUsed = 1;
this.bounding(0);
this.subdivide(0);
}
bounding(nodeIdx: number) {
const off = nodeIdx * 3;
// initialize the node's AABB.
this.nodesMin[off + 0] = MAX_BOUND;
this.nodesMin[off + 1] = MAX_BOUND;
this.nodesMin[off + 2] = MAX_BOUND;
this.nodesMax[off + 0] = -MAX_BOUND;
this.nodesMax[off + 1] = -MAX_BOUND;
this.nodesMax[off + 2] = -MAX_BOUND;
const count = this.nodesInstanceCount[nodeIdx];
const start = this.nodesLeft[nodeIdx];
// temp vectors
const minVec = Vec3.create();
const maxVec = Vec3.create();
for (let i = 0; i < count; i++) {
const tri = this.triangles[this.triIdx[start + i]];
// walk through each tri and update the bounds
Vec3.min(minVec, [this.nodesMin[off], this.nodesMin[off + 1], this.nodesMin[off + 2]], tri.cornerA);
Vec3.max(maxVec, [this.nodesMax[off], this.nodesMax[off + 1], this.nodesMax[off + 2]], tri.cornerA);
this.nodesMin[off + 0] = minVec[0];
this.nodesMin[off + 1] = minVec[1];
this.nodesMin[off + 2] = minVec[2];
this.nodesMax[off + 0] = maxVec[0];
this.nodesMax[off + 1] = maxVec[1];
this.nodesMax[off + 2] = maxVec[2];
Vec3.min(minVec, [this.nodesMin[off], this.nodesMin[off + 1], this.nodesMin[off + 2]], tri.cornerB);
Vec3.max(maxVec, [this.nodesMax[off], this.nodesMax[off + 1], this.nodesMax[off + 2]], tri.cornerB);
this.nodesMin[off + 0] = minVec[0];
this.nodesMin[off + 1] = minVec[1];
this.nodesMin[off + 2] = minVec[2];
this.nodesMax[off + 0] = maxVec[0];
this.nodesMax[off + 1] = maxVec[1];
this.nodesMax[off + 2] = maxVec[2];
Vec3.min(minVec, [this.nodesMin[off], this.nodesMin[off + 1], this.nodesMin[off + 2]], tri.cornerC);
Vec3.max(maxVec, [this.nodesMax[off], this.nodesMax[off + 1], this.nodesMax[off + 2]], tri.cornerC);
this.nodesMin[off + 0] = minVec[0];
this.nodesMin[off + 1] = minVec[1];
this.nodesMin[off + 2] = minVec[2];
this.nodesMax[off + 0] = maxVec[0];
this.nodesMax[off + 1] = maxVec[1];
this.nodesMax[off + 2] = maxVec[2];
}
}
subdivide(nodeIdx: number) {
// not enough primitives
if (this.nodesInstanceCount[nodeIdx] <= 2) return;
let [split, axis, cost] = this.findBestPlane(nodeIdx);
// eval the parent node’s extent.
const off = nodeIdx * 3;
const extent = Vec3.create();
Vec3.subtract(
extent,
[this.nodesMax[off + 0], this.nodesMax[off + 1], this.nodesMax[off + 2]],
[this.nodesMin[off + 0], this.nodesMin[off + 1], this.nodesMin[off + 2]]
);
const parentArea = extent[0] * extent[1] + extent[1] * extent[2] + extent[2] * extent[0];
const parentCost = this.nodesInstanceCount[nodeIdx] * parentArea;
// fallback to median split if SAH cost is not better
if (cost >= parentCost) {
let longestAxis = 0;
if (extent[1] > extent[0]) longestAxis = 1;
if (extent[2] > extent[longestAxis]) longestAxis = 2;
const start = this.nodesLeft[nodeIdx];
const count = this.nodesInstanceCount[nodeIdx];
const centroids: number[] = [];
for (let i = 0; i < count; i++) {
centroids.push(this.triangles[this.triIdx[start + i]].centroid[longestAxis]);
}
centroids.sort((a, b) => a - b);
split = centroids[Math.floor(count / 2)];
axis = longestAxis;
}
// partition primitives based on the chosen split
let i = this.nodesLeft[nodeIdx];
let j = i + this.nodesInstanceCount[nodeIdx] - 1;
while (i <= j) {
const tri = this.triangles[this.triIdx[i]];
if (tri.centroid[axis] < split) {
i++;
} else {
const tmp = this.triIdx[i];
this.triIdx[i] = this.triIdx[j];
this.triIdx[j] = tmp;
j--;
}
}
const leftCount = i - this.nodesLeft[nodeIdx];
if (leftCount === 0 || leftCount === this.nodesInstanceCount[nodeIdx]) return;
// construct child nodes.
const leftIdx = this.nodesUsed++;
const rightIdx = this.nodesUsed++;
this.nodesLeft[leftIdx] = this.nodesLeft[nodeIdx];
this.nodesInstanceCount[leftIdx] = leftCount;
this.nodesLeft[rightIdx] = i;
this.nodesInstanceCount[rightIdx] = this.nodesInstanceCount[nodeIdx] - leftCount;
// internal node
this.nodesLeft[nodeIdx] = leftIdx;
this.nodesInstanceCount[nodeIdx] = 0;
// keep going
this.bounding(leftIdx);
this.bounding(rightIdx);
this.subdivide(leftIdx);
this.subdivide(rightIdx);
}
findBestPlane(nodeIdx: number): [number, number, number] {
let bestAxis = -1;
let bestSplit = 0;
let bestCost = Infinity;
const count = this.nodesInstanceCount[nodeIdx];
const start = this.nodesLeft[nodeIdx];
// eval centroid bounds
const centroidMin = [Infinity, Infinity, Infinity];
const centroidMax = [-Infinity, -Infinity, -Infinity];
for (let i = 0; i < count; i++) {
const tri = this.triangles[this.triIdx[start + i]];
for (let axis = 0; axis < 3; axis++) {
centroidMin[axis] = Math.min(centroidMin[axis], tri.centroid[axis]);
centroidMax[axis] = Math.max(centroidMax[axis], tri.centroid[axis]);
}
}
// fallback if this centroid has a degenerate centroid distributions
const EPSILON = 1e-5;
let degenerate = false;
for (let axis = 0; axis < 3; axis++) {
if (Math.abs(centroidMax[axis] - centroidMin[axis]) < EPSILON) {
degenerate = true;
break;
}
}
if (degenerate) {
// use median split along the longest axis of actual bounds
let longestAxis = 0;
const actualMin = [Infinity, Infinity, Infinity];
const actualMax = [-Infinity, -Infinity, -Infinity];
for (let i = 0; i < count; i++) {
const tri = this.triangles[this.triIdx[start + i]];
for (let axis = 0; axis < 3; axis++) {
actualMin[axis] = Math.min(actualMin[axis], tri.cornerA[axis], tri.cornerB[axis], tri.cornerC[axis]);
actualMax[axis] = Math.max(actualMax[axis], tri.cornerA[axis], tri.cornerB[axis], tri.cornerC[axis]);
}
}
const extent = [actualMax[0] - actualMin[0], actualMax[1] - actualMin[1], actualMax[2] - actualMin[2]];
if (extent[1] > extent[0]) longestAxis = 1;
if (extent[2] > extent[longestAxis]) longestAxis = 2;
const centroids: number[] = [];
for (let i = 0; i < count; i++) {
centroids.push(this.triangles[this.triIdx[start + i]].centroid[longestAxis]);
}
centroids.sort((a, b) => a - b);
bestSplit = centroids[Math.floor(count / 2)];
bestAxis = longestAxis;
return [bestSplit, bestAxis, bestCost];
}
// use adaptive binning (bin count based on number of primitives, clamped between 8 and 32)
const BINS = Math.max(8, Math.min(32, count));
const bins: Bin[] = Array.from({ length: BINS }, () => ({
instanceCount: 0,
bounds: new AABB(),
}));
// for each axis, evaluate candidate splits.
for (let axis = 0; axis < 3; axis++) {
const axisMin = centroidMin[axis];
const axisMax = centroidMax[axis];
if (axisMax === axisMin) continue; // skip degenerate axis
const axisScale = BINS / (axisMax - axisMin);
// reset bins for this axis.
for (let i = 0; i < BINS; i++) {
bins[i].instanceCount = 0;
bins[i].bounds = new AABB();
}
// distribute primitives into bins.
for (let i = 0; i < count; i++) {
const tri = this.triangles[this.triIdx[start + i]];
let binIDX = Math.floor((tri.centroid[axis] - axisMin) * axisScale);
if (binIDX >= BINS) binIDX = BINS - 1;
bins[binIDX].instanceCount++;
bins[binIDX].bounds.grow(tri.cornerA);
bins[binIDX].bounds.grow(tri.cornerB);
bins[binIDX].bounds.grow(tri.cornerC);
}
// compute cumulative sums from left and right
const leftCount = new Array(BINS - 1).fill(0);
const leftArea = new Array(BINS - 1).fill(0);
const rightCount = new Array(BINS - 1).fill(0);
const rightArea = new Array(BINS - 1).fill(0);
let leftBox = new AABB();
for (let i = 0, sum = 0; i < BINS - 1; i++) {
sum += bins[i].instanceCount;
leftCount[i] = sum;
leftBox.growAABB(bins[i].bounds);
leftArea[i] = leftBox.area();
}
let rightBox = new AABB();
for (let i = BINS - 1, sum = 0; i > 0; i--) {
sum += bins[i].instanceCount;
rightCount[i - 1] = sum;
rightBox.growAABB(bins[i].bounds);
rightArea[i - 1] = rightBox.area();
}
// highly axis aligned polygons like lucy and sponza fail
// duct taped solution to handle degen cases
const binWidth = (axisMax - axisMin) / BINS;
// eval candidate splits.
for (let i = 0; i < BINS - 1; i++) {
// small epsilon jitter to help break ties.
const jitter = 1e-6;
const candidatePos = axisMin + binWidth * (i + 1) + jitter;
const cost = leftCount[i] * leftArea[i] + rightCount[i] * rightArea[i];
if (cost < bestCost) {
bestCost = cost;
bestAxis = axis;
bestSplit = candidatePos;
}
}
}
return [bestSplit, bestAxis, bestCost];
}
}
export default BVH;