diff --git a/piet-wgsl/src/shaders/preprocess.rs b/piet-wgsl/src/shaders/preprocess.rs index ce474de..8da2e28 100644 --- a/piet-wgsl/src/shaders/preprocess.rs +++ b/piet-wgsl/src/shaders/preprocess.rs @@ -41,25 +41,44 @@ pub fn preprocess( ) -> String { let mut output = String::with_capacity(input.len()); let mut stack = vec![]; + 'all_lines: for (line_number, mut line) in input.lines().enumerate() { + loop { + if line.is_empty() { + break; + } + let hash_index = line.find('#'); + let comment_index = line.find("//"); + let hash_index = match (hash_index, comment_index) { + (Some(hash_index), None) => hash_index, + (Some(hash_index), Some(comment_index)) if hash_index < comment_index => hash_index, + // Add this line to the output - all directives are commented out or there are no directives + _ => break, + }; + let directive_start = &line[hash_index + '#'.len_utf8()..]; + let directive_len = directive_start + // The first character which can't be part of the directive name marks the end of the directive + // In practise this should always be whitespace, but in theory a 'unit' directive + // could be added + .find(|c: char| !c.is_alphanumeric()) + .unwrap_or(directive_start.len()); + let directive = &directive_start[..directive_len]; + let directive_is_at_start = line.trim_start().starts_with('#'); - for (line_number, line) in input.lines().enumerate() { - let trimmed = line.trim(); - if trimmed.starts_with("#") { - let trimmed = &trimmed[1..]; - let val_idx = trimmed - .chars() - .take_while(|char| char.is_alphanumeric()) - .map(char::len_utf8) - .sum(); - let arg = trimmed[val_idx..].trim(); - match &trimmed[..val_idx] { - x @ ("ifdef" | "ifndef") => { - let exists = defines.contains(arg); - let mode = x == "ifdef"; + match directive { + if_item @ ("ifdef" | "ifndef" | "else" | "endif") if !directive_is_at_start => { + eprintln!("#{if_item} directives must be the first non_whitespace items on their line, ignoring (line {line_number})"); + break; + } + def_test @ ("ifdef" | "ifndef") => { + let def = directive_start[directive_len..].trim(); + let exists = defines.contains(def); + let mode = def_test == "ifdef"; stack.push(StackItem { active: mode == exists, else_passed: false, }); + // Don't add this line to the output; instead process the next line + continue 'all_lines; } "else" => { let item = stack.last_mut(); @@ -71,29 +90,62 @@ pub fn preprocess( item.active = !item.active; } } + let remainder = directive_start[directive_len..].trim(); + if !remainder.is_empty() { + eprintln!("#else directives don't take an argument. `{remainder}` will not be in output (line {line_number})"); + } + // Don't add this line to the output; it should be empty (see warning above) + continue 'all_lines; } "endif" => { if let None = stack.pop() { eprintln!("Mismatched endif (line {line_number})"); } + let remainder = directive_start[directive_len..].trim(); + if !remainder.is_empty() { + eprintln!("#endif directives don't take an argument. `{remainder}` will not be in output (line {line_number})"); + } + // Don't add this line to the output; it should be empty (see warning above) + continue 'all_lines; } "import" => { - let import = imports.get(arg); - if let Some(import) = import { - output.push_str(&preprocess(import, defines, imports)); + output.push_str(&line[..hash_index]); + let directive_end = &directive_start[directive_len..]; + let import_name_start = if let Some(import_name_start) = + directive_end.find(|c: char| !c.is_whitespace()) + { + import_name_start } else { - eprintln!("Unkown import `{arg}` (line {line_number})"); + eprintln!("#import needs a non_whitespace argument (line {line_number})"); + continue 'all_lines; + }; + let import_name_start = &directive_end[import_name_start..]; + let import_name_end_index = import_name_start + // The first character which can't be part of the import name marks the end of the import + .find(|c: char| !(c == '_' || c.is_alphanumeric())) + .unwrap_or(import_name_start.len()); + let import_name = &import_name_start[..import_name_end_index]; + line = &import_name_start[import_name_end_index..]; + let import = imports.get(import_name); + if let Some(import) = import { + // In theory, we can cache this until the top item of the stack changes + // However, in practise there will only ever be at most 2 stack items, so it's reasonable to just recompute it every time + if stack.iter().all(|item| item.active) { + output.push_str(&preprocess(import, defines, imports)); + } + } else { + eprintln!("Unknown import `{import_name}` (line {line_number})"); } + continue; } val => { - eprintln!("Unknown preprocessor directive `{val}` (line {line_number})") + eprintln!("Unknown preprocessor directive `{val}` (line {line_number})"); } } - } else { - if stack.last().map(|x| x.active).unwrap_or(true) { - output.push_str(line); - output.push('\n'); - } + } + if stack.iter().all(|item| item.active) { + output.push_str(line); + output.push('\n'); } } output