Finite state machines in rust; bendns fork to add types.
| -rw-r--r-- | README.md | 8 | ||||
| -rw-r--r-- | doc-example/src/lib.rs | 4 | ||||
| -rw-r--r-- | rust-fsm-dsl/Cargo.toml | 2 | ||||
| -rw-r--r-- | rust-fsm-dsl/src/lib.rs | 23 | ||||
| -rw-r--r-- | rust-fsm-dsl/src/parser.rs | 31 | ||||
| -rw-r--r-- | rust-fsm-dsl/src/variant.rs | 73 | ||||
| -rw-r--r-- | rust-fsm/src/lib.rs | 43 | ||||
| -rw-r--r-- | rust-fsm/tests/circuit_breaker.rs | 14 | ||||
| -rw-r--r-- | rust-fsm/tests/circuit_breaker_dsl.rs | 13 | ||||
| -rw-r--r-- | rust-fsm/tests/circuit_breaker_dsl_custom_types.rs | 15 | ||||
| -rw-r--r-- | rust-fsm/tests/simple.rs | 21 |
11 files changed, 171 insertions, 76 deletions
@@ -75,8 +75,8 @@ state_machine! { /// A Circuit Breaker state machine. circuit_breaker(Closed) - Closed(Unsuccessful) => Open [SetupTimer], - Open(TimerTriggered) => HalfOpen, + Closed => Unsuccessful => Open [SetupTimer], + Open => TimerTriggered => HalfOpen, HalfOpen => { Successful => Closed, Unsuccessful => Open [SetupTimer] @@ -103,10 +103,10 @@ This state machine can be used as follows: // Initialize the state machine. The state is `Closed` now. let mut machine = circuit_breaker::StateMachine::new(); // Consume the `Successful` input. No state transition is performed. -let _ = machine.consume(&circuit_breaker::Input::Successful); +let _ = machine.consume(circuit_breaker::Input::Successful); // Consume the `Unsuccesful` input. The machine is moved to the `Open` // state. The output is `SetupTimer`. -let output = machine.consume(&circuit_breaker::Input::Unsuccessful).unwrap(); +let output = machine.consume(circuit_breaker::Input::Unsuccessful).unwrap(); // Check the output if let Some(circuit_breaker::Output::SetupTimer) = output { // Set up the timer... diff --git a/doc-example/src/lib.rs b/doc-example/src/lib.rs index 18c0978..43691ee 100644 --- a/doc-example/src/lib.rs +++ b/doc-example/src/lib.rs @@ -6,8 +6,8 @@ state_machine! { /// https://martinfowler.com/bliki/CircuitBreaker.html pub circuit_breaker(Closed) - Closed(Unsuccessful) => Open [SetupTimer], - Open(TimerTriggered) => HalfOpen, + Closed => Unsuccessful => Open [SetupTimer], + Open => TimerTriggered => HalfOpen, HalfOpen => { Successful => Closed, Unsuccessful => Open [SetupTimer] diff --git a/rust-fsm-dsl/Cargo.toml b/rust-fsm-dsl/Cargo.toml index 33c7010..1513dfb 100644 --- a/rust-fsm-dsl/Cargo.toml +++ b/rust-fsm-dsl/Cargo.toml @@ -20,5 +20,5 @@ diagram = [] [dependencies] proc-macro2 = "1" -syn = "2" +syn = { version = "2", features = ["extra-traits", "full"] } quote = "1" diff --git a/rust-fsm-dsl/src/lib.rs b/rust-fsm-dsl/src/lib.rs index ea0400e..70cd5a2 100644 --- a/rust-fsm-dsl/src/lib.rs +++ b/rust-fsm-dsl/src/lib.rs @@ -7,17 +7,17 @@ extern crate proc_macro; use proc_macro::TokenStream; use quote::{quote, ToTokens}; use std::{collections::BTreeSet, iter::FromIterator}; -use syn::{parse_macro_input, Attribute, Ident}; - +use syn::*; mod parser; - +mod variant; +use variant::Variant; /// The full information about a state transition. Used to unify the /// represantion of the simple and the compact forms. struct Transition<'a> { - initial_state: &'a Ident, - input_value: &'a Ident, - final_state: &'a Ident, - output: &'a Option<Ident>, + initial_state: &'a Variant, + input_value: &'a Variant, + final_state: &'a Variant, + output: &'a Option<Variant>, } fn attrs_to_token_stream(attrs: Vec<Attribute>) -> proc_macro2::TokenStream { @@ -80,15 +80,16 @@ pub fn state_machine(tokens: TokenStream) -> TokenStream { "/// {initial_state} --> {final_state}: {input_value}" )); + let input_ = input_value.match_on(); transition_cases.push(quote! { - (Self::State::#initial_state, Self::Input::#input_value) => { + (Self::State::#initial_state, Self::Input::#input_) => { Some(Self::State::#final_state) } }); if let Some(output_value) = output { output_cases.push(quote! { - (Self::State::#initial_state, Self::Input::#input_value) => { + (Self::State::#initial_state, Self::Input::#input_) => { Some(Self::Output::#output_value) } }); @@ -191,14 +192,14 @@ pub fn state_machine(tokens: TokenStream) -> TokenStream { type Output = #output_type; const INITIAL_STATE: Self::State = Self::State::#initial_state_name; - fn transition(state: &Self::State, input: &Self::Input) -> Option<Self::State> { + fn transition(state: Self::State, input: Self::Input) -> Option<Self::State> { match (state, input) { #(#transition_cases)* _ => None, } } - fn output(state: &Self::State, input: &Self::Input) -> Option<Self::Output> { + fn output(state: Self::State, input: Self::Input) -> Option<Self::Output> { match (state, input) { #(#output_cases)* _ => None, diff --git a/rust-fsm-dsl/src/parser.rs b/rust-fsm-dsl/src/parser.rs index d4a9c2e..b0fe02e 100644 --- a/rust-fsm-dsl/src/parser.rs +++ b/rust-fsm-dsl/src/parser.rs @@ -1,12 +1,11 @@ +use super::variant::Variant; use syn::{ - braced, bracketed, parenthesized, parse::{Error, Parse, ParseStream, Result}, - token::{Bracket, Paren}, - Attribute, Ident, Path, Token, Visibility, + token::Bracket, + *, }; - /// The output of a state transition -pub struct Output(Option<Ident>); +pub struct Output(Option<Variant>); impl Parse for Output { fn parse(input: ParseStream) -> Result<Self> { @@ -20,7 +19,7 @@ impl Parse for Output { } } -impl From<Output> for Option<Ident> { +impl From<Output> for Option<Variant> { fn from(output: Output) -> Self { output.0 } @@ -29,9 +28,9 @@ impl From<Output> for Option<Ident> { /// Represents a part of state transition without the initial state. The `Parse` /// trait is implemented for the compact form. pub struct TransitionEntry { - pub input_value: Ident, - pub final_state: Ident, - pub output: Option<Ident>, + pub input_value: Variant, + pub final_state: Variant, + pub output: Option<Variant>, } impl Parse for TransitionEntry { @@ -50,19 +49,18 @@ impl Parse for TransitionEntry { /// Parses the transition in any of the possible formats. pub struct TransitionDef { - pub initial_state: Ident, + pub initial_state: Variant, pub transitions: Vec<TransitionEntry>, } impl Parse for TransitionDef { fn parse(input: ParseStream) -> Result<Self> { let initial_state = input.parse()?; + input.parse::<Token![=>]>()?; // Parse the transition in the simple format - // InitialState(Input) => ResultState [Output] - let transitions = if input.lookahead1().peek(Paren) { - let input_content; - parenthesized!(input_content in input); - let input_value = input_content.parse()?; + // InitialState => Input => ResultState + let transitions = if !input.lookahead1().peek(token::Brace) { + let input_value = input.parse()?; input.parse::<Token![=>]>()?; let final_state = input.parse()?; let output = input.parse::<Output>()?.into(); @@ -78,7 +76,6 @@ impl Parse for TransitionDef { // Input1 => State1, // Input2 => State2 [Output] // } - input.parse::<Token![=>]>()?; let entries_content; braced!(entries_content in input); @@ -120,7 +117,7 @@ pub struct StateMachineDef { /// The visibility modifier (applies to all generated items) pub visibility: Visibility, pub name: Ident, - pub initial_state: Ident, + pub initial_state: Variant, pub transitions: Vec<TransitionDef>, pub attributes: Vec<Attribute>, pub input_type: Option<Path>, diff --git a/rust-fsm-dsl/src/variant.rs b/rust-fsm-dsl/src/variant.rs new file mode 100644 index 0000000..417476b --- /dev/null +++ b/rust-fsm-dsl/src/variant.rs @@ -0,0 +1,73 @@ +use std::fmt::Display; + +use quote::ToTokens; +use syn::{parse::Parse, *}; +/// Variant with no discriminator +#[derive(Hash, Debug, PartialEq, Eq)] +pub struct Variant { + // attrs: Vec<Attribute>, + ident: Ident, + field: Option<(Type, Expr)>, +} + +impl Parse for Variant { + fn parse(input: parse::ParseStream) -> Result<Self> { + // let attrs = input.call(Attribute::parse_outer)?; + // let _visibility: Visibility = input.parse()?; + let ident: Ident = input.parse()?; + let field = if input.peek(token::Paren) { + let inp; + parenthesized!(inp in input); + let t = inp.parse()?; + inp.parse::<Token![=>]>()?; + Some((t, inp.parse()?)) + } else { + None + }; + Ok(Variant { + // attrs, + ident, + field, + }) + } +} + +impl PartialOrd for Variant { + fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { + self.ident.partial_cmp(&other.ident) + } +} +impl Ord for Variant { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.ident.cmp(&other.ident) + } +} + +impl ToTokens for Variant { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + self.ident.to_tokens(tokens); + if let Some((t, _)) = &self.field { + tokens.extend(quote::quote! { (#t) }) + } + } +} + +impl Variant { + pub fn match_on(&self) -> proc_macro2::TokenStream { + if let Self { + ident, + field: Some((_, p)), + } = self + { + quote::quote! { #ident(#p) } + } else { + self.ident.to_token_stream() + } + } +} + +impl Display for Variant { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.match_on()) + } +} diff --git a/rust-fsm/src/lib.rs b/rust-fsm/src/lib.rs index 3324e94..5a44016 100644 --- a/rust-fsm/src/lib.rs +++ b/rust-fsm/src/lib.rs @@ -101,10 +101,10 @@ This state machine can be used as follows: // Initialize the state machine. The state is `Closed` now. let mut machine = circuit_breaker::StateMachine::new(); // Consume the `Successful` input. No state transition is performed. -let _ = machine.consume(&circuit_breaker::Input::Successful); +let _ = machine.consume(circuit_breaker::Input::Successful); // Consume the `Unsuccesful` input. The machine is moved to the `Open` // state. The output is `SetupTimer`. -let output = machine.consume(&circuit_breaker::Input::Unsuccessful).unwrap(); +let output = machine.consume(circuit_breaker::Input::Unsuccessful).unwrap(); // Check the output if let Some(circuit_breaker::Output::SetupTimer) = output { // Set up the timer... @@ -229,7 +229,7 @@ You can see an example of the Circuit Breaker state machine in the #![cfg_attr(not(feature = "std"), no_std)] -use core::fmt; +use core::{fmt, ops::Deref}; #[cfg(feature = "std")] use std::error::Error; @@ -257,11 +257,11 @@ pub trait StateMachineImpl { /// The transition fuction that outputs a new state based on the current /// state and the provided input. Outputs `None` when there is no transition /// for a given combination of the input and the state. - fn transition(state: &Self::State, input: &Self::Input) -> Option<Self::State>; + fn transition(state: Self::State, input: Self::Input) -> Option<Self::State>; /// The output function that outputs some value from the output alphabet /// based on the current state and the given input. Outputs `None` when /// there is no output for a given combination of the input and the state. - fn output(state: &Self::State, input: &Self::Input) -> Option<Self::Output>; + fn output(state: Self::State, input: Self::Input) -> Option<Self::Output>; } /// A convenience wrapper around the `StateMachine` trait that encapsulates the @@ -292,20 +292,27 @@ where Self { state } } + pub fn transition(self, input: T::Input) -> Result<T::State, TransitionImpossibleError> { + T::transition(self.state, input).ok_or(TransitionImpossibleError) + } + pub fn output(self, input: T::Input) -> Result<T::Output, TransitionImpossibleError> { + T::output(self.state, input).ok_or(TransitionImpossibleError) + } + /// Consumes the provided input, gives an output and performs a state /// transition. If a state transition with the current state and the /// provided input is not allowed, returns an error. pub fn consume( &mut self, - input: &T::Input, - ) -> Result<Option<T::Output>, TransitionImpossibleError> { - if let Some(state) = T::transition(&self.state, input) { - let output = T::output(&self.state, input); - self.state = state; - Ok(output) - } else { - Err(TransitionImpossibleError) - } + input: T::Input, + ) -> Result<Option<T::Output>, TransitionImpossibleError> + where + T::Input: Clone, + T::State: Clone, + { + T::transition(self.state.clone(), input.clone()) + .ok_or(TransitionImpossibleError) + .map(|state| T::output(std::mem::replace(&mut self.state, state), input)) } /// Returns the current state. @@ -314,6 +321,14 @@ where } } +impl<T: StateMachineImpl> Deref for StateMachine<T> { + type Target = T::State; + + fn deref(&self) -> &Self::Target { + &self.state + } +} + impl<T> Default for StateMachine<T> where T: StateMachineImpl, diff --git a/rust-fsm/tests/circuit_breaker.rs b/rust-fsm/tests/circuit_breaker.rs index aa25a1e..d8e300d 100644 --- a/rust-fsm/tests/circuit_breaker.rs +++ b/rust-fsm/tests/circuit_breaker.rs @@ -5,7 +5,7 @@ use rust_fsm::*; use std::sync::{Arc, Mutex}; use std::time::Duration; -#[derive(Debug)] +#[derive(Debug, Clone, Copy)] enum CircuitBreakerInput { Successful, Unsuccessful, @@ -31,7 +31,7 @@ impl StateMachineImpl for CircuitBreakerMachine { type Output = CircuitBreakerOutputSetTimer; const INITIAL_STATE: Self::State = CircuitBreakerState::Closed; - fn transition(state: &Self::State, input: &Self::Input) -> Option<Self::State> { + fn transition(state: Self::State, input: Self::Input) -> Option<Self::State> { match (state, input) { (CircuitBreakerState::Closed, CircuitBreakerInput::Unsuccessful) => { Some(CircuitBreakerState::Open) @@ -49,7 +49,7 @@ impl StateMachineImpl for CircuitBreakerMachine { } } - fn output(state: &Self::State, input: &Self::Input) -> Option<Self::Output> { + fn output(state: Self::State, input: Self::Input) -> Option<Self::Output> { match (state, input) { (CircuitBreakerState::Closed, CircuitBreakerInput::Unsuccessful) => { Some(CircuitBreakerOutputSetTimer) @@ -70,7 +70,7 @@ fn circuit_breaker() { let machine = Arc::new(Mutex::new(machine)); { let mut lock = machine.lock().unwrap(); - let res = lock.consume(&CircuitBreakerInput::Unsuccessful).unwrap(); + let res = lock.consume(CircuitBreakerInput::Unsuccessful).unwrap(); assert_eq!(res, Some(CircuitBreakerOutputSetTimer)); assert_eq!(lock.state(), &CircuitBreakerState::Open); } @@ -80,7 +80,7 @@ fn circuit_breaker() { std::thread::spawn(move || { std::thread::sleep(Duration::new(5, 0)); let mut lock = machine_wait.lock().unwrap(); - let res = lock.consume(&CircuitBreakerInput::TimerTriggered).unwrap(); + let res = lock.consume(CircuitBreakerInput::TimerTriggered).unwrap(); assert_eq!(res, None); assert_eq!(lock.state(), &CircuitBreakerState::HalfOpen); }); @@ -90,7 +90,7 @@ fn circuit_breaker() { std::thread::spawn(move || { std::thread::sleep(Duration::new(1, 0)); let mut lock = machine_try.lock().unwrap(); - let res = lock.consume(&CircuitBreakerInput::Successful); + let res = lock.consume(CircuitBreakerInput::Successful); assert!(matches!(res, Err(TransitionImpossibleError))); assert_eq!(lock.state(), &CircuitBreakerState::Open); }); @@ -99,7 +99,7 @@ fn circuit_breaker() { std::thread::sleep(Duration::new(7, 0)); { let mut lock = machine.lock().unwrap(); - let res = lock.consume(&CircuitBreakerInput::Successful).unwrap(); + let res = lock.consume(CircuitBreakerInput::Successful).unwrap(); assert_eq!(res, None); assert_eq!(lock.state(), &CircuitBreakerState::Closed); } diff --git a/rust-fsm/tests/circuit_breaker_dsl.rs b/rust-fsm/tests/circuit_breaker_dsl.rs index 7645fdb..c7315c1 100644 --- a/rust-fsm/tests/circuit_breaker_dsl.rs +++ b/rust-fsm/tests/circuit_breaker_dsl.rs @@ -6,10 +6,11 @@ use std::sync::{Arc, Mutex}; use std::time::Duration; state_machine! { + #[derive(Clone, Copy)] circuit_breaker(Closed) - Closed(Unsuccessful) => Open [SetupTimer], - Open(TimerTriggered) => HalfOpen, + Closed => Unsuccessful => Open [SetupTimer], + Open => TimerTriggered => HalfOpen, HalfOpen => { Successful => Closed, Unsuccessful => Open [SetupTimer] @@ -24,7 +25,7 @@ fn circit_breaker_dsl() { let machine = Arc::new(Mutex::new(machine)); { let mut lock = machine.lock().unwrap(); - let res = lock.consume(&circuit_breaker::Input::Unsuccessful).unwrap(); + let res = lock.consume(circuit_breaker::Input::Unsuccessful).unwrap(); assert!(matches!(res, Some(circuit_breaker::Output::SetupTimer))); assert!(matches!(lock.state(), &circuit_breaker::State::Open)); } @@ -35,7 +36,7 @@ fn circit_breaker_dsl() { std::thread::sleep(Duration::new(5, 0)); let mut lock = machine_wait.lock().unwrap(); let res = lock - .consume(&circuit_breaker::Input::TimerTriggered) + .consume(circuit_breaker::Input::TimerTriggered) .unwrap(); assert!(matches!(res, None)); assert!(matches!(lock.state(), &circuit_breaker::State::HalfOpen)); @@ -46,7 +47,7 @@ fn circit_breaker_dsl() { std::thread::spawn(move || { std::thread::sleep(Duration::new(1, 0)); let mut lock = machine_try.lock().unwrap(); - let res = lock.consume(&circuit_breaker::Input::Successful); + let res = lock.consume(circuit_breaker::Input::Successful); assert!(matches!(res, Err(TransitionImpossibleError))); assert!(matches!(lock.state(), &circuit_breaker::State::Open)); }); @@ -55,7 +56,7 @@ fn circit_breaker_dsl() { std::thread::sleep(Duration::new(7, 0)); { let mut lock = machine.lock().unwrap(); - let res = lock.consume(&circuit_breaker::Input::Successful).unwrap(); + let res = lock.consume(circuit_breaker::Input::Successful).unwrap(); assert!(matches!(res, None)); assert!(matches!(lock.state(), &circuit_breaker::State::Closed)); } diff --git a/rust-fsm/tests/circuit_breaker_dsl_custom_types.rs b/rust-fsm/tests/circuit_breaker_dsl_custom_types.rs index ca878b3..c17617b 100644 --- a/rust-fsm/tests/circuit_breaker_dsl_custom_types.rs +++ b/rust-fsm/tests/circuit_breaker_dsl_custom_types.rs @@ -5,18 +5,21 @@ use rust_fsm::*; use std::sync::{Arc, Mutex}; use std::time::Duration; +#[derive(Clone, Copy)] pub enum Input { Successful, Unsuccessful, TimerTriggered, } +#[derive(Clone, Copy)] pub enum State { Closed, HalfOpen, Open, } +#[derive(Clone, Copy)] pub enum Output { SetupTimer, } @@ -25,8 +28,8 @@ state_machine! { #[state_machine(input(crate::Input), state(crate::State), output(crate::Output))] circuit_breaker(Closed) - Closed(Unsuccessful) => Open [SetupTimer], - Open(TimerTriggered) => HalfOpen, + Closed => Unsuccessful => Open [SetupTimer], + Open => TimerTriggered => HalfOpen, HalfOpen => { Successful => Closed, Unsuccessful => Open [SetupTimer] @@ -41,7 +44,7 @@ fn circit_breaker_dsl() { let machine = Arc::new(Mutex::new(machine)); { let mut lock = machine.lock().unwrap(); - let res = lock.consume(&Input::Unsuccessful).unwrap(); + let res = lock.consume(Input::Unsuccessful).unwrap(); assert!(matches!(res, Some(Output::SetupTimer))); assert!(matches!(lock.state(), &State::Open)); } @@ -51,7 +54,7 @@ fn circit_breaker_dsl() { std::thread::spawn(move || { std::thread::sleep(Duration::new(5, 0)); let mut lock = machine_wait.lock().unwrap(); - let res = lock.consume(&Input::TimerTriggered).unwrap(); + let res = lock.consume(Input::TimerTriggered).unwrap(); assert!(matches!(res, None)); assert!(matches!(lock.state(), &State::HalfOpen)); }); @@ -61,7 +64,7 @@ fn circit_breaker_dsl() { std::thread::spawn(move || { std::thread::sleep(Duration::new(1, 0)); let mut lock = machine_try.lock().unwrap(); - let res = lock.consume(&Input::Successful); + let res = lock.consume(Input::Successful); assert!(matches!(res, Err(TransitionImpossibleError))); assert!(matches!(lock.state(), &State::Open)); }); @@ -70,7 +73,7 @@ fn circit_breaker_dsl() { std::thread::sleep(Duration::new(7, 0)); { let mut lock = machine.lock().unwrap(); - let res = lock.consume(&Input::Successful).unwrap(); + let res = lock.consume(Input::Successful).unwrap(); assert!(matches!(res, None)); assert!(matches!(lock.state(), &State::Closed)); } diff --git a/rust-fsm/tests/simple.rs b/rust-fsm/tests/simple.rs index 8754b92..d7bdf7c 100644 --- a/rust-fsm/tests/simple.rs +++ b/rust-fsm/tests/simple.rs @@ -1,23 +1,28 @@ use rust_fsm::*; state_machine! { - #[derive(Debug)] + #[derive(Debug, Clone, Copy)] #[repr(C)] door(Open) - Open(Key) => Closed, - Closed(Key) => Open, - Open(Break) => Broken, - Closed(Break) => Broken, + Open => Key => Closed, + Closed => Key => Open, + Open => Break => Broken, + Closed => Break => Broken, + Open => Thing(u32 => 5) => Fine, + // Open(u32) => Key => Open [output], + // Open(u32) => { + // Key => Open, + // } } #[test] fn simple() { let mut machine = door::StateMachine::new(); - machine.consume(&door::Input::Key).unwrap(); + machine.consume(door::Input::Key).unwrap(); println!("{:?}", machine.state()); - machine.consume(&door::Input::Key).unwrap(); + machine.consume(door::Input::Key).unwrap(); println!("{:?}", machine.state()); - machine.consume(&door::Input::Break).unwrap(); + machine.consume(door::Input::Break).unwrap(); println!("{:?}", machine.state()); } |