//! Generation of Rust derive functions for encoding. use quote::{format_ident, quote}; use crate::layout::{LayoutModule, LayoutTypeDef}; use crate::parse::{GpuScalar, GpuType}; pub fn gen_derive(module: &LayoutModule) -> proc_macro2::TokenStream { let mut ts = proc_macro2::TokenStream::new(); let module_name = format_ident!("{}", module.name); for name in &module.def_names { let def = module.defs.get(name).unwrap(); ts.extend(gen_derive_def(name, def.0.size, &def.1)); } quote! { mod #module_name { pub trait HalfToLeBytes { fn to_le_bytes(&self) -> [u8; 2]; } impl HalfToLeBytes for half::f16 { fn to_le_bytes(&self) -> [u8; 2] { self.to_bits().to_le_bytes() } } #ts } } } fn gen_derive_def(name: &str, size: usize, def: &LayoutTypeDef) -> proc_macro2::TokenStream { let name_id = format_ident!("{}", name); match def { LayoutTypeDef::Struct(fields) => { let mut gen_fields = proc_macro2::TokenStream::new(); let mut encode_fields = proc_macro2::TokenStream::new(); for (field_name, offset, ty) in fields { let field_name_id = format_ident!("{}", field_name); let gen_ty = gen_derive_ty(&ty.ty); let gen_field = quote! { pub #field_name_id: #gen_ty, }; gen_fields.extend(gen_field); encode_fields.extend(gen_encode_field(field_name, *offset, &ty.ty)); } quote! { #[derive(Clone)] pub struct #name_id { #gen_fields } impl crate::encoder::Encode for #name_id { fn fixed_size() -> usize { #size } fn encode_to(&self, buf: &mut [u8]) { #encode_fields } } } } LayoutTypeDef::Enum(variants) => { let mut gen_variants = proc_macro2::TokenStream::new(); let mut cases = proc_macro2::TokenStream::new(); for (variant_ix, (variant_name, payload)) in variants.iter().enumerate() { let variant_id = format_ident!("{}", variant_name); let field_tys = payload.iter().map(|(_offset, ty)| gen_derive_ty(&ty.ty)); let variant = quote! { #variant_id(#(#field_tys),*), }; gen_variants.extend(variant); let mut args = Vec::new(); let mut field_encoders = proc_macro2::TokenStream::new(); let mut tag_field = None; for (i, (offset, ty)) in payload.iter().enumerate() { let field_id = format_ident!("f{}", i); if matches!(ty.ty, GpuType::Scalar(GpuScalar::TagFlags)) { tag_field = Some(field_id.clone()); } else { let field_encoder = quote! { #field_id.encode_to(&mut buf[#offset..]); }; field_encoders.extend(field_encoder); } args.push(field_id); } let tag = variant_ix as u32; let tag_encode = match tag_field { None => quote! { buf[0..4].copy_from_slice(&#tag.to_le_bytes()); }, Some(tag_field) => quote! { buf[0..4].copy_from_slice(&(#tag | ((*#tag_field as u32) << 16)).to_le_bytes()); }, }; let case = quote! { #name_id::#variant_id(#(#args),*) => { #tag_encode #field_encoders } }; cases.extend(case); } quote! { #[derive(Clone)] pub enum #name_id { #gen_variants } impl crate::encoder::Encode for #name_id { fn fixed_size() -> usize { #size } fn encode_to(&self, buf: &mut [u8]) { match self { #cases } } } } } } } /// Generate a Rust type. fn gen_derive_ty(ty: &GpuType) -> proc_macro2::TokenStream { match ty { GpuType::Scalar(s) => gen_derive_scalar_ty(s), GpuType::Vector(s, len) => { let scalar = gen_derive_scalar_ty(s); quote! { [#scalar; #len] } } GpuType::InlineStruct(name) => { let name_id = format_ident!("{}", name); quote! { #name_id } } GpuType::Ref(ty) => { let gen_ty = gen_derive_ty(ty); quote! { crate::encoder::Ref<#gen_ty> } } } } fn gen_derive_scalar_ty(ty: &GpuScalar) -> proc_macro2::TokenStream { match ty { GpuScalar::F16 => quote!(half::f16), GpuScalar::F32 => quote!(f32), GpuScalar::I8 => quote!(i8), GpuScalar::I16 => quote!(i16), GpuScalar::I32 => quote!(i32), GpuScalar::U8 => quote!(u8), GpuScalar::U16 => quote!(u16), GpuScalar::U32 => quote!(u32), GpuScalar::TagFlags => quote!(u16), } } fn gen_encode_field(name: &str, offset: usize, ty: &GpuType) -> proc_macro2::TokenStream { let name_id = format_ident!("{}", name); match ty { // encoding of flags into tag word is handled elsewhere GpuType::Scalar(GpuScalar::TagFlags) => quote! {}, GpuType::Scalar(s) => { let end = offset + s.size(); quote! { buf[#offset..#end].copy_from_slice(&self.#name_id.to_le_bytes()); } } GpuType::Vector(s, len) => { let size = s.size(); quote! { for i in 0..#len { let offset = #offset + i * #size; buf[offset..offset + #size].copy_from_slice(&self.#name_id[i].to_le_bytes()); } } } GpuType::Ref(_) => { quote! { buf[#offset..#offset + 4].copy_from_slice(&self.#name_id.offset().to_le_bytes()); } } _ => { quote! { &self.#name_id.encode_to(&mut buf[#offset..]); } } } }