package openai_test

import (
	"errors"
	"net/http"
	"reflect"
	"testing"

	"github.com/sashabaranov/go-openai"
)

func TestAPIErrorUnmarshalJSON(t *testing.T) {
	type testCase struct {
		name      string
		response  string
		hasError  bool
		checkFunc func(t *testing.T, apiErr openai.APIError)
	}
	testCases := []testCase{
		// testcase for message field
		{
			name:     "parse succeeds when the message is string",
			response: `{"message":"foo","type":"invalid_request_error","param":null,"code":null}`,
			hasError: false,
			checkFunc: func(t *testing.T, apiErr openai.APIError) {
				assertAPIErrorMessage(t, apiErr, "foo")
			},
		},
		{
			name:     "parse succeeds when the message is array with single item",
			response: `{"message":["foo"],"type":"invalid_request_error","param":null,"code":null}`,
			hasError: false,
			checkFunc: func(t *testing.T, apiErr openai.APIError) {
				assertAPIErrorMessage(t, apiErr, "foo")
			},
		},
		{
			name:     "parse succeeds when the message is array with multiple items",
			response: `{"message":["foo", "bar", "baz"],"type":"invalid_request_error","param":null,"code":null}`,
			hasError: false,
			checkFunc: func(t *testing.T, apiErr openai.APIError) {
				assertAPIErrorMessage(t, apiErr, "foo, bar, baz")
			},
		},
		{
			name:     "parse succeeds when the message is empty array",
			response: `{"message":[],"type":"invalid_request_error","param":null,"code":null}`,
			hasError: false,
			checkFunc: func(t *testing.T, apiErr openai.APIError) {
				assertAPIErrorMessage(t, apiErr, "")
			},
		},
		{
			name:     "parse succeeds when the message is null",
			response: `{"message":null,"type":"invalid_request_error","param":null,"code":null}`,
			hasError: false,
			checkFunc: func(t *testing.T, apiErr openai.APIError) {
				assertAPIErrorMessage(t, apiErr, "")
			},
		},
		{
			name: "parse succeeds when the innerError is not exists (Azure Openai)",
			response: `{
						"message": "test message",
						"type": null,
						"param": "prompt",
						"code": "content_filter",
						"status": 400,
						"innererror": {
							"code": "ResponsibleAIPolicyViolation",
							"content_filter_result": {
								"hate": {
									"filtered": false,
									"severity": "safe"
								},
								"self_harm": {
									"filtered": false,
									"severity": "safe"
								},
								"sexual": {
									"filtered": true,
									"severity": "medium"
								},
								"violence": {
									"filtered": false,
									"severity": "safe"
								}
							}
						}
					}`,
			hasError: false,
			checkFunc: func(t *testing.T, apiErr openai.APIError) {
				assertAPIErrorInnerError(t, apiErr, &openai.InnerError{
					Code: "ResponsibleAIPolicyViolation",
					ContentFilterResults: openai.ContentFilterResults{
						Hate: openai.Hate{
							Filtered: false,
							Severity: "safe",
						},
						SelfHarm: openai.SelfHarm{
							Filtered: false,
							Severity: "safe",
						},
						Sexual: openai.Sexual{
							Filtered: true,
							Severity: "medium",
						},
						Violence: openai.Violence{
							Filtered: false,
							Severity: "safe",
						},
					},
				})
			},
		},
		{
			name:     "parse succeeds when the innerError is empty (Azure Openai)",
			response: `{"message": "","type": null,"param": "","code": "","status": 0,"innererror": {}}`,
			hasError: false,
			checkFunc: func(t *testing.T, apiErr openai.APIError) {
				assertAPIErrorInnerError(t, apiErr, &openai.InnerError{})
			},
		},
		{
			name:     "parse succeeds when the innerError is not InnerError struct (Azure Openai)",
			response: `{"message": "","type": null,"param": "","code": "","status": 0,"innererror": "test"}`,
			hasError: true,
			checkFunc: func(t *testing.T, apiErr openai.APIError) {
				assertAPIErrorInnerError(t, apiErr, &openai.InnerError{})
			},
		},
		{
			name:     "parse failed when the message is object",
			response: `{"message":{},"type":"invalid_request_error","param":null,"code":null}`,
			hasError: true,
		},
		{
			name:     "parse failed when the message is int",
			response: `{"message":1,"type":"invalid_request_error","param":null,"code":null}`,
			hasError: true,
		},
		{
			name:     "parse failed when the message is float",
			response: `{"message":0.1,"type":"invalid_request_error","param":null,"code":null}`,
			hasError: true,
		},
		{
			name:     "parse failed when the message is bool",
			response: `{"message":true,"type":"invalid_request_error","param":null,"code":null}`,
			hasError: true,
		},
		{
			name:     "parse failed when the message is not exists",
			response: `{"type":"invalid_request_error","param":null,"code":null}`,
			hasError: true,
		},
		// testcase for code field
		{
			name:     "parse succeeds when the code is int",
			response: `{"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`,
			hasError: false,
			checkFunc: func(t *testing.T, apiErr openai.APIError) {
				assertAPIErrorCode(t, apiErr, 418)
			},
		},
		{
			name:     "parse succeeds when the code is string",
			response: `{"code":"teapot","message":"I'm a teapot","param":"prompt","type":"teapot_error"}`,
			hasError: false,
			checkFunc: func(t *testing.T, apiErr openai.APIError) {
				assertAPIErrorCode(t, apiErr, "teapot")
			},
		},
		{
			name:     "parse succeeds when the code is not exists",
			response: `{"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`,
			hasError: false,
			checkFunc: func(t *testing.T, apiErr openai.APIError) {
				assertAPIErrorCode(t, apiErr, nil)
			},
		},
		// testcase for param field
		{
			name:     "parse failed when the param is bool",
			response: `{"code":418,"message":"I'm a teapot","param":true,"type":"teapot_error"}`,
			hasError: true,
		},
		// testcase for type field
		{
			name:     "parse failed when the type is bool",
			response: `{"code":418,"message":"I'm a teapot","param":"prompt","type":true}`,
			hasError: true,
		},
		// testcase for error response
		{
			name:     "parse failed when the response is invalid json",
			response: `--- {"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`,
			hasError: true,
			checkFunc: func(t *testing.T, apiErr openai.APIError) {
				assertAPIErrorCode(t, apiErr, nil)
				assertAPIErrorMessage(t, apiErr, "")
				assertAPIErrorParam(t, apiErr, nil)
				assertAPIErrorType(t, apiErr, "")
			},
		},
	}
	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			var apiErr openai.APIError
			err := apiErr.UnmarshalJSON([]byte(tc.response))
			if (err != nil) != tc.hasError {
				t.Errorf("Unexpected error: %v", err)
			}
			if tc.checkFunc != nil {
				tc.checkFunc(t, apiErr)
			}
		})
	}
}

