package dao import ( _ "github.com/go-sql-driver/mysql" "github.com/spf13/viper" "os" "os/signal" "strings" "sync" "time" "xorm.io/xorm" ) var ( dbConnMap = make(map[string]*DBInfo) dbOnce sync.Once ) type DBInfo struct { Name string DB *xorm.Engine } func init() { viper.SetDefault("db.global.max_idle", 4000) viper.SetDefault("db.global.max_open", 8000) viper.SetDefault("db.global.max_lifetime", 8000000000) viper.SetDefault("db.spider.host", "root:123456@tcp(localhost:3306)") viper.SetDefault("db.spider.name", "spider") } // 获取特定数据库访问对象的实例 func DB(name string) *xorm.Engine { if len(strings.TrimSpace(name)) == 0 { return nil } openDB() if info, ok := dbConnMap[name]; ok { return info.DB } else { println("DB "+name+" disconnected") return nil } } // 打开所有数据库 func openDB() { dbOnce.Do(func() { // 一次程序运行仅执行一次,以达到“单例”的效果 readDBConfigAndOpen() go func() { defer closeDB() quit := make(chan os.Signal) signal.Notify(quit, os.Interrupt) <-quit }() }) } // 关闭所有数据库 func closeDB() { if nil == dbConnMap || len(dbConnMap) == 0 { return } for _, dbInfo := range dbConnMap { if nil != dbInfo && nil != dbInfo.DB { err := dbInfo.DB.Close() if nil != err { println("Close db "+dbInfo.Name+" failed:", err.Error()) continue } println("Close db " + dbInfo.Name) } } } // 读取配置文件并进行数据库连接,方法名首字母小写即为包私有方法,否则若大写则为公共方法 func readDBConfigAndOpen() { dbs := viper.GetStringMap("db") if nil == dbs || len(dbs) == 0 { println("not found any DB config") return } globalMaxIdle := viper.GetInt("db.global.max_idle") globalMaxOpen := viper.GetInt("db.global.max_open") globalMaxLifetime := viper.GetDuration("db.global.max_lifetime") for name := range dbs { key := name if key == "global" { continue } cnf := dbs[key].(map[string]interface{}) if nil == cnf { println("wrong DB config") return } var dbHost, dbName, maxIdle, maxOpen, maxLifetime interface{} var ok bool if dbHost, ok = cnf["host"]; !ok { println("wrong DB host config") return } if dbName, ok = cnf["name"]; !ok { println("wrong DB name config") return } if maxIdle, ok = cnf["max_idle"]; !ok { maxIdle = globalMaxIdle } if maxOpen, ok = cnf["max_open"]; !ok { maxOpen = globalMaxOpen } if maxLifetime, ok = cnf["max_lifetime"]; !ok { maxLifetime = globalMaxLifetime } newDB(name, dbHost.(string), dbName.(string), maxIdle.(int), maxOpen.(int), maxLifetime.(time.Duration)) } } func newDB(name string, dbHost string, dbName string, maxIdle int, maxOpen int, maxLifetime time.Duration) { db, err := xorm.NewEngine("mysql", dbHost+"/"+dbName+"?charset=utf8&parseTime=True&loc=Local") if err != nil { println("Open DB conn failed", err.Error()) return } // 设置连接池的空闲数大小 if maxIdle > 0 { db.SetMaxIdleConns(maxIdle) } // 设置最大打开连接数 if maxOpen > 0 { db.SetMaxOpenConns(maxOpen) } // 设置最大连接超时时间 if maxLifetime > 0 { db.SetConnMaxLifetime(maxLifetime) } // 连接测试 if err = db.DB().Ping(); err != nil { println("Can not conn to DB", err.Error()) return } println("Open db " + dbName) dbConnMap[name] = &DBInfo{Name: dbName, DB: db} }