From aa8f600aa8f3a34b9e5bf2954d18d9ac683875dc Mon Sep 17 00:00:00 2001 From: Marijn Suijten Date: Sat, 6 May 2023 20:39:57 +0200 Subject: [PATCH] Generate generic builder setters for fields taking an `objecttype` (#724) Generate templated builder setters for fields taking an `objecttype` We already do this for hand-written extension functions but can now also implement it for setters since `vk_parse` fields are available within the builder generator code: when a field refers to another field for setting its `objecttype`, that `VkObjectType` field setter is omitted and instead assigned when the object is set, based on a type generic that implements the `Handle` trait instead of an untyped `u64`. --- Changelog.md | 1 + ash/src/vk/definitions.rs | 30 ++--- generator/src/lib.rs | 237 +++++++++++++++++++++++--------------- 3 files changed, 156 insertions(+), 112 deletions(-) diff --git a/Changelog.md b/Changelog.md index 387c5bd..ee6d071 100644 --- a/Changelog.md +++ b/Changelog.md @@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Inlined struct setters (#602) - Bumped MSRV from 1.59 to 1.60 (#709) - Replaced `const fn name()` with associated `NAME` constants (#715) +- Generic builders now automatically set `objecttype` to `::ObjectType` (#724) - extensions/khr: Take the remaining `p_next`-containing structs as `&mut` to allow chains (#744) - `AccelerationStructure::get_acceleration_structure_build_sizes()` - `ExternalMemoryFd::get_memory_fd_properties()` diff --git a/ash/src/vk/definitions.rs b/ash/src/vk/definitions.rs index b3845bb..568b994 100644 --- a/ash/src/vk/definitions.rs +++ b/ash/src/vk/definitions.rs @@ -17852,13 +17852,9 @@ unsafe impl<'a> TaggedStructure for DebugUtilsObjectNameInfoEXT<'a> { unsafe impl ExtendsPipelineShaderStageCreateInfo for DebugUtilsObjectNameInfoEXT<'_> {} impl<'a> DebugUtilsObjectNameInfoEXT<'a> { #[inline] - pub fn object_type(mut self, object_type: ObjectType) -> Self { - self.object_type = object_type; - self - } - #[inline] - pub fn object_handle(mut self, object_handle: u64) -> Self { - self.object_handle = object_handle; + pub fn object_handle(mut self, object_handle: T) -> Self { + self.object_handle = object_handle.as_raw(); + self.object_type = T::TYPE; self } #[inline] @@ -17901,13 +17897,9 @@ unsafe impl<'a> TaggedStructure for DebugUtilsObjectTagInfoEXT<'a> { } impl<'a> DebugUtilsObjectTagInfoEXT<'a> { #[inline] - pub fn object_type(mut self, object_type: ObjectType) -> Self { - self.object_type = object_type; - self - } - #[inline] - pub fn object_handle(mut self, object_handle: u64) -> Self { - self.object_handle = object_handle; + pub fn object_handle(mut self, object_handle: T) -> Self { + self.object_handle = object_handle.as_raw(); + self.object_type = T::TYPE; self } #[inline] @@ -18295,13 +18287,9 @@ impl<'a> DeviceMemoryReportCallbackDataEXT<'a> { self } #[inline] - pub fn object_type(mut self, object_type: ObjectType) -> Self { - self.object_type = object_type; - self - } - #[inline] - pub fn object_handle(mut self, object_handle: u64) -> Self { - self.object_handle = object_handle; + pub fn object_handle(mut self, object_handle: T) -> Self { + self.object_handle = object_handle.as_raw(); + self.object_type = T::TYPE; self } #[inline] diff --git a/generator/src/lib.rs b/generator/src/lib.rs index bebc213..e7015f3 100644 --- a/generator/src/lib.rs +++ b/generator/src/lib.rs @@ -734,7 +734,7 @@ fn discard_outmost_delimiter(stream: TokenStream) -> TokenStream { impl FieldExt for vkxml::Field { fn param_ident(&self) -> Ident { - let name = self.name.as_deref().unwrap_or("field"); + let name = self.name.as_deref().unwrap(); let name_corrected = match name { "type" => "ty", _ => name, @@ -1637,7 +1637,7 @@ pub fn generate_enum<'a>( } } -pub fn generate_result(ident: Ident, enum_: &vk_parse::Enums) -> TokenStream { +fn generate_result(ident: Ident, enum_: &vk_parse::Enums) -> TokenStream { let notation = enum_.children.iter().filter_map(|elem| { let (variant_name, notation) = match elem { vk_parse::EnumsChild::Enum(constant) => ( @@ -1678,9 +1678,10 @@ pub fn generate_result(ident: Ident, enum_: &vk_parse::Enums) -> TokenStream { fn is_static_array(field: &vkxml::Field) -> bool { matches!(field.array, Some(vkxml::ArrayType::Static)) } -pub fn derive_default( + +fn derive_default( struct_: &vkxml::Struct, - members: &[(&vkxml::Field, Option)], + members: &[PreprocessedMember], has_lifetime: bool, ) -> Option { let name = name_to_tokens(&struct_.name); @@ -1703,42 +1704,39 @@ pub fn derive_default( ]; let contains_ptr = members .iter() - .cloned() - .any(|(field, _)| field.reference.is_some()); - let contains_structure_type = members.iter().map(|(f, _)| *f).any(is_structure_type); - let contains_static_array = members.iter().map(|(f, _)| *f).any(is_static_array); - let contains_deprecated = members.iter().any(|(_, d)| d.is_some()); + .any(|member| member.vkxml_field.reference.is_some()); + let contains_structure_type = members + .iter() + .any(|member| is_structure_type(member.vkxml_field)); + let contains_static_array = members + .iter() + .any(|member| is_static_array(member.vkxml_field)); + let contains_deprecated = members.iter().any(|member| member.deprecated.is_some()); let allow_deprecated = contains_deprecated.then(|| quote!(#[allow(deprecated)])); if !(contains_ptr || contains_structure_type || contains_static_array) { return None; }; - let default_fields = members.iter().map(|(field, _)| { - let param_ident = field.param_ident(); - if is_structure_type(field) { - if field.type_enums.is_some() { - quote! { - #param_ident: Self::STRUCTURE_TYPE - } + let default_fields = members.iter().map(|member| { + let param_ident = member.vkxml_field.param_ident(); + if is_structure_type(member.vkxml_field) { + if member.vkxml_field.type_enums.is_some() { + quote!(#param_ident: Self::STRUCTURE_TYPE) } else { - quote! { - #param_ident: unsafe { ::std::mem::zeroed() } - } + quote!(#param_ident: unsafe { ::std::mem::zeroed() }) } - } else if field.reference.is_some() { - if field.is_const { + } else if member.vkxml_field.reference.is_some() { + if member.vkxml_field.is_const { quote!(#param_ident: ::std::ptr::null()) } else { quote!(#param_ident: ::std::ptr::null_mut()) } - } else if is_static_array(field) || handles.contains(&field.basetype.as_str()) { - quote! { - #param_ident: unsafe { ::std::mem::zeroed() } - } + } else if is_static_array(member.vkxml_field) + || handles.contains(&member.vkxml_field.basetype.as_str()) + { + quote!(#param_ident: unsafe { ::std::mem::zeroed() }) } else { - let ty = field.type_tokens(false); - quote! { - #param_ident: #ty::default() - } + let ty = member.vkxml_field.type_tokens(false); + quote!(#param_ident: #ty::default()) } }); let lifetime = has_lifetime.then(|| quote!(<'_>)); @@ -1759,15 +1757,17 @@ pub fn derive_default( }; Some(q) } -pub fn derive_debug( + +fn derive_debug( struct_: &vkxml::Struct, - members: &[(&vkxml::Field, Option)], + members: &[PreprocessedMember], union_types: &HashSet<&str>, has_lifetime: bool, ) -> Option { let name = name_to_tokens(&struct_.name); - let contains_pfn = members.iter().any(|(field, _)| { - field + let contains_pfn = members.iter().any(|member| { + member + .vkxml_field .name .as_ref() .map(|n| n.contains("pfn")) @@ -1775,14 +1775,15 @@ pub fn derive_debug( }); let contains_static_array = members .iter() - .any(|(x, _)| is_static_array(x) && x.basetype == "char"); + .any(|member| is_static_array(member.vkxml_field) && member.vkxml_field.basetype == "char"); let contains_union = members .iter() - .any(|(field, _)| union_types.contains(field.basetype.as_str())); + .any(|member| union_types.contains(member.vkxml_field.basetype.as_str())); if !(contains_union || contains_static_array || contains_pfn) { return None; } - let debug_fields = members.iter().map(|(field, _)| { + let debug_fields = members.iter().map(|member| { + let field = &member.vkxml_field; let param_ident = field.param_ident(); let param_str = param_ident.to_string(); let debug_value = if is_static_array(field) && field.basetype == "char" { @@ -1821,9 +1822,9 @@ pub fn derive_debug( Some(q) } -pub fn derive_setters( +fn derive_setters( struct_: &vkxml::Struct, - members: &[(&vkxml::Field, Option)], + members: &[PreprocessedMember], root_structs: &HashSet, has_lifetimes: &HashSet, ) -> Option { @@ -1839,11 +1840,11 @@ pub fn derive_setters( let next_field = members .iter() - .find(|(field, _)| field.param_ident() == "p_next"); + .find(|member| member.vkxml_field.param_ident() == "p_next"); let structure_type_field = members .iter() - .find(|(field, _)| field.param_ident() == "s_type"); + .find(|member| member.vkxml_field.param_ident() == "s_type"); // Must either have both, or none: assert_eq!(next_field.is_some(), structure_type_field.is_some()); @@ -1858,9 +1859,11 @@ pub fn derive_setters( // No ImageView attachments when VK_FRAMEBUFFER_CREATE_IMAGELESS_BIT is set ("VkFramebufferCreateInfo", "attachmentCount"), ]; - let filter_members = members + let skip_members = members .iter() - .filter_map(|(field, _)| { + .filter_map(|member| { + let field = &member.vkxml_field; + // Associated _count members if field.array.is_some() { if let Some(array_size) = &field.size { @@ -1870,12 +1873,32 @@ pub fn derive_setters( } } + if let Some(objecttype) = &member.vk_parse_type_member.objecttype { + let objecttype_field = members + .iter() + .find(|m| m.vkxml_field.name.as_ref().unwrap() == objecttype) + .unwrap(); + // Extensions using this type are deprecated exactly because of the existence of VkObjectType, hence + // there won't be an additional ash trait to support VkDebugReportObjectTypeEXT. + // See also https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VK_EXT_debug_utils.html#_description + if objecttype_field.vkxml_field.basetype != "VkDebugReportObjectTypeEXT" { + return Some(objecttype); + } + } + None }) .collect::>(); - let setters = members.iter().filter_map(|(field, deprecated)| { - let deprecated = deprecated.as_ref().map(|d| quote!(#d #[allow(deprecated)])); + let setters = members.iter().filter_map(|member| { + let field = &member.vkxml_field; + + let name = field.name.as_ref().unwrap(); + if skip_members.contains(&name) { + return None; + } + + let deprecated = member.deprecated.as_ref().map(|d| quote!(#d #[allow(deprecated)])); let param_ident = field.param_ident(); let param_ty_tokens = field.safe_type_tokens(quote!('a), None); @@ -1890,45 +1913,39 @@ pub fn derive_setters( .unwrap_or(¶m_ident_string); let mut param_ident_short = format_ident!("{}", param_ident_short); - if let Some(name) = field.name.as_ref() { - if filter_members.contains(&name) { - return None; - } + // Unique cases + if struct_.name == "VkShaderModuleCreateInfo" && name == "codeSize" { + return None; + } - // Unique cases - if struct_.name == "VkShaderModuleCreateInfo" && name == "codeSize" { - return None; - } + if struct_.name == "VkShaderModuleCreateInfo" && name == "pCode" { + return Some(quote! { + #[inline] + pub fn code(mut self, code: &'a [u32]) -> Self { + self.code_size = code.len() * 4; + self.p_code = code.as_ptr(); + self + } + }); + } - if struct_.name == "VkShaderModuleCreateInfo" && name == "pCode" { - return Some(quote!{ - #[inline] - pub fn code(mut self, code: &'a [u32]) -> Self { - self.code_size = code.len() * 4; - self.p_code = code.as_ptr(); - self - } - }); - } - - if name == "pSampleMask" { - return Some(quote!{ - /// Sets `p_sample_mask` to `null` if the slice is empty. The mask will - /// be treated as if it has all bits set to `1`. - /// - /// See - /// for more details. - #[inline] - pub fn sample_mask(mut self, sample_mask: &'a [SampleMask]) -> Self { - self.p_sample_mask = if sample_mask.is_empty() { - std::ptr::null() - } else { - sample_mask.as_ptr() - }; - self - } - }); - } + if name == "pSampleMask" { + return Some(quote! { + /// Sets `p_sample_mask` to `null` if the slice is empty. The mask will + /// be treated as if it has all bits set to `1`. + /// + /// See + /// for more details. + #[inline] + pub fn sample_mask(mut self, sample_mask: &'a [SampleMask]) -> Self { + self.p_sample_mask = if sample_mask.is_empty() { + std::ptr::null() + } else { + sample_mask.as_ptr() + }; + self + } + }); } // TODO: Improve in future when https://github.com/rust-lang/rust/issues/53667 is merged id:6 @@ -1985,9 +2002,9 @@ pub fn derive_setters( let array_size_ident = format_ident!("{}", array_size.to_snake_case()); - let size_field = members.iter().map(|(m, _)| m).find(|m| m.name.as_deref() == Some(array_size)).unwrap(); + let size_field = members.iter().find(|member| member.vkxml_field.name.as_deref() == Some(array_size)).unwrap(); - let cast = if size_field.basetype == "size_t" { + let cast = if size_field.vkxml_field.basetype == "size_t" { quote!() } else { quote!(as _) @@ -2022,6 +2039,29 @@ pub fn derive_setters( }); } + if let Some(objecttype) = &member.vk_parse_type_member.objecttype { + let objecttype_field = members + .iter() + .find(|m| m.vkxml_field.name.as_ref().unwrap() == objecttype) + .unwrap(); + + // Extensions using this type are deprecated exactly because of the existence of VkObjectType, hence + // there won't be an additional ash trait to support VkDebugReportObjectTypeEXT. + if objecttype_field.vkxml_field.basetype != "VkDebugReportObjectTypeEXT" { + let objecttype_ident = format_ident!("{}", objecttype.to_snake_case()); + + return Some(quote!{ + #[inline] + #deprecated + pub fn #param_ident_short(mut self, #param_ident_short: T) -> Self { + self.#param_ident = #param_ident_short.as_raw(); + self.#objecttype_ident = T::TYPE; + self + } + }); + } + }; + let param_ty_tokens = if is_opaque_type(&field.basetype) { // Use raw pointers for void/opaque types field.type_tokens(false) @@ -2048,8 +2088,9 @@ pub fn derive_setters( // The `p_next` field should only be considered if this struct is also a root struct let root_struct_next_field = next_field.filter(|_| root_structs.contains(&name)); - // We only implement a next methods for root structs with a `pnext` field. - let next_function = if let Some((next_field, _)) = root_struct_next_field { + // We only implement a next method for root structs with a `pnext` field. + let next_function = if let Some(next_member) = root_struct_next_field { + let next_field = &next_member.vkxml_field; assert_eq!(next_field.basetype, "void"); let mutability = if next_field.is_const { quote!(const) @@ -2106,8 +2147,9 @@ pub fn derive_setters( quote!(unsafe impl #extends for #name<'_> {}) }); - let impl_structure_type_trait = structure_type_field.map(|(s_type, _)| { - let value = s_type + let impl_structure_type_trait = structure_type_field.map(|member| { + let value = member + .vkxml_field .type_enums .as_deref() .expect("s_type field must have a value in `vk.xml`"); @@ -2148,6 +2190,13 @@ pub fn manual_derives(struct_: &vkxml::Struct) -> TokenStream { _ => quote! {}, } } + +struct PreprocessedMember<'a> { + vkxml_field: &'a vkxml::Field, + vk_parse_type_member: &'a vk_parse::TypeMemberDefinition, + deprecated: Option, +} + pub fn generate_struct( struct_: &vkxml::Struct, vk_parse_types: &HashMap, @@ -2240,7 +2289,7 @@ pub fn generate_struct( matches!(vk_parse_field.api.as_deref(), None | Some(DESIRED_API)) }) .map(|(field, vk_parse_field)| { - let deprecation = vk_parse_field + let deprecated = vk_parse_field .deprecated .as_ref() .map(|deprecated| match deprecated.as_str() { @@ -2250,11 +2299,17 @@ pub fn generate_struct( } x => panic!("Unknown deprecation reason {}", x), }); - (field, deprecation) + PreprocessedMember { + vkxml_field: field, + vk_parse_type_member: vk_parse_field, + deprecated, + } }) .collect::>(); - let params = members.iter().map(|(field, deprecation)| { + let params = members.iter().map(|member| { + let field = &member.vkxml_field; + let deprecated = &member.deprecated; let param_ident = field.param_ident(); let param_ty_tokens = if field.basetype == struct_.name { let pointer = field @@ -2270,7 +2325,7 @@ pub fn generate_struct( quote!(#ty #lifetime) }; - quote!(#deprecation pub #param_ident: #param_ty_tokens) + quote!(#deprecated pub #param_ident: #param_ty_tokens) }); let has_lifetime = has_lifetimes.contains(&name);