gltf.ts

webgpu-based path tracer

src/gltf.ts

14.73 KB
import { load } from "@loaders.gl/core";
import { GLTFImagePostprocessed, GLTFLoader, GLTFMeshPrimitivePostprocessed, GLTFPostprocessed, postProcessGLTF } from "@loaders.gl/gltf";
import { Mat3, Mat4, Mat4Like, Quat, QuatLike, Vec3, Vec3Like } from "gl-matrix";

interface Triangle {
  centroid: number[];
  cornerA: number[];
  cornerB: number[];
  cornerC: number[];
  normalA: number[];
  normalB: number[];
  normalC: number[];
  mat: number;
  uvA: number[];
  uvB: number[];
  uvC: number[];
  tangentA: number[];
  tangentB: number[];
  tangentC: number[];
}

interface ProcessedMaterial {
  baseColorFactor: number[],
  baseColorTexture: number, // idx
  metallicFactor: number,
  roughnessFactor: number,
  metallicRoughnessTexture: number, //idx
  normalTexture: number, //idx
  emissiveFactor: number[],
  emissiveTexture: number, //idx
  alphaMode: number, // parseAlphaMode
  alphaCutoff: number,
  doubleSided: number,
};

interface ProcessedTexture {
  id: String,
  sampler: GPUSampler,
  texture: GPUTexture,
  view: GPUTextureView,
  source: GLTFImagePostprocessed,
  samplerDescriptor: GPUSamplerDescriptor,
}

export class GLTF2 {
  triangles: Array<Triangle>;
  materials: Array<ProcessedMaterial>; // could make material more explicit here, like with textures
  textures: Array<ProcessedTexture>;
  device: GPUDevice;
  gltfData!: GLTFPostprocessed;
  url: string;
  scale: number[];
  position: number[];
  rotation?: number[];

  // pre allocating these here, faster that way? Intuitevely, I could be wrong. 
  tempVec3_0: Vec3 = Vec3.create();
  tempVec3_1: Vec3 = Vec3.create();
  tempVec3_2: Vec3 = Vec3.create();
  tempVec3_3: Vec3 = Vec3.create();
  tempVec3_4: Vec3 = Vec3.create();
  tempMat4_0: Mat4 = Mat4.create();
  tempMat4_1: Mat4 = Mat4.create();
  tempMat3_0: Mat3 = Mat3.create();

  constructor(device: GPUDevice, url: string, scale: number[], position: number[], rotation?: number[]) {
    this.triangles = [];
    this.materials = [];
    this.textures = [];
    this.device = device;
    this.url = url;
    this.scale = scale;
    this.position = position;
    this.rotation = rotation;
  }
  async initialize() {
    const t = await load(this.url, GLTFLoader);
    this.gltfData = postProcessGLTF(t);
    this.traverseNodes();
    return [this.triangles, this.materials, this.textures];
  }
  // some data swizzling swash buckling utils
  transformVec3(inputVec: ArrayLike<number>, matrix: Mat4): number[] {
    const v = this.tempVec3_0;
    Vec3.set(v, inputVec[0], inputVec[1], inputVec[2]);
    Vec3.transformMat4(v, v, matrix);
    return [v[0], v[1], v[2]]; // Return new array copy
  }
  transformNormal(inputNormal: ArrayLike<number>, transformMatrix: Mat4): number[] {
    const normalMatrix = this.tempMat3_0;      // Reused Mat3
    const tempMatrix = this.tempMat4_1;        // Use tempMat4_1 to avoid conflict with finalTransform
    const transformedNormal = this.tempVec3_0; // Reused Vec3 for result
    const inputNormalVec = this.tempVec3_1;    // Reused Vec3 for input

    // calculate transpose(invert(transformMatrix))
    // tempMat4_1 as scratch space to avoid clobbering tempMat4_0 (finalTransform)
    Mat4.invert(tempMatrix, transformMatrix);
    Mat4.transpose(tempMatrix, tempMatrix);

    // upper-left 3x3 submatrix
    Mat3.fromMat4(normalMatrix, tempMatrix);

    // normal into reusable Vec3
    Vec3.set(inputNormalVec, inputNormal[0], inputNormal[1], inputNormal[2]);

    // transfrom that normal
    Vec3.transformMat3(transformedNormal, inputNormalVec, normalMatrix);
    Vec3.normalize(transformedNormal, transformedNormal);

    // new array copy
    return [transformedNormal[0], transformedNormal[1], transformedNormal[2]];
  }
  parseAlphaMode(alphaMode: string) {
    if (alphaMode === "MASK") { return 1 }
    return 2
  }


