aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBobby <[email protected]>2025-07-07 11:58:04 +0530
committerBobby <[email protected]>2025-07-07 11:58:04 +0530
commit9e4d6b1e271032d14e55f16395343979276e8de5 (patch)
treeadb0915e2854113c401960e3c1f8cf97e5e7aca8
parent48c6de533c459a1bb923f292e43914689b1357df (diff)
downloadimageboard-9e4d6b1e271032d14e55f16395343979276e8de5.tar.xz
imageboard-9e4d6b1e271032d14e55f16395343979276e8de5.zip
refactored config system with generic `Defaults` function, fixed database migration setup, and applied dry principles
-rw-r--r--config/functions.go146
-rw-r--r--config/types.go13
-rw-r--r--database/database.go39
-rw-r--r--imageboard/main.go4
4 files changed, 152 insertions, 50 deletions
diff --git a/config/functions.go b/config/functions.go
index 0ec2d8e..28ef3f7 100644
--- a/config/functions.go
+++ b/config/functions.go
@@ -51,17 +51,90 @@ func getEnvFloat64(key string, defaultVal float64) float64 {
return defaultVal
}
-func Parse(config interface{}) error {
+func setFieldFromEnv(field reflect.Value, envKey, defaultVal string) {
+ switch field.Kind() {
+ case reflect.String:
+ field.SetString(getEnv(envKey, defaultVal))
+ case reflect.Bool:
+ defaultBool, _ := strconv.ParseBool(defaultVal)
+ field.SetBool(getEnvBool(envKey, defaultBool))
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ defaultInt, _ := strconv.ParseInt(defaultVal, 10, 64)
+ field.SetInt(getEnvInt64(envKey, defaultInt))
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ defaultUint, _ := strconv.ParseUint(defaultVal, 10, 64)
+ setUintField(field, envKey, defaultUint)
+ case reflect.Float32, reflect.Float64:
+ defaultFloat, _ := strconv.ParseFloat(defaultVal, 64)
+ field.SetFloat(getEnvFloat64(envKey, defaultFloat))
+ default:
+ setDurationField(field, envKey, defaultVal)
+ }
+}
+
+func setUintField(field reflect.Value, envKey string, defaultVal uint64) {
+ if value := os.Getenv(envKey); value != "" {
+ if parsed, err := strconv.ParseUint(value, 10, 64); err == nil {
+ field.SetUint(parsed)
+ return
+ }
+ }
+ field.SetUint(defaultVal)
+}
+
+func setDurationField(field reflect.Value, envKey, defaultVal string) {
+ if field.Type() == reflect.TypeOf(time.Duration(0)) {
+ defaultDuration, _ := time.ParseDuration(defaultVal)
+ field.Set(reflect.ValueOf(getEnvDuration(envKey, defaultDuration)))
+ }
+}
+
+func setFieldDefault(field reflect.Value, defaultVal string) {
+ switch field.Kind() {
+ case reflect.String:
+ field.SetString(defaultVal)
+ case reflect.Bool:
+ if defaultBool, err := strconv.ParseBool(defaultVal); err == nil {
+ field.SetBool(defaultBool)
+ }
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ if defaultInt, err := strconv.ParseInt(defaultVal, 10, 64); err == nil {
+ field.SetInt(defaultInt)
+ }
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ if defaultUint, err := strconv.ParseUint(defaultVal, 10, 64); err == nil {
+ field.SetUint(defaultUint)
+ }
+ case reflect.Float32, reflect.Float64:
+ if defaultFloat, err := strconv.ParseFloat(defaultVal, 64); err == nil {
+ field.SetFloat(defaultFloat)
+ }
+ default:
+ if field.Type() == reflect.TypeOf(time.Duration(0)) {
+ if defaultDuration, err := time.ParseDuration(defaultVal); err == nil {
+ field.Set(reflect.ValueOf(defaultDuration))
+ }
+ }
+ }
+}
+
+func validateConfigInput(config any) (reflect.Value, reflect.Type, error) {
v := reflect.ValueOf(config)
if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct {
- return fmt.Errorf("config must be a pointer to struct")
+ return reflect.Value{}, nil, fmt.Errorf("config must be a pointer to struct")
}
+ elem := v.Elem()
+ return elem, elem.Type(), nil
+}
- v = v.Elem()
- t := v.Type()
+func Parse(config any) error {
+ elem, t, err := validateConfigInput(config)
+ if err != nil {
+ return err
+ }
- for i := range v.NumField() {
- field := v.Field(i)
+ for i := range elem.NumField() {
+ field := elem.Field(i)
fieldType := t.Field(i)
if !field.CanSet() {
@@ -75,41 +148,38 @@ func Parse(config interface{}) error {
continue
}
- switch field.Kind() {
- case reflect.String:
- field.SetString(getEnv(envKey, defaultVal))
-
- case reflect.Bool:
- defaultBool, _ := strconv.ParseBool(defaultVal)
- field.SetBool(getEnvBool(envKey, defaultBool))
-
- case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
- defaultInt, _ := strconv.ParseInt(defaultVal, 10, 64)
- field.SetInt(getEnvInt64(envKey, defaultInt))
-
- case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
- defaultUint, _ := strconv.ParseUint(defaultVal, 10, 64)
- if value := os.Getenv(envKey); value != "" {
- if parsed, err := strconv.ParseUint(value, 10, 64); err == nil {
- field.SetUint(parsed)
- continue
- }
- }
- field.SetUint(defaultUint)
+ setFieldFromEnv(field, envKey, defaultVal)
+ }
- case reflect.Float32, reflect.Float64:
- defaultFloat, _ := strconv.ParseFloat(defaultVal, 64)
- field.SetFloat(getEnvFloat64(envKey, defaultFloat))
+ return nil
+}
- default:
- if field.Type() == reflect.TypeOf(time.Duration(0)) {
- defaultDuration, _ := time.ParseDuration(defaultVal)
- field.Set(reflect.ValueOf(getEnvDuration(envKey, defaultDuration)))
- } else {
- return fmt.Errorf("unsupported field type: %s", field.Kind())
- }
+func Defaults[T any](config *T) *T {
+ v := reflect.ValueOf(config)
+ if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct {
+ return config
+ }
+
+ elem := v.Elem()
+ t := elem.Type()
+ newStruct := reflect.New(t)
+ newElem := newStruct.Elem()
+
+ for i := range elem.NumField() {
+ field := newElem.Field(i)
+ fieldType := t.Field(i)
+
+ if !field.CanSet() {
+ continue
+ }
+
+ defaultVal := fieldType.Tag.Get("default")
+ if defaultVal == "" {
+ continue
}
+
+ setFieldDefault(field, defaultVal)
}
- return nil
+ return newStruct.Interface().(*T)
}
diff --git a/config/types.go b/config/types.go
index fc4ad3c..7786915 100644
--- a/config/types.go
+++ b/config/types.go
@@ -12,12 +12,13 @@ type ServerConfig struct {
}
type DatabaseConfig struct {
- Host string `env:"DB_HOST" default:"localhost"`
- Port int `env:"DB_PORT" default:"5432"`
- Username string `env:"DB_USERNAME" default:"postgres"`
- Password string `env:"DB_PASSWORD" default:""`
- DatabaseName string `env:"DB_NAME" default:"imageboard"`
- SSLMode string `env:"DB_SSLMODE" default:"disable"`
+ Host string `env:"DB_HOST" default:"localhost"`
+ Port int `env:"DB_PORT" default:"5432"`
+ Username string `env:"DB_USERNAME" default:"postgres"`
+ Password string `env:"DB_PASSWORD" default:""`
+ DatabaseName string `env:"DB_NAME" default:"imageboard"`
+ SSLMode string `env:"DB_SSLMODE" default:"disable"`
+ WipeAndResetDatabase bool `env:"DB_WIPE_AND_RESET" default:"false"`
}
type SessionConfig struct {
diff --git a/database/database.go b/database/database.go
index d063cfc..d284c0f 100644
--- a/database/database.go
+++ b/database/database.go
@@ -3,6 +3,7 @@ package database
import (
"fmt"
"imageboard/config"
+ "imageboard/models"
"log"
"gorm.io/driver/postgres"
@@ -16,27 +17,39 @@ var (
)
func init() {
- dsn := fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%d sslmode=%s",
+ dsn := fmt.Sprintf("host=%s port=%d user=%s dbname=%s sslmode=%s",
config.Database.Host,
+ config.Database.Port,
config.Database.Username,
- config.Database.Password,
config.Database.DatabaseName,
- config.Database.Port,
config.Database.SSLMode,
)
+ if config.Database.Password != "" {
+ dsn += fmt.Sprintf(" password=%s", config.Database.Password)
+ }
+
logLevel := logger.Silent
if config.Server.IsDevMode {
logLevel = logger.Info
}
- DB, err = gorm.Open(postgres.Open(dsn), &gorm.Config{
+ dialector := postgres.Open(dsn)
+
+ DB, err = gorm.Open(dialector, &gorm.Config{
Logger: logger.Default.LogMode(logLevel),
})
if err != nil {
log.Fatalf("failed to connect to database: %v", err)
}
+ if config.Server.IsDevMode && config.Database.WipeAndResetDatabase {
+ if err := wipeAndResetDatabase(); err != nil {
+ log.Fatalf("failed to wipe and reset database: %v", err)
+ }
+ log.Println("Database wiped and reset successfully")
+ }
+
if err := autoMigrate(); err != nil {
log.Fatalf("failed to auto migrate database: %v", err)
}
@@ -45,5 +58,21 @@ func init() {
}
func autoMigrate() error {
- return DB.AutoMigrate()
+ return DB.AutoMigrate(
+ &models.User{},
+ &models.Image{},
+ &models.ImageSize{},
+ &models.Tag{},
+ &models.Comment{},
+ )
+}
+
+func wipeAndResetDatabase() error {
+ if err := DB.Exec("DROP SCHEMA public CASCADE").Error; err != nil {
+ return err
+ }
+ if err := DB.Exec("CREATE SCHEMA public").Error; err != nil {
+ return err
+ }
+ return nil
}
diff --git a/imageboard/main.go b/imageboard/main.go
index 49eda9c..fbba38c 100644
--- a/imageboard/main.go
+++ b/imageboard/main.go
@@ -8,6 +8,8 @@ import (
"imageboard/router"
"log"
+ _ "imageboard/database"
+
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/cors"
"github.com/gofiber/fiber/v2/middleware/helmet"
@@ -17,7 +19,7 @@ import (
)
func main() {
- if config.Server.AppSecret == "default_secret" {
+ if config.Server.AppSecret == config.Defaults(&config.Server).AppSecret {
log.Println("Warning: AppSecret is set to a default value which is not secure. Please set a strong random secret in your APP_SECRET environment variable or .env file.")
}