Generate generic builder setters for fields taking an objecttype (#724)

Generate templated builder setters for fields taking an `objecttype`

We already do this for hand-written extension functions but can now also
implement it for setters since `vk_parse` fields are available within
the builder generator code: when a field refers to another field for
setting its `objecttype`, that `VkObjectType` field setter is omitted
and instead assigned when the object is set, based on a type generic
that implements the `Handle` trait instead of an untyped `u64`.
This commit is contained in:
Marijn Suijten 2023-05-06 20:39:57 +02:00 committed by GitHub
parent f840977b72
commit aa8f600aa8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 156 additions and 112 deletions

View file

@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Inlined struct setters (#602) - Inlined struct setters (#602)
- Bumped MSRV from 1.59 to 1.60 (#709) - Bumped MSRV from 1.59 to 1.60 (#709)
- Replaced `const fn name()` with associated `NAME` constants (#715) - Replaced `const fn name()` with associated `NAME` constants (#715)
- Generic builders now automatically set `objecttype` to `<T as Handle>::ObjectType` (#724)
- extensions/khr: Take the remaining `p_next`-containing structs as `&mut` to allow chains (#744) - extensions/khr: Take the remaining `p_next`-containing structs as `&mut` to allow chains (#744)
- `AccelerationStructure::get_acceleration_structure_build_sizes()` - `AccelerationStructure::get_acceleration_structure_build_sizes()`
- `ExternalMemoryFd::get_memory_fd_properties()` - `ExternalMemoryFd::get_memory_fd_properties()`

View file

@ -17852,13 +17852,9 @@ unsafe impl<'a> TaggedStructure for DebugUtilsObjectNameInfoEXT<'a> {
unsafe impl ExtendsPipelineShaderStageCreateInfo for DebugUtilsObjectNameInfoEXT<'_> {} unsafe impl ExtendsPipelineShaderStageCreateInfo for DebugUtilsObjectNameInfoEXT<'_> {}
impl<'a> DebugUtilsObjectNameInfoEXT<'a> { impl<'a> DebugUtilsObjectNameInfoEXT<'a> {
#[inline] #[inline]
pub fn object_type(mut self, object_type: ObjectType) -> Self { pub fn object_handle<T: Handle>(mut self, object_handle: T) -> Self {
self.object_type = object_type; self.object_handle = object_handle.as_raw();
self self.object_type = T::TYPE;
}
#[inline]
pub fn object_handle(mut self, object_handle: u64) -> Self {
self.object_handle = object_handle;
self self
} }
#[inline] #[inline]
@ -17901,13 +17897,9 @@ unsafe impl<'a> TaggedStructure for DebugUtilsObjectTagInfoEXT<'a> {
} }
impl<'a> DebugUtilsObjectTagInfoEXT<'a> { impl<'a> DebugUtilsObjectTagInfoEXT<'a> {
#[inline] #[inline]
pub fn object_type(mut self, object_type: ObjectType) -> Self { pub fn object_handle<T: Handle>(mut self, object_handle: T) -> Self {
self.object_type = object_type; self.object_handle = object_handle.as_raw();
self self.object_type = T::TYPE;
}
#[inline]
pub fn object_handle(mut self, object_handle: u64) -> Self {
self.object_handle = object_handle;
self self
} }
#[inline] #[inline]
@ -18295,13 +18287,9 @@ impl<'a> DeviceMemoryReportCallbackDataEXT<'a> {
self self
} }
#[inline] #[inline]
pub fn object_type(mut self, object_type: ObjectType) -> Self { pub fn object_handle<T: Handle>(mut self, object_handle: T) -> Self {
self.object_type = object_type; self.object_handle = object_handle.as_raw();
self self.object_type = T::TYPE;
}
#[inline]
pub fn object_handle(mut self, object_handle: u64) -> Self {
self.object_handle = object_handle;
self self
} }
#[inline] #[inline]

View file

@ -734,7 +734,7 @@ fn discard_outmost_delimiter(stream: TokenStream) -> TokenStream {
impl FieldExt for vkxml::Field { impl FieldExt for vkxml::Field {
fn param_ident(&self) -> Ident { fn param_ident(&self) -> Ident {
let name = self.name.as_deref().unwrap_or("field"); let name = self.name.as_deref().unwrap();
let name_corrected = match name { let name_corrected = match name {
"type" => "ty", "type" => "ty",
_ => name, _ => name,
@ -1637,7 +1637,7 @@ pub fn generate_enum<'a>(
} }
} }
pub fn generate_result(ident: Ident, enum_: &vk_parse::Enums) -> TokenStream { fn generate_result(ident: Ident, enum_: &vk_parse::Enums) -> TokenStream {
let notation = enum_.children.iter().filter_map(|elem| { let notation = enum_.children.iter().filter_map(|elem| {
let (variant_name, notation) = match elem { let (variant_name, notation) = match elem {
vk_parse::EnumsChild::Enum(constant) => ( vk_parse::EnumsChild::Enum(constant) => (
@ -1678,9 +1678,10 @@ pub fn generate_result(ident: Ident, enum_: &vk_parse::Enums) -> TokenStream {
fn is_static_array(field: &vkxml::Field) -> bool { fn is_static_array(field: &vkxml::Field) -> bool {
matches!(field.array, Some(vkxml::ArrayType::Static)) matches!(field.array, Some(vkxml::ArrayType::Static))
} }
pub fn derive_default(
fn derive_default(
struct_: &vkxml::Struct, struct_: &vkxml::Struct,
members: &[(&vkxml::Field, Option<TokenStream>)], members: &[PreprocessedMember],
has_lifetime: bool, has_lifetime: bool,
) -> Option<TokenStream> { ) -> Option<TokenStream> {
let name = name_to_tokens(&struct_.name); let name = name_to_tokens(&struct_.name);
@ -1703,42 +1704,39 @@ pub fn derive_default(
]; ];
let contains_ptr = members let contains_ptr = members
.iter() .iter()
.cloned() .any(|member| member.vkxml_field.reference.is_some());
.any(|(field, _)| field.reference.is_some()); let contains_structure_type = members
let contains_structure_type = members.iter().map(|(f, _)| *f).any(is_structure_type); .iter()
let contains_static_array = members.iter().map(|(f, _)| *f).any(is_static_array); .any(|member| is_structure_type(member.vkxml_field));
let contains_deprecated = members.iter().any(|(_, d)| d.is_some()); let contains_static_array = members
.iter()
.any(|member| is_static_array(member.vkxml_field));
let contains_deprecated = members.iter().any(|member| member.deprecated.is_some());
let allow_deprecated = contains_deprecated.then(|| quote!(#[allow(deprecated)])); let allow_deprecated = contains_deprecated.then(|| quote!(#[allow(deprecated)]));
if !(contains_ptr || contains_structure_type || contains_static_array) { if !(contains_ptr || contains_structure_type || contains_static_array) {
return None; return None;
}; };
let default_fields = members.iter().map(|(field, _)| { let default_fields = members.iter().map(|member| {
let param_ident = field.param_ident(); let param_ident = member.vkxml_field.param_ident();
if is_structure_type(field) { if is_structure_type(member.vkxml_field) {
if field.type_enums.is_some() { if member.vkxml_field.type_enums.is_some() {
quote! { quote!(#param_ident: Self::STRUCTURE_TYPE)
#param_ident: Self::STRUCTURE_TYPE
}
} else { } else {
quote! { quote!(#param_ident: unsafe { ::std::mem::zeroed() })
#param_ident: unsafe { ::std::mem::zeroed() }
} }
} } else if member.vkxml_field.reference.is_some() {
} else if field.reference.is_some() { if member.vkxml_field.is_const {
if field.is_const {
quote!(#param_ident: ::std::ptr::null()) quote!(#param_ident: ::std::ptr::null())
} else { } else {
quote!(#param_ident: ::std::ptr::null_mut()) quote!(#param_ident: ::std::ptr::null_mut())
} }
} else if is_static_array(field) || handles.contains(&field.basetype.as_str()) { } else if is_static_array(member.vkxml_field)
quote! { || handles.contains(&member.vkxml_field.basetype.as_str())
#param_ident: unsafe { ::std::mem::zeroed() } {
} quote!(#param_ident: unsafe { ::std::mem::zeroed() })
} else { } else {
let ty = field.type_tokens(false); let ty = member.vkxml_field.type_tokens(false);
quote! { quote!(#param_ident: #ty::default())
#param_ident: #ty::default()
}
} }
}); });
let lifetime = has_lifetime.then(|| quote!(<'_>)); let lifetime = has_lifetime.then(|| quote!(<'_>));
@ -1759,15 +1757,17 @@ pub fn derive_default(
}; };
Some(q) Some(q)
} }
pub fn derive_debug(
fn derive_debug(
struct_: &vkxml::Struct, struct_: &vkxml::Struct,
members: &[(&vkxml::Field, Option<TokenStream>)], members: &[PreprocessedMember],
union_types: &HashSet<&str>, union_types: &HashSet<&str>,
has_lifetime: bool, has_lifetime: bool,
) -> Option<TokenStream> { ) -> Option<TokenStream> {
let name = name_to_tokens(&struct_.name); let name = name_to_tokens(&struct_.name);
let contains_pfn = members.iter().any(|(field, _)| { let contains_pfn = members.iter().any(|member| {
field member
.vkxml_field
.name .name
.as_ref() .as_ref()
.map(|n| n.contains("pfn")) .map(|n| n.contains("pfn"))
@ -1775,14 +1775,15 @@ pub fn derive_debug(
}); });
let contains_static_array = members let contains_static_array = members
.iter() .iter()
.any(|(x, _)| is_static_array(x) && x.basetype == "char"); .any(|member| is_static_array(member.vkxml_field) && member.vkxml_field.basetype == "char");
let contains_union = members let contains_union = members
.iter() .iter()
.any(|(field, _)| union_types.contains(field.basetype.as_str())); .any(|member| union_types.contains(member.vkxml_field.basetype.as_str()));
if !(contains_union || contains_static_array || contains_pfn) { if !(contains_union || contains_static_array || contains_pfn) {
return None; return None;
} }
let debug_fields = members.iter().map(|(field, _)| { let debug_fields = members.iter().map(|member| {
let field = &member.vkxml_field;
let param_ident = field.param_ident(); let param_ident = field.param_ident();
let param_str = param_ident.to_string(); let param_str = param_ident.to_string();
let debug_value = if is_static_array(field) && field.basetype == "char" { let debug_value = if is_static_array(field) && field.basetype == "char" {
@ -1821,9 +1822,9 @@ pub fn derive_debug(
Some(q) Some(q)
} }
pub fn derive_setters( fn derive_setters(
struct_: &vkxml::Struct, struct_: &vkxml::Struct,
members: &[(&vkxml::Field, Option<TokenStream>)], members: &[PreprocessedMember],
root_structs: &HashSet<Ident>, root_structs: &HashSet<Ident>,
has_lifetimes: &HashSet<Ident>, has_lifetimes: &HashSet<Ident>,
) -> Option<TokenStream> { ) -> Option<TokenStream> {
@ -1839,11 +1840,11 @@ pub fn derive_setters(
let next_field = members let next_field = members
.iter() .iter()
.find(|(field, _)| field.param_ident() == "p_next"); .find(|member| member.vkxml_field.param_ident() == "p_next");
let structure_type_field = members let structure_type_field = members
.iter() .iter()
.find(|(field, _)| field.param_ident() == "s_type"); .find(|member| member.vkxml_field.param_ident() == "s_type");
// Must either have both, or none: // Must either have both, or none:
assert_eq!(next_field.is_some(), structure_type_field.is_some()); assert_eq!(next_field.is_some(), structure_type_field.is_some());
@ -1858,9 +1859,11 @@ pub fn derive_setters(
// No ImageView attachments when VK_FRAMEBUFFER_CREATE_IMAGELESS_BIT is set // No ImageView attachments when VK_FRAMEBUFFER_CREATE_IMAGELESS_BIT is set
("VkFramebufferCreateInfo", "attachmentCount"), ("VkFramebufferCreateInfo", "attachmentCount"),
]; ];
let filter_members = members let skip_members = members
.iter() .iter()
.filter_map(|(field, _)| { .filter_map(|member| {
let field = &member.vkxml_field;
// Associated _count members // Associated _count members
if field.array.is_some() { if field.array.is_some() {
if let Some(array_size) = &field.size { if let Some(array_size) = &field.size {
@ -1870,12 +1873,32 @@ pub fn derive_setters(
} }
} }
if let Some(objecttype) = &member.vk_parse_type_member.objecttype {
let objecttype_field = members
.iter()
.find(|m| m.vkxml_field.name.as_ref().unwrap() == objecttype)
.unwrap();
// Extensions using this type are deprecated exactly because of the existence of VkObjectType, hence
// there won't be an additional ash trait to support VkDebugReportObjectTypeEXT.
// See also https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VK_EXT_debug_utils.html#_description
if objecttype_field.vkxml_field.basetype != "VkDebugReportObjectTypeEXT" {
return Some(objecttype);
}
}
None None
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let setters = members.iter().filter_map(|(field, deprecated)| { let setters = members.iter().filter_map(|member| {
let deprecated = deprecated.as_ref().map(|d| quote!(#d #[allow(deprecated)])); let field = &member.vkxml_field;
let name = field.name.as_ref().unwrap();
if skip_members.contains(&name) {
return None;
}
let deprecated = member.deprecated.as_ref().map(|d| quote!(#d #[allow(deprecated)]));
let param_ident = field.param_ident(); let param_ident = field.param_ident();
let param_ty_tokens = field.safe_type_tokens(quote!('a), None); let param_ty_tokens = field.safe_type_tokens(quote!('a), None);
@ -1890,11 +1913,6 @@ pub fn derive_setters(
.unwrap_or(&param_ident_string); .unwrap_or(&param_ident_string);
let mut param_ident_short = format_ident!("{}", param_ident_short); let mut param_ident_short = format_ident!("{}", param_ident_short);
if let Some(name) = field.name.as_ref() {
if filter_members.contains(&name) {
return None;
}
// Unique cases // Unique cases
if struct_.name == "VkShaderModuleCreateInfo" && name == "codeSize" { if struct_.name == "VkShaderModuleCreateInfo" && name == "codeSize" {
return None; return None;
@ -1929,7 +1947,6 @@ pub fn derive_setters(
} }
}); });
} }
}
// TODO: Improve in future when https://github.com/rust-lang/rust/issues/53667 is merged id:6 // TODO: Improve in future when https://github.com/rust-lang/rust/issues/53667 is merged id:6
if field.reference.is_some() { if field.reference.is_some() {
@ -1985,9 +2002,9 @@ pub fn derive_setters(
let array_size_ident = format_ident!("{}", array_size.to_snake_case()); let array_size_ident = format_ident!("{}", array_size.to_snake_case());
let size_field = members.iter().map(|(m, _)| m).find(|m| m.name.as_deref() == Some(array_size)).unwrap(); let size_field = members.iter().find(|member| member.vkxml_field.name.as_deref() == Some(array_size)).unwrap();
let cast = if size_field.basetype == "size_t" { let cast = if size_field.vkxml_field.basetype == "size_t" {
quote!() quote!()
} else { } else {
quote!(as _) quote!(as _)
@ -2022,6 +2039,29 @@ pub fn derive_setters(
}); });
} }
if let Some(objecttype) = &member.vk_parse_type_member.objecttype {
let objecttype_field = members
.iter()
.find(|m| m.vkxml_field.name.as_ref().unwrap() == objecttype)
.unwrap();
// Extensions using this type are deprecated exactly because of the existence of VkObjectType, hence
// there won't be an additional ash trait to support VkDebugReportObjectTypeEXT.
if objecttype_field.vkxml_field.basetype != "VkDebugReportObjectTypeEXT" {
let objecttype_ident = format_ident!("{}", objecttype.to_snake_case());
return Some(quote!{
#[inline]
#deprecated
pub fn #param_ident_short<T: Handle>(mut self, #param_ident_short: T) -> Self {
self.#param_ident = #param_ident_short.as_raw();
self.#objecttype_ident = T::TYPE;
self
}
});
}
};
let param_ty_tokens = if is_opaque_type(&field.basetype) { let param_ty_tokens = if is_opaque_type(&field.basetype) {
// Use raw pointers for void/opaque types // Use raw pointers for void/opaque types
field.type_tokens(false) field.type_tokens(false)
@ -2048,8 +2088,9 @@ pub fn derive_setters(
// The `p_next` field should only be considered if this struct is also a root struct // The `p_next` field should only be considered if this struct is also a root struct
let root_struct_next_field = next_field.filter(|_| root_structs.contains(&name)); let root_struct_next_field = next_field.filter(|_| root_structs.contains(&name));
// We only implement a next methods for root structs with a `pnext` field. // We only implement a next method for root structs with a `pnext` field.
let next_function = if let Some((next_field, _)) = root_struct_next_field { let next_function = if let Some(next_member) = root_struct_next_field {
let next_field = &next_member.vkxml_field;
assert_eq!(next_field.basetype, "void"); assert_eq!(next_field.basetype, "void");
let mutability = if next_field.is_const { let mutability = if next_field.is_const {
quote!(const) quote!(const)
@ -2106,8 +2147,9 @@ pub fn derive_setters(
quote!(unsafe impl #extends for #name<'_> {}) quote!(unsafe impl #extends for #name<'_> {})
}); });
let impl_structure_type_trait = structure_type_field.map(|(s_type, _)| { let impl_structure_type_trait = structure_type_field.map(|member| {
let value = s_type let value = member
.vkxml_field
.type_enums .type_enums
.as_deref() .as_deref()
.expect("s_type field must have a value in `vk.xml`"); .expect("s_type field must have a value in `vk.xml`");
@ -2148,6 +2190,13 @@ pub fn manual_derives(struct_: &vkxml::Struct) -> TokenStream {
_ => quote! {}, _ => quote! {},
} }
} }
struct PreprocessedMember<'a> {
vkxml_field: &'a vkxml::Field,
vk_parse_type_member: &'a vk_parse::TypeMemberDefinition,
deprecated: Option<TokenStream>,
}
pub fn generate_struct( pub fn generate_struct(
struct_: &vkxml::Struct, struct_: &vkxml::Struct,
vk_parse_types: &HashMap<String, &vk_parse::Type>, vk_parse_types: &HashMap<String, &vk_parse::Type>,
@ -2240,7 +2289,7 @@ pub fn generate_struct(
matches!(vk_parse_field.api.as_deref(), None | Some(DESIRED_API)) matches!(vk_parse_field.api.as_deref(), None | Some(DESIRED_API))
}) })
.map(|(field, vk_parse_field)| { .map(|(field, vk_parse_field)| {
let deprecation = vk_parse_field let deprecated = vk_parse_field
.deprecated .deprecated
.as_ref() .as_ref()
.map(|deprecated| match deprecated.as_str() { .map(|deprecated| match deprecated.as_str() {
@ -2250,11 +2299,17 @@ pub fn generate_struct(
} }
x => panic!("Unknown deprecation reason {}", x), x => panic!("Unknown deprecation reason {}", x),
}); });
(field, deprecation) PreprocessedMember {
vkxml_field: field,
vk_parse_type_member: vk_parse_field,
deprecated,
}
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let params = members.iter().map(|(field, deprecation)| { let params = members.iter().map(|member| {
let field = &member.vkxml_field;
let deprecated = &member.deprecated;
let param_ident = field.param_ident(); let param_ident = field.param_ident();
let param_ty_tokens = if field.basetype == struct_.name { let param_ty_tokens = if field.basetype == struct_.name {
let pointer = field let pointer = field
@ -2270,7 +2325,7 @@ pub fn generate_struct(
quote!(#ty #lifetime) quote!(#ty #lifetime)
}; };
quote!(#deprecation pub #param_ident: #param_ty_tokens) quote!(#deprecated pub #param_ident: #param_ty_tokens)
}); });
let has_lifetime = has_lifetimes.contains(&name); let has_lifetime = has_lifetimes.contains(&name);