diff --git a/cmd/app/list_discussions.go b/cmd/app/list_discussions.go index 74a9a33..dfa2817 100644 --- a/cmd/app/list_discussions.go +++ b/cmd/app/list_discussions.go @@ -85,10 +85,7 @@ func (a discussionsListerService) ServeHTTP(w http.ResponseWriter, r *http.Reque request := r.Context().Value(payload(payload("payload"))).(*DiscussionsRequest) mergeRequestDiscussionOptions := gitlab.ListMergeRequestDiscussionsOptions{ - ListOptions: gitlab.ListOptions{ - Page: 1, - PerPage: 250, - }, + ListOptions: gitlab.ListOptions{}, } it, hasErr := gitlab.Scan(func(p gitlab.PaginationOptionFunc) ([]*gitlab.Discussion, *gitlab.Response, error) { diff --git a/cmd/app/merge_requests_by_branch.go b/cmd/app/merge_requests_by_branch.go new file mode 100644 index 0000000..2d47bed --- /dev/null +++ b/cmd/app/merge_requests_by_branch.go @@ -0,0 +1,120 @@ +package app + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "sync" + + gitlab "gitlab.com/gitlab-org/api/client-go" +) + +type MergeRequestListerByBranch interface { + ListProjectMergeRequests(pid interface{}, opt *gitlab.ListProjectMergeRequestsOptions, options ...gitlab.RequestOptionFunc) ([]*gitlab.BasicMergeRequest, *gitlab.Response, error) +} + +type mergeRequestListerByBranchService struct { + data + client MergeRequestListerByBranch +} + +type MergeRequestByBranchRequest struct { + Branch string `json:"branch" validate:"required"` + State string `json:"state,omitempty"` +} + +// Returns a list of merge requests where the given branch is the source branch +func (a mergeRequestListerByBranchService) ServeHTTP(w http.ResponseWriter, r *http.Request) { + + request := r.Context().Value(payload("payload")).(*MergeRequestByBranchRequest) + + if request.State == "" { + request.State = "opened" + } + + payloads := []gitlab.ListProjectMergeRequestsOptions{ + { + SourceBranch: gitlab.Ptr(request.Branch), + State: gitlab.Ptr(request.State), + Scope: gitlab.Ptr("all"), + }, + } + + type apiResponse struct { + mrs []*gitlab.BasicMergeRequest + err error + } + + mrChan := make(chan apiResponse, len(payloads)) + wg := sync.WaitGroup{} + go func() { + wg.Wait() + close(mrChan) + }() + + for _, payload := range payloads { + wg.Add(1) + go func(p gitlab.ListProjectMergeRequestsOptions) { + defer wg.Done() + mrs, err := a.getMrs(&p) + mrChan <- apiResponse{mrs, err} + }(payload) + } + + var mergeRequests []*gitlab.BasicMergeRequest + existingIds := make(map[int64]bool) + var errs []error + for res := range mrChan { + if res.err != nil { + errs = append(errs, res.err) + } else { + for _, mr := range res.mrs { + if !existingIds[mr.ID] { + mergeRequests = append(mergeRequests, mr) + existingIds[mr.ID] = true + } + } + } + } + + if len(errs) > 0 { + combinedErr := "" + for _, err := range errs { + combinedErr += err.Error() + "; " + } + handleError(w, errors.New(combinedErr), "An error occurred", http.StatusInternalServerError) + return + } + + if len(mergeRequests) == 0 { + handleError(w, fmt.Errorf("%s did not have any MRs", request.Branch), "No MRs found", http.StatusNotFound) + return + } + + w.WriteHeader(http.StatusOK) + response := ListMergeRequestResponse{ + SuccessResponse: SuccessResponse{Message: fmt.Sprintf("Merge requests fetched for %s", request.Branch)}, + MergeRequests: mergeRequests, + } + + err := json.NewEncoder(w).Encode(response) + if err != nil { + handleError(w, err, "Could not encode response", http.StatusInternalServerError) + } +} + +func (a mergeRequestListerByBranchService) getMrs(payload *gitlab.ListProjectMergeRequestsOptions) ([]*gitlab.BasicMergeRequest, error) { + mrs, res, err := a.client.ListProjectMergeRequests(a.projectInfo.ProjectId, payload) + if err != nil { + return []*gitlab.BasicMergeRequest{}, err + } + + if res.StatusCode >= 300 { + return []*gitlab.BasicMergeRequest{}, GenericError{endpoint: "/merge_requests_by_branch"} + } + + defer res.Body.Close() + + return mrs, err +} diff --git a/cmd/app/server.go b/cmd/app/server.go index f143904..094894a 100644 --- a/cmd/app/server.go +++ b/cmd/app/server.go @@ -233,6 +233,11 @@ func CreateRouter(gitlabClient *Client, projectInfo *ProjectInfo, s *shutdownSer withPayloadValidation(methodToPayload{http.MethodPost: newPayload[gitlab.ListProjectMergeRequestsOptions]}), // TODO: How to validate external object withMethodCheck(http.MethodPost), )) + m.HandleFunc("/merge_requests_by_branch", middleware( + mergeRequestListerByBranchService{d, gitlabClient}, + withPayloadValidation(methodToPayload{http.MethodPost: newPayload[MergeRequestByBranchRequest]}), + withMethodCheck(http.MethodPost), + )) m.HandleFunc("/merge_requests_by_username", middleware( mergeRequestListerByUsernameService{d, gitlabClient}, withPayloadValidation(methodToPayload{http.MethodPost: newPayload[MergeRequestByUsernameRequest]}), diff --git a/lua/gitlab/init.lua b/lua/gitlab/init.lua index da066e6..85b0beb 100644 --- a/lua/gitlab/init.lua +++ b/lua/gitlab/init.lua @@ -29,6 +29,7 @@ local latest_pipeline = state.dependencies.latest_pipeline local revisions = state.dependencies.revisions local merge_requests_dep = state.dependencies.merge_requests local merge_requests_by_username_dep = state.dependencies.merge_requests_by_username +local merge_requests_by_branch_dep = state.dependencies.merge_requests_by_branch local draft_notes_dep = state.dependencies.draft_notes local discussion_data = state.dependencies.discussion_data @@ -117,6 +118,10 @@ return { { project_members, merge_requests_by_username_dep }, merge_requests.choose_merge_request ), + choose_merge_request_by_branch = async.sequence( + { merge_requests_by_branch_dep }, + merge_requests.choose_merge_request + ), open_in_browser = async.sequence({ info }, function() local web_url = u.get_web_url() if web_url ~= nil then diff --git a/lua/gitlab/state.lua b/lua/gitlab/state.lua index 5db54a1..591a00b 100644 --- a/lua/gitlab/state.lua +++ b/lua/gitlab/state.lua @@ -325,7 +325,7 @@ M.set_global_keymaps = function() if keymaps.global.choose_merge_request then vim.keymap.set("n", keymaps.global.choose_merge_request, function() - require("gitlab").choose_merge_request() + require("gitlab").choose_merge_request_by_branch() end, { desc = "Choose MR for review", nowait = keymaps.global.choose_merge_request_nowait }) end @@ -568,6 +568,24 @@ M.dependencies = { return opts end, }, + merge_requests_by_branch = { + endpoint = "/merge_requests_by_branch", + key = "merge_requests", + state = "MERGE_REQUESTS", + refresh = true, + method = "POST", + body = function(opts) + if not opts then + opts = {} + end + local branch = require("gitlab.git").get_current_branch() + if branch == nil then + error("Invalid payload, branch could not be found!") + end + opts.branch = branch + return opts + end, + }, discussion_data = { endpoint = "/mr/discussions/list", state = "DISCUSSION_DATA",