Skip to content

Commit 71097fe

Browse files
committed
feat(🪼): Create root.beginRenderPass API for more flexibility (#852)
1 parent 825e046 commit 71097fe

File tree

6 files changed

+454
-41
lines changed

6 files changed

+454
-41
lines changed

packages/typegpu/src/core/pipeline/renderPipeline.ts

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,11 @@ export interface TgpuRenderPipeline<Output extends IOLayout = IOLayout>
6363
): void;
6464
}
6565

66+
export interface INTERNAL_TgpuRenderPipeline {
67+
readonly core: RenderPipelineCore;
68+
readonly priors: TgpuRenderPipelinePriors;
69+
}
70+
6671
export type FragmentOutToTargets<T extends IOLayout> = T extends IOData
6772
? GPUColorTargetState
6873
: T extends Record<string, unknown>
@@ -193,6 +198,10 @@ export function INTERNAL_createRenderPipeline(
193198
return new TgpuRenderPipelineImpl(new RenderPipelineCore(options), {});
194199
}
195200

201+
export function isRenderPipeline(value: unknown): value is TgpuRenderPipeline {
202+
return (value as TgpuRenderPipeline)?.resourceType === 'render-pipeline';
203+
}
204+
196205
// --------------
197206
// Implementation
198207
// --------------
@@ -214,20 +223,22 @@ type Memo = {
214223
catchall: [number, TgpuBindGroup] | null;
215224
};
216225

