From 3c7615b3bb11099e982e5248228c2b77aca17352 Mon Sep 17 00:00:00 2001 From: Steve Wooster Date: Sat, 26 Feb 2022 15:03:54 -0800 Subject: [PATCH] Iterator simplification (#565) * Unnest iterators This hopefully makes the iterator definitions better resemble paths into the XML tree. * Use for-loop instead of .for_each() * Use elems.contains(x) instead of elems.iter().any(...) * Shrink commands-related .fold() Co-authored-by: Steve Wooster --- generator/src/lib.rs | 280 +++++++++++++++++-------------------------- 1 file changed, 108 insertions(+), 172 deletions(-) diff --git a/generator/src/lib.rs b/generator/src/lib.rs index 17930ca..cc37b5a 100644 --- a/generator/src/lib.rs +++ b/generator/src/lib.rs @@ -18,6 +18,21 @@ use std::fmt::Display; use std::path::Path; use syn::Ident; +macro_rules! get_variant { + ($variant:path) => { + |enum_| match enum_ { + $variant(inner) => Some(inner), + _ => None, + } + }; + ($variant:path { $($member:ident),+ }) => { + |enum_| match enum_ { + $variant { $($member),+, .. } => Some(( $($member),+ )), + _ => None, + } + }; +} + const BACKWARDS_COMPATIBLE_ALIAS_COMMENT: &str = "Backwards-compatible alias containing a typo"; pub trait ExtensionExt {} @@ -912,10 +927,7 @@ pub fn generate_extension_constants<'a>( ) -> TokenStream { let items = extension_items .iter() - .filter_map(|item| match item { - vk_parse::ExtensionChild::Require { items, .. } => Some(items.iter()), - _ => None, - }) + .filter_map(get_variant!(vk_parse::ExtensionChild::Require { items })) .flatten(); let mut extended_enums = BTreeMap::>::new(); @@ -984,29 +996,22 @@ pub fn generate_extension_commands<'a>( ) -> TokenStream { let mut commands = Vec::new(); let mut aliases = HashMap::new(); - items + let names = items .iter() - .filter_map(|ext_item| match ext_item { - vk_parse::ExtensionChild::Require { items, .. } => { - Some(items.iter().filter_map(|item| match item { - vk_parse::InterfaceItem::Command { ref name, .. } => Some(name), - _ => None, - })) - } - _ => None, - }) + .filter_map(get_variant!(vk_parse::ExtensionChild::Require { items })) .flatten() - .for_each(|name| { - if let Some(cmd) = cmd_map.get(name).copied() { - commands.push(cmd); - } else if let Some(cmd) = cmd_aliases - .get(name) - .and_then(|alias_name| cmd_map.get(alias_name).copied()) - { - aliases.insert(cmd.name.clone(), name.to_string()); - commands.push(cmd); - } - }); + .filter_map(get_variant!(vk_parse::InterfaceItem::Command { name })); + for name in names { + if let Some(cmd) = cmd_map.get(name).copied() { + commands.push(cmd); + } else if let Some(cmd) = cmd_aliases + .get(name) + .and_then(|alias_name| cmd_map.get(alias_name).copied()) + { + aliases.insert(cmd.name.clone(), name.to_string()); + commands.push(cmd); + } + } let ident = format_ident!( "{}Fn", @@ -1016,17 +1021,10 @@ pub fn generate_extension_commands<'a>( let spec_version = items .iter() - .find_map(|ext_item| match ext_item { - vk_parse::ExtensionChild::Require { items, .. } => { - items.iter().find_map(|item| match item { - vk_parse::InterfaceItem::Enum(ref e) if e.name.contains("SPEC_VERSION") => { - Some(e) - } - _ => None, - }) - } - _ => None, - }) + .filter_map(get_variant!(vk_parse::ExtensionChild::Require { items })) + .flatten() + .filter_map(get_variant!(vk_parse::InterfaceItem::Enum)) + .find(|e| e.name.contains("SPEC_VERSION")) .and_then(|e| { if let vk_parse::EnumSpec::Value { value, .. } = &e.spec { let v: u32 = str::parse(value).unwrap(); @@ -1283,10 +1281,7 @@ pub fn generate_enum<'a>( let constants = enum_ .children .iter() - .filter_map(|elem| match *elem { - vk_parse::EnumsChild::Enum(ref constant) => Some(constant), - _ => None, - }) + .filter_map(get_variant!(vk_parse::EnumsChild::Enum)) .filter(|constant| constant.notation() != Some(BACKWARDS_COMPATIBLE_ALIAS_COMMENT)) .collect_vec(); @@ -1407,10 +1402,10 @@ fn is_static_array(field: &vkxml::Field) -> bool { } pub fn derive_default(_struct: &vkxml::Struct) -> Option { let name = name_to_tokens(&_struct.name); - let members = _struct.elements.iter().filter_map(|elem| match *elem { - vkxml::StructElement::Member(ref field) => Some(field), - _ => None, - }); + let members = _struct + .elements + .iter() + .filter_map(get_variant!(vkxml::StructElement::Member)); let is_structure_type = |field: &vkxml::Field| field.basetype == "VkStructureType"; // This are also pointers, and therefor also don't implement Default. The spec @@ -1472,10 +1467,10 @@ pub fn derive_default(_struct: &vkxml::Struct) -> Option { } pub fn derive_debug(_struct: &vkxml::Struct, union_types: &HashSet<&str>) -> Option { let name = name_to_tokens(&_struct.name); - let members = _struct.elements.iter().filter_map(|elem| match *elem { - vkxml::StructElement::Member(ref field) => Some(field), - _ => None, - }); + let members = _struct + .elements + .iter() + .filter_map(get_variant!(vkxml::StructElement::Member)); let contains_pfn = members.clone().any(|field| { field .name @@ -1545,19 +1540,19 @@ pub fn derive_setters( let name = name_to_tokens(&struct_.name); let name_builder = name_to_tokens(&(struct_.name.clone() + "Builder")); - let members = struct_.elements.iter().filter_map(|elem| match *elem { - vkxml::StructElement::Member(ref field) => Some(field), - _ => None, - }); + let members = struct_ + .elements + .iter() + .filter_map(get_variant!(vkxml::StructElement::Member)); let next_field = members .clone() .find(|field| field.param_ident() == "p_next"); let nofilter_count_members = [ - "VkPipelineViewportStateCreateInfo.pViewports", - "VkPipelineViewportStateCreateInfo.pScissors", - "VkDescriptorSetLayoutBinding.pImmutableSamplers", + ("VkPipelineViewportStateCreateInfo", "pViewports"), + ("VkPipelineViewportStateCreateInfo", "pScissors"), + ("VkDescriptorSetLayoutBinding", "pImmutableSamplers"), ]; let filter_members: Vec = members .clone() @@ -1567,10 +1562,7 @@ pub fn derive_setters( // Associated _count members if field.array.is_some() { if let Some(ref array_size) = field.size { - if !nofilter_count_members - .iter() - .any(|&n| n == (struct_.name.clone() + "." + field_name)) - { + if !nofilter_count_members.contains(&(&struct_.name, field_name)) { return Some((*array_size).clone()); } } @@ -1935,10 +1927,10 @@ pub fn generate_struct( }; } - let members = _struct.elements.iter().filter_map(|elem| match *elem { - vkxml::StructElement::Member(ref field) => Some(field), - _ => None, - }); + let members = _struct + .elements + .iter() + .filter_map(get_variant!(vkxml::StructElement::Member)); let params = members.clone().map(|field| { let param_ident = field.param_ident(); @@ -2101,42 +2093,21 @@ pub fn generate_feature<'a>( let (static_commands, entry_commands, device_commands, instance_commands) = feature .elements .iter() - .flat_map(|feature| { - if let vkxml::FeatureElement::Require(ref spec) = feature { - spec.elements - .iter() - .filter_map(|feature_spec| { - if let vkxml::FeatureReference::CommandReference(ref cmd_ref) = feature_spec - { - Some(cmd_ref) - } else { - None - } - }) - .collect() - } else { - vec![] - } - }) + .filter_map(get_variant!(vkxml::FeatureElement::Require)) + .flat_map(|spec| &spec.elements) + .filter_map(get_variant!(vkxml::FeatureReference::CommandReference)) .filter_map(|cmd_ref| commands.get(&cmd_ref.name)) .fold( (Vec::new(), Vec::new(), Vec::new(), Vec::new()), - |mut acc, &cmd_ref| { - match cmd_ref.function_type() { - FunctionType::Static => { - acc.0.push(cmd_ref); - } - FunctionType::Entry => { - acc.1.push(cmd_ref); - } - FunctionType::Device => { - acc.2.push(cmd_ref); - } - FunctionType::Instance => { - acc.3.push(cmd_ref); - } - } - acc + |mut accs, &cmd_ref| { + let acc = match cmd_ref.function_type() { + FunctionType::Static => &mut accs.0, + FunctionType::Entry => &mut accs.1, + FunctionType::Device => &mut accs.2, + FunctionType::Instance => &mut accs.3, + }; + acc.push(cmd_ref); + accs }, ); let version = feature.version_string(); @@ -2206,16 +2177,19 @@ pub fn generate_feature_extension<'a>( const_cache: &mut HashSet<&'a str>, const_values: &mut BTreeMap, ) -> TokenStream { - let constants = registry.0.iter().filter_map(|item| match item { - vk_parse::RegistryChild::Feature(feature) => Some(generate_extension_constants( - &feature.name, - 0, - &feature.children, - const_cache, - const_values, - )), - _ => None, - }); + let constants = registry + .0 + .iter() + .filter_map(get_variant!(vk_parse::RegistryChild::Feature)) + .map(|feature| { + generate_extension_constants( + &feature.name, + 0, + &feature.children, + const_cache, + const_values, + ) + }); quote! { #(#constants)* } @@ -2317,16 +2291,9 @@ pub fn extract_native_types(registry: &vk_parse::Registry) -> (Vec<(String, Stri let types = registry .0 .iter() - .filter_map(|item| match item { - vk_parse::RegistryChild::Types(ref ty) => { - Some(ty.children.iter().filter_map(|child| match child { - vk_parse::TypesChild::Type(ty) => Some(ty), - _ => None, - })) - } - _ => None, - }) - .flatten(); + .filter_map(get_variant!(vk_parse::RegistryChild::Types)) + .flat_map(|ty| &ty.children) + .filter_map(get_variant!(vk_parse::TypesChild::Type)); for ty in types { match ty.category.as_deref() { @@ -2372,11 +2339,10 @@ pub fn generate_aliases_of_types( let aliases = types .children .iter() - .filter_map(|child| match child { - vk_parse::TypesChild::Type(ty) => Some((ty.name.as_ref()?, ty.alias.as_ref()?)), - _ => None, - }) - .filter_map(|(name, alias)| { + .filter_map(get_variant!(vk_parse::TypesChild::Type)) + .filter_map(|ty| { + let name = ty.name.as_ref()?; + let alias = ty.alias.as_ref()?; let name_ident = name_to_tokens(name); if !ty_cache.insert(name_ident.clone()) { return None; @@ -2399,80 +2365,54 @@ pub fn write_source_code>(vk_headers_dir: &Path, src_dir: P) { let extensions: &Vec = spec2 .0 .iter() - .find_map(|item| match item { - vk_parse::RegistryChild::Extensions(ref ext) => Some(&ext.children), - _ => None, - }) + .find_map(get_variant!(vk_parse::RegistryChild::Extensions)) + .map(|ext| &ext.children) .expect("extension"); let mut ty_cache = HashSet::new(); let aliases: Vec<_> = spec2 .0 .iter() - .filter_map(|item| match item { - vk_parse::RegistryChild::Types(ref ty) => { - Some(generate_aliases_of_types(ty, &mut ty_cache)) - } - _ => None, - }) + .filter_map(get_variant!(vk_parse::RegistryChild::Types)) + .map(|ty| generate_aliases_of_types(ty, &mut ty_cache)) .collect(); let spec = vk_parse::parse_file_as_vkxml(&vk_xml).expect("Invalid xml file."); let cmd_aliases: HashMap = spec2 .0 .iter() - .filter_map(|item| match item { - vk_parse::RegistryChild::Commands(cmds) => { - let cmd_tuple_iter = cmds.children.iter().filter_map(|cmd| match cmd { - vk_parse::Command::Alias { name, alias } => { - Some((name.to_string(), alias.to_string())) - } - _ => None, - }); - Some(cmd_tuple_iter) - } - _ => None, - }) - .flatten() + .filter_map(get_variant!(vk_parse::RegistryChild::Commands)) + .flat_map(|cmds| &cmds.children) + .filter_map(get_variant!(vk_parse::Command::Alias { name, alias })) + .map(|(name, alias)| (name.to_string(), alias.to_string())) .collect(); let commands: HashMap = spec .elements .iter() - .filter_map(|elem| match elem { - vkxml::RegistryElement::Commands(ref cmds) => Some(cmds), - _ => None, - }) - .flat_map(|cmds| cmds.elements.iter().map(|cmd| (cmd.name.clone(), cmd))) + .filter_map(get_variant!(vkxml::RegistryElement::Commands)) + .flat_map(|cmds| &cmds.elements) + .map(|cmd| (cmd.name.clone(), cmd)) .collect(); let features: Vec<&vkxml::Feature> = spec .elements .iter() - .filter_map(|elem| match elem { - vkxml::RegistryElement::Features(ref features) => Some(features), - _ => None, - }) - .flat_map(|features| features.elements.iter()) + .filter_map(get_variant!(vkxml::RegistryElement::Features)) + .flat_map(|features| &features.elements) .collect(); let definitions: Vec<&vkxml::DefinitionsElement> = spec .elements .iter() - .filter_map(|elem| match elem { - vkxml::RegistryElement::Definitions(ref definitions) => Some(definitions), - _ => None, - }) - .flat_map(|definitions| definitions.elements.iter()) + .filter_map(get_variant!(vkxml::RegistryElement::Definitions)) + .flat_map(|definitions| &definitions.elements) .collect(); let constants: Vec<&vkxml::Constant> = spec .elements .iter() - .filter_map(|elem| match elem { - vkxml::RegistryElement::Constants(ref constants) => Some(constants), - _ => None, - }) - .flat_map(|constants| constants.elements.iter()) + .filter_map(get_variant!(vkxml::RegistryElement::Constants)) + .flat_map(|constants| &constants.elements) .collect(); let mut fn_cache = HashSet::new(); @@ -2484,10 +2424,8 @@ pub fn write_source_code>(vk_headers_dir: &Path, src_dir: P) { let (enum_code, bitflags_code) = spec2 .0 .iter() - .filter_map(|item| match item { - vk_parse::RegistryChild::Enums(ref enums) if enums.kind.is_some() => Some(enums), - _ => None, - }) + .filter_map(get_variant!(vk_parse::RegistryChild::Enums)) + .filter(|enums| enums.kind.is_some()) .map(|e| generate_enum(e, &mut const_cache, &mut const_values, &mut bitflags_cache)) .fold((Vec::new(), Vec::new()), |mut acc, elem| { match elem { @@ -2520,10 +2458,8 @@ pub fn write_source_code>(vk_headers_dir: &Path, src_dir: P) { let union_types = definitions .iter() - .filter_map(|def| match def { - vkxml::DefinitionsElement::Union(ref union) => Some(union.name.as_str()), - _ => None, - }) + .filter_map(get_variant!(vkxml::DefinitionsElement::Union)) + .map(|union_| union_.name.as_str()) .collect::>(); let mut identifier_renames = BTreeMap::new();