diff --git a/conf.go b/conf.go index 301533c..dd1be0a 100644 --- a/conf.go +++ b/conf.go @@ -4,12 +4,14 @@ import ( "encoding/json" "fmt" "os" - "path/filepath" ) -const CONFIG_LOG string = "/var/lib/gocustomurls/" +type Config struct { + MappingFilePath string + MappingRules ImportRulesMappings +} -type ConfigFile struct { +type ImportRulesMappings struct { Mappings []struct { Protocol string `json:"protocol"` VanityUrl string `json:"vanity_url"` @@ -17,6 +19,12 @@ type ConfigFile struct { } `json:"mappings"` } +type ImportRuleStruct struct { + VanityUrl string + Proto string + RepoUrl string +} + // isFile - check if fp is a valid file func isFile(fp string) bool { info, err := os.Stat(fp) @@ -27,26 +35,28 @@ func isFile(fp string) bool { } // load mapping file -func LoadFile() (ConfigFile, error) { - var mapping ConfigFile - dirname, err := os.UserConfigDir() +func (c *Config) LoadMappingFile(fp string) error { + var mapping ImportRulesMappings + mappingFilePath := fp + if len(c.MappingFilePath) == 0 { + ok := isFile(mappingFilePath) + if !ok { + return fmt.Errorf("%s is not found", mappingFilePath) + } + } + mappingFile, err := os.Open(mappingFilePath) if err != nil { - return mapping, err + return err } - configFilePath := filepath.Join(dirname, "gocustomurls/config.json") - ok := isFile(configFilePath) - if !ok { - return mapping, fmt.Errorf("%s/gocustomurls/config.json file is not found", dirname) - } - configFile, err := os.Open(configFilePath) - if err != nil { - return mapping, err - } - defer configFile.Close() + defer mappingFile.Close() - err = json.NewDecoder(configFile).Decode(&mapping) + err = json.NewDecoder(mappingFile).Decode(&mapping) if err != nil { - return mapping, err + return err } - return mapping, nil + c.MappingRules = mapping + if len(c.MappingFilePath) == 0 { + c.MappingFilePath = mappingFilePath + } + return nil } diff --git a/main.go b/main.go index 7905807..288ca4a 100644 --- a/main.go +++ b/main.go @@ -1,5 +1,115 @@ package main -func main() { +import ( + "context" + "errors" + "flag" + "fmt" + "log" + "net/http" + "os" + "os/signal" + "syscall" + "time" +) +var errorLog *log.Logger = log.New(os.Stderr, "", log.LstdFlags) + +const ( + DEFAULT_RULES_FILE string = "/var/lib/gocustomurls/rules.json" + DEFAULT_LOG_FILE string = "/var/log/gocustomurls/app.log" +) + +// flagsSet returns a set of all the flags what were actually set on the +// command line. +func flagsSet(flags *flag.FlagSet) map[string]bool { + s := make(map[string]bool) + flags.Visit(func(f *flag.Flag) { + s[f.Name] = true + }) + return s +} + +func main() { + programName := os.Args[0] + // errorLog = log.New(os.Stderr, "", log.LstdFlags) + + flags := flag.NewFlagSet(os.Args[0], flag.ExitOnError) + flags.Usage = func() { + out := flags.Output() + fmt.Fprintf(out, "Usage: %v [flags]\n\n", programName) + fmt.Fprint(out, " This utility serves vanity urls for the go get/install command.\n") + fmt.Fprint(out, " By default, the server listens on localhost:7070.\n") + flags.PrintDefaults() + } + + portFlag := flags.String("port", "7070", "port to listen to") + rulesFileFlag := flags.String("config", DEFAULT_RULES_FILE, "contains go-import mapping") + logFileFlag := flags.String("logfile", DEFAULT_LOG_FILE, "default log file") + flags.Parse(os.Args[1:]) + + if len(flags.Args()) > 1 { + errorLog.Println("Error: too many command-line arguments") + flags.Usage() + os.Exit(1) + } + + allSetFlags := flagsSet(flags) + + var port string + if allSetFlags["port"] { + port = *portFlag + } + + var rulesFile string + if allSetFlags["config"] { + rulesFile = *rulesFileFlag + } + + var logFile string + if allSetFlags["logFile"] { + logFile = *logFileFlag + } + + // load rules mapping + c := &Config{} + err := c.LoadMappingFile(rulesFile) + if err != nil { + errorLog.Println(err) + os.Exit(1) + } + l, err := newFileLogger(logFile) + if err != nil { + errorLog.Println(err) + os.Exit(1) + } + + app := &Application{ + Config: c, + Log: l, + } + srv := app.Setup(port) + + // For graceful shutdowns + go func() { + err := srv.ListenAndServe() + if !errors.Is(err, http.ErrServerClosed) { + errorLog.Printf("HTTP Server error: %+v\n", err) + os.Exit(1) + } + app.Log.logger.Println("Stopped serving new connections.") + }() + + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + <-sigChan + + shutdownCtx, shutdownRelease := context.WithTimeout(context.Background(), 10*time.Second) + defer shutdownRelease() + + if err := srv.Shutdown(shutdownCtx); err != nil { + errorLog.Printf("HTTP shutdown error: %+v\n", err) + os.Exit(1) + } + app.Log.logger.Println("Graceful shutdown complete.") }