You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ai_model_manage.go 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. package repo
  2. import (
  3. "archive/zip"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "net/http"
  8. "path"
  9. "strings"
  10. "code.gitea.io/gitea/models"
  11. "code.gitea.io/gitea/modules/context"
  12. "code.gitea.io/gitea/modules/log"
  13. "code.gitea.io/gitea/modules/setting"
  14. "code.gitea.io/gitea/modules/storage"
  15. uuid "github.com/satori/go.uuid"
  16. )
  17. const (
  18. Model_prefix = "aimodels/"
  19. tplModelManageIndex = "repo/modelmanage/index"
  20. tplModelManageDownload = "repo/modelmanage/download"
  21. tplModelInfo = "repo/modelmanage/showinfo"
  22. MODEL_LATEST = 1
  23. MODEL_NOT_LATEST = 0
  24. )
  25. func saveModelByParameters(jobId string, versionName string, name string, version string, label string, description string, userId int64, userName string, userHeadUrl string) error {
  26. aiTask, err := models.GetCloudbrainByJobIDAndVersionName(jobId, versionName)
  27. //aiTask, err := models.GetCloudbrainByJobID(jobId)
  28. if err != nil {
  29. log.Info("query task error." + err.Error())
  30. return err
  31. }
  32. uuid := uuid.NewV4()
  33. id := uuid.String()
  34. modelPath := id
  35. var lastNewModelId string
  36. var modelSize int64
  37. cloudType := models.TypeCloudBrainTwo
  38. log.Info("find task name:" + aiTask.JobName)
  39. aimodels := models.QueryModelByName(name, aiTask.RepoID)
  40. if len(aimodels) > 0 {
  41. for _, model := range aimodels {
  42. if model.New == MODEL_LATEST {
  43. lastNewModelId = model.ID
  44. }
  45. }
  46. }
  47. cloudType = aiTask.Type
  48. //download model zip //train type
  49. if cloudType == models.TypeCloudBrainTwo {
  50. modelPath, modelSize, err = downloadModelFromCloudBrainTwo(id, aiTask.JobName, "")
  51. if err != nil {
  52. log.Info("download model from CloudBrainTwo faild." + err.Error())
  53. return err
  54. }
  55. }
  56. accuracy := make(map[string]string)
  57. accuracy["F1"] = ""
  58. accuracy["Recall"] = ""
  59. accuracy["Accuracy"] = ""
  60. accuracy["Precision"] = ""
  61. accuracyJson, _ := json.Marshal(accuracy)
  62. log.Info("accuracyJson=" + string(accuracyJson))
  63. aiTaskJson, _ := json.Marshal(aiTask)
  64. //taskConfigInfo,err := models.GetCloudbrainByJobIDAndVersionName(jobId,aiTask.VersionName)
  65. model := &models.AiModelManage{
  66. ID: id,
  67. Version: version,
  68. VersionCount: len(aimodels) + 1,
  69. Label: label,
  70. Name: name,
  71. Description: description,
  72. New: MODEL_LATEST,
  73. Type: cloudType,
  74. Path: modelPath,
  75. Size: modelSize,
  76. AttachmentId: aiTask.Uuid,
  77. RepoId: aiTask.RepoID,
  78. UserId: userId,
  79. UserName: userName,
  80. UserRelAvatarLink: userHeadUrl,
  81. CodeBranch: aiTask.BranchName,
  82. CodeCommitID: aiTask.CommitID,
  83. Engine: aiTask.EngineID,
  84. TrainTaskInfo: string(aiTaskJson),
  85. Accuracy: string(accuracyJson),
  86. }
  87. err = models.SaveModelToDb(model)
  88. if err != nil {
  89. return err
  90. }
  91. if len(lastNewModelId) > 0 {
  92. //udpate status and version count
  93. models.ModifyModelNewProperty(lastNewModelId, MODEL_NOT_LATEST, 0)
  94. }
  95. log.Info("save model end.")
  96. return nil
  97. }
  98. func SaveModel(ctx *context.Context) {
  99. log.Info("save model start.")
  100. JobId := ctx.Query("JobId")
  101. VersionName := ctx.Query("VersionName")
  102. name := ctx.Query("Name")
  103. version := ctx.Query("Version")
  104. label := ctx.Query("Label")
  105. description := ctx.Query("Description")
  106. if JobId == "" || VersionName == "" {
  107. ctx.Error(500, fmt.Sprintf("JobId or VersionName is null."))
  108. return
  109. }
  110. if name == "" || version == "" {
  111. ctx.Error(500, fmt.Sprintf("name or version is null."))
  112. return
  113. }
  114. err := saveModelByParameters(JobId, VersionName, name, version, label, description, ctx.User.ID, ctx.User.Name, ctx.User.RelAvatarLink())
  115. if err != nil {
  116. log.Info("save model error." + err.Error())
  117. ctx.Error(500, fmt.Sprintf("save model error. %v", err))
  118. return
  119. }
  120. log.Info("save model end.")
  121. }
  122. func downloadModelFromCloudBrainTwo(modelUUID string, jobName string, parentDir string) (string, int64, error) {
  123. objectkey := strings.TrimPrefix(path.Join(setting.TrainJobModelPath, jobName, setting.OutPutPath, parentDir), "/")
  124. modelDbResult, err := storage.GetOneLevelAllObjectUnderDir(setting.Bucket, objectkey, "")
  125. log.Info("bucket=" + setting.Bucket + " objectkey=" + objectkey)
  126. if err != nil {
  127. log.Info("get TrainJobListModel failed:", err)
  128. return "", 0, err
  129. }
  130. if len(modelDbResult) == 0 {
  131. return "", 0, errors.New("cannot create model, as model is empty.")
  132. }
  133. prefix := objectkey + "/"
  134. destKeyNamePrefix := Model_prefix + models.AttachmentRelativePath(modelUUID) + "/"
  135. size, err := storage.ObsCopyManyFile(setting.Bucket, prefix, setting.Bucket, destKeyNamePrefix)
  136. dataActualPath := setting.Bucket + "/" + destKeyNamePrefix
  137. return dataActualPath, size, nil
  138. }
  139. func DeleteModel(ctx *context.Context) {
  140. log.Info("delete model start.")
  141. id := ctx.Query("ID")
  142. err := DeleteModelByID(id)
  143. if err != nil {
  144. ctx.JSON(500, err.Error())
  145. } else {
  146. ctx.JSON(200, map[string]string{
  147. "result_code": "0",
  148. })
  149. }
  150. }
  151. func DeleteModelByID(id string) error {
  152. log.Info("delete model start. id=" + id)
  153. model, err := models.QueryModelById(id)
  154. if err == nil {
  155. log.Info("bucket=" + setting.Bucket + " path=" + model.Path)
  156. if strings.HasPrefix(model.Path, setting.Bucket+"/"+Model_prefix) {
  157. err := storage.ObsRemoveObject(setting.Bucket, model.Path[len(setting.Bucket)+1:])
  158. if err != nil {
  159. log.Info("Failed to delete model. id=" + id)
  160. return err
  161. }
  162. }
  163. err = models.DeleteModelById(id)
  164. if err == nil { //find a model to change new
  165. if model.New == MODEL_LATEST {
  166. aimodels := models.QueryModelByName(model.Name, model.RepoId)
  167. if len(aimodels) > 0 {
  168. //udpate status and version count
  169. models.ModifyModelNewProperty(aimodels[0].ID, MODEL_LATEST, len(aimodels))
  170. }
  171. }
  172. }
  173. }
  174. return err
  175. }
  176. func QueryModelByParameters(repoId int64, page int) ([]*models.AiModelManage, int64, error) {
  177. return models.QueryModel(&models.AiModelQueryOptions{
  178. ListOptions: models.ListOptions{
  179. Page: page,
  180. PageSize: setting.UI.IssuePagingNum,
  181. },
  182. RepoID: repoId,
  183. Type: -1,
  184. New: MODEL_LATEST,
  185. })
  186. }
  187. func DownloadMultiModelFile(ctx *context.Context) {
  188. log.Info("DownloadMultiModelFile start.")
  189. id := ctx.Query("ID")
  190. log.Info("id=" + id)
  191. task, err := models.QueryModelById(id)
  192. if err != nil {
  193. log.Error("no such model!", err.Error())
  194. ctx.ServerError("no such model:", err)
  195. return
  196. }
  197. path := Model_prefix + models.AttachmentRelativePath(id) + "/"
  198. allFile, err := storage.GetAllObjectByBucketAndPrefix(setting.Bucket, path)
  199. if err == nil {
  200. //count++
  201. models.ModifyModelDownloadCount(id)
  202. returnFileName := task.Name + "_" + task.Version + ".zip"
  203. ctx.Resp.Header().Set("Content-Disposition", "attachment; filename="+returnFileName)
  204. ctx.Resp.Header().Set("Content-Type", "application/octet-stream")
  205. w := zip.NewWriter(ctx.Resp)
  206. defer w.Close()
  207. for _, oneFile := range allFile {
  208. if oneFile.IsDir {
  209. log.Info("zip dir name:" + oneFile.FileName)
  210. } else {
  211. log.Info("zip file name:" + oneFile.FileName)
  212. fDest, err := w.Create(oneFile.FileName)
  213. if err != nil {
  214. log.Info("create zip entry error, download file failed: %s\n", err.Error())
  215. ctx.ServerError("download file failed:", err)
  216. return
  217. }
  218. body, err := storage.ObsDownloadAFile(setting.Bucket, path+oneFile.FileName)
  219. if err != nil {
  220. log.Info("download file failed: %s\n", err.Error())
  221. ctx.ServerError("download file failed:", err)
  222. return
  223. } else {
  224. defer body.Close()
  225. p := make([]byte, 1024)
  226. var readErr error
  227. var readCount int
  228. // 读取对象内容
  229. for {
  230. readCount, readErr = body.Read(p)
  231. if readCount > 0 {
  232. fDest.Write(p[:readCount])
  233. }
  234. if readErr != nil {
  235. break
  236. }
  237. }
  238. }
  239. }
  240. }
  241. } else {
  242. log.Info("error,msg=" + err.Error())
  243. ctx.ServerError("no file to download.", err)
  244. }
  245. }
  246. func QueryTrainJobVersionList(ctx *context.Context) {
  247. log.Info("query train job version list. start.")
  248. JobID := ctx.Query("JobID")
  249. VersionListTasks, count, err := models.QueryModelTrainJobVersionList(JobID)
  250. log.Info("query return count=" + fmt.Sprint(count))
  251. if err != nil {
  252. ctx.ServerError("QueryTrainJobList:", err)
  253. } else {
  254. ctx.JSON(200, VersionListTasks)
  255. }
  256. }
  257. func QueryTrainJobList(ctx *context.Context) {
  258. log.Info("query train job list. start.")
  259. repoId := ctx.QueryInt64("repoId")
  260. VersionListTasks, count, err := models.QueryModelTrainJobList(repoId)
  261. log.Info("query return count=" + fmt.Sprint(count))
  262. if err != nil {
  263. ctx.ServerError("QueryTrainJobList:", err)
  264. } else {
  265. ctx.JSON(200, VersionListTasks)
  266. }
  267. }
  268. func DownloadSingleModelFile(ctx *context.Context) {
  269. log.Info("DownloadSingleModelFile start.")
  270. id := ctx.Params(":ID")
  271. parentDir := ctx.Query("parentDir")
  272. fileName := ctx.Query("fileName")
  273. path := Model_prefix + models.AttachmentRelativePath(id) + "/" + parentDir + fileName
  274. if setting.PROXYURL != "" {
  275. body, err := storage.ObsDownloadAFile(setting.Bucket, path)
  276. if err != nil {
  277. log.Info("download error.")
  278. } else {
  279. //count++
  280. models.ModifyModelDownloadCount(id)
  281. defer body.Close()
  282. ctx.Resp.Header().Set("Content-Disposition", "attachment; filename="+fileName)
  283. ctx.Resp.Header().Set("Content-Type", "application/octet-stream")
  284. p := make([]byte, 1024)
  285. var readErr error
  286. var readCount int
  287. // 读取对象内容
  288. for {
  289. readCount, readErr = body.Read(p)
  290. if readCount > 0 {
  291. ctx.Resp.Write(p[:readCount])
  292. //fmt.Printf("%s", p[:readCount])
  293. }
  294. if readErr != nil {
  295. break
  296. }
  297. }
  298. }
  299. } else {
  300. url, err := storage.GetObsCreateSignedUrlByBucketAndKey(setting.Bucket, path)
  301. if err != nil {
  302. log.Error("GetObsCreateSignedUrl failed: %v", err.Error(), ctx.Data["msgID"])
  303. ctx.ServerError("GetObsCreateSignedUrl", err)
  304. return
  305. }
  306. //count++
  307. models.ModifyModelDownloadCount(id)
  308. http.Redirect(ctx.Resp, ctx.Req.Request, url, http.StatusMovedPermanently)
  309. }
  310. }
  311. func ShowModelInfo(ctx *context.Context) {
  312. ctx.HTML(200, tplModelInfo)
  313. }
  314. func ShowSingleModel(ctx *context.Context) {
  315. id := ctx.Params(":ID")
  316. parentDir := ctx.Query("parentDir")
  317. log.Info("Show single ModelInfo start.id=" + id)
  318. task, err := models.QueryModelById(id)
  319. if err != nil {
  320. log.Error("no such model!", err.Error())
  321. ctx.ServerError("no such model:", err)
  322. return
  323. }
  324. log.Info("bucket=" + setting.Bucket + " key=" + task.Path[len(setting.Bucket)+1:])
  325. models, err := storage.GetOneLevelAllObjectUnderDir(setting.Bucket, task.Path[len(setting.Bucket)+1:], parentDir)
  326. if err != nil {
  327. log.Info("get model list failed:", err)
  328. ctx.ServerError("GetObsListObject:", err)
  329. return
  330. } else {
  331. log.Info("get model file,size=" + fmt.Sprint(len(models)))
  332. }
  333. ctx.Data["Dirs"] = models
  334. ctx.Data["task"] = task
  335. ctx.Data["ID"] = id
  336. ctx.HTML(200, tplModelManageDownload)
  337. }
  338. func ShowOneVersionOtherModel(ctx *context.Context) {
  339. repoId := ctx.Repo.Repository.ID
  340. name := ctx.Query("name")
  341. aimodels := models.QueryModelByName(name, repoId)
  342. if len(aimodels) > 0 {
  343. ctx.JSON(200, aimodels[1:])
  344. } else {
  345. ctx.JSON(200, aimodels)
  346. }
  347. }
  348. func ShowModelTemplate(ctx *context.Context) {
  349. ctx.HTML(200, tplModelManageIndex)
  350. }
  351. func ShowModelPageInfo(ctx *context.Context) {
  352. log.Info("ShowModelInfo start.")
  353. page := ctx.QueryInt("page")
  354. if page <= 0 {
  355. page = 1
  356. }
  357. repoId := ctx.Repo.Repository.ID
  358. Type := -1
  359. modelResult, count, err := models.QueryModel(&models.AiModelQueryOptions{
  360. ListOptions: models.ListOptions{
  361. Page: page,
  362. PageSize: setting.UI.IssuePagingNum,
  363. },
  364. RepoID: repoId,
  365. Type: Type,
  366. New: MODEL_LATEST,
  367. })
  368. if err != nil {
  369. ctx.ServerError("Cloudbrain", err)
  370. return
  371. }
  372. mapInterface := make(map[string]interface{})
  373. mapInterface["data"] = modelResult
  374. mapInterface["count"] = count
  375. ctx.JSON(http.StatusOK, mapInterface)
  376. }
  377. func ModifyModel(id string, description string) error {
  378. err := models.ModifyModelDescription(id, description)
  379. if err == nil {
  380. log.Info("modify success.")
  381. } else {
  382. log.Info("Failed to modify.id=" + id + " desc=" + description + " error:" + err.Error())
  383. }
  384. return err
  385. }
  386. func ModifyModelInfo(ctx *context.Context) {
  387. log.Info("delete model start.")
  388. id := ctx.Query("ID")
  389. description := ctx.Query("Description")
  390. err := ModifyModel(id, description)
  391. if err == nil {
  392. ctx.HTML(200, "success")
  393. } else {
  394. ctx.HTML(500, "Failed.")
  395. }
  396. }