package logger import ( "context" "errors" "testing" "time" "go.uber.org/zap" "go.uber.org/zap/zapcore" "go.uber.org/zap/zaptest/observer" "gorm.io/gorm" gormlogger "gorm.io/gorm/logger" ) // newObservedLogger creates a zap logger backed by an observer for test assertions. func newObservedLogger() (*zap.Logger, *observer.ObservedLogs) { core, recorded := observer.New(zapcore.DebugLevel) return zap.New(core), recorded } func TestNewGormLogger_DefaultLevel(t *testing.T) { log, _ := newObservedLogger() gl := NewGormLogger(log, "") g, ok := gl.(*GormLogger) if !ok { t.Fatal("expected *GormLogger") } if g.level != zapcore.WarnLevel { t.Fatalf("expected warn level, got %v", g.level) } if g.silent { t.Fatal("expected silent=false") } } func TestNewGormLogger_Silent(t *testing.T) { log, _ := newObservedLogger() gl := NewGormLogger(log, "silent") g, ok := gl.(*GormLogger) if !ok { t.Fatal("expected *GormLogger") } if !g.silent { t.Fatal("expected silent=true") } } func TestNewGormLogger_ExplicitLevel(t *testing.T) { log, _ := newObservedLogger() gl := NewGormLogger(log, "error") g, ok := gl.(*GormLogger) if !ok { t.Fatal("expected *GormLogger") } if g.level != zapcore.ErrorLevel { t.Fatalf("expected error level, got %v", g.level) } } func TestNewGormLogger_InvalidLevel(t *testing.T) { log, _ := newObservedLogger() gl := NewGormLogger(log, "bogus") g, ok := gl.(*GormLogger) if !ok { t.Fatal("expected *GormLogger") } // Should default to warn if g.level != zapcore.WarnLevel { t.Fatalf("expected warn level fallback, got %v", g.level) } } func TestGormLogger_Info(t *testing.T) { log, recorded := newObservedLogger() gl := NewGormLogger(log, "info") gl.Info(context.Background(), "test info", "key", "value") entries := recorded.All() if len(entries) != 1 { t.Fatalf("expected 1 log entry, got %d", len(entries)) } if entries[0].Message != "test info" { t.Fatalf("expected message 'test info', got %q", entries[0].Message) } if entries[0].Level != zapcore.InfoLevel { t.Fatalf("expected InfoLevel, got %v", entries[0].Level) } } func TestGormLogger_Warn(t *testing.T) { log, recorded := newObservedLogger() gl := NewGormLogger(log, "warn") gl.Warn(context.Background(), "test warn", "code", 42) entries := recorded.All() if len(entries) != 1 { t.Fatalf("expected 1 log entry, got %d", len(entries)) } if entries[0].Message != "test warn" { t.Fatalf("expected message 'test warn', got %q", entries[0].Message) } if entries[0].Level != zapcore.WarnLevel { t.Fatalf("expected WarnLevel, got %v", entries[0].Level) } } func TestGormLogger_Error(t *testing.T) { log, recorded := newObservedLogger() gl := NewGormLogger(log, "error") gl.Error(context.Background(), "test error", "module", "gorm") entries := recorded.All() if len(entries) != 1 { t.Fatalf("expected 1 log entry, got %d", len(entries)) } if entries[0].Message != "test error" { t.Fatalf("expected message 'test error', got %q", entries[0].Message) } if entries[0].Level != zapcore.ErrorLevel { t.Fatalf("expected ErrorLevel, got %v", entries[0].Level) } } func TestGormLogger_LevelFiltering(t *testing.T) { log, recorded := newObservedLogger() gl := NewGormLogger(log, "warn") // Info should be suppressed at warn level gl.Info(context.Background(), "should be suppressed") if len(recorded.All()) != 0 { t.Fatal("info should be suppressed at warn level") } // Warn should pass gl.Warn(context.Background(), "should pass") if len(recorded.All()) != 1 { t.Fatal("warn should pass at warn level") } } func TestGormLogger_SilentSuppressesAll(t *testing.T) { log, recorded := newObservedLogger() gl := NewGormLogger(log, "silent") gl.Info(context.Background(), "info msg") gl.Warn(context.Background(), "warn msg") gl.Error(context.Background(), "error msg") if len(recorded.All()) != 0 { t.Fatalf("silent mode should suppress all logs, got %d", len(recorded.All())) } } func TestGormLogger_LogMode_ReturnsNewInstance(t *testing.T) { log, _ := newObservedLogger() original := NewGormLogger(log, "warn").(*GormLogger) modified := original.LogMode(gormlogger.Info) // Must be a different instance if modified == original { t.Fatal("LogMode should return a new instance") } // Original should be unchanged if original.level != zapcore.WarnLevel { t.Fatalf("original level should remain warn, got %v", original.level) } // New instance should have InfoLevel mod := modified.(*GormLogger) if mod.level != zapcore.InfoLevel { t.Fatalf("modified level should be info, got %v", mod.level) } } func TestGormLogger_LogMode_Silent(t *testing.T) { log, _ := newObservedLogger() gl := NewGormLogger(log, "info").(*GormLogger) silent := gl.LogMode(gormlogger.Silent).(*GormLogger) if !silent.silent { t.Fatal("LogMode(Silent) should set silent=true") } // Original should not be silent if gl.silent { t.Fatal("original should not be affected") } } func TestGormLogger_LogMode_ErrorLevel(t *testing.T) { log, _ := newObservedLogger() gl := NewGormLogger(log, "info").(*GormLogger) modified := gl.LogMode(gormlogger.Error).(*GormLogger) if modified.level != zapcore.ErrorLevel { t.Fatalf("expected ErrorLevel, got %v", modified.level) } } func TestGormLogger_LogMode_WarnLevel(t *testing.T) { log, _ := newObservedLogger() gl := NewGormLogger(log, "info").(*GormLogger) modified := gl.LogMode(gormlogger.Warn).(*GormLogger) if modified.level != zapcore.WarnLevel { t.Fatalf("expected WarnLevel, got %v", modified.level) } } func TestGormLogger_Trace_WithError(t *testing.T) { log, recorded := newObservedLogger() gl := NewGormLogger(log, "warn") begin := time.Now() fc := func() (string, int64) { return "SELECT * FROM users", 0 } err := errors.New("connection refused") gl.Trace(context.Background(), begin, fc, err) entries := recorded.All() if len(entries) != 1 { t.Fatalf("expected 1 entry, got %d", len(entries)) } if entries[0].Level != zapcore.ErrorLevel { t.Fatalf("expected ErrorLevel for real errors, got %v", entries[0].Level) } if entries[0].Message != "gorm query error" { t.Fatalf("unexpected message: %q", entries[0].Message) } } func TestGormLogger_Trace_ErrRecordNotFound(t *testing.T) { log, recorded := newObservedLogger() gl := NewGormLogger(log, "info") begin := time.Now() fc := func() (string, int64) { return "SELECT * FROM users WHERE id = ?", 0 } gl.Trace(context.Background(), begin, fc, gorm.ErrRecordNotFound) // ErrRecordNotFound should NOT be logged as error // At default level with no slow query, it should log as info entries := recorded.All() if len(entries) != 1 { t.Fatalf("expected 1 entry, got %d", len(entries)) } if entries[0].Level == zapcore.ErrorLevel { t.Fatal("ErrRecordNotFound should NOT be logged at ErrorLevel") } if entries[0].Message != "gorm query" { t.Fatalf("expected 'gorm query', got %q", entries[0].Message) } } func TestGormLogger_Trace_ErrRecordNotFound_SuppressedAtWarnLevel(t *testing.T) { log, recorded := newObservedLogger() gl := NewGormLogger(log, "warn") begin := time.Now() fc := func() (string, int64) { return "SELECT * FROM users WHERE id = ?", 0 } gl.Trace(context.Background(), begin, fc, gorm.ErrRecordNotFound) // ErrRecordNotFound is not a real error, so it falls through to the default path. // At warn level, the default path (info-level) is suppressed. entries := recorded.All() if len(entries) != 0 { t.Fatalf("expected 0 entries at warn level, got %d", len(entries)) } } func TestGormLogger_Trace_SlowQuery(t *testing.T) { log, recorded := newObservedLogger() gl := NewGormLogger(log, "warn") // Simulate a begin time far enough in the past to exceed the threshold begin := time.Now().Add(-500 * time.Millisecond) fc := func() (string, int64) { return "SELECT SLEEP(1)", 1 } gl.Trace(context.Background(), begin, fc, nil) entries := recorded.All() if len(entries) != 1 { t.Fatalf("expected 1 entry, got %d", len(entries)) } if entries[0].Level != zapcore.WarnLevel { t.Fatalf("expected WarnLevel for slow query, got %v", entries[0].Level) } if entries[0].Message != "gorm slow query" { t.Fatalf("expected 'gorm slow query', got %q", entries[0].Message) } } func TestGormLogger_Trace_NormalQuery(t *testing.T) { log, recorded := newObservedLogger() gl := NewGormLogger(log, "info") begin := time.Now() fc := func() (string, int64) { return "SELECT 1", 1 } gl.Trace(context.Background(), begin, fc, nil) entries := recorded.All() if len(entries) != 1 { t.Fatalf("expected 1 entry, got %d", len(entries)) } if entries[0].Level != zapcore.InfoLevel { t.Fatalf("expected InfoLevel for normal query, got %v", entries[0].Level) } if entries[0].Message != "gorm query" { t.Fatalf("expected 'gorm query', got %q", entries[0].Message) } } func TestGormLogger_Trace_NormalQuerySuppressedAtWarnLevel(t *testing.T) { log, recorded := newObservedLogger() gl := NewGormLogger(log, "warn") begin := time.Now() fc := func() (string, int64) { return "SELECT 1", 1 } gl.Trace(context.Background(), begin, fc, nil) // Normal queries are info-level, suppressed at warn if len(recorded.All()) != 0 { t.Fatal("normal query should be suppressed at warn level") } } func TestGormLogger_Trace_SilentSuppressesAll(t *testing.T) { log, recorded := newObservedLogger() gl := NewGormLogger(log, "silent") begin := time.Now().Add(-500 * time.Millisecond) fc := func() (string, int64) { return "SELECT SLEEP(1)", 1 } err := errors.New("some error") gl.Trace(context.Background(), begin, fc, err) if len(recorded.All()) != 0 { t.Fatal("silent mode should suppress even Trace errors") } } func TestGormLogger_Trace_StructuredFields(t *testing.T) { log, recorded := newObservedLogger() gl := NewGormLogger(log, "info") begin := time.Now() fc := func() (string, int64) { return "SELECT * FROM users", 42 } gl.Trace(context.Background(), begin, fc, nil) entries := recorded.All() if len(entries) != 1 { t.Fatalf("expected 1 entry, got %d", len(entries)) } // Verify structured fields fields := entries[0].ContextMap() if fields["sql"] != "SELECT * FROM users" { t.Fatalf("expected sql field, got %v", fields["sql"]) } if fields["rows"] != int64(42) { t.Fatalf("expected rows=42, got %v", fields["rows"]) } if _, ok := fields["elapsed"]; !ok { t.Fatal("expected elapsed field") } } func TestGormLogger_Trace_ErrorFields(t *testing.T) { log, recorded := newObservedLogger() gl := NewGormLogger(log, "warn") begin := time.Now() fc := func() (string, int64) { return "INSERT INTO users", 0 } err := errors.New("duplicate key") gl.Trace(context.Background(), begin, fc, err) entries := recorded.All() if len(entries) != 1 { t.Fatalf("expected 1 entry, got %d", len(entries)) } fields := entries[0].ContextMap() if fields["sql"] != "INSERT INTO users" { t.Fatalf("expected sql field, got %v", fields["sql"]) } if _, ok := fields["error"]; !ok { t.Fatal("expected error field") } } func TestGormLogger_Trace_SlowQueryFields(t *testing.T) { log, recorded := newObservedLogger() gl := NewGormLogger(log, "warn") begin := time.Now().Add(-500 * time.Millisecond) fc := func() (string, int64) { return "SELECT * FROM large_table", 1000 } gl.Trace(context.Background(), begin, fc, nil) entries := recorded.All() if len(entries) != 1 { t.Fatalf("expected 1 entry, got %d", len(entries)) } fields := entries[0].ContextMap() if fields["sql"] != "SELECT * FROM large_table" { t.Fatalf("expected sql field, got %v", fields["sql"]) } if fields["rows"] != int64(1000) { t.Fatalf("expected rows=1000, got %v", fields["rows"]) } if _, ok := fields["elapsed_ms"]; !ok { t.Fatal("expected elapsed_ms field for slow query") } } func TestArgsToFields(t *testing.T) { tests := []struct { name string args []interface{} expected int }{ {"empty", nil, 0}, {"single_pair", []interface{}{"key", "value"}, 1}, {"two_pairs", []interface{}{"a", 1, "b", 2}, 2}, {"odd_args_ignores_last", []interface{}{"key", "value", "orphan"}, 1}, {"non_string_key_ignored", []interface{}{123, "value"}, 0}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { fields := argsToFields(tt.args) if len(fields) != tt.expected { t.Fatalf("expected %d fields, got %d", tt.expected, len(fields)) } }) } } func TestArgsToFields_FieldValues(t *testing.T) { fields := argsToFields([]interface{}{"name", "test", "count", 42}) if len(fields) != 2 { t.Fatalf("expected 2 fields, got %d", len(fields)) } if fields[0].Key != "name" { t.Fatalf("expected key 'name', got %q", fields[0].Key) } if fields[1].Key != "count" { t.Fatalf("expected key 'count', got %q", fields[1].Key) } } func TestParseGormLevel(t *testing.T) { tests := []struct { input string expected zapcore.Level }{ {"debug", zapcore.DebugLevel}, {"info", zapcore.InfoLevel}, {"warn", zapcore.WarnLevel}, {"error", zapcore.ErrorLevel}, {"", zapcore.WarnLevel}, {"silent", zapcore.WarnLevel}, {"invalid", zapcore.WarnLevel}, } for _, tt := range tests { t.Run(tt.input, func(t *testing.T) { got := parseGormLevel(tt.input) if got != tt.expected { t.Fatalf("expected %v, got %v", tt.expected, got) } }) } } func TestGormLogger_InfoWithMultipleFields(t *testing.T) { log, recorded := newObservedLogger() gl := NewGormLogger(log, "info") gl.Info(context.Background(), "multi fields", "key1", "val1", "key2", 123, "key3", true) entries := recorded.All() if len(entries) != 1 { t.Fatalf("expected 1 entry, got %d", len(entries)) } fields := entries[0].ContextMap() if fields["key1"] != "val1" { t.Fatalf("expected key1=val1, got %v", fields["key1"]) } if fields["key2"] != int64(123) { t.Fatalf("expected key2=123, got %v", fields["key2"]) } if fields["key3"] != true { t.Fatalf("expected key3=true, got %v", fields["key3"]) } }