aboutsummaryrefslogtreecommitdiff
path: root/vendor/github.com/jackc/pgx/v5/internal/stmtcache/lru_cache.go
blob: dec83f47b9ab51e66082cec44e3a6b9685543a4a (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
package stmtcache

import (
	"container/list"

	"github.com/jackc/pgx/v5/pgconn"
)

// LRUCache implements Cache with a Least Recently Used (LRU) cache.
type LRUCache struct {
	cap          int
	m            map[string]*list.Element
	l            *list.List
	invalidStmts []*pgconn.StatementDescription
}

// NewLRUCache creates a new LRUCache. cap is the maximum size of the cache.
func NewLRUCache(cap int) *LRUCache {
	return &LRUCache{
		cap: cap,
		m:   make(map[string]*list.Element),
		l:   list.New(),
	}
}

// Get returns the statement description for sql. Returns nil if not found.
func (c *LRUCache) Get(key string) *pgconn.StatementDescription {
	if el, ok := c.m[key]; ok {
		c.l.MoveToFront(el)
		return el.Value.(*pgconn.StatementDescription)
	}

	return nil

}

// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache or
// sd.SQL has been invalidated and HandleInvalidated has not been called yet.
func (c *LRUCache) Put(sd *pgconn.StatementDescription) {
	if sd.SQL == "" {
		panic("cannot store statement description with empty SQL")
	}

	if _, present := c.m[sd.SQL]; present {
		return
	}

	// The statement may have been invalidated but not yet handled. Do not readd it to the cache.
	for _, invalidSD := range c.invalidStmts {
		if invalidSD.SQL == sd.SQL {
			return
		}
	}

	if c.l.Len() == c.cap {
		c.invalidateOldest()
	}

	el := c.l.PushFront(sd)
	c.m[sd.SQL] = el
}

// Invalidate invalidates statement description identified by sql. Does nothing if not found.
func (c *LRUCache) Invalidate(sql string) {
	if el, ok := c.m[sql]; ok {
		delete(c.m, sql)
		c.invalidStmts = append(c.invalidStmts, el.Value.(*pgconn.StatementDescription))
		c.l.Remove(el)
	}
}

// InvalidateAll invalidates all statement descriptions.
func (c *LRUCache) InvalidateAll() {
	el := c.l.Front()
	for el != nil {
		c.invalidStmts = append(c.invalidStmts, el.Value.(*pgconn.StatementDescription))
		el = el.Next()
	}

	c.m = make(map[string]*list.Element)
	c.l = list.New()
}

// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated.
func (c *LRUCache) GetInvalidated() []*pgconn.StatementDescription {
	return c.invalidStmts
}

// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a
// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were
// never seen by the call to GetInvalidated.
func (c *LRUCache) RemoveInvalidated() {
	c.invalidStmts = nil
}

// Len returns the number of cached prepared statement descriptions.
func (c *LRUCache) Len() int {
	return c.l.Len()
}

// Cap returns the maximum number of cached prepared statement descriptions.
func (c *LRUCache) Cap() int {
	return c.cap
}

func (c *LRUCache) invalidateOldest() {
	oldest := c.l.Back()
	sd := oldest.Value.(*pgconn.StatementDescription)
	c.invalidStmts = append(c.invalidStmts, sd)
	delete(c.m, sd.SQL)
	c.l.Remove(oldest)
}