reflect(msl): naga msl implementation

This commit is contained in:
chyyran 2024-02-11 15:39:11 -05:00 committed by Ronny Chan
parent d0a5224c10
commit 4762055dc1
12 changed files with 254 additions and 59 deletions

View file

@ -6,9 +6,8 @@ use librashader::reflect::targets::SPIRV;
use librashader::reflect::{CompileShader, ReflectShader, ShaderCompilerOutput, ShaderReflection};
use librashader::{FilterMode, WrapMode};
use librashader::reflect::SpirvCompilation;
use librashader::reflect::f
use librashader::reflect::helper::image::{Image, UVDirection, RGBA8};
use librashader::reflect::SpirvCompilation;
pub(crate) struct LookupTexture {
wrap_mode: WrapMode,

View file

@ -25,7 +25,7 @@ librashader-presets = { path = "../librashader-presets", version = "0.2.0-beta.9
spirv_cross = { package = "librashader-spirv-cross", version = "0.24", optional = true }
naga = { version = "0.19.0", features = ["spv-in", "spv-out"], optional = true }
naga = { version = "0.19.0", optional = true }
rspirv = { version = "0.12.0", optional = true }
spirv = { version = "0.3.0", optional = true}
@ -39,10 +39,10 @@ version = "0.4"
optional = true
[features]
default = ["cross", "naga", "serialize"]
default = ["cross", "naga", "serialize", "wgsl", "msl"]
dxil = ["spirv_cross/hlsl", "spirv-to-dxil"]
wgsl = ["cross", "naga/wgsl-out", "spirv", "rspirv"]
cross = [ "spirv_cross", "spirv_cross/glsl", "spirv_cross/hlsl", "spirv_cross/msl" ]
naga = [ "wgsl" ]
naga = [ "rspirv", "spirv", "naga/spv-in", "naga/spv-out", "naga/wgsl-out", "naga/msl-out" ]
serialize = [ "serde" ]
msl = [ "spirv_cross/msl", "naga/msl-out" ]

View file

@ -126,7 +126,7 @@ where
#[cfg(test)]
mod test {
use crate::front::{Glslang, ShaderInputCompiler, SpirvCompilation};
use crate::front::{Glslang, ShaderInputCompiler};
use librashader_preprocess::ShaderSource;
pub fn test() {

View file

@ -6,6 +6,8 @@ use crate::reflect::cross::msl::MslReflect;
use crate::reflect::cross::{CompiledProgram, SpirvCross};
use crate::reflect::naga::{Naga, NagaReflect};
use crate::reflect::ReflectShader;
use naga::back::msl::TranslationInfo;
use naga::Module;
/// The HLSL shader model version to target.
pub use spirv_cross::msl::Version as MslVersion;
@ -25,7 +27,7 @@ pub struct CrossMslContext {
impl FromCompilation<SpirvCompilation, SpirvCross> for MSL {
type Target = MSL;
type Options = Option<spirv_cross::msl::Version>;
type Options = Option<self::MslVersion>;
type Context = CrossMslContext;
type Output = impl CompileShader<Self::Target, Options = Self::Options, Context = Self::Context>
+ ReflectShader;
@ -39,10 +41,21 @@ impl FromCompilation<SpirvCompilation, SpirvCross> for MSL {
}
}
/// The naga module for a shader after compilation
pub struct NagaMslModule {
pub translation_info: TranslationInfo,
pub module: Module,
}
pub struct NagaMslContext {
pub vertex: NagaMslModule,
pub fragment: NagaMslModule,
}
impl FromCompilation<SpirvCompilation, Naga> for MSL {
type Target = MSL;
type Options = ();
type Context = ();
type Options = Option<self::MslVersion>;
type Context = NagaMslContext;
type Output = impl CompileShader<Self::Target, Options = Self::Options, Context = Self::Context>
+ ReflectShader;

View file

@ -35,9 +35,9 @@ mod test {
use crate::reflect::naga::NagaLoweringOptions;
use crate::reflect::semantics::{Semantic, ShaderSemantics, UniformSemantic, UniqueSemantics};
use crate::reflect::ReflectShader;
use bitflags::Flags;
use librashader_preprocess::ShaderSource;
use rustc_hash::FxHashMap;
use bitflags::Flags;
#[test]
pub fn test_into() {

View file

@ -35,10 +35,15 @@ pub enum ShaderCompileError {
/// Error when transpiling from naga
#[cfg(feature = "naga")]
#[error("naga-spv")]
NagaGlslError(#[from] naga::back::spv::Error),
NagaSpvError(#[from] naga::back::spv::Error),
/// Error when transpiling from naga
#[cfg(feature = "wgsl")]
#[cfg(all(feature = "naga", feature = "msl"))]
#[error("naga-spv")]
NagaMslError(#[from] naga::back::msl::Error),
/// Error when transpiling from naga
#[cfg(any(feature = "naga", feature = "wgsl"))]
#[error("naga-wgsl")]
NagaValidationError(#[from] naga::WithSpan<naga::valid::ValidationError>),
}

View file

@ -1,5 +1,3 @@
use std::collections::BTreeMap;
use naga::{Module};
use crate::back::msl::CrossMslContext;
use crate::back::targets::MSL;
use crate::back::{CompileShader, ShaderCompilerOutput};
@ -8,6 +6,7 @@ use crate::reflect::cross::{CompiledAst, CompiledProgram, CrossReflect};
use spirv_cross::msl;
use spirv_cross::msl::{ResourceBinding, ResourceBindingLocation};
use spirv_cross::spirv::{Ast, Decoration, ExecutionModel};
use std::collections::BTreeMap;
pub(crate) type MslReflect = CrossReflect<spirv_cross::msl::Target>;
@ -26,7 +25,11 @@ impl CompileShader<MSL> for CrossReflect<spirv_cross::msl::Target> {
vert_options.version = version;
frag_options.version = version;
fn get_binding(ast: &Ast<msl::Target>, stage: ExecutionModel, binding_map: &mut BTreeMap<ResourceBindingLocation, ResourceBinding>) -> Result<(), ShaderCompileError>{
fn set_bindings(
ast: &Ast<msl::Target>,
stage: ExecutionModel,
binding_map: &mut BTreeMap<ResourceBindingLocation, ResourceBinding>,
) -> Result<(), ShaderCompileError> {
let resources = ast.get_shader_resources()?;
for resource in &resources.push_constant_buffers {
let location = ResourceBindingLocation {
@ -45,7 +48,11 @@ impl CompileShader<MSL> for CrossReflect<spirv_cross::msl::Target> {
binding_map.insert(location, overridden);
}
for resource in resources.uniform_buffers.iter().chain(resources.sampled_images.iter()) {
for resource in resources
.uniform_buffers
.iter()
.chain(resources.sampled_images.iter())
{
let binding = ast.get_decoration(resource.id, Decoration::Binding)?;
let location = ResourceBindingLocation {
stage,
@ -66,19 +73,18 @@ impl CompileShader<MSL> for CrossReflect<spirv_cross::msl::Target> {
Ok(())
}
get_binding(
set_bindings(
&self.vertex,
ExecutionModel::Vertex,
&mut vert_options.resource_binding_overrides
&mut vert_options.resource_binding_overrides,
)?;
get_binding(
set_bindings(
&self.fragment,
ExecutionModel::Fragment,
&mut frag_options.resource_binding_overrides
&mut frag_options.resource_binding_overrides,
)?;
eprintln!("{:?}", frag_options.resource_binding_overrides);
self.vertex.set_compiler_options(&vert_options)?;
self.fragment.set_compiler_options(&frag_options)?;
@ -97,17 +103,17 @@ impl CompileShader<MSL> for CrossReflect<spirv_cross::msl::Target> {
#[cfg(test)]
mod test {
use std::io::Write;
use crate::back::targets::{MSL, WGSL};
use crate::back::{CompileShader, FromCompilation};
use crate::reflect::cross::SpirvCross;
use crate::reflect::naga::{Naga, NagaLoweringOptions};
use crate::reflect::semantics::{Semantic, ShaderSemantics, UniformSemantic, UniqueSemantics};
use crate::reflect::ReflectShader;
use bitflags::Flags;
use librashader_preprocess::ShaderSource;
use rustc_hash::FxHashMap;
use bitflags::Flags;
use spirv_cross::msl;
use crate::reflect::cross::SpirvCross;
use std::io::Write;
#[test]
pub fn test_into() {
@ -129,7 +135,8 @@ mod test {
let compilation = crate::front::SpirvCompilation::try_from(&result).unwrap();
let mut msl = <MSL as FromCompilation<_, SpirvCross>>::from_compilation(compilation).unwrap();
let mut msl =
<MSL as FromCompilation<_, SpirvCross>>::from_compilation(compilation).unwrap();
msl.reflect(
0,
@ -138,25 +145,10 @@ mod test {
texture_semantics: Default::default(),
},
)
.expect("");
.expect("");
let compiled = msl
.compile(Some(msl::Version::V2_0))
.unwrap();
let compiled = msl.compile(Some(msl::Version::V2_0)).unwrap();
println!("{}", compiled.fragment);
// println!("{}", compiled.fragment);
// let mut loader = rspirv::dr::Loader::new();
// rspirv::binary::parse_words(compilation.vertex.as_binary(), &mut loader).unwrap();
// let module = loader.module();
//
// let outputs: Vec<&Instruction> = module
// .types_global_values
// .iter()
// .filter(|i| i.class.opcode == Op::Variable)
// .collect();
//
// println!("{outputs:#?}");
}
}
}

View file

@ -37,7 +37,10 @@ pub(crate) struct NagaReflect {
/// Options to lower samplers and pcbs
#[derive(Debug, Default, Clone)]
pub struct NagaLoweringOptions {
/// Whether to write the PCB as a UBO.
pub write_pcb_as_ubo: bool,
/// The bind group to assign samplers to. This is to ensure that samplers will
/// maintain the same bindings as textures.
pub sampler_bind_group: u32,
}
@ -360,13 +363,13 @@ impl NagaReflect {
let binding = self.get_next_binding(0);
// Reassign to UBO later if we want during compilation.
if let Some(vertex_pcb) = vertex_pcb {
let ubo = &mut self.vertex.global_variables[vertex_pcb];
ubo.binding = Some(ResourceBinding { group: 0, binding });
let pcb = &mut self.vertex.global_variables[vertex_pcb];
pcb.binding = Some(ResourceBinding { group: 0, binding });
}
if let Some(fragment_pcb) = fragment_pcb {
let ubo = &mut self.fragment.global_variables[fragment_pcb];
ubo.binding = Some(ResourceBinding { group: 0, binding });
let pcb = &mut self.fragment.global_variables[fragment_pcb];
pcb.binding = Some(ResourceBinding { group: 0, binding });
};
match (vertex_pcb, fragment_pcb) {

View file

@ -1,18 +1,201 @@
use naga::{Module, ResourceBinding};
use crate::back::msl::{MslVersion, NagaMslContext, NagaMslModule};
use crate::back::targets::MSL;
use crate::back::{CompileShader, ShaderCompilerOutput};
use crate::error::ShaderCompileError;
use crate::reflect::naga::NagaReflect;
use crate::reflect::naga::{NagaLoweringOptions, NagaReflect};
use naga::back::msl::{
BindSamplerTarget, BindTarget, EntryPointResources, Options, PipelineOptions, TranslationInfo,
};
use naga::valid::{Capabilities, ValidationFlags};
use naga::{Module, TypeInner};
use spirv_cross::msl::Version;
fn msl_version_to_naga_msl(version: MslVersion) -> (u8, u8) {
match version {
Version::V1_0 => (1, 0),
Version::V1_1 => (1, 1),
Version::V1_2 => (1, 2),
Version::V2_0 => (2, 0),
Version::V2_1 => (2, 1),
Version::V2_2 => (2, 2),
Version::V2_3 => (2, 3),
_ => (0, 0),
}
}
impl CompileShader<MSL> for NagaReflect {
type Options = ();
type Context = ();
type Options = Option<crate::back::msl::MslVersion>;
type Context = NagaMslContext;
fn compile(
self,
mut self,
options: Self::Options,
) -> Result<ShaderCompilerOutput<String, Self::Context>, ShaderCompileError> {
// https://github.com/libretro/RetroArch/blob/434e94c782af2e4d4277a24b7ed8e5fc54870088/gfx/drivers_shader/slang_process.cpp#L524
todo!()
let lang_version = msl_version_to_naga_msl(options.unwrap_or(MslVersion::V2_0));
let mut vert_options = Options {
lang_version,
per_entry_point_map: Default::default(),
inline_samplers: vec![],
spirv_cross_compatibility: true,
fake_missing_bindings: false,
bounds_check_policies: Default::default(),
zero_initialize_workgroup_memory: false,
};
let mut frag_options = vert_options.clone();
fn write_msl(
module: &Module,
options: Options,
) -> Result<(String, TranslationInfo), ShaderCompileError> {
let mut valid =
naga::valid::Validator::new(ValidationFlags::all(), Capabilities::empty());
let info = valid.validate(&module)?;
let pipeline_options = PipelineOptions {
allow_and_force_point_size: false,
};
let msl = naga::back::msl::write_string(&module, &info, &options, &pipeline_options)?;
Ok(msl)
}
fn generate_bindings(module: &Module) -> EntryPointResources {
let mut resources = EntryPointResources::default();
let binding_map = &mut resources.resources;
// Don't set PCB because they'll be gone after lowering..
// resources.push_constant_buffer = Some(1u8);
for (_, variable) in module.global_variables.iter() {
let Some(binding) = &variable.binding else {
continue;
};
let Ok(ty) = module.types.get_handle(variable.ty) else {
continue;
};
match ty.inner {
TypeInner::Sampler { .. } => {
binding_map.insert(
binding.clone(),
BindTarget {
buffer: None,
texture: None,
sampler: Some(BindSamplerTarget::Resource(binding.binding as u8)),
binding_array_size: None,
mutable: false,
},
);
}
TypeInner::Struct { .. } => {
binding_map.insert(
binding.clone(),
BindTarget {
buffer: Some(binding.binding as u8),
texture: None,
sampler: None,
binding_array_size: None,
mutable: false,
},
);
}
TypeInner::Image { .. } => {
binding_map.insert(
binding.clone(),
BindTarget {
buffer: None,
texture: Some(binding.binding as u8),
sampler: None,
binding_array_size: None,
mutable: false,
},
);
}
_ => continue,
}
}
resources
}
self.do_lowering(&NagaLoweringOptions {
write_pcb_as_ubo: true,
sampler_bind_group: 1,
});
frag_options
.per_entry_point_map
.insert(String::from("main"), generate_bindings(&self.fragment));
vert_options
.per_entry_point_map
.insert(String::from("main"), generate_bindings(&self.vertex));
let fragment = write_msl(&self.fragment, frag_options)?;
let vertex = write_msl(&self.vertex, vert_options)?;
Ok(ShaderCompilerOutput {
vertex: vertex.0,
fragment: fragment.0,
context: NagaMslContext {
fragment: NagaMslModule {
translation_info: fragment.1,
module: self.fragment,
},
vertex: NagaMslModule {
translation_info: vertex.1,
module: self.vertex,
},
},
})
}
}
#[cfg(test)]
mod test {
use crate::back::targets::MSL;
use crate::back::{CompileShader, FromCompilation};
use crate::reflect::naga::{Naga, NagaLoweringOptions};
use crate::reflect::semantics::{Semantic, ShaderSemantics, UniformSemantic, UniqueSemantics};
use crate::reflect::ReflectShader;
use bitflags::Flags;
use librashader_preprocess::ShaderSource;
use rustc_hash::FxHashMap;
use spirv_cross::msl;
#[test]
pub fn test_into() {
let result = ShaderSource::load("../test/basic.slang").unwrap();
let mut uniform_semantics: FxHashMap<String, UniformSemantic> = Default::default();
for (_index, param) in result.parameters.iter().enumerate() {
uniform_semantics.insert(
param.1.id.clone(),
UniformSemantic::Unique(Semantic {
semantics: UniqueSemantics::FloatParameter,
index: (),
}),
);
}
let compilation = crate::front::SpirvCompilation::try_from(&result).unwrap();
let mut msl = <MSL as FromCompilation<_, Naga>>::from_compilation(compilation).unwrap();
msl.reflect(
0,
&ShaderSemantics {
uniform_semantics,
texture_semantics: Default::default(),
},
)
.expect("");
let compiled = msl.compile(Some(msl::Version::V2_0)).unwrap();
println!("{}", compiled.fragment);
}
}

View file

@ -501,9 +501,9 @@ impl FilterChainD3D12 {
(dxil_reflection, graphics_pipeline)
} else {
let hlsl_reflection = hlsl.reflect(index, semantics)?;
let hlsl = hlsl.compile(
Some(librashader_reflect::back::hlsl::HlslShaderModel::V6_0)
)?;
let hlsl = hlsl.compile(Some(
librashader_reflect::back::hlsl::HlslShaderModel::V6_0,
))?;
let graphics_pipeline = D3D12GraphicsPipeline::new_from_hlsl(
device,

View file

@ -3,8 +3,8 @@ use crate::error::assume_d3d12_init;
use crate::error::FilterChainError::Direct3DOperationError;
use crate::{error, util};
use librashader_cache::{cache_pipeline, cache_shader_object};
use librashader_reflect::back::hlsl::CrossHlslContext;
use librashader_reflect::back::dxil::DxilObject;
use librashader_reflect::back::hlsl::CrossHlslContext;
use librashader_reflect::back::ShaderCompilerOutput;
use std::mem::ManuallyDrop;
use std::ops::Deref;

View file

@ -150,7 +150,7 @@ pub mod reflect {
FromCompilation, ShaderCompilerOutput,
};
pub use librashader_reflect::front::{SpirvCompilation, Glslang, ShaderReflectObject };
pub use librashader_reflect::front::{Glslang, ShaderReflectObject, SpirvCompilation};
/// Reflection via SPIRV-Cross.
#[cfg(feature = "reflect-cross")]
@ -196,8 +196,8 @@ pub mod reflect {
#[cfg(feature = "reflect-naga")]
#[doc(cfg(feature = "reflect-naga"))]
pub mod naga {
pub use librashader_reflect::reflect::naga::Naga;
pub use librashader_reflect::back::wgsl::NagaWgslContext;
pub use librashader_reflect::reflect::naga::Naga;
pub use librashader_reflect::reflect::naga::NagaLoweringOptions;
}