Files
junk2jive-server/internal/router/router_test.go
rogueking d975bea218
Some checks failed
golangci-lint / lint (push) Failing after 20s
Run Go Tests / build (push) Failing after 24s
Build and Push Docker Image / Build and push image (push) Successful in 2m32s
build / Build (push) Successful in 22s
testing fixes
2025-05-06 20:56:12 -07:00

153 lines
5.1 KiB
Go

package router
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/go-chi/chi/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewRouter(t *testing.T) {
r := NewRouter()
assert.NotNil(t, r, "Router should not be nil")
assert.NotNil(t, r.router, "Chi router should not be nil")
assert.Nil(t, r.logger, "Logger should be nil initially")
assert.Nil(t, r.rateLimiter, "RateLimiter should be nil initially")
}
func TestSetupRouter(t *testing.T) {
origins := []string{"http://localhost:3000"}
r := SetupRouter(origins)
assert.NotNil(t, r, "Router should not be nil")
assert.NotNil(t, r.router, "Chi router should not be nil")
assert.NotNil(t, r.logger, "Logger should not be nil")
assert.NotNil(t, r.rateLimiter, "RateLimiter should not be nil")
}
func TestGetRouter(t *testing.T) {
r := NewRouter()
chiRouter := r.GetRouter()
assert.Equal(t, r.router, chiRouter, "GetRouter should return the router field")
}
func TestCoffeeEndpoint(t *testing.T) {
origins := []string{"http://localhost:8080"}
r := SetupRouter(origins)
// Create a test server
ts := httptest.NewServer(r.GetRouter())
defer ts.Close()
// Make request to the coffee endpoint
resp, err := http.Get(ts.URL + "/v1/api/coffee")
require.NoError(t, err, "Error making request to coffee endpoint")
defer resp.Body.Close()
// Check status code
assert.Equal(t, http.StatusTeapot, resp.StatusCode, "Should return teapot status")
// Check response body
var responseBody map[string]string
err = json.NewDecoder(resp.Body).Decode(&responseBody)
require.NoError(t, err, "Error decoding response body")
assert.Equal(t, "I'm A Teapot!", responseBody["error"], "Response should contain teapot message")
}
func TestCORSMiddleware(t *testing.T) {
testCases := []struct {
name string
origins []string
requestOrigin string
expectedHeader string
}{
{
name: "Allowed origin",
origins: []string{"http://allowed-origin.com"},
requestOrigin: "http://allowed-origin.com",
expectedHeader: "http://allowed-origin.com",
},
{
name: "Multiple allowed origins",
origins: []string{"http://origin1.com", "http://origin2.com"},
requestOrigin: "http://origin2.com",
expectedHeader: "http://origin2.com",
},
{
name: "Wildcard origin",
origins: []string{"*"},
requestOrigin: "http://any-origin.com",
expectedHeader: "*",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
r := SetupRouter(tc.origins)
req := httptest.NewRequest("OPTIONS", "/v1/api/coffee", nil)
req.Header.Set("Origin", tc.requestOrigin)
req.Header.Set("Access-Control-Request-Method", "GET")
rr := httptest.NewRecorder()
r.GetRouter().ServeHTTP(rr, req)
// For wildcard origin or matching origin, CORS headers should be set
if tc.origins[0] == "*" || contains(tc.origins, tc.requestOrigin) {
assert.Equal(t, tc.expectedHeader, rr.Header().Get("Access-Control-Allow-Origin"),
"CORS origin header should match expected value")
assert.Contains(t, rr.Header().Get("Access-Control-Allow-Methods"), "GET",
"CORS methods should contain GET")
}
})
}
}
func TestRouterMiddlewares(t *testing.T) {
r := SetupRouter([]string{"http://localhost:3000"})
// Check if the router has middleware registered (indirect test)
// We can't easily test the actual middleware without mocking, but we can verify
// the router itself has routes registered
chiRouter := r.GetRouter()
routes := getAllRoutes(chiRouter)
assert.Contains(t, routes, "/v1/api/coffee", "Coffee endpoint should be registered")
assert.Contains(t, routes, "/v1/api/text", "Text endpoint should be registered")
assert.Contains(t, routes, "/v1/api/visual", "Visual endpoint should be registered")
}
// Helper functions
// contains checks if a slice contains a specific string
func contains(slice []string, item string) bool {
for _, s := range slice {
if s == item {
return true
}
}
return false
}
// getAllRoutes is a helper to extract all routes from a chi router
func getAllRoutes(router *chi.Mux) []string {
var routes []string
// This is a simplified way to check for routes
// Chi doesn't expose routes directly, so this is a proxy check
walkFunc := func(method string, route string, handler http.Handler, middlewares ...func(http.Handler) http.Handler) error {
routes = append(routes, route)
return nil
}
_ = chi.Walk(router, walkFunc)
return routes
}