diff options
| author | Bobby <[email protected]> | 2025-07-07 11:58:04 +0530 |
|---|---|---|
| committer | Bobby <[email protected]> | 2025-07-07 11:58:04 +0530 |
| commit | 9e4d6b1e271032d14e55f16395343979276e8de5 (patch) | |
| tree | adb0915e2854113c401960e3c1f8cf97e5e7aca8 | |
| parent | 48c6de533c459a1bb923f292e43914689b1357df (diff) | |
| download | imageboard-9e4d6b1e271032d14e55f16395343979276e8de5.tar.xz imageboard-9e4d6b1e271032d14e55f16395343979276e8de5.zip | |
refactored config system with generic `Defaults` function, fixed database migration setup, and applied dry principles
| -rw-r--r-- | config/functions.go | 146 | ||||
| -rw-r--r-- | config/types.go | 13 | ||||
| -rw-r--r-- | database/database.go | 39 | ||||
| -rw-r--r-- | imageboard/main.go | 4 |
4 files changed, 152 insertions, 50 deletions
diff --git a/config/functions.go b/config/functions.go index 0ec2d8e..28ef3f7 100644 --- a/config/functions.go +++ b/config/functions.go @@ -51,17 +51,90 @@ func getEnvFloat64(key string, defaultVal float64) float64 { return defaultVal
}
-func Parse(config interface{}) error {
+func setFieldFromEnv(field reflect.Value, envKey, defaultVal string) {
+ switch field.Kind() {
+ case reflect.String:
+ field.SetString(getEnv(envKey, defaultVal))
+ case reflect.Bool:
+ defaultBool, _ := strconv.ParseBool(defaultVal)
+ field.SetBool(getEnvBool(envKey, defaultBool))
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ defaultInt, _ := strconv.ParseInt(defaultVal, 10, 64)
+ field.SetInt(getEnvInt64(envKey, defaultInt))
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ defaultUint, _ := strconv.ParseUint(defaultVal, 10, 64)
+ setUintField(field, envKey, defaultUint)
+ case reflect.Float32, reflect.Float64:
+ defaultFloat, _ := strconv.ParseFloat(defaultVal, 64)
+ field.SetFloat(getEnvFloat64(envKey, defaultFloat))
+ default:
+ setDurationField(field, envKey, defaultVal)
+ }
+}
+
+func setUintField(field reflect.Value, envKey string, defaultVal uint64) {
+ if value := os.Getenv(envKey); value != "" {
+ if parsed, err := strconv.ParseUint(value, 10, 64); err == nil {
+ field.SetUint(parsed)
+ return
+ }
+ }
+ field.SetUint(defaultVal)
+}
+
+func setDurationField(field reflect.Value, envKey, defaultVal string) {
+ if field.Type() == reflect.TypeOf(time.Duration(0)) {
+ defaultDuration, _ := time.ParseDuration(defaultVal)
+ field.Set(reflect.ValueOf(getEnvDuration(envKey, defaultDuration)))
+ }
+}
+
+func setFieldDefault(field reflect.Value, defaultVal string) {
+ switch field.Kind() {
+ case reflect.String:
+ field.SetString(defaultVal)
+ case reflect.Bool:
+ if defaultBool, err := strconv.ParseBool(defaultVal); err == nil {
+ field.SetBool(defaultBool)
+ }
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ if defaultInt, err := strconv.ParseInt(defaultVal, 10, 64); err == nil {
+ field.SetInt(defaultInt)
+ }
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ if defaultUint, err := strconv.ParseUint(defaultVal, 10, 64); err == nil {
+ field.SetUint(defaultUint)
+ }
+ case reflect.Float32, reflect.Float64:
+ if defaultFloat, err := strconv.ParseFloat(defaultVal, 64); err == nil {
+ field.SetFloat(defaultFloat)
+ }
+ default:
+ if field.Type() == reflect.TypeOf(time.Duration(0)) {
+ if defaultDuration, err := time.ParseDuration(defaultVal); err == nil {
+ field.Set(reflect.ValueOf(defaultDuration))
+ }
+ }
+ }
+}
+
+func validateConfigInput(config any) (reflect.Value, reflect.Type, error) {
v := reflect.ValueOf(config)
if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct {
- return fmt.Errorf("config must be a pointer to struct")
+ return reflect.Value{}, nil, fmt.Errorf("config must be a pointer to struct")
}
+ elem := v.Elem()
+ return elem, elem.Type(), nil
+}
- v = v.Elem()
- t := v.Type()
+func Parse(config any) error {
+ elem, t, err := validateConfigInput(config)
+ if err != nil {
+ return err
+ }
- for i := range v.NumField() {
- field := v.Field(i)
+ for i := range elem.NumField() {
+ field := elem.Field(i)
fieldType := t.Field(i)
if !field.CanSet() {
@@ -75,41 +148,38 @@ func Parse(config interface{}) error { continue
}
- switch field.Kind() {
- case reflect.String:
- field.SetString(getEnv(envKey, defaultVal))
-
- case reflect.Bool:
- defaultBool, _ := strconv.ParseBool(defaultVal)
- field.SetBool(getEnvBool(envKey, defaultBool))
-
- case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
- defaultInt, _ := strconv.ParseInt(defaultVal, 10, 64)
- field.SetInt(getEnvInt64(envKey, defaultInt))
-
- case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
- defaultUint, _ := strconv.ParseUint(defaultVal, 10, 64)
- if value := os.Getenv(envKey); value != "" {
- if parsed, err := strconv.ParseUint(value, 10, 64); err == nil {
- field.SetUint(parsed)
- continue
- }
- }
- field.SetUint(defaultUint)
+ setFieldFromEnv(field, envKey, defaultVal)
+ }
- case reflect.Float32, reflect.Float64:
- defaultFloat, _ := strconv.ParseFloat(defaultVal, 64)
- field.SetFloat(getEnvFloat64(envKey, defaultFloat))
+ return nil
+}
- default:
- if field.Type() == reflect.TypeOf(time.Duration(0)) {
- defaultDuration, _ := time.ParseDuration(defaultVal)
- field.Set(reflect.ValueOf(getEnvDuration(envKey, defaultDuration)))
- } else {
- return fmt.Errorf("unsupported field type: %s", field.Kind())
- }
+func Defaults[T any](config *T) *T {
+ v := reflect.ValueOf(config)
+ if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct {
+ return config
+ }
+
+ elem := v.Elem()
+ t := elem.Type()
+ newStruct := reflect.New(t)
+ newElem := newStruct.Elem()
+
+ for i := range elem.NumField() {
+ field := newElem.Field(i)
+ fieldType := t.Field(i)
+
+ if !field.CanSet() {
+ continue
+ }
+
+ defaultVal := fieldType.Tag.Get("default")
+ if defaultVal == "" {
+ continue
}
+
+ setFieldDefault(field, defaultVal)
}
- return nil
+ return newStruct.Interface().(*T)
}
diff --git a/config/types.go b/config/types.go index fc4ad3c..7786915 100644 --- a/config/types.go +++ b/config/types.go @@ -12,12 +12,13 @@ type ServerConfig struct { }
type DatabaseConfig struct {
- Host string `env:"DB_HOST" default:"localhost"`
- Port int `env:"DB_PORT" default:"5432"`
- Username string `env:"DB_USERNAME" default:"postgres"`
- Password string `env:"DB_PASSWORD" default:""`
- DatabaseName string `env:"DB_NAME" default:"imageboard"`
- SSLMode string `env:"DB_SSLMODE" default:"disable"`
+ Host string `env:"DB_HOST" default:"localhost"`
+ Port int `env:"DB_PORT" default:"5432"`
+ Username string `env:"DB_USERNAME" default:"postgres"`
+ Password string `env:"DB_PASSWORD" default:""`
+ DatabaseName string `env:"DB_NAME" default:"imageboard"`
+ SSLMode string `env:"DB_SSLMODE" default:"disable"`
+ WipeAndResetDatabase bool `env:"DB_WIPE_AND_RESET" default:"false"`
}
type SessionConfig struct {
diff --git a/database/database.go b/database/database.go index d063cfc..d284c0f 100644 --- a/database/database.go +++ b/database/database.go @@ -3,6 +3,7 @@ package database import ( "fmt" "imageboard/config" + "imageboard/models" "log" "gorm.io/driver/postgres" @@ -16,27 +17,39 @@ var ( ) func init() { - dsn := fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%d sslmode=%s", + dsn := fmt.Sprintf("host=%s port=%d user=%s dbname=%s sslmode=%s", config.Database.Host, + config.Database.Port, config.Database.Username, - config.Database.Password, config.Database.DatabaseName, - config.Database.Port, config.Database.SSLMode, ) + if config.Database.Password != "" { + dsn += fmt.Sprintf(" password=%s", config.Database.Password) + } + logLevel := logger.Silent if config.Server.IsDevMode { logLevel = logger.Info } - DB, err = gorm.Open(postgres.Open(dsn), &gorm.Config{ + dialector := postgres.Open(dsn) + + DB, err = gorm.Open(dialector, &gorm.Config{ Logger: logger.Default.LogMode(logLevel), }) if err != nil { log.Fatalf("failed to connect to database: %v", err) } + if config.Server.IsDevMode && config.Database.WipeAndResetDatabase { + if err := wipeAndResetDatabase(); err != nil { + log.Fatalf("failed to wipe and reset database: %v", err) + } + log.Println("Database wiped and reset successfully") + } + if err := autoMigrate(); err != nil { log.Fatalf("failed to auto migrate database: %v", err) } @@ -45,5 +58,21 @@ func init() { } func autoMigrate() error { - return DB.AutoMigrate() + return DB.AutoMigrate( + &models.User{}, + &models.Image{}, + &models.ImageSize{}, + &models.Tag{}, + &models.Comment{}, + ) +} + +func wipeAndResetDatabase() error { + if err := DB.Exec("DROP SCHEMA public CASCADE").Error; err != nil { + return err + } + if err := DB.Exec("CREATE SCHEMA public").Error; err != nil { + return err + } + return nil } diff --git a/imageboard/main.go b/imageboard/main.go index 49eda9c..fbba38c 100644 --- a/imageboard/main.go +++ b/imageboard/main.go @@ -8,6 +8,8 @@ import ( "imageboard/router"
"log"
+ _ "imageboard/database"
+
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/cors"
"github.com/gofiber/fiber/v2/middleware/helmet"
@@ -17,7 +19,7 @@ import ( )
func main() {
- if config.Server.AppSecret == "default_secret" {
+ if config.Server.AppSecret == config.Defaults(&config.Server).AppSecret {
log.Println("Warning: AppSecret is set to a default value which is not secure. Please set a strong random secret in your APP_SECRET environment variable or .env file.")
}
|
