reflect(wgsl): implement WGSL reflection

This commit is contained in:
chyyran 2023-12-13 02:30:01 -05:00 committed by Ronny Chan
parent 4dfcdf2725
commit 171c842c97
6 changed files with 903 additions and 50 deletions

8
.idea/.gitignore vendored Normal file
View file

@ -0,0 +1,8 @@
# Default ignored files
/shelf/
/workspace.xml
# Editor-based HTTP Client requests
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

View file

@ -12,7 +12,15 @@
<cargoProject FILE="$PROJECT_DIR$/Cargo.toml" /> <cargoProject FILE="$PROJECT_DIR$/Cargo.toml" />
</component> </component>
<component name="ChangeListManager"> <component name="ChangeListManager">
<list default="true" id="02471831-07cd-4975-a00c-e042450023a1" name="Changes" comment="rt(wgpu): load shaders" /> <list default="true" id="02471831-07cd-4975-a00c-e042450023a1" name="Changes" comment="rt(wgpu): load shaders">
<change beforePath="$PROJECT_DIR$/.idea/vcs.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/vcs.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/librashader-reflect/src/back/wgsl/mod.rs" beforeDir="false" afterPath="$PROJECT_DIR$/librashader-reflect/src/back/wgsl/mod.rs" afterDir="false" />
<change beforePath="$PROJECT_DIR$/librashader-reflect/src/error.rs" beforeDir="false" afterPath="$PROJECT_DIR$/librashader-reflect/src/error.rs" afterDir="false" />
<change beforePath="$PROJECT_DIR$/librashader-reflect/src/reflect/cross.rs" beforeDir="false" afterPath="$PROJECT_DIR$/librashader-reflect/src/reflect/cross.rs" afterDir="false" />
<change beforePath="$PROJECT_DIR$/librashader-reflect/src/reflect/naga.rs" beforeDir="false" afterPath="$PROJECT_DIR$/librashader-reflect/src/reflect/naga.rs" afterDir="false" />
<change beforePath="$PROJECT_DIR$/test/shaders_slang" beforeDir="false" afterPath="$PROJECT_DIR$/test/shaders_slang" afterDir="false" />
</list>
<option name="SHOW_DIALOG" value="false" /> <option name="SHOW_DIALOG" value="false" />
<option name="HIGHLIGHT_CONFLICTS" value="true" /> <option name="HIGHLIGHT_CONFLICTS" value="true" />
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" /> <option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
@ -29,7 +37,14 @@
</option> </option>
</component> </component>
<component name="Git.Settings"> <component name="Git.Settings">
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$/test/shaders_slang" /> <option name="RECENT_BRANCH_BY_REPOSITORY">
<map>
<entry key="$PROJECT_DIR$" value="feat-wgpu-runtime" />
<entry key="$PROJECT_DIR$/test/shaders_slang" value="a6e11453ad8c62931c62eeb79d51c70887b40bba" />
</map>
</option>
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
<option name="ROOT_SYNC" value="DONT_SYNC" />
</component> </component>
<component name="MacroExpansionManager"> <component name="MacroExpansionManager">
<option name="directoryName" value="z2x7agij" /> <option name="directoryName" value="z2x7agij" />
@ -63,7 +78,7 @@
"RunOnceActivity.readMode.enableVisualFormatting": "true", "RunOnceActivity.readMode.enableVisualFormatting": "true",
"cf.first.check.clang-format": "false", "cf.first.check.clang-format": "false",
"cidr.known.project.marker": "true", "cidr.known.project.marker": "true",
"git-widget-placeholder": "a6e11453", "git-widget-placeholder": "86a88bb2",
"last_opened_file_path": "F:/coding/librashader", "last_opened_file_path": "F:/coding/librashader",
"node.js.detected.package.eslint": "true", "node.js.detected.package.eslint": "true",
"node.js.detected.package.tslint": "true", "node.js.detected.package.tslint": "true",
@ -213,10 +228,10 @@
<recent_temporary> <recent_temporary>
<list> <list>
<item itemvalue="Cargo.Test back::wgsl::test::test_into" /> <item itemvalue="Cargo.Test back::wgsl::test::test_into" />
<item itemvalue="Cargo.Test reflect::cross::test::test_into" /> <item itemvalue="Cargo.Test triangle_wgpu" />
<item itemvalue="Cargo.Test front::naga::test::naga_playground" /> <item itemvalue="Cargo.Test front::naga::test::naga_playground" />
<item itemvalue="Cargo.Test front::naga::test::naga_playground (1)" /> <item itemvalue="Cargo.Test front::naga::test::naga_playground (1)" />
<item itemvalue="Cargo.Test triangle_wgpu" /> <item itemvalue="Cargo.Test reflect::cross::test::test_into" />
</list> </list>
</recent_temporary> </recent_temporary>
</component> </component>
@ -246,6 +261,7 @@
<workItem from="1702100458574" duration="619000" /> <workItem from="1702100458574" duration="619000" />
<workItem from="1702163869988" duration="618000" /> <workItem from="1702163869988" duration="618000" />
<workItem from="1702419933111" duration="5007000" /> <workItem from="1702419933111" duration="5007000" />
<workItem from="1702425796345" duration="12240000" />
</task> </task>
<task id="LOCAL-00001" summary="rt(wgpu): basic triangle example"> <task id="LOCAL-00001" summary="rt(wgpu): basic triangle example">
<option name="closed" value="true" /> <option name="closed" value="true" />

