aboutsummaryrefslogblamecommitdiff
path: root/cmd/moncheck/main.go
blob: 62a15517082982988cf1cf339f9e234c2c69ae2a (plain) (tree)
1
2
3
4
5
6
7
8
9





                      
                             

                       
             

                   
            
                 

                 












                                                                                      



                                                 
                                                 
         

                    








                                                            
                                                                  



                                                             



                                                                                 



                                                                    



                                                              





                                                                         




                                                                 
                                             
                                                                





                              
                                                                                          





                                                                                       
                                                                                





                                       







                                                                                 
                                      

                                   
                 

                                                                        
                                     
                                                
                                



                                                                    
                 
                                                                                 



                                                                           

                                                                        
                                
                                 
                                                                                 

                                                                           
                                 

                                      
                                                                                 

                                                                                  
                                         
                                
                                                           


                                
                                 
                 
                                      
 
                                                             

                                                                                                                                                              
                                                                                         

                                     
                 
                                                                                                                              






                                                                                                  


                                                                                                      



                           
























































                                                                                               














                                                           
package main

import (
	"bytes"
	"context"
	"database/sql"
	"database/sql/driver"
	"encoding/json"
	"flag"
	"fmt"
	"io/ioutil"
	"log"
	"os"
	"os/exec"
	"strconv"
	"strings"
	"sync"
	"syscall"
	"time"

	"github.com/lib/pq"
)

var (
	configPath = flag.String("config", "moncheck.conf", "path to the config file")
)

type (
	Config struct {
		DB      string   `json:"db"`
		Timeout string   `json:"timeout"`
		Wait    string   `json:"wait"`
		Path    []string `json:"path"`
		Workers int      `json:"workers"`
	}

	States []int
)

func main() {
	flag.Parse()

	raw, err := ioutil.ReadFile(*configPath)
	if err != nil {
		log.Fatalf("could not read config: %s", err)
	}
	config := Config{Timeout: "30s", Wait: "30s", Workers: 25}
	if err := json.Unmarshal(raw, &config); err != nil {
		log.Fatalf("could not parse config: %s", err)
	}

	if err := os.Setenv("PATH", strings.Join(config.Path, ":")); err != nil {
		log.Fatalf("could not set PATH: %s", err)
	}

	waitDuration, err := time.ParseDuration(config.Wait)
	if err != nil {
		log.Fatalf("could not parse wait duration: %s", err)
	}
	timeout, err := time.ParseDuration(config.Timeout)
	if err != nil {
		log.Fatalf("could not parse timeout: %s", err)
	}

	db, err := sql.Open("postgres", config.DB)
	if err != nil {
		log.Fatalf("could not open database connection: %s", err)
	}

	hostname, err := os.Hostname()
	if err != nil {
		log.Fatalf("could not resolve hostname: %s", err)
	}

	for i := 0; i < config.Workers; i++ {
		go check(i, db, waitDuration, timeout, hostname)
	}
	wg := sync.WaitGroup{}
	wg.Add(1)
	wg.Wait()
}

