You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ai_model_manage.go 8.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. package repo
  2. import (
  3. "errors"
  4. "fmt"
  5. "net/http"
  6. "path"
  7. "strings"
  8. "code.gitea.io/gitea/models"
  9. "code.gitea.io/gitea/modules/context"
  10. "code.gitea.io/gitea/modules/log"
  11. "code.gitea.io/gitea/modules/setting"
  12. "code.gitea.io/gitea/modules/storage"
  13. uuid "github.com/satori/go.uuid"
  14. )
  15. const (
  16. Model_prefix = "aimodels/"
  17. tplModelManageIndex = "repo/modelmanage/index"
  18. tplModelManageDownload = "repo/modelmanage/download"
  19. MODEL_LATEST = 1
  20. MODEL_NOT_LATEST = 0
  21. )
  22. func SaveModelByParameters(jobId string, name string, version string, label string, description string, userId int64) error {
  23. aiTask, err := models.GetCloudbrainByJobID(jobId)
  24. if err != nil {
  25. log.Info("query task error." + err.Error())
  26. return err
  27. }
  28. uuid := uuid.NewV4()
  29. id := uuid.String()
  30. modelPath := id
  31. var lastNewModelId string
  32. var modelSize int64
  33. cloudType := models.TypeCloudBrainTwo
  34. log.Info("find task name:" + aiTask.JobName)
  35. aimodels := models.QueryModelByName(name, aiTask.RepoID)
  36. if len(aimodels) > 0 {
  37. for _, model := range aimodels {
  38. if model.New == MODEL_LATEST {
  39. lastNewModelId = model.ID
  40. }
  41. }
  42. }
  43. cloudType = aiTask.Type
  44. //download model zip //train type
  45. if cloudType == models.TypeCloudBrainTwo {
  46. modelPath, modelSize, err = downloadModelFromCloudBrainTwo(id, aiTask.JobName, "")
  47. if err != nil {
  48. log.Info("download model from CloudBrainTwo faild." + err.Error())
  49. return err
  50. }
  51. }
  52. model := &models.AiModelManage{
  53. ID: id,
  54. Version: version,
  55. Label: label,
  56. Name: name,
  57. Description: description,
  58. New: MODEL_LATEST,
  59. Type: cloudType,
  60. Path: modelPath,
  61. Size: modelSize,
  62. AttachmentId: aiTask.Uuid,
  63. RepoId: aiTask.RepoID,
  64. UserId: userId,
  65. }
  66. models.SaveModelToDb(model)
  67. if len(lastNewModelId) > 0 {
  68. //udpate status
  69. models.ModifyModelNewProperty(lastNewModelId, MODEL_NOT_LATEST)
  70. }
  71. log.Info("save model end.")
  72. return nil
  73. }
  74. func SaveModel(ctx *context.Context) {
  75. log.Info("save model start.")
  76. JobId := ctx.Query("JobId")
  77. name := ctx.Query("Name")
  78. version := ctx.Query("Version")
  79. label := ctx.Query("Label")
  80. description := ctx.Query("Description")
  81. err := SaveModelByParameters(JobId, name, version, label, description, ctx.User.ID)
  82. if err != nil {
  83. log.Info("save model error." + err.Error())
  84. ctx.Error(500, fmt.Sprintf("save model error. %v", err))
  85. return
  86. }
  87. log.Info("save model end.")
  88. }
  89. func downloadModelFromCloudBrainTwo(modelUUID string, jobName string, parentDir string) (string, int64, error) {
  90. objectkey := strings.TrimPrefix(path.Join(setting.TrainJobModelPath, jobName, setting.OutPutPath, parentDir), "/")
  91. modelDbResult, err := storage.GetAllObsListObjectUnderDir(setting.Bucket, objectkey)
  92. if err != nil {
  93. log.Info("get TrainJobListModel failed:", err)
  94. return "", 0, err
  95. }
  96. if len(modelDbResult) == 0 {
  97. return "", 0, errors.New("cannot create model, as model is empty.")
  98. }
  99. prefix := objectkey + "/"
  100. destKeyNamePrefix := Model_prefix + models.AttachmentRelativePath(modelUUID) + "/"
  101. size, err := storage.ObsCopyManyFile(setting.Bucket, prefix, setting.Bucket, destKeyNamePrefix)
  102. // for _, modelFile := range modelDbResult {
  103. // if modelFile.IsDir {
  104. // log.Info("copy dir, continue. dir=" + modelFile.FileName)
  105. // continue
  106. // }
  107. // srcKeyName := prefix + modelFile.FileName
  108. // log.Info("copy file, bucket=" + setting.Bucket + ", src keyname=" + srcKeyName)
  109. // destKeyName := destKeyNamePrefix + modelFile.FileName
  110. // log.Info("Dest key name=" + destKeyName)
  111. // err := storage.ObsCopyFile(setting.Bucket, srcKeyName, setting.Bucket, destKeyName)
  112. // if err != nil {
  113. // log.Info("copy failed.")
  114. // }
  115. // size += modelFile.Size
  116. // }
  117. dataActualPath := setting.Bucket + "/" + destKeyNamePrefix
  118. return dataActualPath, size, nil
  119. }
  120. func DeleteModel(ctx *context.Context) {
  121. log.Info("delete model start.")
  122. id := ctx.Query("ID")
  123. err := DeleteModelByID(id)
  124. if err != nil {
  125. ctx.JSON(500, err.Error())
  126. } else {
  127. ctx.JSON(200, map[string]string{
  128. "result_code": "0",
  129. })
  130. }
  131. }
  132. func DeleteModelByID(id string) error {
  133. log.Info("delete model start. id=" + id)
  134. model, err := models.QueryModelById(id)
  135. if err == nil {
  136. log.Info("bucket=" + setting.Bucket + " path=" + model.Path)
  137. if strings.HasPrefix(model.Path, setting.Bucket+"/"+Model_prefix) {
  138. err := storage.ObsRemoveObject(setting.Bucket, model.Path[len(setting.Bucket)+1:])
  139. if err != nil {
  140. log.Info("Failed to delete model. id=" + id)
  141. return err
  142. }
  143. }
  144. err = models.DeleteModelById(id)
  145. if err == nil { //find a model to change new
  146. if model.New == MODEL_LATEST {
  147. aimodels := models.QueryModelByName(model.Name, model.RepoId)
  148. if len(aimodels) > 0 {
  149. models.ModifyModelNewProperty(aimodels[0].ID, MODEL_LATEST)
  150. }
  151. }
  152. }
  153. }
  154. return err
  155. }
  156. func DownloadModel(ctx *context.Context) {
  157. log.Info("download model start.")
  158. }
  159. func QueryModelByParameters(repoId int64, page int) ([]*models.AiModelManage, int64, error) {
  160. return models.QueryModel(&models.AiModelQueryOptions{
  161. ListOptions: models.ListOptions{
  162. Page: page,
  163. PageSize: setting.UI.IssuePagingNum,
  164. },
  165. RepoID: repoId,
  166. Type: -1,
  167. New: MODEL_LATEST,
  168. })
  169. }
  170. func DownloadMultiModelFile(ctx *context.Context) {
  171. log.Info("DownloadMultiModelFile start.")
  172. id := ctx.Query("ID")
  173. log.Info("id=" + id)
  174. }
  175. func DownloadSingleModelFile(ctx *context.Context) {
  176. log.Info("DownloadSingleModelFile start.")
  177. id := ctx.Params(":ID")
  178. parentDir := ctx.Query("parentDir")
  179. fileName := ctx.Query("fileName")
  180. path := Model_prefix + models.AttachmentRelativePath(id) + "/" + parentDir + fileName
  181. url, err := storage.GetObsCreateSignedUrlByBucketAndKey(setting.Bucket, path)
  182. if err != nil {
  183. log.Error("GetObsCreateSignedUrl failed: %v", err.Error(), ctx.Data["msgID"])
  184. ctx.ServerError("GetObsCreateSignedUrl", err)
  185. return
  186. }
  187. http.Redirect(ctx.Resp, ctx.Req.Request, url, http.StatusMovedPermanently)
  188. }
  189. func ShowSingleModel(ctx *context.Context) {
  190. id := ctx.Params(":ID")
  191. parentDir := ctx.Query("parentDir")
  192. log.Info("Show single ModelInfo start.id=" + id)
  193. task, err := models.QueryModelById(id)
  194. if err != nil {
  195. log.Error("no such model!", err.Error())
  196. ctx.ServerError("no such model:", err)
  197. return
  198. }
  199. log.Info("bucket=" + setting.Bucket + " key=" + task.Path[len(setting.Bucket)+1:] + parentDir)
  200. models, err := storage.GetAllObsListObjectUnderDir(setting.Bucket, task.Path[len(setting.Bucket)+1:]+parentDir)
  201. if err != nil {
  202. log.Info("get model list failed:", err)
  203. ctx.ServerError("GetObsListObject:", err)
  204. return
  205. } else {
  206. log.Info("get model file,size=" + fmt.Sprint(len(models)))
  207. }
  208. ctx.Data["Dirs"] = models
  209. ctx.Data["task"] = task
  210. ctx.Data["ID"] = id
  211. ctx.HTML(200, tplModelManageDownload)
  212. }
  213. func ShowOneVersionOtherModel(ctx *context.Context) {
  214. repoId := ctx.Repo.Repository.ID
  215. name := ctx.Query("name")
  216. aimodels := models.QueryModelByName(name, repoId)
  217. if len(aimodels) > 0 {
  218. ctx.JSON(200, aimodels[1:])
  219. } else {
  220. ctx.JSON(200, aimodels)
  221. }
  222. }
  223. func ShowModelPageInfo(ctx *context.Context) {
  224. log.Info("ShowModelInfo start.")
  225. page := ctx.QueryInt("page")
  226. if page <= 0 {
  227. page = 1
  228. }
  229. repoId := ctx.Repo.Repository.ID
  230. Type := -1
  231. modelResult, count, err := models.QueryModel(&models.AiModelQueryOptions{
  232. ListOptions: models.ListOptions{
  233. Page: page,
  234. PageSize: setting.UI.IssuePagingNum,
  235. },
  236. RepoID: repoId,
  237. Type: Type,
  238. New: MODEL_LATEST,
  239. })
  240. if err != nil {
  241. ctx.ServerError("Cloudbrain", err)
  242. return
  243. }
  244. pager := context.NewPagination(int(count), setting.UI.IssuePagingNum, page, 5)
  245. pager.SetDefaultParams(ctx)
  246. ctx.Data["Page"] = pager
  247. ctx.Data["PageIsCloudBrain"] = true
  248. ctx.Data["Tasks"] = modelResult
  249. ctx.HTML(200, tplModelManageIndex)
  250. }
  251. func ModifyModel(id string, description string) error {
  252. err := models.ModifyModelDescription(id, description)
  253. if err == nil {
  254. log.Info("modify success.")
  255. } else {
  256. log.Info("Failed to modify.id=" + id + " desc=" + description + " error:" + err.Error())
  257. }
  258. return err
  259. }
  260. func ModifyModelInfo(ctx *context.Context) {
  261. log.Info("delete model start.")
  262. id := ctx.Query("ID")
  263. description := ctx.Query("Description")
  264. err := ModifyModel(id, description)
  265. if err == nil {
  266. ctx.HTML(200, "success")
  267. } else {
  268. ctx.HTML(500, "Failed.")
  269. }
  270. }