217-
class TgpuRenderPipelineImpl implements TgpuRenderPipeline {
226+
class TgpuRenderPipelineImpl
227+
implements TgpuRenderPipeline, INTERNAL_TgpuRenderPipeline
228+
{
218229
public readonly resourceType = 'render-pipeline';
219230

220231
constructor(
221-
private readonly _core: RenderPipelineCore,
222-
private readonly _priors: TgpuRenderPipelinePriors,
232+
public readonly core: RenderPipelineCore,
233+
public readonly priors: TgpuRenderPipelinePriors,
223234
) {}
224235

225236
get label() {
226-
return this._core.label;
237+
return this.core.label;
227238
}
228239

229240
$name(label?: string | undefined): this {
230-
this._core.label = label;
241+
this.core.label = label;
231242
return this;
232243
}
233244

@@ -244,20 +255,20 @@ class TgpuRenderPipelineImpl implements TgpuRenderPipeline {
244255
resource: (TgpuBuffer<AnyWgslData> & Vertex) | TgpuBindGroup,
245256
): TgpuRenderPipeline {
246257
if (isBindGroupLayout(definition)) {
247-
return new TgpuRenderPipelineImpl(this._core, {
248-
...this._priors,
258+
return new TgpuRenderPipelineImpl(this.core, {
259+
...this.priors,
249260
bindGroupLayoutMap: new Map([
250-
...(this._priors.bindGroupLayoutMap ?? []),
261+
...(this.priors.bindGroupLayoutMap ?? []),
251262
[definition, resource as TgpuBindGroup],
252263
]),
253264
});
254265
}
255266

256267
if (isVertexLayout(definition)) {
257-
return new TgpuRenderPipelineImpl(this._core, {
258-
...this._priors,
268+
return new TgpuRenderPipelineImpl(this.core, {
269+
...this.priors,
259270
vertexLayoutMap: new Map([
260-
...(this._priors.vertexLayoutMap ?? []),
271+
...(this.priors.vertexLayoutMap ?? []),
261272
[definition, resource as TgpuBuffer<AnyWgslData> & Vertex],
262273
]),
263274
});
@@ -269,17 +280,17 @@ class TgpuRenderPipelineImpl implements TgpuRenderPipeline {
269280
withColorAttachment(
270281
attachment: AnyFragmentColorAttachment,
271282
): TgpuRenderPipeline {
272-
return new TgpuRenderPipelineImpl(this._core, {
273-
...this._priors,
283+
return new TgpuRenderPipelineImpl(this.core, {
284+
...this.priors,
274285
colorAttachment: attachment,
275286
});
276287
}
277288

278289
withDepthStencilAttachment(
279290
attachment: DepthStencilAttachment,
280291
): TgpuRenderPipeline {
281-
return new TgpuRenderPipelineImpl(this._core, {
282-
...this._priors,
292+
return new TgpuRenderPipelineImpl(this.core, {
293+
...this.priors,
283294
depthStencilAttachment: attachment,
284295
});
285296
}
@@ -290,12 +301,12 @@ class TgpuRenderPipelineImpl implements TgpuRenderPipeline {
290301
firstVertex?: number,
291302
firstInstance?: number,
292303
): void {
293-
const memo = this._core.unwrap();
294-
const { branch, fragmentFn } = this._core.options;
304+
const memo = this.core.unwrap();
305+
const { branch, fragmentFn } = this.core.options;
295306

296307
const colorAttachments = connectAttachmentToShader(
297308
fragmentFn.shell.targets,
298-
this._priors.colorAttachment ?? {},
309+
this.priors.colorAttachment ?? {},
299310
).map((attachment) => {
300311
if (isTexture(attachment.view)) {
301312
return {
@@ -311,12 +322,12 @@ class TgpuRenderPipelineImpl implements TgpuRenderPipeline {
311322
colorAttachments,
312323
};
313324

314-
if (this._core.label !== undefined) {
315-
renderPassDescriptor.label = this._core.label;
325+
if (this.core.label !== undefined) {
326+
renderPassDescriptor.label = this.core.label;
316327
}
317328

318-
if (this._priors.depthStencilAttachment !== undefined) {
319-
const attachment = this._priors.depthStencilAttachment;
329+
if (this.priors.depthStencilAttachment !== undefined) {
330+
const attachment = this.priors.depthStencilAttachment;
320331
if (isTexture(attachment.view)) {
321332
renderPassDescriptor.depthStencilAttachment = {
322333
...attachment,
@@ -340,19 +351,19 @@ class TgpuRenderPipelineImpl implements TgpuRenderPipeline {
340351
pass.setBindGroup(idx, branch.unwrap(memo.catchall[1]));
341352
missingBindGroups.delete(layout);
342353
} else {
343-
const bindGroup = this._priors.bindGroupLayoutMap?.get(layout);
354+
const bindGroup = this.priors.bindGroupLayoutMap?.get(layout);
344355
if (bindGroup !== undefined) {
345356
missingBindGroups.delete(layout);
346357
pass.setBindGroup(idx, branch.unwrap(bindGroup));
347358
}
348359
}
349360
});
350361

351-
const missingVertexLayouts = new Set(this._core.usedVertexLayouts);
362+
const missingVertexLayouts = new Set(this.core.usedVertexLayouts);
352363

353-
const usedVertexLayouts = this._core.usedVertexLayouts;
364+
const usedVertexLayouts = this.core.usedVertexLayouts;
354365
usedVertexLayouts.forEach((vertexLayout, idx) => {
355-
const buffer = this._priors.vertexLayoutMap?.get(vertexLayout);
366+
const buffer = this.priors.vertexLayoutMap?.get(vertexLayout);
356367
if (buffer) {
357368
missingVertexLayouts.delete(vertexLayout);
358369
pass.setVertexBuffer(idx, branch.unwrap(buffer));

packages/typegpu/src/core/root/init.ts

Lines changed: 173 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
import type { AnyComputeBuiltin, OmitBuiltins } from '../../builtin';
2-
import type { AnyWgslData } from '../../data';
3-
import type { AnyData } from '../../data/dataTypes';
4-
import { invariant } from '../../errors';
2+
import type { AnyData, Disarray } from '../../data/dataTypes';
3+
import type { AnyWgslData, BaseData, WgslArray } from '../../data/wgslTypes';
4+
import {
5+
MissingBindGroupsError,
6+
MissingVertexBuffersError,
7+
invariant,
8+
} from '../../errors';
59
import type { JitTranspiler } from '../../jitTranspiler';
610
import { WeakMemo } from '../../memo';
711
import {
@@ -25,6 +29,7 @@ import {
2529
import {
2630
INTERNAL_createBuffer,
2731
type TgpuBuffer,
32+
type Vertex,
2833
isBuffer,
2934
} from '../buffer/buffer';
3035
import type {
@@ -47,9 +52,11 @@ import {
4752
} from '../pipeline/computePipeline';
4853
import {
4954
type AnyFragmentTargets,
55+
type INTERNAL_TgpuRenderPipeline,
5056
INTERNAL_createRenderPipeline,
5157
type RenderPipelineCoreOptions,
5258
type TgpuRenderPipeline,
59+
isRenderPipeline,
5360
} from '../pipeline/renderPipeline';
5461
import {
5562
type TgpuAccessor,
@@ -79,6 +86,7 @@ import type {
7986
CreateTextureOptions,
8087
CreateTextureResult,
8188
ExperimentalTgpuRoot,
89+
RenderPass,
8290
TgpuRoot,
8391
WithBinding,
8492
WithCompute,
@@ -325,6 +333,7 @@ class TgpuRootImpl
325333
}
326334

327335
unwrap(resource: TgpuComputePipeline): GPUComputePipeline;
336+
unwrap(resource: TgpuRenderPipeline): GPURenderPipeline;
328337
unwrap(resource: TgpuBindGroupLayout): GPUBindGroupLayout;
329338
unwrap(resource: TgpuBindGroup): GPUBindGroup;
330339
unwrap(resource: TgpuBuffer<AnyData>): GPUBuffer;
@@ -340,6 +349,7 @@ class TgpuRootImpl
340349
unwrap(
341350
resource:
342351
| TgpuComputePipeline
352+
| TgpuRenderPipeline
343353
| TgpuBindGroupLayout
344354
| TgpuBindGroup
345355
| TgpuBuffer<AnyData>
@@ -351,6 +361,7 @@ class TgpuRootImpl
351361
| TgpuVertexLayout,
352362
):
353363
| GPUComputePipeline
364+
| GPURenderPipeline
354365
| GPUBindGroupLayout
355366
| GPUBindGroup
356367
| GPUBuffer
@@ -361,6 +372,11 @@ class TgpuRootImpl
361372
return (resource as unknown as INTERNAL_TgpuComputePipeline).rawPipeline;
362373
}
363374

375+
if (isRenderPipeline(resource)) {
376+
return (resource as unknown as INTERNAL_TgpuRenderPipeline).core.unwrap()
377+
.pipeline;
378+
}
379+
364380
if (isBindGroupLayout(resource)) {
365381
return this._unwrappedBindGroupLayouts.getOrMake(resource);
366382
}
@@ -394,6 +410,160 @@ class TgpuRootImpl
394410
throw new Error(`Unknown resource type: ${resource}`);
395411
}
396412

413+
beginRenderPass(
414+
descriptor: GPURenderPassDescriptor,
415+
callback: (pass: RenderPass) => void,
416+
): void {
417+
const pass = this.commandEncoder.beginRenderPass(descriptor);
418+
419+
const bindGroups = new Map<
420+
TgpuBindGroupLayout,
421+
TgpuBindGroup | GPUBindGroup
422+
>();
423+
const vertexBuffers = new Map<
424+
TgpuVertexLayout,
425+
{
426+
buffer:
427+
| (TgpuBuffer<WgslArray<BaseData> | Disarray<BaseData>> & Vertex)
428+
| GPUBuffer;
429+
offset?: number | undefined;
430+
size?: number | undefined;
431+
}
432+
>();
433+
434+
let currentPipeline:
435+
| (TgpuRenderPipeline & INTERNAL_TgpuRenderPipeline)
436+
| undefined;
437+
438+
const setupPassBeforeDraw = () => {
439+
if (!currentPipeline) {
440+
throw new Error('Cannot draw without a call to pass.setPipeline');
441+
}
442+
443+
const { core, priors } = currentPipeline;
444+
const memo = core.unwrap();
445+
446+
pass.setPipeline(memo.pipeline);
447+
448+
const missingBindGroups = new Set(memo.bindGroupLayouts);
449+
memo.bindGroupLayouts.forEach((layout, idx) => {
450+
if (memo.catchall && idx === memo.catchall[0]) {
451+
// Catch-all
452+
pass.setBindGroup(idx, this.unwrap(memo.catchall[1]));
453+
missingBindGroups.delete(layout);
454+
} else {
455+
const bindGroup =
456+
priors.bindGroupLayoutMap?.get(layout) ?? bindGroups.get(layout);
457+
if (bindGroup !== undefined) {
458+
missingBindGroups.delete(layout);
459+
if (isBindGroup(bindGroup)) {
460+
pass.setBindGroup(idx, this.unwrap(bindGroup));
461+
} else {
462+
pass.setBindGroup(idx, bindGroup);
463+
}
464+
}
465+
}
466+
});
467+
468+
const missingVertexLayouts = new Set<TgpuVertexLayout>();
469+
core.usedVertexLayouts.forEach((vertexLayout, idx) => {
470+
const opts =
471+
{
472+
buffer: priors.vertexLayoutMap?.get(vertexLayout),
473+
offset: undefined,
474+
size: undefined,
475+
} ?? vertexBuffers.get(vertexLayout);
476+
477+
if (!opts || !opts.buffer) {
478+
missingVertexLayouts.add(vertexLayout);
479+
} else if (isBuffer(opts.buffer)) {
480+
pass.setVertexBuffer(
481+
idx,
482+
this.unwrap(opts.buffer),
483+
opts.offset,
484+
opts.size,
485+
);
486+
} else {
487+
pass.setVertexBuffer(idx, opts.buffer, opts.offset, opts.size);
488+
}
489+
});
490+
491+
if (missingBindGroups.size > 0) {
492+
throw new MissingBindGroupsError(missingBindGroups);
493+
}
494+
495+
if (missingVertexLayouts.size > 0) {
496+
throw new MissingVertexBuffersError(missingVertexLayouts);
497+
}
498+
};
499+
500+
callback({
501+
setViewport(...args) {
502+
pass.setViewport(...args);
503+
},
504+
setScissorRect(...args) {
505+
pass.setScissorRect(...args);
506+
},
507+
setBlendConstant(...args) {
508+
pass.setBlendConstant(...args);
509+
},
510+
setStencilReference(...args) {
511+
pass.setStencilReference(...args);
512+
},
513+
beginOcclusionQuery(...args) {
514+
pass.beginOcclusionQuery(...args);
515+
},
516+
endOcclusionQuery(...args) {
517+
pass.endOcclusionQuery(...args);
518+
},
519+
executeBundles(...args) {
520+
pass.executeBundles(...args);
521+
},
522+
setPipeline(pipeline) {
523+
currentPipeline = pipeline as TgpuRenderPipeline &
524+
INTERNAL_TgpuRenderPipeline;
525+
},
526+
527+
setIndexBuffer: (buffer, indexFormat, offset, size) => {
528+
if (isBuffer(buffer)) {
529+
pass.setIndexBuffer(this.unwrap(buffer), indexFormat, offset, size);
530+
} else {
531+
pass.setIndexBuffer(buffer, indexFormat, offset, size);
532+
}
533+
},
534+
535+
setVertexBuffer(vertexLayout, buffer, offset, size) {
536+
vertexBuffers.set(vertexLayout, { buffer, offset, size });
537+
},
538+
539+
setBindGroup(bindGroupLayout, bindGroup) {
540+
bindGroups.set(bindGroupLayout, bindGroup);
541+
},
542+
543+
draw(vertexCount, instanceCount, firstVertex, firstInstance) {
544+
setupPassBeforeDraw();
545+
pass.draw(vertexCount, instanceCount, firstVertex, firstInstance);
546+
},
547+
548+
drawIndexed(...args) {
549+
setupPassBeforeDraw();
550+
pass.drawIndexed(...args);
551+
},
552+
553+
drawIndirect(...args) {
554+
setupPassBeforeDraw();
555+
pass.drawIndirect(...args);
556+
},
557+
558+
drawIndexedIndirect(...args) {
559+
setupPassBeforeDraw();
560+
pass.drawIndexedIndirect(...args);
561+
},
562+
});
563+
564+
pass.end();
565+
}
566+
397567
flush() {
398568
if (!this._commandEncoder) {
399569
return;

0 commit comments

Comments
 (0)