  // could break this up
  extractTriangles(primitive: GLTFMeshPrimitivePostprocessed, transform: Mat4, targetArray: Array<Triangle>) {
    const positions = primitive.attributes["POSITION"].value;
    const indicesData = primitive.indices ? primitive.indices.value : null;
    const numVertices = positions.length / 3;
    const indices = indicesData ?? (() => {
      const generatedIndices = new Uint32Array(numVertices);
      for (let i = 0; i < numVertices; i++) generatedIndices[i] = i;
      return generatedIndices;
    })();
    const normals = primitive.attributes["NORMAL"]
      ? primitive.attributes["NORMAL"].value
      : null;
    const uvCoords = primitive.attributes["TEXCOORD_0"]
      ? primitive.attributes["TEXCOORD_0"].value
      : null;
    const tangents = primitive.attributes["TANGENT"]
      ? primitive.attributes["TANGENT"].value
      : null;

    const mat = parseInt(primitive.material?.id.match(/\d+$/)?.[0] ?? "-1");

    // ensure these don't clash with temps used in transformNormal/transformVec3
    // if called within the face normal logic (they aren't).
    const vA = this.tempVec3_1; // maybe use distinct temps if needed, but seems ok
    const vB = this.tempVec3_2;
    const vC = this.tempVec3_3;
    const edge1 = this.tempVec3_4;
    const edge2 = this.tempVec3_0;      // tempVec3_0 reused safely after position transforms
    const faceNormal = this.tempVec3_1; // tempVec3_1 reused safely after normal transforms or for input

    const defaultUV = [0, 0];
    const defaultTangent = [1, 0, 0, 1];

    for (let i = 0; i < indices.length; i += 3) {
      const ai = indices[i];
      const bi = indices[i + 1];
      const ci = indices[i + 2];

      const posA = [positions[ai * 3], positions[ai * 3 + 1], positions[ai * 3 + 2]];
      const posB = [positions[bi * 3], positions[bi * 3 + 1], positions[bi * 3 + 2]];
      const posC = [positions[ci * 3], positions[ci * 3 + 1], positions[ci * 3 + 2]];

      // transform positions uses tempVec3_0 internally
      const cornerA = this.transformVec3(posA, transform);
      const cornerB = this.transformVec3(posB, transform);
      const cornerC = this.transformVec3(posC, transform);

      let normalA: number[], normalB: number[], normalC: number[];
      if (normals) {
        // transform normals uses tempVec3_0, tempVec3_1 internally
        normalA = this.transformNormal([normals[ai * 3], normals[ai * 3 + 1], normals[ai * 3 + 2]], transform);
        normalB = this.transformNormal([normals[bi * 3], normals[bi * 3 + 1], normals[bi * 3 + 2]], transform);
        normalC = this.transformNormal([normals[ci * 3], normals[ci * 3 + 1], normals[ci * 3 + 2]], transform);
      } else {
        // compute fallback flat face normal
        Vec3.set(vA, cornerA[0], cornerA[1], cornerA[2]);
        Vec3.set(vB, cornerB[0], cornerB[1], cornerB[2]);
        Vec3.set(vC, cornerC[0], cornerC[1], cornerC[2]);

        Vec3.subtract(edge1, vB, vA);
        Vec3.subtract(edge2, vC, vA);
        Vec3.cross(faceNormal, edge1, edge2);
        Vec3.normalize(faceNormal, faceNormal);

        const normalArray = [faceNormal[0], faceNormal[1], faceNormal[2]];
        normalA = normalArray;
        normalB = normalArray;
        normalC = normalArray;
      }

      const uvA = uvCoords ? [uvCoords[ai * 2], uvCoords[ai * 2 + 1]] : defaultUV;
      const uvB = uvCoords ? [uvCoords[bi * 2], uvCoords[bi * 2 + 1]] : defaultUV;
      const uvC = uvCoords ? [uvCoords[ci * 2], uvCoords[ci * 2 + 1]] : defaultUV;

      const tangentA = tangents ? [tangents[ai * 4], tangents[ai * 4 + 1], tangents[ai * 4 + 2], tangents[ai * 4 + 3]] : defaultTangent;
      const tangentB = tangents ? [tangents[bi * 4], tangents[bi * 4 + 1], tangents[bi * 4 + 2], tangents[bi * 4 + 3]] : defaultTangent;
      const tangentC = tangents ? [tangents[ci * 4], tangents[ci * 4 + 1], tangents[ci * 4 + 2], tangents[ci * 4 + 3]] : defaultTangent;

      const centroid = [
        (cornerA[0] + cornerB[0] + cornerC[0]) / 3,
        (cornerA[1] + cornerB[1] + cornerC[1]) / 3,
        (cornerA[2] + cornerB[2] + cornerC[2]) / 3,
      ];

      targetArray.push({
        centroid,
        cornerA, cornerB, cornerC,
        normalA, normalB, normalC, mat,
        uvA, uvB, uvC,
        tangentA, tangentB, tangentC,
      });
    }
  }