func assertAPIErrorMessage(t *testing.T, apiErr openai.APIError, expected string) {
	if apiErr.Message != expected {
		t.Errorf("Unexpected APIError message: %v; expected: %s", apiErr, expected)
	}
}

func assertAPIErrorInnerError(t *testing.T, apiErr openai.APIError, expected interface{}) {
	if !reflect.DeepEqual(apiErr.InnerError, expected) {
		t.Errorf("Unexpected APIError InnerError: %v; expected: %v; ", apiErr, expected)
	}
}

func assertAPIErrorCode(t *testing.T, apiErr openai.APIError, expected interface{}) {
	switch v := apiErr.Code.(type) {
	case int:
		if v != expected {
			t.Errorf("Unexpected APIError code integer: %d; expected %d", v, expected)
		}
	case string:
		if v != expected {
			t.Errorf("Unexpected APIError code string: %s; expected %s", v, expected)
		}
	case nil:
	default:
		t.Errorf("Unexpected APIError error code type: %T", v)
	}
}

func assertAPIErrorParam(t *testing.T, apiErr openai.APIError, expected *string) {
	if apiErr.Param != expected {
		t.Errorf("Unexpected APIError param: %v; expected: %s", apiErr, *expected)
	}
}

func assertAPIErrorType(t *testing.T, apiErr openai.APIError, typ string) {
	if apiErr.Type != typ {
		t.Errorf("Unexpected API type: %v; expected: %s", apiErr, typ)
	}
}

func TestRequestError(t *testing.T) {
	var err error = &openai.RequestError{
		HTTPStatusCode: http.StatusTeapot,
		Err:            errors.New("i am a teapot"),
	}

	var reqErr *openai.RequestError
	if !errors.As(err, &reqErr) {
		t.Fatalf("Error is not a RequestError: %+v", err)
	}

	if reqErr.HTTPStatusCode != 418 {
		t.Fatalf("Unexpected request error status code: %d", reqErr.HTTPStatusCode)
	}

	if reqErr.Unwrap() == nil {
		t.Fatalf("Empty request error occurred")
	}
}