View file

@ -6,11 +6,10 @@ use crate::back::{CompileShader, CompilerBackend, FromCompilation, ShaderCompile
use crate::error::{ShaderCompileError, ShaderReflectError}; use crate::error::{ShaderCompileError, ShaderReflectError};
use crate::front::GlslangCompilation; use crate::front::GlslangCompilation;
use crate::reflect::naga::NagaReflect; use crate::reflect::naga::NagaReflect;
use crate::reflect::semantics::ShaderSemantics; use crate::reflect::ReflectShader;
use crate::reflect::{ReflectShader, ShaderReflection};
use naga::back::wgsl::WriterFlags; use naga::back::wgsl::WriterFlags;
use naga::valid::{Capabilities, ValidationFlags}; use naga::valid::{Capabilities, ValidationFlags};
use naga::Module; use naga::{AddressSpace, Module};
use rspirv::binary::Assemble; use rspirv::binary::Assemble;
use rspirv::dr::Builder; use rspirv::dr::Builder;
@ -20,9 +19,16 @@ pub struct NagaWgslContext {
pub vertex: Module, pub vertex: Module,
} }
/// Compiler options for WGSL
#[derive(Debug, Default, Clone)]
pub struct WgslCompileOptions {
pub write_pcb_as_ubo: bool,
pub sampler_bind_group: u32,
}
impl FromCompilation<GlslangCompilation> for WGSL { impl FromCompilation<GlslangCompilation> for WGSL {
type Target = WGSL; type Target = WGSL;
type Options = Option<()>; type Options = WgslCompileOptions;
type Context = NagaWgslContext; type Context = NagaWgslContext;
type Output = impl CompileShader<Self::Target, Options = Self::Options, Context = Self::Context> type Output = impl CompileShader<Self::Target, Options = Self::Options, Context = Self::Context>
+ ReflectShader; + ReflectShader;
@ -65,12 +71,12 @@ impl FromCompilation<GlslangCompilation> for WGSL {
} }
impl CompileShader<WGSL> for NagaReflect { impl CompileShader<WGSL> for NagaReflect {
type Options = Option<()>; type Options = WgslCompileOptions;
type Context = NagaWgslContext; type Context = NagaWgslContext;
fn compile( fn compile(
mut self, mut self,
_options: Self::Options, options: Self::Options,
) -> Result<ShaderCompilerOutput<String, Self::Context>, ShaderCompileError> { ) -> Result<ShaderCompilerOutput<String, Self::Context>, ShaderCompileError> {
fn write_wgsl(module: &Module) -> Result<String, ShaderCompileError> { fn write_wgsl(module: &Module) -> Result<String, ShaderCompileError> {
let mut valid = let mut valid =
@ -81,6 +87,19 @@ impl CompileShader<WGSL> for NagaReflect {
Ok(wgsl) Ok(wgsl)
} }
if options.write_pcb_as_ubo {
for (_, gv) in self.fragment.global_variables.iter_mut() {
if gv.space == AddressSpace::PushConstant {
gv.space = AddressSpace::Uniform;
}
}
for (_, gv) in self.vertex.global_variables.iter_mut() {
if gv.space == AddressSpace::PushConstant {
gv.space = AddressSpace::Uniform;
}
}
}
// Reassign shit. // Reassign shit.
let images = self let images = self
.fragment .fragment
@ -117,7 +136,7 @@ impl CompileShader<WGSL> for NagaReflect {
.for_each(|(_, gv)| { .for_each(|(_, gv)| {
if images.contains(&(gv.binding.clone(), gv.space)) { if images.contains(&(gv.binding.clone(), gv.space)) {
if let Some(binding) = &mut gv.binding { if let Some(binding) = &mut gv.binding {
binding.group = 1; binding.group = options.sampler_bind_group;
} }
} }
}); });
@ -138,21 +157,54 @@ impl CompileShader<WGSL> for NagaReflect {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use crate::back::targets::WGSL; use crate::back::targets::WGSL;
use crate::back::wgsl::WgslCompileOptions;
use crate::back::{CompileShader, FromCompilation}; use crate::back::{CompileShader, FromCompilation};
use crate::reflect::semantics::{Semantic, ShaderSemantics, UniformSemantic, UniqueSemantics};
use crate::reflect::ReflectShader;
use librashader_preprocess::ShaderSource; use librashader_preprocess::ShaderSource;
use rustc_hash::FxHashMap;
#[test] #[test]
pub fn test_into() { pub fn test_into() {
let result = ShaderSource::load("../test/shaders_slang/crt/shaders/crt-royale/src/crt-royale-scanlines-horizontal-apply-mask.slang").unwrap(); // let result = ShaderSource::load("../test/shaders_slang/crt/shaders/crt-royale/src/crt-royale-scanlines-horizontal-apply-mask.slang").unwrap();
// let result = ShaderSource::load("../test/shaders_slang/crt/shaders/crt-royale/src/crt-royale-scanlines-horizontal-apply-mask.slang").unwrap();
let result = ShaderSource::load("../test/basic.slang").unwrap();
let mut uniform_semantics: FxHashMap<String, UniformSemantic> = Default::default();
for (_index, param) in result.parameters.iter().enumerate() {
uniform_semantics.insert(
param.1.id.clone(),
UniformSemantic::Unique(Semantic {
semantics: UniqueSemantics::FloatParameter,
index: (),
}),
);
}
let compilation = crate::front::GlslangCompilation::try_from(&result).unwrap(); let compilation = crate::front::GlslangCompilation::try_from(&result).unwrap();
let wgsl = WGSL::from_compilation(compilation).unwrap(); let mut wgsl = WGSL::from_compilation(compilation).unwrap();
let compiled = wgsl.compile(None).unwrap(); wgsl.reflect(
0,
&ShaderSemantics {
uniform_semantics,
texture_semantics: Default::default(),
},
)
.expect("");
let compiled = wgsl
.compile(WgslCompileOptions {
write_pcb_as_ubo: false,
sampler_bind_group: 1,
})
.unwrap();
println!("{}", compiled.vertex); println!("{}", compiled.vertex);
println!("{}", compiled.fragment); // println!("{}", compiled.fragment);
// let mut loader = rspirv::dr::Loader::new(); // let mut loader = rspirv::dr::Loader::new();
// rspirv::binary::parse_words(compilation.vertex.as_binary(), &mut loader).unwrap(); // rspirv::binary::parse_words(compilation.vertex.as_binary(), &mut loader).unwrap();
// let module = loader.module(); // let module = loader.module();

View file

@ -43,6 +43,8 @@ pub enum SemanticsErrorKind {
/// The number of uniform buffers was invalid. Only one UBO is permitted. /// The number of uniform buffers was invalid. Only one UBO is permitted.
InvalidUniformBufferCount(usize), InvalidUniformBufferCount(usize),
/// The number of push constant blocks was invalid. Only one push constant block is permitted. /// The number of push constant blocks was invalid. Only one push constant block is permitted.
InvalidPushBufferCount(usize),
/// The size of the push constant block was invalid.
InvalidPushBufferSize(u32), InvalidPushBufferSize(u32),
/// The location of a varying was invalid. /// The location of a varying was invalid.
InvalidLocation(u32), InvalidLocation(u32),
@ -52,12 +54,16 @@ pub enum SemanticsErrorKind {
InvalidInputCount(usize), InvalidInputCount(usize),
/// The number of outputs declared was invalid. /// The number of outputs declared was invalid.
InvalidOutputCount(usize), InvalidOutputCount(usize),
/// Expected a binding but there was none.
MissingBinding,
/// The declared binding point was invalid. /// The declared binding point was invalid.
InvalidBinding(u32), InvalidBinding(u32),
/// The declared resource type was invalid. /// The declared resource type was invalid.
InvalidResourceType, InvalidResourceType,
/// The range of a struct member was invalid. /// The range of a struct member was invalid.
InvalidRange(u32), InvalidRange(u32),
/// The number of entry points in the shader was invalid.
InvalidEntryPointCount(usize),
/// The requested uniform or texture name was not provided semantics. /// The requested uniform or texture name was not provided semantics.
UnknownSemantics(String), UnknownSemantics(String),
/// The type of the requested uniform was not compatible with the provided semantics. /// The type of the requested uniform was not compatible with the provided semantics.
@ -104,7 +110,6 @@ pub enum ShaderReflectError {
/// The binding number is already in use. /// The binding number is already in use.
#[error("the binding is already in use")] #[error("the binding is already in use")]
BindingInUse(u32), BindingInUse(u32),
/// Error when transpiling from naga /// Error when transpiling from naga
#[cfg(feature = "wgsl")] #[cfg(feature = "wgsl")]
#[error("naga-spv")] #[error("naga-spv")]

View file

@ -208,6 +208,7 @@ where
)); ));
} }
// Ensure that vertex attributes use location 0 and 1
let vert_mask = vertex_res.stage_inputs.iter().try_fold(0, |mask, input| { let vert_mask = vertex_res.stage_inputs.iter().try_fold(0, |mask, input| {
Ok::<u32, ErrorCode>( Ok::<u32, ErrorCode>(
mask | 1 << self.vertex.get_decoration(input.id, Decoration::Location)?, mask | 1 << self.vertex.get_decoration(input.id, Decoration::Location)?,
@ -227,9 +228,7 @@ where
if vertex_res.push_constant_buffers.len() > 1 { if vertex_res.push_constant_buffers.len() > 1 {
return Err(ShaderReflectError::VertexSemanticError( return Err(ShaderReflectError::VertexSemanticError(
SemanticsErrorKind::InvalidUniformBufferCount( SemanticsErrorKind::InvalidPushBufferCount(vertex_res.push_constant_buffers.len()),
vertex_res.push_constant_buffers.len(),
),
)); ));
} }
@ -241,7 +240,7 @@ where
if fragment_res.push_constant_buffers.len() > 1 { if fragment_res.push_constant_buffers.len() > 1 {
return Err(ShaderReflectError::FragmentSemanticError( return Err(ShaderReflectError::FragmentSemanticError(
SemanticsErrorKind::InvalidUniformBufferCount( SemanticsErrorKind::InvalidPushBufferCount(
fragment_res.push_constant_buffers.len(), fragment_res.push_constant_buffers.len(),
), ),
)); ));

View file

@ -1,9 +1,18 @@
use crate::error::ShaderReflectError; use crate::error::{SemanticsErrorKind, ShaderReflectError};
use naga::Module; use naga::{
AddressSpace, Binding, GlobalVariable, Handle, ImageClass, Module, ResourceBinding, ScalarKind,
TypeInner, VectorSize,
};
use crate::reflect::semantics::ShaderSemantics; use crate::reflect::helper::{SemanticErrorBlame, TextureData, UboData};
use crate::reflect::{ReflectShader, ShaderReflection}; use crate::reflect::semantics::{
BindingMeta, BindingStage, MemberOffset, PushReflection, ShaderSemantics, TextureBinding,
TextureSemanticMap, TextureSemantics, TextureSizeMeta, TypeInfo, UboReflection,
UniformMemberBlock, UniqueSemanticMap, UniqueSemantics, ValidateTypeSemantics, VariableMeta,
MAX_BINDINGS_COUNT, MAX_PUSH_BUFFER_SIZE,
};
use crate::reflect::{align_uniform_size, ReflectShader, ShaderReflection};
#[derive(Debug)] #[derive(Debug)]
pub struct NagaReflect { pub struct NagaReflect {
@ -11,33 +20,668 @@ pub struct NagaReflect {
pub(crate) fragment: Module, pub(crate) fragment: Module,
} }
// impl ValidateTypeSemantics<&TypeInner> for UniqueSemantics {
// struct UboData { fn validate_type(&self, ty: &&TypeInner) -> Option<TypeInfo> {
// // id: u32, let (TypeInner::Vector { .. } | TypeInner::Scalar { .. } | TypeInner::Matrix { .. }) = *ty
// // descriptor_set: u32, else {
// binding: u32, return None;
// size: u32, };
// }
// match self {
// struct Ubo { UniqueSemantics::MVP => {
// members: Vec<StructMember>, if matches!(ty, TypeInner::Matrix { columns, rows, width } if *columns == VectorSize::Quad
// span: u32, && *rows == VectorSize::Quad && *width == 4)
// } {
// return Some(TypeInfo {
// impl TryFrom<naga::Type> for Ubo { size: 4,
// type Error = Infallible; columns: 4,
// });
// fn try_from(value: Type) -> Result<Self, Infallible> { }
// match value.inner { }
// TypeInner::Struct { members, span } => Ok(Ubo { members, span }), UniqueSemantics::FrameCount => {
// // todo: make this programmer error // Uint32 == width 4
// _ => panic!(), if matches!(ty, TypeInner::Scalar { kind, width } if *kind == ScalarKind::Uint && *width == 4)
// } {
// } return Some(TypeInfo {
// } size: 1,
columns: 1,
});
}
}
UniqueSemantics::FrameDirection => {
// Uint32 == width 4
if matches!(ty, TypeInner::Scalar { kind, width } if *kind == ScalarKind::Sint && *width == 4)
{
return Some(TypeInfo {
size: 1,
columns: 1,
});
}
}
UniqueSemantics::FloatParameter => {
// Float32 == width 4
if matches!(ty, TypeInner::Scalar { kind, width } if *kind == ScalarKind::Float && *width == 4)
{
return Some(TypeInfo {
size: 1,
columns: 1,
});
}
}
_ => {
if matches!(ty, TypeInner::Vector { kind, width, size } if *kind == ScalarKind::Float && *width == 4 && *size == VectorSize::Quad)
{
return Some(TypeInfo {
size: 4,
columns: 1,
});
}
}
};
return None;
}
}
impl ValidateTypeSemantics<&TypeInner> for TextureSemantics {
fn validate_type(&self, ty: &&TypeInner) -> Option<TypeInfo> {
let TypeInner::Vector { size, kind, width } = ty else {
return None;
};
if *kind == ScalarKind::Float && *width == 4 && *size == VectorSize::Quad {
return Some(TypeInfo {
size: 4,
columns: 1,
});
}
None
}
}
impl NagaReflect { impl NagaReflect {
fn reflect_ubos(
&mut self,
vertex_ubo: Option<Handle<GlobalVariable>>,
fragment_ubo: Option<Handle<GlobalVariable>>,
) -> Result<Option<UboReflection>, ShaderReflectError> {
if let Some(vertex_ubo) = vertex_ubo {
let ubo = &mut self.vertex.global_variables[vertex_ubo];
ubo.binding = Some(ResourceBinding {
group: 0,
binding: 0,
})
}
if let Some(fragment_ubo) = fragment_ubo {
let ubo = &mut self.fragment.global_variables[fragment_ubo];
ubo.binding = Some(ResourceBinding {
group: 0,
binding: 0,
})
}
// todo: merge this with the spirv-cross code
match (vertex_ubo, fragment_ubo) {
(None, None) => Ok(None),
(Some(vertex_ubo), Some(fragment_ubo)) => {
let vertex_ubo = Self::get_ubo_data(
&self.vertex,
&self.vertex.global_variables[vertex_ubo],
SemanticErrorBlame::Vertex,
)?;
let fragment_ubo = Self::get_ubo_data(
&self.fragment,
&self.fragment.global_variables[fragment_ubo],
SemanticErrorBlame::Fragment,
)?;
if vertex_ubo.binding != fragment_ubo.binding {
return Err(ShaderReflectError::MismatchedUniformBuffer {
vertex: vertex_ubo.binding,
fragment: fragment_ubo.binding,
});
}
let size = std::cmp::max(vertex_ubo.size, fragment_ubo.size);
Ok(Some(UboReflection {
binding: vertex_ubo.binding,
size: align_uniform_size(size),
stage_mask: BindingStage::VERTEX | BindingStage::FRAGMENT,
}))
}
(Some(vertex_ubo), None) => {
let vertex_ubo = Self::get_ubo_data(
&self.vertex,
&self.vertex.global_variables[vertex_ubo],
SemanticErrorBlame::Vertex,
)?;
Ok(Some(UboReflection {
binding: vertex_ubo.binding,
size: align_uniform_size(vertex_ubo.size),
stage_mask: BindingStage::VERTEX,
}))
}
(None, Some(fragment_ubo)) => {
let fragment_ubo = Self::get_ubo_data(
&self.fragment,
&self.fragment.global_variables[fragment_ubo],
SemanticErrorBlame::Fragment,
)?;
Ok(Some(UboReflection {
binding: fragment_ubo.binding,
size: align_uniform_size(fragment_ubo.size),
stage_mask: BindingStage::FRAGMENT,
}))
}
}
}
fn get_ubo_data(
module: &Module,
ubo: &GlobalVariable,
blame: SemanticErrorBlame,
) -> Result<UboData, ShaderReflectError> {
let Some(binding) = &ubo.binding else {
return Err(blame.error(SemanticsErrorKind::MissingBinding));
};
if binding.binding >= MAX_BINDINGS_COUNT {
return Err(blame.error(SemanticsErrorKind::InvalidBinding(binding.binding)));
}
if binding.group != 0 {
return Err(blame.error(SemanticsErrorKind::InvalidDescriptorSet(binding.group)));
}
let ty = &module.types[ubo.ty];
let size = ty.inner.size(module.to_ctx());
Ok(UboData {
// descriptor_set,
// id: ubo.id,
binding: binding.binding,
size,
})
}
fn get_push_size(
module: &Module,
push: &GlobalVariable,
blame: SemanticErrorBlame,
) -> Result<u32, ShaderReflectError> {
let ty = &module.types[push.ty];
let size = ty.inner.size(module.to_ctx());
if size > MAX_PUSH_BUFFER_SIZE {
return Err(blame.error(SemanticsErrorKind::InvalidPushBufferSize(size)));
}
Ok(size)
}
fn reflect_push_constant_buffer(
&mut self,
vertex_pcb: Option<Handle<GlobalVariable>>,
fragment_pcb: Option<Handle<GlobalVariable>>,
) -> Result<Option<PushReflection>, ShaderReflectError> {
// Reassign to UBO later if we want during compilation.
if let Some(vertex_pcb) = vertex_pcb {
let ubo = &mut self.vertex.global_variables[vertex_pcb];
ubo.binding = Some(ResourceBinding {
group: 0,
binding: 1,
});
}
if let Some(fragment_pcb) = fragment_pcb {
let ubo = &mut self.fragment.global_variables[fragment_pcb];
ubo.binding = Some(ResourceBinding {
group: 0,
binding: 1,
});
};
match (vertex_pcb, fragment_pcb) {
(None, None) => Ok(None),
(Some(vertex_push), Some(fragment_push)) => {
let vertex_size = Self::get_push_size(
&self.vertex,
&self.vertex.global_variables[vertex_push],
SemanticErrorBlame::Vertex,
)?;
let fragment_size = Self::get_push_size(
&self.fragment,
&self.fragment.global_variables[fragment_push],
SemanticErrorBlame::Fragment,
)?;
let size = std::cmp::max(vertex_size, fragment_size);
Ok(Some(PushReflection {
size: align_uniform_size(size),
stage_mask: BindingStage::VERTEX | BindingStage::FRAGMENT,
}))
}
(Some(vertex_push), None) => {
let vertex_size = Self::get_push_size(
&self.vertex,
&self.vertex.global_variables[vertex_push],
SemanticErrorBlame::Vertex,
)?;
Ok(Some(PushReflection {
size: align_uniform_size(vertex_size),
stage_mask: BindingStage::VERTEX,
}))
}
(None, Some(fragment_push)) => {
let fragment_size = Self::get_push_size(
&self.fragment,
&self.fragment.global_variables[fragment_push],
SemanticErrorBlame::Fragment,
)?;
Ok(Some(PushReflection {
size: align_uniform_size(fragment_size),
stage_mask: BindingStage::FRAGMENT,
}))
}
}
}
fn validate(&self) -> Result<(), ShaderReflectError> {
// Verify types
if self.vertex.global_variables.iter().any(|(_, gv)| {
let ty = &self.vertex.types[gv.ty];
match ty.inner {
TypeInner::Scalar { .. }
| TypeInner::Vector { .. }
| TypeInner::Matrix { .. }
| TypeInner::Struct { .. } => false,
_ => true,
}
}) {
return Err(ShaderReflectError::VertexSemanticError(
SemanticsErrorKind::InvalidResourceType,
));
}
if self.fragment.global_variables.iter().any(|(_, gv)| {
let ty = &self.fragment.types[gv.ty];
match ty.inner {
TypeInner::Scalar { .. }
| TypeInner::Vector { .. }
| TypeInner::Matrix { .. }
| TypeInner::Struct { .. }
| TypeInner::Image { .. }
| TypeInner::Sampler { .. } => false,
TypeInner::BindingArray { base, .. } => {
let ty = &self.fragment.types[base];
match ty.inner {
TypeInner::Image { class, .. }
if !matches!(class, ImageClass::Storage { .. }) =>
{
false
}
TypeInner::Sampler { .. } => false,
_ => true,
}
}
_ => true,
}
}) {
return Err(ShaderReflectError::FragmentSemanticError(
SemanticsErrorKind::InvalidResourceType,
));
}
// Verify Vertex inputs
'vertex: {
if self.vertex.entry_points.len() != 1 {
return Err(ShaderReflectError::VertexSemanticError(
SemanticsErrorKind::InvalidEntryPointCount(self.vertex.entry_points.len()),
));
}
let vertex_entry_point = &self.vertex.entry_points[0];
let vert_inputs = vertex_entry_point.function.arguments.len();
if vert_inputs != 2 {
return Err(ShaderReflectError::VertexSemanticError(
SemanticsErrorKind::InvalidInputCount(vert_inputs),
));
}
for input in &vertex_entry_point.function.arguments {
let &Some(Binding::Location { location, .. }) = &input.binding else {
return Err(ShaderReflectError::VertexSemanticError(
SemanticsErrorKind::MissingBinding,
));
};
if location == 0 {
let pos_type = &self.vertex.types[input.ty];
if !matches!(pos_type.inner, TypeInner::Vector { size, ..} if size == VectorSize::Quad)
{
return Err(ShaderReflectError::VertexSemanticError(
SemanticsErrorKind::InvalidLocation(location),
));
}
break 'vertex;
}
if location == 1 {
let coord_type = &self.vertex.types[input.ty];
if !matches!(coord_type.inner, TypeInner::Vector { size, ..} if size == VectorSize::Bi)
{
return Err(ShaderReflectError::VertexSemanticError(
SemanticsErrorKind::InvalidLocation(location),
));
}
break 'vertex;
}
return Err(ShaderReflectError::VertexSemanticError(
SemanticsErrorKind::InvalidLocation(location),
));
}
let uniform_buffer_count = self
.vertex
.global_variables
.iter()
.filter(|(_, gv)| gv.space == AddressSpace::Uniform)
.count();
if uniform_buffer_count > 1 {
return Err(ShaderReflectError::VertexSemanticError(
SemanticsErrorKind::InvalidUniformBufferCount(uniform_buffer_count),
));
}
let push_buffer_count = self
.vertex
.global_variables
.iter()
.filter(|(_, gv)| gv.space == AddressSpace::PushConstant)
.count();
if push_buffer_count > 1 {
return Err(ShaderReflectError::VertexSemanticError(
SemanticsErrorKind::InvalidPushBufferCount(push_buffer_count),
));
}
}
{
if self.fragment.entry_points.len() != 1 {
return Err(ShaderReflectError::FragmentSemanticError(
SemanticsErrorKind::InvalidEntryPointCount(self.vertex.entry_points.len()),
));
}
let frag_entry_point = &self.fragment.entry_points[0];
let Some(frag_output) = &frag_entry_point.function.result else {
return Err(ShaderReflectError::FragmentSemanticError(
SemanticsErrorKind::InvalidOutputCount(0),
));
};
let &Some(Binding::Location { location, .. }) = &frag_output.binding else {
return Err(ShaderReflectError::VertexSemanticError(
SemanticsErrorKind::MissingBinding,
));
};
if location != 0 {
return Err(ShaderReflectError::FragmentSemanticError(
SemanticsErrorKind::InvalidLocation(location),
));
}
let uniform_buffer_count = self
.fragment
.global_variables
.iter()
.filter(|(_, gv)| gv.space == AddressSpace::Uniform)
.count();
if uniform_buffer_count > 1 {
return Err(ShaderReflectError::FragmentSemanticError(
SemanticsErrorKind::InvalidUniformBufferCount(uniform_buffer_count),
));
}
let push_buffer_count = self
.fragment
.global_variables
.iter()
.filter(|(_, gv)| gv.space == AddressSpace::PushConstant)
.count();
if push_buffer_count > 1 {
return Err(ShaderReflectError::FragmentSemanticError(
SemanticsErrorKind::InvalidPushBufferCount(push_buffer_count),
));
}
}
Ok(())
}
fn reflect_buffer_struct_members(
module: &Module,
resource: Handle<GlobalVariable>,
pass_number: usize,
semantics: &ShaderSemantics,
meta: &mut BindingMeta,
offset_type: UniformMemberBlock,
blame: SemanticErrorBlame,
) -> Result<(), ShaderReflectError> {
let resource = &module.global_variables[resource];
let TypeInner::Struct { members, .. } = &module.types[resource.ty].inner else {
return Err(blame.error(SemanticsErrorKind::InvalidResourceType));
};
for member in members {
let Some(name) = member.name.clone() else {
return Err(blame.error(SemanticsErrorKind::InvalidRange(member.offset)));
};
let member_type = &module.types[member.ty].inner;
if let Some(parameter) = semantics.uniform_semantics.get_unique_semantic(&name) {
let Some(typeinfo) = parameter.semantics.validate_type(&member_type) else {
return Err(blame.error(SemanticsErrorKind::InvalidTypeForSemantic(name)));
};
match &parameter.semantics {
UniqueSemantics::FloatParameter => {
let offset = member.offset;
if let Some(meta) = meta.parameter_meta.get_mut(&name) {
if let Some(expected) = meta.offset.offset(offset_type)
&& expected != offset as usize
{
return Err(ShaderReflectError::MismatchedOffset {
semantic: name,
expected,
received: offset as usize,
ty: offset_type,
pass: pass_number,
});
}
if meta.size != typeinfo.size {
return Err(ShaderReflectError::MismatchedSize {
semantic: name,
vertex: meta.size,
fragment: typeinfo.size,
pass: pass_number,
});
}
*meta.offset.offset_mut(offset_type) = Some(offset as usize);
} else {
meta.parameter_meta.insert(
name.clone(),
VariableMeta {
id: name,
offset: MemberOffset::new(offset as usize, offset_type),
size: typeinfo.size,
},
);
}
}
semantics => {
let offset = member.offset;
if let Some(meta) = meta.unique_meta.get_mut(semantics) {
if let Some(expected) = meta.offset.offset(offset_type)
&& expected != offset as usize
{
return Err(ShaderReflectError::MismatchedOffset {
semantic: name,
expected,
received: offset as usize,
ty: offset_type,
pass: pass_number,
});
}
if meta.size != typeinfo.size * typeinfo.columns {
return Err(ShaderReflectError::MismatchedSize {
semantic: name,
vertex: meta.size,
fragment: typeinfo.size,
pass: pass_number,
});
}
*meta.offset.offset_mut(offset_type) = Some(offset as usize);
} else {
meta.unique_meta.insert(
*semantics,
VariableMeta {
id: name,
offset: MemberOffset::new(offset as usize, offset_type),
size: typeinfo.size * typeinfo.columns,
},
);
}
}
}
} else if let Some(texture) = semantics.uniform_semantics.get_texture_semantic(&name) {
let Some(_typeinfo) = texture.semantics.validate_type(&member_type) else {
return Err(blame.error(SemanticsErrorKind::InvalidTypeForSemantic(name)));
};
if let TextureSemantics::PassOutput = texture.semantics {
if texture.index >= pass_number {
return Err(ShaderReflectError::NonCausalFilterChain {
pass: pass_number,
target: texture.index,
});
}
}
let offset = member.offset;
if let Some(meta) = meta.texture_size_meta.get_mut(&texture) {
if let Some(expected) = meta.offset.offset(offset_type)
&& expected != offset as usize
{
return Err(ShaderReflectError::MismatchedOffset {
semantic: name,
expected,
received: offset as usize,
ty: offset_type,
pass: pass_number,
});
}
meta.stage_mask.insert(match blame {
SemanticErrorBlame::Vertex => BindingStage::VERTEX,
SemanticErrorBlame::Fragment => BindingStage::FRAGMENT,
});
*meta.offset.offset_mut(offset_type) = Some(offset as usize);
} else {
meta.texture_size_meta.insert(
texture,
TextureSizeMeta {
offset: MemberOffset::new(offset as usize, offset_type),
stage_mask: match blame {
SemanticErrorBlame::Vertex => BindingStage::VERTEX,
SemanticErrorBlame::Fragment => BindingStage::FRAGMENT,
},
id: name,
},
);
}
} else {
return Err(blame.error(SemanticsErrorKind::UnknownSemantics(name)));
}
}
Ok(())
}
fn reflect_texture<'a>(
&'a self,
texture: &'a GlobalVariable,
) -> Result<TextureData<'a>, ShaderReflectError> {
let Some(binding) = &texture.binding else {
return Err(ShaderReflectError::FragmentSemanticError(
SemanticsErrorKind::MissingBinding,
));
};
let Some(name) = texture.name.as_ref() else {
return Err(ShaderReflectError::FragmentSemanticError(
SemanticsErrorKind::InvalidBinding(binding.binding),
));
};
if binding.group != 0 {
return Err(ShaderReflectError::FragmentSemanticError(
SemanticsErrorKind::InvalidDescriptorSet(binding.group),
));
}
if binding.binding >= MAX_BINDINGS_COUNT {
return Err(ShaderReflectError::FragmentSemanticError(
SemanticsErrorKind::InvalidBinding(binding.binding),
));
}
Ok(TextureData {
// id: texture.id,
// descriptor_set,
name: &name,
binding: binding.binding,
})
}
// todo: share this with cross
fn reflect_texture_metas(
&self,
texture: TextureData,
pass_number: usize,
semantics: &ShaderSemantics,
meta: &mut BindingMeta,
) -> Result<(), ShaderReflectError> {
let Some(semantic) = semantics
.texture_semantics
.get_texture_semantic(texture.name)
else {
return Err(
SemanticErrorBlame::Fragment.error(SemanticsErrorKind::UnknownSemantics(
texture.name.to_string(),
)),
);
};
if semantic.semantics == TextureSemantics::PassOutput && semantic.index >= pass_number {
return Err(ShaderReflectError::NonCausalFilterChain {
pass: pass_number,
target: semantic.index,
});
}
meta.texture_meta.insert(
semantic,
TextureBinding {
binding: texture.binding,
},
);
Ok(())
}
} }
impl ReflectShader for NagaReflect { impl ReflectShader for NagaReflect {
@ -46,7 +690,136 @@ impl ReflectShader for NagaReflect {
pass_number: usize, pass_number: usize,
semantics: &ShaderSemantics, semantics: &ShaderSemantics,
) -> Result<ShaderReflection, ShaderReflectError> { ) -> Result<ShaderReflection, ShaderReflectError> {
todo!() self.validate()?;
// Validate verifies that there's only one uniform block.
let vertex_ubo = self
.vertex
.global_variables
.iter()
.find_map(|(handle, gv)| {
if gv.space == AddressSpace::Uniform {
Some(handle)
} else {
None
}
});
let fragment_ubo = self
.fragment
.global_variables
.iter()
.find_map(|(handle, gv)| {
if gv.space == AddressSpace::Uniform {
Some(handle)
} else {
None
}
});
let ubo = self.reflect_ubos(vertex_ubo, fragment_ubo)?;
let vertex_push = self
.vertex
.global_variables
.iter()
.find_map(|(handle, gv)| {
if gv.space == AddressSpace::PushConstant {
Some(handle)
} else {
None
}
});
let fragment_push = self
.fragment
.global_variables
.iter()
.find_map(|(handle, gv)| {
if gv.space == AddressSpace::PushConstant {
Some(handle)
} else {
None
}
});
let push_constant = self.reflect_push_constant_buffer(vertex_push, fragment_push)?;
let mut meta = BindingMeta::default();
if let Some(ubo) = vertex_ubo {
Self::reflect_buffer_struct_members(
&self.vertex,
ubo,
pass_number,
semantics,
&mut meta,
UniformMemberBlock::Ubo,
SemanticErrorBlame::Vertex,
)?;
}
if let Some(ubo) = fragment_ubo {
Self::reflect_buffer_struct_members(
&self.fragment,
ubo,
pass_number,
semantics,
&mut meta,
UniformMemberBlock::Ubo,
SemanticErrorBlame::Fragment,
)?;
}
if let Some(push) = vertex_push {
Self::reflect_buffer_struct_members(
&self.vertex,
push,
pass_number,
semantics,
&mut meta,
UniformMemberBlock::PushConstant,
SemanticErrorBlame::Vertex,
)?;
}
if let Some(push) = fragment_push {
Self::reflect_buffer_struct_members(
&self.fragment,
push,
pass_number,
semantics,
&mut meta,
UniformMemberBlock::PushConstant,
SemanticErrorBlame::Fragment,
)?;
}
let mut ubo_bindings = 0u16;
if vertex_ubo.is_some() || fragment_ubo.is_some() {
ubo_bindings = 1 << ubo.as_ref().expect("UBOs should be present").binding;
}
let textures = self.fragment.global_variables.iter().filter(|(_, gv)| {
let ty = &self.fragment.types[gv.ty];
matches!(ty.inner, TypeInner::Image { .. })
});
for (_, texture) in textures {
let texture_data = self.reflect_texture(texture)?;
if ubo_bindings & (1 << texture_data.binding) != 0 {
return Err(ShaderReflectError::BindingInUse(texture_data.binding));
}
ubo_bindings |= 1 << texture_data.binding;
self.reflect_texture_metas(texture_data, pass_number, semantics, &mut meta)?;
}
Ok(ShaderReflection {
ubo,
push_constant,
meta,
})
} }
} }