func check(thread int, db *sql.DB, waitDuration, timeout time.Duration, hostname string) {
	for {
		tx, err := db.Begin()
		if err != nil {
			log.Printf("[%d] could not start transaction: %s", thread, err)
			continue
		}
		row := tx.QueryRow(`select check_id, cmdLine, states, mapping_id
		from active_checks
		where next_time < now()
			and enabled
		order by next_time
		for update skip locked
		limit 1;`)
		if err != nil {
			log.Printf("[%d] could not start query: %s", thread, err)
			tx.Rollback()
			continue
		}
		var (
			id      int64
			cmdLine []string
			states  States
			mapId   int
			state   int
		)
		err = row.Scan(&id, pq.Array(&cmdLine), &states, &mapId)
		if err != nil && err == sql.ErrNoRows {
			tx.Rollback()
			time.Sleep(waitDuration)
			continue
		} else if err != nil {
			log.Printf("could not scan values: %s", err)
			tx.Rollback()
			break
		}
		ctx, cancel := context.WithTimeout(context.Background(), timeout)
		cmd := exec.CommandContext(ctx, cmdLine[0], cmdLine[1:]...)
		output := bytes.NewBuffer([]byte{})
		cmd.Stdout = output
		cmd.Stderr = output
		err = cmd.Run()
		if err != nil && ctx.Err() == context.DeadlineExceeded {
			cancel()
			state = 2
			fmt.Fprintf(output, "check took longer than %s", timeout)
		} else if err != nil && cmd.ProcessState == nil {
			log.Printf("[%d] error running check: %s", id, err)
			state = 3
		} else if err != nil {
			cancel()
			status, ok := cmd.ProcessState.Sys().(syscall.WaitStatus)
			if !ok {
				log.Printf("[%d]error running check: %s", id, err)
				state = 2
			} else {
				state = status.ExitStatus()
			}
		} else {
			cancel()
			state = 0
		}
		msg := output.String()

		if _, err := tx.Exec(`update active_checks ac
		set next_time = now() + intval, states = ARRAY[$2::int] || states[1:4], msg = $3, acknowledged = case when $4 then false else acknowledged end
where check_id = $1`, id, &state, &msg, states.ToOK()); err != nil {
			log.Printf("[%d] could not update row '%d': %s", thread, id, err)
			tx.Rollback()
			continue
		}
		if _, err := tx.Exec(`insert into notifications(check_id, states, output, mapping_id, notifier_id, check_host)
			select $1, array_agg(ml.target), $2, $3, cn.notifier_id, $4
			from active_checks ac
			cross join lateral unnest(ac.states) s
			join checks_notify cn on ac.check_id = cn.check_id
			join mapping_level ml on ac.mapping_id = ml.mapping_id and s.s = ml.source
			where ac.check_id = $1
			group by cn.notifier_id;`, &id, &msg, &mapId, &hostname); err != nil {
			log.Printf("[%d] could not create notification for '%d': %s", thread, id, err)
			tx.Rollback()
			continue
		}
		tx.Commit()
	}
}

func (s *States) Value() (driver.Value, error) {
	last := len(*s)
	if last == 0 {
		return "{}", nil
	}
	result := strings.Builder{}
	_, err := result.WriteString("{")
	if err != nil {
		return "", fmt.Errorf("could not write to buffer: %s", err)
	}
	for i, state := range *s {
		if _, err := fmt.Fprintf(&result, "%d", state); err != nil {
			return "", fmt.Errorf("could not write to buffer: %s", err)
		}
		if i < last-1 {
			if _, err := result.WriteString(","); err != nil {
				return "", fmt.Errorf("could not write to buffer: %s", err)
			}
		}
	}
	if _, err := result.WriteString("}"); err != nil {
		return "", fmt.Errorf("could not write to buffer: %s", err)
	}
	return result.String(), nil
}

func (s *States) Scan(src interface{}) error {
	switch src := src.(type) {
	case []byte:
		tmp := bytes.Trim(src, "{}")
		states := bytes.Split(tmp, []byte(","))
		result := make([]int, len(states))
		for i, state := range states {
			var err error
			result[i], err = strconv.Atoi(string(state))
			if err != nil {
				return fmt.Errorf("could not parse element %s: %s", state, err)
			}
		}
		*s = result
		return nil
	default:
		return fmt.Errorf("could not convert %T to states", src)
	}
}

// Append prepends the new state before all others.
func (s *States) Add(state int) {
	vals := *s
	statePos := 5
	if len(vals) < 6 {
		statePos = len(vals)
	}
	*s = append([]int{state}, vals[:statePos]...)
	return
}

// ToOK returns true when the state returns from != 0 to 0.
func (s *States) ToOK() bool {
	vals := *s
	if len(vals) == 0 {
		return false
	}
	if len(vals) <= 1 {
		return vals[0] == 0
	}
	if vals[0] == 0 && vals[1] > 0 {
		return true
	}
	return false
}