  traverseNodes() {
    // texture processing 
    if (this.gltfData.textures) {
      this.gltfData.textures.forEach((texture) => {
        if (!texture.source?.image) {
          // empty textures are handled on atlas creation
          return;
        }
        const gpuTexture = this.device.createTexture({
          size: {
            width: texture.source.image.width ?? 0,
            height: texture.source.image.height ?? 0,
            depthOrArrayLayers: 1,
          },
          format: "rgba8unorm",
          usage: GPUTextureUsage.TEXTURE_BINDING | GPUTextureUsage.COPY_DST | GPUTextureUsage.RENDER_ATTACHMENT,
        });

        const view = gpuTexture.createView({ format: "rgba8unorm" });

        // TODO: Process gltfData.samplers[texture.sampler] if it exists
        let samplerDescriptor: GPUSamplerDescriptor = {
          magFilter: "linear", minFilter: "linear",
          addressModeU: "repeat", addressModeV: "repeat",
        };
        const sampler = this.device.createSampler(samplerDescriptor);
        this.textures.push({
          id: texture.id,
          texture: gpuTexture,
          view: view,
          sampler: sampler,
          source: texture.source,
          samplerDescriptor: samplerDescriptor
        });
      });
    }
    if (this.gltfData.materials) {
      this.materials = this.gltfData.materials.map(mat => {
        return {
          baseColorFactor: mat.pbrMetallicRoughness?.baseColorFactor ?? [1.0, 1.0, 1.0, 1.0],
          baseColorTexture: mat.pbrMetallicRoughness?.baseColorTexture?.index ?? -1,
          metallicFactor: mat.pbrMetallicRoughness?.metallicFactor ?? 1.0,
          roughnessFactor: mat.pbrMetallicRoughness?.roughnessFactor ?? 1.0,
          metallicRoughnessTexture: mat.pbrMetallicRoughness?.metallicRoughnessTexture?.index ?? -1,
          normalTexture: mat.normalTexture?.index ?? -1,
          emissiveFactor: mat.emissiveFactor ?? [0.0, 0.0, 0.0],
          emissiveTexture: mat.emissiveTexture?.index ?? -1,
          alphaMode: mat.alphaMode ? this.parseAlphaMode(mat.alphaMode) : 0,
          alphaCutoff: mat.alphaCutoff ?? 0.5,
          doubleSided: mat.doubleSided ? 1 : 0,
        };
      });
    }

    // initial node transforms
    const finalTransform = this.tempMat4_0;     // reused for final calc per node
    const nodeLocalTransform = this.tempMat4_1; // reused for local calc per node
    const tMat = Mat4.create();
    const rMat = Mat4.create();
    const sMat = Mat4.create();
    const tMatCustom = Mat4.create();
    const rMatCustom = Mat4.create();
    const sMatCustom = Mat4.create();
    const yToZUp = Mat4.fromValues(
      1, 0, 0, 0,
      0, 0, 1, 0,
      0, -1, 0, 0,
      0, 0, 0, 1
    );
    // scene transforms
    const sceneTransform = Mat4.create();
    const sc_translation = this.position || [0, 0, 0];
    const sc_rotation = this.rotation || [0, 0, 0, 1];
    const sc_scale = this.scale || [1, 1, 1];
    Mat4.fromTranslation(tMatCustom, sc_translation as Vec3Like);
    Quat.normalize(rMatCustom, sc_rotation as QuatLike);
    Mat4.fromQuat(rMatCustom, rMatCustom);
    Mat4.fromScaling(sMatCustom, sc_scale as Vec3Like);
    Mat4.multiply(sceneTransform, rMatCustom, sMatCustom);
    Mat4.multiply(sceneTransform, tMatCustom, sceneTransform);

    const meshMap = new Map(this.gltfData.meshes.map(m => [m.id, m]));

    for (const node of this.gltfData.nodes) {
      if (!node.mesh?.id) continue;
      const mesh = meshMap.get(node.mesh.id);
      if (!mesh) continue;
      Mat4.identity(nodeLocalTransform);
      if (node.matrix) {
        Mat4.copy(nodeLocalTransform, node.matrix as Mat4Like);
      } else {
        const nodeTranslation = node.translation || [0, 0, 0];
        const nodeRotation = node.rotation || [0, 0, 0, 1];
        const nodeScale = node.scale || [1, 1, 1];
        Mat4.fromTranslation(tMat, nodeTranslation as Vec3Like);
        Mat4.fromQuat(rMat, nodeRotation as QuatLike);
        Mat4.fromScaling(sMat, nodeScale as Vec3Like);
        Mat4.multiply(nodeLocalTransform, rMat, sMat);
        Mat4.multiply(nodeLocalTransform, tMat, nodeLocalTransform);
      }

      // finalTransform = sceneTransform * yToZUp * nodeLocalTransform
      Mat4.multiply(finalTransform, yToZUp, nodeLocalTransform);
      Mat4.multiply(finalTransform, sceneTransform, finalTransform);

      mesh.primitives.forEach((primitive: GLTFMeshPrimitivePostprocessed) => {
        this.extractTriangles(primitive, finalTransform, this.triangles);
      });
    }
  }
}

