package ensure import ( "fmt" "sync" ) type ( // Runner takes a list of ensurers and enforces their state. // The runner is not transactional. Runner struct { // Is is checked before enforcing all states. // The states will not be enforced when the function is returning true. Is Enforced // Parallel defines if the states are ensured in parallel. There is no // dependency management at work. The requirements must already be met // beforehand. // When parallel is set to true the processing will not be halted on the // first error. Instead all errors will be collected and returned. Parallel bool // States is the list of states to ensure when Ensure() is called. States []Ensurer // When parallel mode is used, the number of threads to spawn can be set. // When not set, every state will spawn a thread. Workers int } ) // Ensure will call Ensure() on every state. // When Parallel is true, all states will be ensured in parallel. The number of // threads that will be spawned can be controlled by the Workers attribute. // In case of an error the processing will be aborted. When Parallel is true // all errors are collected and returned as one and the processing will not be // halted. func (r *Runner) Ensure() error { if r.Is != nil && r.Is() { return nil } if r.Parallel { return r.ensureParallel() } return r.ensureSequence() } func (r *Runner) ensureSequence() error { for i, state := range r.States { if err := state.Ensure(); err != nil { return fmt.Errorf("could not ensure resource with index %d: %w", i, err) } } return nil } func (r *Runner) ensureParallel() error { if len(r.States) == 0 { return nil } threads := r.Workers if threads == 0 { threads = len(r.States) } work := make(chan Ensurer, threads) results := make(chan error, threads) var err error wg := &sync.WaitGroup{} wg.Add(threads) go func() { for _, state := range r.States { work <- state } }() go func() { c := 0 for result := range results { c++ if result != nil { err = fmt.Errorf("%s\n%s", err, result) } if c == len(r.States) { close(work) close(results) } } }() for i := 0; i < threads; i++ { go func() { for state := range work { err := state.Ensure() results <- err } wg.Done() }() } wg.Wait() return err }