// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

package hicli

import (
	"context"
	"encoding/json"
	"errors"
	"fmt"
	"net/url"
	"strings"
	"time"

	"github.com/rs/zerolog"
	"maunium.net/go/mautrix"
	"maunium.net/go/mautrix/event"
	"maunium.net/go/mautrix/id"
	"maunium.net/go/mautrix/pushrules"

	"go.mau.fi/gomuks/pkg/hicli/database"
	"go.mau.fi/gomuks/pkg/hicli/jsoncmd"
)

func (h *HiClient) handleJSONCommand(ctx context.Context, req *JSONCommand) (any, error) {
	switch req.Command {
	case jsoncmd.ReqGetState:
		return h.State(), nil
	case jsoncmd.ReqCancel:
		return unmarshalAndCall(req.Data, func(params *jsoncmd.CancelRequestParams) (bool, error) {
			h.jsonRequestsLock.Lock()
			cancelTarget, ok := h.jsonRequests[params.RequestID]
			h.jsonRequestsLock.Unlock()
			if !ok {
				return false, nil
			}
			if params.Reason == "" {
				cancelTarget(nil)
			} else {
				cancelTarget(errors.New(params.Reason))
			}
			return true, nil
		})
	case jsoncmd.ReqSendMessage:
		return unmarshalAndCall(req.Data, func(params *jsoncmd.SendMessageParams) (*database.Event, error) {
			return h.SendMessage(ctx, params.RoomID, params.BaseContent, params.Extra, params.Text, params.RelatesTo, params.Mentions, params.URLPreviews)
		})
	case jsoncmd.ReqSendEvent:
		return unmarshalAndCall(req.Data, func(params *jsoncmd.SendEventParams) (*database.Event, error) {
			return h.Send(ctx, params.RoomID, params.EventType, params.Content, params.DisableEncryption, params.Synchronous)
		})
	case jsoncmd.ReqResendEvent:
		return unmarshalAndCall(req.Data, func(params *jsoncmd.ResendEventParams) (*database.Event, error) {
			return h.Resend(ctx, params.TransactionID)
		})
	case jsoncmd.ReqReportEvent:
		return unmarshalAndCall(req.Data, func(params *jsoncmd.ReportEventParams) (bool, error) {
			return true, h.Client.ReportEvent(ctx, params.RoomID, params.EventID, params.Reason)
		})
	case jsoncmd.ReqRedactEvent:
		return unmarshalAndCall(req.Data, func(params *jsoncmd.RedactEventParams) (*mautrix.RespSendEvent, error) {
			return h.Client.RedactEvent(ctx, params.RoomID, params.EventID, mautrix.ReqRedact{
				Reason: params.Reason,
			})
		})
	case jsoncmd.ReqSetState:
		return unmarshalAndCall(req.Data, func(params *jsoncmd.SendStateEventParams) (id.EventID, error) {
			return h.SetState(ctx, params.RoomID, params.EventType, params.StateKey, params.Content, mautrix.ReqSendEvent{
				UnstableDelay: time.Duration(params.DelayMS) * time.Millisecond,
			})
		})
	case jsoncmd.ReqUpdateDelayedEvent:
		return unmarshalAndCall(req.Data, func(params *jsoncmd.UpdateDelayedEventParams) (*mautrix.RespUpdateDelayedEvent, error) {
			return h.Client.UpdateDelayedEvent(ctx, &mautrix.ReqUpdateDelayedEvent{
				DelayID: params.DelayID,
				Action:  params.Action,
			})
		})
	case jsoncmd.ReqSetMembership:
		return unmarshalAndCall(req.Data, func(params *jsoncmd.SetMembershipParams) (any, error) {
			switch params.Action {
			case "invite":
				return h.Client.InviteUser(ctx, params.RoomID, &mautrix.ReqInviteUser{UserID: params.UserID, Reason: params.Reason})
			case "kick":
				return h.Client.KickUser(ctx, params.RoomID, &mautrix.ReqKickUser{UserID: params.UserID, Reason: params.Reason})
			case "ban":
				return h.Client.BanUser(ctx, params.RoomID, &mautrix.ReqBanUser{UserID: params.UserID, Reason: params.Reason, MSC4293RedactEvents: params.MSC4293RedactEvents})
			case "unban":
				return h.Client.UnbanUser(ctx, params.RoomID, &mautrix.ReqUnbanUser{UserID: params.UserID, Reason: params.Reason})
			default:
				return nil, fmt.Errorf("unknown action %q", params.Action)
			}
		})
	case jsoncmd.ReqSetAccountData:
		return unmarshalAndCall(req.Data, func(params *jsoncmd.SetAccountDataParams) (bool, error) {
			if params.RoomID != "" {
				return true, h.Client.SetRoomAccountData(ctx, params.RoomID, params.Type, params.Content)
			} else {
				return true, h.Client.SetAccountData(ctx, params.Type, params.Content)
			}
		})
	case jsoncmd.ReqMarkRead:
		return unmarshalAndCall(req.Data, func(params *jsoncmd.MarkReadParams) (bool, error) {
			return true, h.MarkRead(ctx, params.RoomID, params.EventID, params.ReceiptType)
		})
	case jsoncmd.ReqSetTyping:
		return unmarshalAndCall(req.Data, func(params *jsoncmd.SetTypingParams) (bool, error) {
			return true, h.SetTyping(ctx, params.RoomID, time.Duration(params.Timeout)*time.Millisecond)
		})
	case jsoncmd.ReqGetProfile:
		return unmarshalAndCall(req.Data, func(params *jsoncmd.GetProfileParams) (*mautrix.RespUserProfile, error) {
			return h.Client.GetProfile(mautrix.WithMaxRetries(ctx, 0), params.UserID)
		})
	case jsoncmd.ReqSetProfileField:
		return unmarshalAndCall(req.Data, func(params *jsoncmd.SetProfileFieldParams) (bool, error) {
			return true, h.Client.SetProfileField(ctx, params.Field, params.Value)
		})
	case jsoncmd.ReqGetMutualRooms:
		return unmarshalAndCall(req.Data, func(params *jsoncmd.GetProfileParams) ([]id.RoomID, error) {
			return h.GetMutualRooms(mautrix.WithMaxRetries(ctx, 0), params.UserID)
		})
	case jsoncmd.ReqTrackUserDevices:
		return unmarshalAndCall(req.Data, func(params *jsoncmd.GetProfileParams) (*jsoncmd.ProfileEncryptionInfo, error) {
			err := h.TrackUserDevices(ctx, params.UserID)
			if err != nil {
				return nil, err
			}
			return h.GetProfileEncryptionInfo(ctx, params.UserID)
		})
	case jsoncmd.ReqGetProfileEncryptionInfo:
		return unmarshalAndCall(req.Data, func(params *jsoncmd.GetProfileParams) (*jsoncmd.ProfileEncryptionInfo, error) {
			return h.GetProfileEncryptionInfo(ctx, params.UserID)
		})
	case jsoncmd.ReqGetEvent:
		return unmarshalAndCall(req.Data, func(params *jsoncmd.GetEventParams) (*database.Event, error) {
			if params.Unredact {
				return h.GetUnredactedEvent(mautrix.WithMaxRetries(ctx, 2), params.RoomID, params.EventID)
			}
			return h.GetEvent(mautrix.WithMaxRetries(ctx, 2), params.RoomID, params.EventID)
		})
	case jsoncmd.ReqGetRelatedEvents:
		return unmarshalAndCall(req.Data, func(params *jsoncmd.GetRelatedEventsParams) ([]*database.Event, error) {
			return nonNilArray(h.DB.Event.GetRelatedEvents(ctx, params.RoomID, params.EventID, params.RelationType))
		})
	case jsoncmd.ReqGetEventContext:
		return unmarshalAndCall(req.Data, func(params *jsoncmd.GetEventContextParams) (*jsoncmd.EventContextResponse, error) {
			return h.GetEventContext(mautrix.WithMaxRetries(ctx, 0), params.RoomID, params.EventID, params.Limit)
		})
	case jsoncmd.ReqPaginateManual:
		return unmarshalAndCall(req.Data, func(params *jsoncmd.PaginateManualParams) (*jsoncmd.ManualPaginationResponse, error) {
			return h.PaginateManual(mautrix.WithMaxRetries(ctx, 0), params.RoomID, params.ThreadRoot, params.Since, params.Direction, params.Limit)
		})
	case jsoncmd.ReqGetRoomState:
		return unmarshalAndCall(req.Data, func(params *jsoncmd.GetRoomStateParams) ([]*database.Event, error) {
			return h.GetRoomState(ctx, params.RoomID, params.IncludeMembers, params.FetchMembers, params.Refetch)
		})
	case jsoncmd.ReqGetSpecificRoomState:
		return unmarshalAndCall(req.Data, func(params *jsoncmd.GetSpecificRoomStateParams) ([]*database.Event, error) {
			return nonNilArray(h.DB.CurrentState.GetMany(ctx, params.Keys))
		})
	case jsoncmd.ReqGetReceipts:
		return unmarshalAndCall(req.Data, func(params *jsoncmd.GetReceiptsParams) (map[id.EventID][]*database.Receipt, error) {
			return h.GetReceipts(ctx, params.RoomID, params.EventIDs)
		})
	case jsoncmd.ReqPaginate:
		return unmarshalAndCall(req.Data, func(params *jsoncmd.PaginateParams) (*jsoncmd.PaginationResponse, error) {
			return h.Paginate(ctx, params.RoomID, params.MaxTimelineID, params.Limit, params.Reset)
		})
	case jsoncmd.ReqGetRoomSummary:
		return unmarshalAndCall(req.Data, func(params *jsoncmd.JoinRoomParams) (*mautrix.RespRoomSummary, error) {
			return h.Client.GetRoomSummary(mautrix.WithMaxRetries(ctx, 2), params.RoomIDOrAlias, params.Via...)
		})
	case jsoncmd.ReqGetSpaceHierarchy:
		return unmarshalAndCall(req.Data, func(params *jsoncmd.GetHierarchyParams) (*mautrix.RespHierarchy, error) {
			return h.Client.Hierarchy(mautrix.WithMaxRetries(ctx, 0), params.RoomID, &mautrix.ReqHierarchy{
				From:          params.From,
				Limit:         params.Limit,
				MaxDepth:      params.MaxDepth,
				SuggestedOnly: params.SuggestedOnly,
			})
		})
	case jsoncmd.ReqJoinRoom:
		return unmarshalAndCall(req.Data, func(params *jsoncmd.JoinRoomParams) (*mautrix.RespJoinRoom, error) {
			return h.Client.JoinRoom(mautrix.WithMaxRetries(ctx, 2), params.RoomIDOrAlias, &mautrix.ReqJoinRoom{
				Via:    params.Via,
				Reason: params.Reason,
			})
		})
	case jsoncmd.ReqKnockRoom:
		return unmarshalAndCall(req.Data, func(params *jsoncmd.JoinRoomParams) (*mautrix.RespKnockRoom, error) {
			return h.Client.KnockRoom(mautrix.WithMaxRetries(ctx, 2), params.RoomIDOrAlias, &mautrix.ReqKnockRoom{
				Via:    params.Via,
				Reason: params.Reason,
			})
		})
	case jsoncmd.ReqLeaveRoom:
		return unmarshalAndCall(req.Data, func(params *jsoncmd.LeaveRoomParams) (*mautrix.RespLeaveRoom, error) {
			resp, err := h.Client.LeaveRoom(mautrix.WithMaxRetries(ctx, 2), params.RoomID, &mautrix.ReqLeave{Reason: params.Reason})
			if err == nil ||
				errors.Is(err, mautrix.MNotFound) ||
				errors.Is(err, mautrix.MForbidden) ||
				// Synapse-specific hack: the server incorrectly returns M_UNKNOWN in some cases
				// instead of a sensible code like M_NOT_FOUND.
				strings.Contains(err.Error(), "Not a known room") {
				deleteInviteErr := h.DB.InvitedRoom.Delete(ctx, params.RoomID)
				if deleteInviteErr != nil {
					zerolog.Ctx(ctx).Err(deleteInviteErr).
						Stringer("room_id", params.RoomID).
						Msg("Failed to delete invite from database after leaving room")
				} else {
					zerolog.Ctx(ctx).Debug().
						Stringer("room_id", params.RoomID).
						Msg("Deleted invite from database after leaving room")
				}
			}
			return resp, err
		})
	case jsoncmd.ReqCreateRoom:
		return unmarshalAndCall(req.Data, func(params *mautrix.ReqCreateRoom) (*mautrix.RespCreateRoom, error) {
			return h.Client.CreateRoom(mautrix.WithMaxRetries(ctx, 0), params)
		})
	case jsoncmd.ReqMuteRoom:
		return unmarshalAndCall(req.Data, func(params *jsoncmd.MuteRoomParams) (bool, error) {
			if params.Muted {
				return true, h.Client.PutPushRule(ctx, "global", pushrules.RoomRule, string(params.RoomID), &mautrix.ReqPutPushRule{
					Actions: []pushrules.PushActionType{},
				})
			} else {
				return false, h.Client.DeletePushRule(ctx, "global", pushrules.RoomRule, string(params.RoomID))
			}
		})
	case jsoncmd.ReqEnsureGroupSessionShared:
		return unmarshalAndCall(req.Data, func(params *jsoncmd.EnsureGroupSessionSharedParams) (bool, error) {
			return true, h.EnsureGroupSessionShared(ctx, params.RoomID)
		})
	case jsoncmd.ReqSendToDevice:
		return unmarshalAndCall(req.Data, func(params *jsoncmd.SendToDeviceParams) (*mautrix.RespSendToDevice, error) {
			params.EventType.Class = event.ToDeviceEventType
			return h.SendToDevice(ctx, params.EventType, params.ReqSendToDevice, params.Encrypted)
		})
	case jsoncmd.ReqResolveAlias:
		return unmarshalAndCall(req.Data, func(params *jsoncmd.ResolveAliasParams) (*mautrix.RespAliasResolve, error) {
			return h.Client.ResolveAlias(mautrix.WithMaxRetries(ctx, 0), params.Alias)
		})
	case jsoncmd.ReqRequestOpenIDToken:
		return h.Client.RequestOpenIDToken(ctx)
	case jsoncmd.ReqLogout:
		if h.LogoutFunc == nil {
			return nil, errors.New("logout not supported")
		}
		return true, h.LogoutFunc(ctx)
	case jsoncmd.ReqLogin:
		return unmarshalAndCall(req.Data, func(params *jsoncmd.LoginParams) (bool, error) {
			err := h.LoginPassword(ctx, params.HomeserverURL, params.Username, params.Password)
			if err != nil {
				h.Log.Err(err).Msg("Failed to login")
			}
			return true, err
		})
	case jsoncmd.ReqLoginCustom:
		return unmarshalAndCall(req.Data, func(params *jsoncmd.LoginCustomParams) (bool, error) {
			var err error
			h.Client.HomeserverURL, err = url.Parse(params.HomeserverURL)
			if err != nil {
				return false, err
			}
			err = h.Login(ctx, params.Request)
			if err != nil {
				h.Log.Err(err).Msg("Failed to login")
			}
			return true, err
		})
	case jsoncmd.ReqVerify:
		return unmarshalAndCall(req.Data, func(params *jsoncmd.VerifyParams) (bool, error) {
			return true, h.Verify(ctx, params.RecoveryKey)
		})
	case jsoncmd.ReqDiscoverHomeserver:
		return unmarshalAndCall(req.Data, func(params *jsoncmd.DiscoverHomeserverParams) (*mautrix.ClientWellKnown, error) {
			_, homeserver, err := params.UserID.Parse()
			if err != nil {
				return nil, err
			}
			return mautrix.DiscoverClientAPI(ctx, homeserver)
		})
	case jsoncmd.ReqGetLoginFlows:
		return unmarshalAndCall(req.Data, func(params *jsoncmd.GetLoginFlowsParams) (*mautrix.RespLoginFlows, error) {
			cli, err := h.tempClient(params.HomeserverURL)
			if err != nil {
				return nil, err
			}
			err = h.checkServerVersions(ctx, cli)
			if err != nil {
				return nil, err
			}
			return cli.GetLoginFlows(ctx)
		})
	case jsoncmd.ReqRegisterPush:
		return unmarshalAndCall(req.Data, func(params *database.PushRegistration) (bool, error) {
			return true, h.DB.PushRegistration.Put(ctx, params)
		})
	case jsoncmd.ReqListenToDevice:
		return unmarshalAndCall(req.Data, func(listen *bool) (bool, error) {
			return h.ToDeviceInSync.Swap(*listen), nil
		})
	case jsoncmd.ReqGetTurnServers:
		return h.Client.TurnServer(ctx)
	case jsoncmd.ReqGetMediaConfig:
		return h.Client.GetMediaConfig(ctx)
	case jsoncmd.ReqCalculateRoomID:
		return unmarshalAndCall(req.Data, func(params *jsoncmd.CalculateRoomIDParams) (id.RoomID, error) {
			return h.CalculateRoomID(params.Timestamp, params.CreationContent)
		})
	default:
		return nil, fmt.Errorf("unknown command %q", req.Command)
	}
}

func nonNilArray[T any](arr []T, err error) ([]T, error) {
	if arr == nil && err == nil {
		return []T{}, nil
	}
	return arr, err
}

func unmarshalAndCall[T, O any](data json.RawMessage, fn func(*T) (O, error)) (output O, err error) {
	var input T
	err = json.Unmarshal(data, &input)
	if err != nil {
		return
	}
	return fn(&input)
}
