261 lines
6.4 KiB
Go
261 lines
6.4 KiB
Go
package roboflow
|
|
|
|
import (
|
|
"bytes"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"reflect"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestNewService(t *testing.T) {
|
|
// Save original env and restore after test
|
|
originalAPIKey := os.Getenv("ROBOFLOW_API_KEY")
|
|
defer os.Setenv("ROBOFLOW_API_KEY", originalAPIKey)
|
|
|
|
// Set test API key
|
|
testAPIKey := "test-api-key"
|
|
os.Setenv("ROBOFLOW_API_KEY", testAPIKey)
|
|
|
|
// Create new service
|
|
service := NewService()
|
|
|
|
// Check if API key is set correctly
|
|
if service.apiKey != testAPIKey {
|
|
t.Errorf("Expected apiKey to be %s, got %s", testAPIKey, service.apiKey)
|
|
}
|
|
}
|
|
|
|
func TestGetDetectedObjects(t *testing.T) {
|
|
// Create test response
|
|
response := &RoboflowResponse{
|
|
Predictions: []struct {
|
|
Class string `json:"class"`
|
|
Confidence float64 `json:"confidence"`
|
|
}{
|
|
{Class: "bottle", Confidence: 0.95},
|
|
{Class: "cup", Confidence: 0.85},
|
|
},
|
|
}
|
|
|
|
// Get detected objects
|
|
objects := GetDetectedObjects(response)
|
|
|
|
// Check if objects are extracted correctly
|
|
expected := []string{"bottle", "cup"}
|
|
if !reflect.DeepEqual(objects, expected) {
|
|
t.Errorf("Expected objects %v, got %v", expected, objects)
|
|
}
|
|
}
|
|
|
|
// MockHTTPClient allows us to mock HTTP responses
|
|
type MockHTTPClient struct {
|
|
DoFunc func(req *http.Request) (*http.Response, error)
|
|
}
|
|
|
|
func (m *MockHTTPClient) Do(req *http.Request) (*http.Response, error) {
|
|
return m.DoFunc(req)
|
|
}
|
|
|
|
func TestAnalyzeImage(t *testing.T) {
|
|
// Setup
|
|
originalClient := http.DefaultClient
|
|
defer func() { http.DefaultClient = originalClient }()
|
|
|
|
// Test cases
|
|
testCases := []struct {
|
|
name string
|
|
responseStatus int
|
|
responseBody string
|
|
expectedError bool
|
|
expectedResult *RoboflowResponse
|
|
}{
|
|
{
|
|
name: "Successful API response",
|
|
responseStatus: http.StatusOK,
|
|
responseBody: `{"predictions":[{"class":"bottle","confidence":0.95}]}`,
|
|
expectedError: false,
|
|
expectedResult: &RoboflowResponse{
|
|
Predictions: []struct {
|
|
Class string `json:"class"`
|
|
Confidence float64 `json:"confidence"`
|
|
}{
|
|
{Class: "bottle", Confidence: 0.95},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "API error response",
|
|
responseStatus: http.StatusInternalServerError,
|
|
responseBody: `{"error":"Internal server error"}`,
|
|
expectedError: true,
|
|
expectedResult: nil,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
// Mock HTTP client
|
|
mockClient := &http.Client{
|
|
Transport: &mockTransport{
|
|
response: &http.Response{
|
|
StatusCode: tc.responseStatus,
|
|
Body: io.NopCloser(strings.NewReader(tc.responseBody)),
|
|
Header: make(http.Header),
|
|
},
|
|
},
|
|
}
|
|
http.DefaultClient = mockClient
|
|
|
|
service := &Service{apiKey: "api-key"}
|
|
result, err := service.AnalyzeImage([]byte("test-image-data"))
|
|
|
|
// Check error
|
|
if tc.expectedError && err == nil {
|
|
t.Error("Expected error but got nil")
|
|
}
|
|
if !tc.expectedError && err != nil {
|
|
t.Errorf("Expected no error but got: %v", err)
|
|
}
|
|
|
|
// Check result
|
|
if !tc.expectedError {
|
|
if result == nil {
|
|
t.Error("Expected result but got nil")
|
|
} else if len(result.Predictions) != len(tc.expectedResult.Predictions) {
|
|
t.Errorf("Expected %d predictions, got %d",
|
|
len(tc.expectedResult.Predictions),
|
|
len(result.Predictions))
|
|
} else {
|
|
for i, pred := range result.Predictions {
|
|
expected := tc.expectedResult.Predictions[i]
|
|
if pred.Class != expected.Class || pred.Confidence != expected.Confidence {
|
|
t.Errorf("Expected prediction %v, got %v", expected, pred)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// mockTransport is a mock implementation of RoundTripper
|
|
type mockTransport struct {
|
|
response *http.Response
|
|
}
|
|
|
|
func (m *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
return m.response, nil
|
|
}
|
|
|
|
func TestHandleImageRequest(t *testing.T) {
|
|
// Test cases
|
|
testCases := []struct {
|
|
name string
|
|
requestBody string
|
|
expectedStatus int
|
|
}{
|
|
{
|
|
name: "Valid request",
|
|
requestBody: `{"image":"dGVzdC1pbWFnZS1kYXRh"}`, // Base64 of "test-image-data"
|
|
expectedStatus: http.StatusOK,
|
|
},
|
|
{
|
|
name: "Invalid JSON",
|
|
requestBody: `{invalid-json}`,
|
|
expectedStatus: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "Empty image",
|
|
requestBody: `{"image":""}`,
|
|
expectedStatus: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "Invalid base64",
|
|
requestBody: `{"image":"not-base64"}`,
|
|
expectedStatus: http.StatusBadRequest,
|
|
},
|
|
}
|
|
|
|
// Set up a mock for AnalyzeImage
|
|
var AnalyzeImage func(imageData []byte) (*RoboflowResponse, error)
|
|
originalAnalyzeImage := AnalyzeImage
|
|
defer func() {
|
|
AnalyzeImage = originalAnalyzeImage
|
|
}()
|
|
|
|
AnalyzeImage = func(imageData []byte) (*RoboflowResponse, error) {
|
|
return &RoboflowResponse{
|
|
Predictions: []struct {
|
|
Class string `json:"class"`
|
|
Confidence float64 `json:"confidence"`
|
|
}{
|
|
{Class: "test-class", Confidence: 0.95},
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
// Create request
|
|
req, err := http.NewRequest("POST", "/analyze", bytes.NewBufferString(tc.requestBody))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Create recorder
|
|
rr := httptest.NewRecorder()
|
|
|
|
// Call handler
|
|
handler := http.HandlerFunc(HandleImageRequest)
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
// Check status
|
|
if status := rr.Code; status != tc.expectedStatus {
|
|
t.Errorf("Expected status %d, got %d", tc.expectedStatus, status)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRoboflowIntegration(t *testing.T) {
|
|
|
|
if os.Getenv("ROBOFLOW_API_KEY") == "" {
|
|
t.Skip("Skipping test; ROBOFLOW_API_KEY not set")
|
|
}
|
|
|
|
t.Run("AnalyzeImage with valid image", func(t *testing.T) {
|
|
// Load a test image (you should place a suitable test image in testdata/)
|
|
imageData, err := os.ReadFile("./test2.jpg")
|
|
require.NoError(t, err, "Failed to load test image")
|
|
|
|
service := NewService()
|
|
resp, err := service.AnalyzeImage(imageData)
|
|
|
|
require.NoError(t, err, "AnalyzeImage failed")
|
|
require.NotNil(t, resp, "Response should not be nil")
|
|
assert.NotEmpty(t, resp.Predictions, "Expected some predictions")
|
|
})
|
|
}
|
|
|
|
// TestMain sets up and tears down test environment
|
|
func TestMain(m *testing.M) {
|
|
// Setup
|
|
originalAPIKey := os.Getenv("ROBOFLOW_API_KEY")
|
|
os.Setenv("ROBOFLOW_API_KEY", "")
|
|
|
|
// Run tests
|
|
code := m.Run()
|
|
|
|
// Teardown
|
|
os.Setenv("ROBOFLOW_API_KEY", originalAPIKey)
|
|
|
|
os.Exit(code)
|
|
}
|