base.go 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. package dao
  2. import (
  3. "errors"
  4. "fmt"
  5. "github.com/astaxie/beego/logs"
  6. "github.com/gogf/gf/v2/util/gconv"
  7. "gorm.io/gorm"
  8. )
  9. type Condition struct {
  10. Table string
  11. Where map[string]any
  12. OrderBy string
  13. Limit int
  14. Fields []string
  15. Offset int
  16. GroupBy string
  17. Having string
  18. }
  19. type Base struct {
  20. TableName string `gorm:"-" json:"-"`
  21. Db *gorm.DB
  22. }
  23. func newBase(tableName string, db *gorm.DB) *Base {
  24. b := Base{TableName: tableName, Db: db}
  25. return &b
  26. }
  27. func (c Base) DB() *gorm.DB {
  28. return c.Db
  29. }
  30. func (c Base) Split(sqlSplit *Condition) *gorm.DB {
  31. dbcon := c.DB()
  32. //if cfg.Debug.Mysql {
  33. // dbcon = dbcon.Debug()
  34. //}
  35. dbcon = dbcon.Table(c.TableName)
  36. if sqlSplit == nil {
  37. return dbcon
  38. }
  39. if sqlSplit.Table != "" {
  40. dbcon = dbcon.Table(sqlSplit.Table)
  41. }
  42. for key, val := range sqlSplit.Where {
  43. dbcon = dbcon.Where(key, val)
  44. }
  45. if len(sqlSplit.Fields) > 0 {
  46. dbcon = dbcon.Select(sqlSplit.Fields)
  47. }
  48. if sqlSplit.OrderBy != "" {
  49. dbcon = dbcon.Order(sqlSplit.OrderBy)
  50. }
  51. if sqlSplit.Limit != 0 {
  52. dbcon = dbcon.Limit(sqlSplit.Limit)
  53. }
  54. if sqlSplit.Offset != 0 {
  55. dbcon = dbcon.Offset(sqlSplit.Offset)
  56. }
  57. if sqlSplit.GroupBy != "" {
  58. dbcon = dbcon.Group(sqlSplit.GroupBy)
  59. }
  60. if sqlSplit.Having != "" {
  61. dbcon = dbcon.Having(sqlSplit.Having)
  62. }
  63. return dbcon
  64. }
  65. // Find
  66. // @Description: 根据条件查询数据库
  67. // @receiver c
  68. // @param sqlCondition 条件
  69. // @param data 对象指针,或者对象切片
  70. // @return error
  71. func (c Base) Find(sqlCondition *Condition, data any) error {
  72. return c.Split(sqlCondition).Find(data).Error
  73. }
  74. func (c Base) Raw(sql string, data any) error {
  75. return c.DB().Raw(sql).Scan(data).Error
  76. }
  77. func (c Base) Count(sqlCondition *Condition, num *int64) error {
  78. sqlCondition.Limit = -1
  79. sqlCondition.Offset = -1
  80. return c.Split(sqlCondition).Count(num).Error
  81. }
  82. func (c Base) Update(sqlCondition *Condition, data any) error {
  83. return c.Split(sqlCondition).Updates(data).Error
  84. }
  85. func (c Base) Delete(data map[string]any) error {
  86. if data == nil || len(data) == 0 {
  87. return errors.New("删除参数异常")
  88. }
  89. return c.DB().Table(c.TableName).Where(data).Delete(c).Error
  90. }
  91. func (c Base) Deletes(cond *Condition) error {
  92. return c.Split(cond).Delete(c).Error
  93. }
  94. func (c Base) Save(data any) error {
  95. return c.DB().Table(c.TableName).Save(data).Error
  96. }
  97. // Replace
  98. // @Description: 1、替代insert;2、原子更新某个字段
  99. // @receiver c
  100. // @param cond
  101. // @return error
  102. func (c Base) Replace(cond map[string]any) error {
  103. if id, ok := cond["id"]; !ok || gconv.Int(id) == 0 {
  104. delete(cond, "id")
  105. }
  106. field := ""
  107. values := ""
  108. updateStr := ""
  109. for key, val := range cond {
  110. field = fmt.Sprintf("%s,`%s`", field, key)
  111. values = fmt.Sprintf("%s,'%s'", values, gconv.String(val))
  112. updateStr = fmt.Sprintf("%s,`%s` = values(`%s`)", updateStr, key, key)
  113. }
  114. field = field[1:]
  115. values = values[1:]
  116. updateStr = updateStr[1:]
  117. sql := fmt.Sprintf(`insert into %s (%s) values (%s) ON DUPLICATE KEY UPDATE %s`, c.TableName, field, values, updateStr)
  118. return c.Exec(sql)
  119. }
  120. func (c Base) Insert(data any) error {
  121. return c.DB().Table(c.TableName).Create(data).Error
  122. }
  123. func (c Base) Exec(sql string) error {
  124. return c.DB().Exec(sql).Error
  125. }
  126. func (c Base) Truncate(table string) error {
  127. sql := fmt.Sprintf(`truncate table %s`, table)
  128. return c.Exec(sql)
  129. }
  130. func (c Base) GetValueByField(tableName, field, value, GetField string) string {
  131. var (
  132. resp = ""
  133. err error
  134. )
  135. sql := fmt.Sprintf(`select %s from %s where %s = "%s"`, GetField, tableName, field, value)
  136. err = c.Raw(sql, &resp)
  137. if err != nil {
  138. logs.Error("sql:%s查询数据异常 err:%s", sql, err.Error())
  139. }
  140. return resp
  141. }