From 44b15190d40cfe6aeb27e8751dd646ac44db379d Mon Sep 17 00:00:00 2001 From: Patrick Steinhardt Date: Mon, 2 May 2022 09:40:49 +0200 Subject: tools: Move `module-updater` into top-level `tools/` directory Move the `module-updater` tool into the top-level `tools/` directory so that all of our custom build tools are in one place. This also makes its sources discoverable for our formatter. --- tools/module-updater/main.go | 324 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 324 insertions(+) create mode 100644 tools/module-updater/main.go (limited to 'tools/module-updater/main.go') diff --git a/tools/module-updater/main.go b/tools/module-updater/main.go new file mode 100644 index 000000000..08c5db2d0 --- /dev/null +++ b/tools/module-updater/main.go @@ -0,0 +1,324 @@ +package main + +import ( + "bytes" + "encoding/json" + "errors" + "flag" + "fmt" + "go/format" + "go/parser" + "go/token" + "io/fs" + "os" + "os/exec" + "path/filepath" + "strconv" + "strings" +) + +var skipDirs = map[string]bool{ + ".git": true, + ".gitlab": true, + "_build": true, + "_support": true, + "changelogs": true, + "danger": true, + "doc": true, + "proto/go/gitalypb": true, + "ruby": true, + "scripts": true, + "tools/protoc-gen-gitaly-lint/testdata": true, + "unreleased": true, +} + +func main() { + if err := changeModuleVersion(); err != nil { + fmt.Println(err.Error()) + os.Exit(1) + } +} + +func changeModuleVersion() error { + dirLocation, fromVersion, toVersion, err := getInput() + if err != nil { + return err + } + + moduleAbsRootPath, err := filepath.Abs(dirLocation) + if err != nil { + return fmt.Errorf("define absolute module path: %w", err) + } + + if err := verifyModulePath(moduleAbsRootPath); err != nil { + return err + } + + module, err := getModule(moduleAbsRootPath) + if err != nil { + return fmt.Errorf("define module path: %w", err) + } + + prev, next, err := normaliseDesiredImportReplacement(module, fromVersion, toVersion) + if err != nil { + return err + } + + if err := rewriteImports(moduleAbsRootPath, prev, next); err != nil { + return fmt.Errorf("re-write go imports: %s", err) + } + + if err := rewriteProto(moduleAbsRootPath, prev, next); err != nil { + return fmt.Errorf("re-write .proto files: %s", err) + } + + if err := rewriteGoMod(moduleAbsRootPath, next); err != nil { + return fmt.Errorf("re-write go.mod file: %s", err) + } + + return nil +} + +func getInput() (dirLocation string, fromVersion string, toVersion string, err error) { + flag.StringVar(&dirLocation, "dir", ".", "directory of the module with the go.mod file") + flag.StringVar(&fromVersion, "from", "", "module version to upgrade from") + flag.StringVar(&toVersion, "to", "", "module version to upgrade to") + flag.Parse() + return +} + +func isModuleVersion(input string) error { + if input == "" { + return errors.New("empty module version") + } + + if !strings.HasPrefix(input, "v") { + return fmt.Errorf("module version should start with 'v': %s", input) + } + + rawVersion, err := strconv.ParseInt(input[1:], 10, 64) + if err != nil { + return err + } + if rawVersion < 1 { + return fmt.Errorf("version number should be positive: %d", rawVersion) + } + + return nil +} + +func getModule(modDir string) (string, error) { + cmd := exec.Command("go", "mod", "edit", "-json") + cmd.Dir = modDir + data, err := cmd.Output() + if err != nil { + var exitErr *exec.ExitError + if errors.As(err, &exitErr) { + return "", fmt.Errorf("command %q: %v", strings.Join(cmd.Args, " "), exitErr.Stderr) + } + return "", fmt.Errorf("command %q: %w", strings.Join(cmd.Args, " "), err) + } + + modInfo := struct{ Module struct{ Path string } }{} + if err := json.Unmarshal(data, &modInfo); err != nil { + return "", err + } + + return modInfo.Module.Path, nil +} + +// rewriteImports rewrites go source files by replacing old import path for the module with the new one. +func rewriteImports(moduleAbsRootPath, prev, next string) error { + fileSet := token.NewFileSet() + if err := filepath.Walk(moduleAbsRootPath, func(path string, info fs.FileInfo, err error) error { + if err != nil { + return err + } + + if info.IsDir() { + relPath := strings.TrimPrefix(path, moduleAbsRootPath) + relPath = strings.TrimPrefix(relPath, "/") + if skipDirs[relPath] { + return fs.SkipDir + } + return nil + } + + if filepath.Ext(info.Name()) == ".go" { + fileSet.AddFile(path, fileSet.Base(), 0) + } + + return nil + }); err != nil { + return fmt.Errorf("scan directory for go source files: %w", err) + } + + seen := map[string]struct{}{} + var rerr error + fileSet.Iterate(func(file *token.File) bool { + if _, found := seen[file.Name()]; found { + return false + } + seen[file.Name()] = struct{}{} + + astFile, err := parser.ParseFile(fileSet, file.Name(), nil, parser.ParseComments) + if err != nil { + rerr = err + return false + } + + for _, imprt := range astFile.Imports { + oldImport, err := strconv.Unquote(imprt.Path.Value) + if err != nil { + rerr = fmt.Errorf("unquote import: %s :%w", imprt.Path.Value, err) + return false + } + + if newImport := strings.Replace(oldImport, prev, next, 1); newImport != oldImport { + imprt.EndPos = imprt.End() + imprt.Path.Value = strconv.Quote(newImport) + } + } + + f, err := os.Create(file.Name()) + if err != nil { + rerr = fmt.Errorf("open file %q: %w", file.Name(), err) + return false + } + + ferr := format.Node(f, fileSet, astFile) + cerr := f.Close() + + if ferr != nil { + rerr = fmt.Errorf("rewrite file %q: %w", file.Name(), err) + return false + } + + if cerr != nil { + rerr = fmt.Errorf("close file %q: %w", file.Name(), err) + return false + } + + return true + }) + + return rerr +} + +func normaliseDesiredImportReplacement(module, from, to string) (string, string, error) { + if err := isModuleVersion(from); err != nil { + return "", "", fmt.Errorf("invalid 'from' version: %w", err) + } + + if err := isModuleVersion(to); err != nil { + return "", "", fmt.Errorf("invalid 'to' version: %w", err) + } + + prev := module + next := filepath.Join(filepath.Dir(module), to) + current := filepath.Base(module) + if err := isModuleVersion(current); err == nil { + if current != from { + return "", "", fmt.Errorf("existing module version is %q, but 'from' specified as %q", current, from) + } + } else { + next = filepath.Join(module, to) + } + return prev, next, nil +} + +func verifyModulePath(moduleRootPath string) error { + st, err := os.Stat(moduleRootPath) + if err != nil { + return fmt.Errorf("inspect module root path: %w", err) + } + + if !st.IsDir() { + return fmt.Errorf("provided module root path is not a directory: %s", moduleRootPath) + } + + entries, err := os.ReadDir(moduleRootPath) + if err != nil { + return fmt.Errorf("inspect module root path: %w", err) + } + + var modExists bool + for _, entry := range entries { + if entry.Name() == "go.mod" { + modExists = true + break + } + } + + if !modExists { + return fmt.Errorf("provided module root path doesn't contain go.mod file: %s", moduleRootPath) + } + + return nil +} + +// rewriteProto re-write proto files by changing the go_package option declaration: +// 1. option go_package = "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"; +// 2. option go_package = "gitlab.com/gitlab-org/gitaly/v14/proto/go/gitalypb"; +// 4. option go_package = "gitlab.com/gitlab-org/gitaly/v15/proto/go/gitalypb"; +func rewriteProto(moduleAbsRootPath, prev, next string) error { + protoDirPath := filepath.Join(moduleAbsRootPath, "proto") + if err := filepath.Walk(protoDirPath, func(path string, info fs.FileInfo, err error) error { + if err != nil { + return err + } + + if info.IsDir() { + if info.Name() == "go" { + return fs.SkipDir + } + return nil + } + + if filepath.Ext(info.Name()) != ".proto" { + return nil + } + + data, err := os.ReadFile(path) + if err != nil { + return err + } + + var modified [][]byte + lines := bytes.Split(data, []byte{'\n'}) + for _, line := range lines { + tokens := bytes.Fields(line) + pckg := bytes.Join(tokens, []byte{'~'}) + if !bytes.HasPrefix(pckg, []byte(`option~go_package~=~"`+prev+`/proto/go/gitalypb";`)) { + modified = append(modified, line) + continue + } + + modified = append(modified, bytes.ReplaceAll(line, []byte(prev), []byte(next))) + } + if len(modified) == 0 { + return nil + } + modifiedData := bytes.Join(modified, []byte{'\n'}) + + if err := os.WriteFile(path, modifiedData, info.Mode()); err != nil { + return fmt.Errorf("write modified file content: %s, %w", info.Name(), err) + } + + return nil + }); err != nil { + return err + } + return nil +} + +// rewriteGoMod modifies name of the module in the go.mod file. +func rewriteGoMod(moduleAbsRootPath, next string) error { + cmd := exec.Command("go", "mod", "edit", "-module", next) + cmd.Dir = moduleAbsRootPath + if err := cmd.Run(); err != nil { + return fmt.Errorf("command %q: %w", strings.Join(cmd.Args, " "), err) + } + + return nil +} -- cgit v1.2.3