diff --git a/generator/src/lib.rs b/generator/src/lib.rs index ae7e00b..12158b3 100644 --- a/generator/src/lib.rs +++ b/generator/src/lib.rs @@ -92,12 +92,17 @@ pub fn define_handle_macro() -> Tokens { pub struct $name{ ptr: *mut u8 } + impl Default for $name{ + fn default() -> $name { + $name::null() + } + } unsafe impl Send for $name {} unsafe impl Sync for $name {} impl $name{ - pub unsafe fn null() -> Self{ + pub fn null() -> Self{ $name{ ptr: ::std::ptr::null_mut() } @@ -113,7 +118,7 @@ pub fn handle_nondispatchable_macro() -> Tokens { macro_rules! handle_nondispatchable { ($name: ident) => { #[repr(C)] - #[derive(Eq, PartialEq, Ord, PartialOrd, Clone, Copy, Hash)] + #[derive(Eq, PartialEq, Ord, PartialOrd, Clone, Copy, Hash, Default)] pub struct $name (uint64_t); impl $name{ @@ -757,11 +762,12 @@ impl<'a> ConstantExt for ExtensionConstant<'a> { } pub fn generate_extension_constants<'a>( - extension: &'a vk_parse::Extension, + extension_name: &str, + extension_number: i64, + extension_items: &'a [vk_parse::ExtensionItem], const_cache: &mut HashSet<&'a str>, ) -> quote::Tokens { - let items = extension - .items + let items = extension_items .iter() .filter_map(|item| match item { vk_parse::ExtensionItem::Require { items, .. } => Some(items.iter()), @@ -788,12 +794,13 @@ pub fn generate_extension_constants<'a>( } => { let ext_base = 1_000_000_000; let ext_block_size = 1000; - let extnumber = extnumber.unwrap_or_else(|| extension.number.expect("number")); + let extnumber = extnumber.unwrap_or_else(|| extension_number); let value = ext_base + (extnumber - 1) * ext_block_size + offset; Some((Constant::Number(value as i32), Some(extends.clone()))) } _ => None, }?; + let extends = extends?; let ext_constant = ExtensionConstant { name: &_enum.name, @@ -801,7 +808,7 @@ pub fn generate_extension_constants<'a>( }; let ident = name_to_tokens(&extends); let impl_block = bitflags_impl_block(ident, &extends, &[&ext_constant]); - let doc_string = format!("Generated from '{}'", extension.name); + let doc_string = format!("Generated from '{}'", extension_name); let q = quote!{ #[doc = #doc_string] #impl_block @@ -817,11 +824,11 @@ pub fn generate_extension_constants<'a>( } } pub fn generate_extension_commands( - extension: &vk_parse::Extension, + extension_name: &str, + items: &[vk_parse::ExtensionItem], cmd_map: &CommandMap, ) -> Tokens { - let commands = extension - .items + let commands = items .iter() .filter_map(|ext_item| match ext_item { vk_parse::ExtensionItem::Require { items, .. } => { @@ -834,7 +841,7 @@ pub fn generate_extension_commands( }) .flat_map(|iter| iter) .collect_vec(); - let name = format!("{}Fn", extension.name.to_camel_case()); + let name = format!("{}Fn", extension_name.to_camel_case()); let ident = Ident::from(&name[2..]); generate_function_pointers(ident, &commands) } @@ -843,12 +850,17 @@ pub fn generate_extension<'a>( cmd_map: &CommandMap, const_cache: &mut HashSet<&'a str>, ) -> Option { - let _ = extension - .supported - .as_ref() - .filter(|s| s.as_str() == "vulkan")?; - let extension_tokens = generate_extension_constants(extension, const_cache); - let fp = generate_extension_commands(extension, cmd_map); + // let _ = extension + // .supported + // .as_ref() + // .filter(|s| s.as_str() == "vulkan")?; + let extension_tokens = generate_extension_constants( + &extension.name, + extension.number.unwrap_or(0), + &extension.items, + const_cache, + ); + let fp = generate_extension_commands(&extension.name, &extension.items, cmd_map); let q = quote!{ #fp #extension_tokens @@ -983,7 +995,7 @@ pub fn generate_enum<'a>( } else { let impl_block = bitflags_impl_block(ident, &_enum.name, &constants); let enum_quote = quote!{ - #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] + #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)] #[repr(C)] pub struct #ident(pub(crate) i32); #impl_block @@ -1050,6 +1062,70 @@ fn is_static_array(field: &vkxml::Field) -> bool { }) .unwrap_or(false) } +pub fn derive_default(_struct: &vkxml::Struct, union_types: &HashSet<&str>) -> Option { + let name = name_to_tokens(&_struct.name); + let members = _struct.elements.iter().filter_map(|elem| match *elem { + vkxml::StructElement::Member(ref field) => Some(field), + _ => None, + }); + let is_structure_type = |field: &vkxml::Field| field.basetype == "VkStructureType"; + let is_pfn = |field: &vkxml::Field| { + field + .name + .as_ref() + .map(|n| n.contains("pfn")) + .unwrap_or(false) + }; + let contains_pfn = members.clone().any(is_pfn); + + let contains_ptr = members.clone().any(|field| field.reference.is_some()); + let contains_strucutre_type = members.clone().any(is_structure_type); + let contains_static_array = members.clone().any(is_static_array); + if !(contains_ptr || contains_pfn || contains_strucutre_type || contains_static_array) { + return None; + }; + let default_fields = members.clone().map(|field| { + let param_ident = field.param_ident(); + if is_structure_type(field) { + let ty = field + .type_enums + .as_ref() + .and_then(|ty| ty.split(',').nth(0)); + if let Some(variant) = ty { + let variant_ident = variant_ident("VkStructureType", variant); + + quote!{ + #param_ident: StructureType::#variant_ident + } + } else { + quote!{ + #param_ident: unsafe { ::std::mem::zeroed() } + } + } + } else if field.reference.is_some() || is_static_array(field) || is_pfn(field) { + quote!{ + #param_ident: unsafe { ::std::mem::zeroed() } + } + } else { + let ty = field.type_tokens(); + quote!{ + #param_ident: #ty::default() + } + } + }); + let q = quote!{ + impl ::std::default::Default for #name { + fn default() -> #name { + #name { + #( + #default_fields + ),* + } + } + } + }; + Some(q) +} pub fn derive_debug(_struct: &vkxml::Struct, union_types: &HashSet<&str>) -> Option { let name = name_to_tokens(&_struct.name); let members = _struct.elements.iter().filter_map(|elem| match *elem { @@ -1120,14 +1196,25 @@ pub fn generate_struct(_struct: &vkxml::Struct, union_types: &HashSet<&str>) -> }); let debug_tokens = derive_debug(_struct, union_types); - let dbg_str = if debug_tokens.is_none() { quote!(Debug,) } else {quote!()}; + let default_tokens = derive_default(_struct, union_types); + let dbg_str = if debug_tokens.is_none() { + quote!(Debug,) + } else { + quote!() + }; + let default_str = if default_tokens.is_none() { + quote!(Default,) + } else { + quote!() + }; quote!{ #[repr(C)] - #[derive(Copy, Clone, #dbg_str)] + #[derive(Copy, Clone, #default_str #dbg_str)] pub struct #name { #(#params,)* } #debug_tokens + #default_tokens } } @@ -1184,6 +1271,11 @@ fn generate_union(union: &vkxml::Union) -> Tokens { pub union #name { #(#fields),* } + impl ::std::default::Default for #name { + fn default() -> #name { + unsafe { ::std::mem::zeroed() } + } + } } } pub fn generate_definition( @@ -1296,8 +1388,13 @@ pub fn write_source_code(path: &Path) { }) .nth(0) .expect("extension"); + spec2.0.iter().for_each(|item| match item { + vk_parse::RegistryItem::Enums { name, .. } => println!("{:?}", name), + _ => (), + }); let spec = vk_parse::parse_file_as_vkxml(path); + //println!("{:#?}", spec); let commands: HashMap = spec .elements .iter()