Make derive_parser work with all optional arguments (#397)

The current implementation using `winnow`'s `opt` does not work for optional arguments that are in the middle of the command. For example, `bar{.cta}.red.op.pred   p, a{, b}, {!}c;`. This is because `opt` is greedy, and will always match `{, b}` instead of `,{!} c`. This change switches to using a custom combinator that handles this properly
This commit is contained in:
Violet
2025-06-30 18:54:31 -07:00
committed by GitHub
parent d4ad17d75a
commit 1cf345329c
2 changed files with 105 additions and 10 deletions

View File

@ -1492,6 +1492,46 @@ pub struct TokenError(std::ops::Range<usize>);
impl std::error::Error for TokenError {}
fn first_optional<
'a,
'input,
Input: Stream,
OptionalOutput,
RequiredOutput,
Error,
ParseOptional,
ParseRequired,
>(
mut optional: ParseOptional,
mut required: ParseRequired,
) -> impl Parser<Input, (Option<OptionalOutput>, RequiredOutput), Error>
where
ParseOptional: Parser<Input, OptionalOutput, Error>,
ParseRequired: Parser<Input, RequiredOutput, Error>,
Error: ParserError<Input>,
{
move |input: &mut Input| -> Result<(Option<OptionalOutput>, RequiredOutput), ErrMode<Error>> {
let start = input.checkpoint();
let parsed_optional = match optional.parse_next(input) {
Ok(v) => Some(v),
Err(ErrMode::Backtrack(_)) => {
input.reset(&start);
None
},
Err(e) => return Err(e)
};
match required.parse_next(input) {
Ok(v) => return Ok((parsed_optional, v)),
Err(ErrMode::Backtrack(_)) => input.reset(&start),
Err(e) => return Err(e)
};
Ok((None, required.parse_next(input)?))
}
}
// This macro is responsible for generating parser code for instruction parser.
// Instruction parsing is by far the most complex part of parsing PTX code:
// * There are tens of instruction kinds, each with slightly different parsing rules
@ -3413,6 +3453,7 @@ derive_parser!(
#[cfg(test)]
mod tests {
use crate::first_optional;
use crate::parse_module_checked;
use crate::PtxError;
@ -3423,6 +3464,55 @@ mod tests {
use logos::Span;
use winnow::prelude::*;
#[test]
fn first_optional_present() {
let text = "AB";
let result = first_optional::<_, _, _, (), _, _>('A', 'B').parse(text);
assert_eq!(result, Ok((Some('A'), 'B')));
}
#[test]
fn first_optional_absent() {
let text = "B";
let result = first_optional::<_, _, _, (), _, _>('A', 'B').parse(text);
assert_eq!(result, Ok((None, 'B')));
}
#[test]
fn first_optional_repeated_absent() {
let text = "A";
let result = first_optional::<_, _, _, (), _, _>('A', 'A').parse(text);
assert_eq!(result, Ok((None, 'A')));
}
#[test]
fn first_optional_repeated_present() {
let text = "AA";
let result = first_optional::<_, _, _, (), _, _>('A', 'A').parse(text);
assert_eq!(result, Ok((Some('A'), 'A')));
}
#[test]
fn first_optional_sequence_absent() {
let text = "AA";
let result = ('A', first_optional::<_, _, _, (), _, _>('A', 'A')).parse(text);
assert_eq!(result, Ok(('A', (None, 'A'))));
}
#[test]
fn first_optional_sequence_present() {
let text = "AAA";
let result = ('A', first_optional::<_, _, _, (), _, _>('A', 'A')).parse(text);
assert_eq!(result, Ok(('A', (Some('A'), 'A'))));
}
#[test]
fn first_optional_no_match() {
let text = "C";
let result = first_optional::<_, _, _, (), _, _>('A', 'B').parse(text);
assert!(result.is_err());
}
#[test]
fn sm_11() {
let text = ".target sm_11";

View File

@ -757,12 +757,13 @@ fn emit_definition_parser(
DotModifierRef::Direct { optional: true, .. }
| DotModifierRef::Indirect { optional: true, .. } => TokenStream::new(),
});
let arguments_parse = definition.arguments.0.iter().enumerate().map(|(idx, arg)| {
let (arguments_pattern, arguments_parser) = definition.arguments.0.iter().enumerate().rfold((quote! { () }, quote! { empty }), |(emitted_pattern, emitted_parser), (idx, arg)| {
let comma = if idx == 0 || arg.pre_pipe {
quote! { empty }
} else {
quote! { any.verify(|(t, _)| *t == #token_type::Comma).void() }
};
let pre_bracket = if arg.pre_bracket {
quote! {
any.verify(|(t, _)| *t == #token_type::LBracket).void()
@ -833,16 +834,20 @@ fn emit_definition_parser(
#pattern.map(|(_, _, _, _, name, _, _)| name)
}
};
if arg.optional {
quote! {
let #arg_name = opt(#inner_parser).parse_next(stream)?;
}
let parser = if arg.optional {
quote! { first_optional(#inner_parser, #emitted_parser) }
} else {
quote! {
let #arg_name = #inner_parser.parse_next(stream)?;
}
}
quote! { (#inner_parser, #emitted_parser) }
};
let pattern = quote! { ( #arg_name, #emitted_pattern ) };
(pattern, parser)
});
let arguments_parse = quote! { let #arguments_pattern = ( #arguments_parser ).parse_next(stream)?; };
let fn_args = definition.function_arguments();
let fn_name = format_ident!("{}_{}", opcode, fn_idx);
let fn_call = quote! {
@ -863,7 +868,7 @@ fn emit_definition_parser(
}
}
#(#unordered_parse_validations)*
#(#arguments_parse)*
#arguments_parse
#fn_call
}
}