Files
junk2jive-server/internal/limiter/limiter_test.go
rogueking 84e150cb11
Some checks failed
Build and Push Docker Image / Build and push image (push) Failing after 43s
Run Go Tests / build (push) Failing after 0s
golangci-lint / lint (push) Failing after 2m0s
added workflows and fixed router
2025-05-06 14:49:14 -07:00

137 lines
3.2 KiB
Go

package limiter
import (
"net"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/time/rate"
)
func TestGetIP(t *testing.T) {
req1 := &http.Request{RemoteAddr: "192.168.1.100:12345"}
ip1 := GetIP(req1)
assert.Equal(t, "192.168.1.100", ip1)
req2 := &http.Request{RemoteAddr: "invalid-address"}
ip2 := GetIP(req2)
assert.Equal(t, "invalid-address", ip2)
}
func TestGetLimiter(t *testing.T) {
rl := New()
limiter1 := rl.Get("/test", "127.0.0.1", 1, 1)
require.NotNil(t, limiter1)
limiter2 := rl.Get("/test", "127.0.0.1", 1, 1)
assert.Same(t, limiter1, limiter2)
limiter3 := rl.Get("/another", "127.0.0.1", 1, 1)
assert.NotSame(t, limiter1, limiter3)
limiter4 := rl.Get("/test", "192.168.0.1", 1, 1)
assert.NotSame(t, limiter1, limiter4)
}
func TestMiddleware(t *testing.T) {
rl := New()
rateLimit := rate.Limit(1)
burst := 1
var nextCalled int32
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.StoreInt32(&nextCalled, 1)
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("OK"))
})
handler := rl.Middleware(rateLimit, burst)(nextHandler)
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = net.JoinHostPort("127.0.0.1", "12345")
rr := httptest.NewRecorder()
atomic.StoreInt32(&nextCalled, 0)
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, int32(1), atomic.LoadInt32(&nextCalled))
rr = httptest.NewRecorder()
atomic.StoreInt32(&nextCalled, 0)
handler.ServeHTTP(rr, req)
assert.NotEqual(t, http.StatusOK, rr.Code)
assert.Equal(t, http.StatusTooManyRequests, rr.Code)
assert.Equal(t, int32(0), atomic.LoadInt32(&nextCalled))
}
func TestMiddlewareIndependent(t *testing.T) {
rl := New()
rateLimit := rate.Limit(1)
burst := 1
var handlerCallCount int32
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&handlerCallCount, 1)
w.WriteHeader(http.StatusOK)
})
middleware := rl.Middleware(rateLimit, burst)
req1 := httptest.NewRequest(http.MethodGet, "/path1", nil)
req1.RemoteAddr = net.JoinHostPort("127.0.0.1", "1111")
req2 := httptest.NewRequest(http.MethodGet, "/path2", nil)
req2.RemoteAddr = net.JoinHostPort("127.0.0.2", "2222")
rr1 := httptest.NewRecorder()
middleware(nextHandler).ServeHTTP(rr1, req1)
assert.Equal(t, http.StatusOK, rr1.Code)
rr2 := httptest.NewRecorder()
middleware(nextHandler).ServeHTTP(rr2, req2)
assert.Equal(t, http.StatusOK, rr2.Code)
assert.Equal(t, int32(2), atomic.LoadInt32(&handlerCallCount))
}
func TestLimiterConcurrency(t *testing.T) {
rl := New()
path := "/concurrent"
ip := "127.0.0.1"
const goroutines = 50
limiterCh := make(chan *rate.Limiter, goroutines)
doneCh := make(chan struct{})
for i := 0; i < goroutines; i++ {
go func() {
limiter := rl.Get(path, ip, 5, 10)
limiterCh <- limiter
doneCh <- struct{}{}
}()
}
for i := 0; i < goroutines; i++ {
<-doneCh
}
close(limiterCh)
var firstLimiter *rate.Limiter
for lim := range limiterCh {
if firstLimiter == nil {
firstLimiter = lim
} else {
assert.Equal(t, firstLimiter, lim)
}
}
}