export function combineGLTFs(gltfs: GLTF2[]) {
  const triangles = [];
  const materials = [];
  const textures = [];

  let textureOffset = 0;
  let materialOffset = 0;
  let largestTextureDimensions = { width: 0, height: 0 };

  // offset idx
  const offsetIdx = (idx: any) => {
    return (typeof idx === 'number' && idx >= 0) ? idx + textureOffset : idx; // Keep original if invalid index
  };

  for (let i = 0; i < gltfs.length; i++) {
    const gltf = gltfs[i];
    const texCount = gltf.textures ? gltf.textures.length : 0;
    const matCount = gltf.materials ? gltf.materials.length : 0;
    // just append the textures for now
    if (gltf.textures && texCount > 0) {
      for (let t = 0; t < texCount; t++) {
        const texture = gltf.textures[t];
        let texHeight = texture.source.image.height as number;
        let texWidth = texture.source.image.width as number;
        textures.push(texture);
        if (texWidth > largestTextureDimensions.width) {
          largestTextureDimensions.width = texWidth;
        }
        if (texHeight > largestTextureDimensions.height) {
          largestTextureDimensions.height = texHeight;
        }
      }
    }
    if (gltf.materials && matCount > 0) {
      for (let m = 0; m < matCount; m++) {
        const src = gltf.materials[m];
        materials.push({
          alphaCutoff: src.alphaCutoff,
          alphaMode: src.alphaMode,
          baseColorFactor: src.baseColorFactor,
          baseColorTexture: offsetIdx(src.baseColorTexture),
          doubleSided: src.doubleSided,
          emissiveFactor: src.emissiveFactor,
          emissiveTexture: offsetIdx(src.emissiveTexture),
          metallicFactor: src.metallicFactor,
          metallicRoughnessTexture: offsetIdx(src.metallicRoughnessTexture),
          normalTexture: offsetIdx(src.normalTexture),
          roughnessFactor: src.roughnessFactor
        });
      }
    }
    // update idx if needed
    if (gltf.triangles) {
      for (let t = 0; t < gltf.triangles.length; t++) {
        const tri = gltf.triangles[t];
        const triCopy = Object.create(tri);
        if (tri.mat >= 0) {
          triCopy.mat = tri.mat + materialOffset;
        } else {
          triCopy.mat = -1;
        }
        triangles.push(triCopy);
      }
    }
    textureOffset += texCount;
    materialOffset += matCount;
  }
  return { triangles, materials, textures, largestTextureDimensions };
}