You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.go 40 kB

2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222
  1. package cloudbrainTask
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "io/ioutil"
  8. "net/http"
  9. "os"
  10. "path"
  11. "regexp"
  12. "strconv"
  13. "strings"
  14. "code.gitea.io/gitea/modules/urfs_client/urchin"
  15. "code.gitea.io/gitea/modules/timeutil"
  16. "code.gitea.io/gitea/modules/notification"
  17. "code.gitea.io/gitea/modules/obs"
  18. "code.gitea.io/gitea/modules/git"
  19. "code.gitea.io/gitea/modules/storage"
  20. "github.com/unknwon/com"
  21. "code.gitea.io/gitea/models"
  22. "code.gitea.io/gitea/modules/cloudbrain"
  23. "code.gitea.io/gitea/modules/context"
  24. "code.gitea.io/gitea/modules/grampus"
  25. "code.gitea.io/gitea/modules/log"
  26. "code.gitea.io/gitea/modules/modelarts"
  27. "code.gitea.io/gitea/modules/redis/redis_key"
  28. "code.gitea.io/gitea/modules/redis/redis_lock"
  29. "code.gitea.io/gitea/modules/setting"
  30. api "code.gitea.io/gitea/modules/structs"
  31. "code.gitea.io/gitea/modules/util"
  32. "code.gitea.io/gitea/services/cloudbrain/resource"
  33. "code.gitea.io/gitea/services/reward/point/account"
  34. )
  35. var jobNamePattern = regexp.MustCompile(`^[a-z0-9][a-z0-9-_]{1,34}[a-z0-9-]$`)
  36. const TaskTypeCloudbrainOne = 0
  37. const TaskTypeModelArts = 1
  38. const TaskTypeGrampusGPU = 2
  39. const TaskTypeGrampusNPU = 3
  40. func CloudbrainOneTrainJobCreate(ctx *context.Context, option api.CreateTrainJobOption) {
  41. displayJobName := option.DisplayJobName
  42. jobName := util.ConvertDisplayJobNameToJobName(displayJobName)
  43. image := strings.TrimSpace(option.Image)
  44. uuids := option.Attachment
  45. jobType := string(models.JobTypeTrain)
  46. codePath := setting.JobPath + jobName + cloudbrain.CodeMountPath
  47. branchName := option.BranchName
  48. repo := ctx.Repo.Repository
  49. lock := redis_lock.NewDistributeLock(redis_key.CloudbrainBindingJobNameKey(fmt.Sprint(repo.ID), jobType, displayJobName))
  50. defer lock.UnLock()
  51. spec, datasetInfos, datasetNames, err := checkParameters(ctx, option, lock, repo)
  52. if err != nil {
  53. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(err.Error()))
  54. return
  55. }
  56. command, err := getTrainJobCommand(option)
  57. if err != nil {
  58. log.Error("getTrainJobCommand failed: %v", err)
  59. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(err.Error()))
  60. return
  61. }
  62. errStr := loadCodeAndMakeModelPath(repo, codePath, branchName, jobName, cloudbrain.ModelMountPath)
  63. if errStr != "" {
  64. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(ctx.Tr(errStr)))
  65. return
  66. }
  67. commitID, _ := ctx.Repo.GitRepo.GetBranchCommitID(branchName)
  68. req := cloudbrain.GenerateCloudBrainTaskReq{
  69. Ctx: ctx,
  70. DisplayJobName: displayJobName,
  71. JobName: jobName,
  72. Image: image,
  73. Command: command,
  74. Uuids: uuids,
  75. DatasetNames: datasetNames,
  76. DatasetInfos: datasetInfos,
  77. CodePath: storage.GetMinioPath(jobName, cloudbrain.CodeMountPath+"/"),
  78. ModelPath: storage.GetMinioPath(jobName, cloudbrain.ModelMountPath+"/"),
  79. BenchmarkPath: storage.GetMinioPath(jobName, cloudbrain.BenchMarkMountPath+"/"),
  80. Snn4ImageNetPath: storage.GetMinioPath(jobName, cloudbrain.Snn4imagenetMountPath+"/"),
  81. BrainScorePath: storage.GetMinioPath(jobName, cloudbrain.BrainScoreMountPath+"/"),
  82. JobType: jobType,
  83. Description: option.Description,
  84. BranchName: branchName,
  85. BootFile: option.BootFile,
  86. Params: option.Params,
  87. CommitID: commitID,
  88. BenchmarkTypeID: 0,
  89. BenchmarkChildTypeID: 0,
  90. ResultPath: storage.GetMinioPath(jobName, cloudbrain.ResultPath+"/"),
  91. Spec: spec,
  92. }
  93. if option.ModelName != "" { //使用预训练模型训练
  94. req.ModelName = option.ModelName
  95. req.LabelName = option.LabelName
  96. req.CkptName = option.CkptName
  97. req.ModelVersion = option.ModelVersion
  98. req.PreTrainModelPath = setting.Attachment.Minio.RealPath + option.PreTrainModelUrl
  99. req.PreTrainModelUrl = option.PreTrainModelUrl
  100. }
  101. jobId, err := cloudbrain.GenerateTask(req)
  102. if err != nil {
  103. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(err.Error()))
  104. return
  105. }
  106. ctx.JSON(http.StatusOK, models.BaseMessageApi{
  107. Code: 0,
  108. Message: jobId,
  109. })
  110. }
  111. func ModelArtsTrainJobNpuCreate(ctx *context.Context, option api.CreateTrainJobOption) {
  112. VersionOutputPath := modelarts.GetOutputPathByCount(modelarts.TotalVersionCount)
  113. displayJobName := option.DisplayJobName
  114. jobName := util.ConvertDisplayJobNameToJobName(displayJobName)
  115. uuid := option.Attachment
  116. description := option.Description
  117. workServerNumber := option.WorkServerNumber
  118. engineID, _ := strconv.Atoi(option.ImageID)
  119. bootFile := strings.TrimSpace(option.BootFile)
  120. params := option.Params
  121. repo := ctx.Repo.Repository
  122. codeLocalPath := setting.JobPath + jobName + modelarts.CodePath
  123. codeObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.CodePath + VersionOutputPath + "/"
  124. outputObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.OutputPath + VersionOutputPath + "/"
  125. logObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.LogPath + VersionOutputPath + "/"
  126. branchName := option.BranchName
  127. isLatestVersion := modelarts.IsLatestVersion
  128. VersionCount := modelarts.VersionCountOne
  129. EngineName := option.Image
  130. errStr := checkMultiNode(ctx.User.ID, option.WorkServerNumber)
  131. if errStr != "" {
  132. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(ctx.Tr(errStr)))
  133. return
  134. }
  135. lock := redis_lock.NewDistributeLock(redis_key.CloudbrainBindingJobNameKey(fmt.Sprint(repo.ID), string(models.JobTypeTrain), displayJobName))
  136. defer lock.UnLock()
  137. spec, _, _, err := checkParameters(ctx, option, lock, repo)
  138. if err != nil {
  139. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(err.Error()))
  140. return
  141. }
  142. //todo: del the codeLocalPath
  143. _, err = ioutil.ReadDir(codeLocalPath)
  144. if err == nil {
  145. os.RemoveAll(codeLocalPath)
  146. }
  147. gitRepo, _ := git.OpenRepository(repo.RepoPath())
  148. commitID, _ := gitRepo.GetBranchCommitID(branchName)
  149. if err := downloadCode(repo, codeLocalPath, branchName); err != nil {
  150. log.Error("downloadCode failed, server timed out: %s (%v)", repo.FullName(), err)
  151. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(ctx.Tr("cloudbrain.load_code_failed")))
  152. return
  153. }
  154. //todo: upload code (send to file_server todo this work?)
  155. if err := obsMkdir(setting.CodePathPrefix + jobName + modelarts.OutputPath + VersionOutputPath + "/"); err != nil {
  156. log.Error("Failed to obsMkdir_output: %s (%v)", repo.FullName(), err)
  157. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi("Failed to obsMkdir_output"))
  158. return
  159. }
  160. if err := obsMkdir(setting.CodePathPrefix + jobName + modelarts.LogPath + VersionOutputPath + "/"); err != nil {
  161. log.Error("Failed to obsMkdir_log: %s (%v)", repo.FullName(), err)
  162. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi("Failed to obsMkdir_log"))
  163. return
  164. }
  165. parentDir := VersionOutputPath + "/"
  166. if err := uploadCodeToObs(codeLocalPath, jobName, parentDir); err != nil {
  167. // if err := uploadCodeToObs(codeLocalPath, jobName, parentDir); err != nil {
  168. log.Error("Failed to uploadCodeToObs: %s (%v)", repo.FullName(), err)
  169. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(ctx.Tr("cloudbrain.load_code_failed")))
  170. return
  171. }
  172. var parameters models.Parameters
  173. param := make([]models.Parameter, 0)
  174. existDeviceTarget := false
  175. if len(params) != 0 {
  176. err := json.Unmarshal([]byte(params), &parameters)
  177. if err != nil {
  178. log.Error("Failed to Unmarshal params: %s (%v)", params, err)
  179. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi("运行参数错误"))
  180. return
  181. }
  182. for _, parameter := range parameters.Parameter {
  183. if parameter.Label == modelarts.DeviceTarget {
  184. existDeviceTarget = true
  185. }
  186. if parameter.Label != modelarts.TrainUrl && parameter.Label != modelarts.DataUrl {
  187. param = append(param, models.Parameter{
  188. Label: parameter.Label,
  189. Value: parameter.Value,
  190. })
  191. }
  192. }
  193. }
  194. if !existDeviceTarget {
  195. param = append(param, models.Parameter{
  196. Label: modelarts.DeviceTarget,
  197. Value: modelarts.Ascend,
  198. })
  199. }
  200. datasUrlList, dataUrl, datasetNames, isMultiDataset, err := getDatasUrlListByUUIDS(uuid)
  201. if err != nil {
  202. log.Error("Failed to getDatasUrlListByUUIDS: %v", err)
  203. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi("Failed to getDatasUrlListByUUIDS:"+err.Error()))
  204. return
  205. }
  206. dataPath := dataUrl
  207. jsondatas, err := json.Marshal(datasUrlList)
  208. if err != nil {
  209. log.Error("Failed to Marshal: %v", err)
  210. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi("json error:"+err.Error()))
  211. return
  212. }
  213. if isMultiDataset {
  214. param = append(param, models.Parameter{
  215. Label: modelarts.MultiDataUrl,
  216. Value: string(jsondatas),
  217. })
  218. }
  219. if option.ModelName != "" { //使用预训练模型训练
  220. ckptUrl := "/" + option.PreTrainModelUrl + option.CkptName
  221. param = append(param, models.Parameter{
  222. Label: modelarts.CkptUrl,
  223. Value: "s3:/" + ckptUrl,
  224. })
  225. }
  226. req := &modelarts.GenerateTrainJobReq{
  227. JobName: jobName,
  228. DisplayJobName: displayJobName,
  229. DataUrl: dataPath,
  230. Description: description,
  231. CodeObsPath: codeObsPath,
  232. BootFileUrl: codeObsPath + bootFile,
  233. BootFile: bootFile,
  234. TrainUrl: outputObsPath,
  235. WorkServerNumber: workServerNumber,
  236. EngineID: int64(engineID),
  237. LogUrl: logObsPath,
  238. PoolID: getPoolId(),
  239. Uuid: uuid,
  240. Parameters: param,
  241. CommitID: commitID,
  242. IsLatestVersion: isLatestVersion,
  243. BranchName: branchName,
  244. Params: option.Params,
  245. EngineName: EngineName,
  246. VersionCount: VersionCount,
  247. TotalVersionCount: modelarts.TotalVersionCount,
  248. DatasetName: datasetNames,
  249. Spec: spec,
  250. }
  251. if option.ModelName != "" { //使用预训练模型训练
  252. req.ModelName = option.ModelName
  253. req.LabelName = option.LabelName
  254. req.CkptName = option.CkptName
  255. req.ModelVersion = option.ModelVersion
  256. req.PreTrainModelUrl = option.PreTrainModelUrl
  257. }
  258. userCommand, userImageUrl := getUserCommand(engineID, req)
  259. req.UserCommand = userCommand
  260. req.UserImageUrl = userImageUrl
  261. //将params转换Parameters.Parameter,出错时返回给前端
  262. var Parameters modelarts.Parameters
  263. if err := json.Unmarshal([]byte(params), &Parameters); err != nil {
  264. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi("json.Unmarshal failed:"+err.Error()))
  265. return
  266. }
  267. jobId, err := modelarts.GenerateTrainJob(ctx, req)
  268. if err != nil {
  269. log.Error("GenerateTrainJob failed:%v", err.Error())
  270. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(err.Error()))
  271. return
  272. }
  273. ctx.JSON(http.StatusOK, models.BaseMessageApi{
  274. Code: 0,
  275. Message: jobId,
  276. })
  277. }
  278. func GrampusTrainJobGpuCreate(ctx *context.Context, option api.CreateTrainJobOption) {
  279. displayJobName := option.DisplayJobName
  280. jobName := util.ConvertDisplayJobNameToJobName(displayJobName)
  281. uuid := option.Attachment
  282. description := option.Description
  283. bootFile := strings.TrimSpace(option.BootFile)
  284. params := option.Params
  285. repo := ctx.Repo.Repository
  286. codeLocalPath := setting.JobPath + jobName + cloudbrain.CodeMountPath + "/"
  287. codeMinioPath := setting.CBCodePathPrefix + jobName + cloudbrain.CodeMountPath + "/"
  288. branchName := option.BranchName
  289. image := strings.TrimSpace(option.Image)
  290. lock := redis_lock.NewDistributeLock(redis_key.CloudbrainBindingJobNameKey(fmt.Sprint(repo.ID), string(models.JobTypeTrain), displayJobName))
  291. defer lock.UnLock()
  292. spec, datasetInfos, datasetNames, err := checkParameters(ctx, option, lock, repo)
  293. if err != nil {
  294. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(err.Error()))
  295. return
  296. }
  297. //prepare code and out path
  298. _, err = ioutil.ReadDir(codeLocalPath)
  299. if err == nil {
  300. os.RemoveAll(codeLocalPath)
  301. }
  302. if err := downloadZipCode(ctx, codeLocalPath, branchName); err != nil {
  303. log.Error("downloadZipCode failed, server timed out: %s (%v)", repo.FullName(), err, ctx.Data["MsgID"])
  304. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(ctx.Tr("cloudbrain.load_code_failed")))
  305. }
  306. //todo: upload code (send to file_server todo this work?)
  307. //upload code
  308. if err := uploadCodeToMinio(codeLocalPath+"/", jobName, cloudbrain.CodeMountPath+"/"); err != nil {
  309. log.Error("Failed to uploadCodeToMinio: %s (%v)", repo.FullName(), err, ctx.Data["MsgID"])
  310. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(ctx.Tr("cloudbrain.load_code_failed")))
  311. return
  312. }
  313. modelPath := setting.JobPath + jobName + cloudbrain.ModelMountPath + "/"
  314. if err := mkModelPath(modelPath); err != nil {
  315. log.Error("Failed to mkModelPath: %s (%v)", repo.FullName(), err, ctx.Data["MsgID"])
  316. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(ctx.Tr("cloudbrain.load_code_failed")))
  317. return
  318. }
  319. //init model readme
  320. if err := uploadCodeToMinio(modelPath, jobName, cloudbrain.ModelMountPath+"/"); err != nil {
  321. log.Error("Failed to uploadCodeToMinio: %s (%v)", repo.FullName(), err, ctx.Data["MsgID"])
  322. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(ctx.Tr("cloudbrain.load_code_failed")))
  323. return
  324. }
  325. var datasetRemotePath, allFileName string
  326. for _, datasetInfo := range datasetInfos {
  327. if datasetRemotePath == "" {
  328. datasetRemotePath = datasetInfo.DataLocalPath
  329. allFileName = datasetInfo.FullName
  330. } else {
  331. datasetRemotePath = datasetRemotePath + ";" + datasetInfo.DataLocalPath
  332. allFileName = allFileName + ";" + datasetInfo.FullName
  333. }
  334. }
  335. //prepare command
  336. preTrainModelPath := getPreTrainModelPath(option.PreTrainModelUrl, option.CkptName)
  337. command, err := generateCommand(repo.Name, grampus.ProcessorTypeGPU, codeMinioPath+cloudbrain.DefaultBranchName+".zip", datasetRemotePath, bootFile, params, setting.CBCodePathPrefix+jobName+cloudbrain.ModelMountPath+"/", allFileName, preTrainModelPath, option.CkptName, "")
  338. if err != nil {
  339. log.Error("Failed to generateCommand: %s (%v)", displayJobName, err, ctx.Data["MsgID"])
  340. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi("Create task failed, internal error"))
  341. return
  342. }
  343. commitID, _ := ctx.Repo.GitRepo.GetBranchCommitID(branchName)
  344. req := &grampus.GenerateTrainJobReq{
  345. JobName: jobName,
  346. DisplayJobName: displayJobName,
  347. ComputeResource: models.GPUResource,
  348. ProcessType: grampus.ProcessorTypeGPU,
  349. Command: command,
  350. ImageUrl: image,
  351. Description: description,
  352. BootFile: bootFile,
  353. Uuid: uuid,
  354. CommitID: commitID,
  355. BranchName: branchName,
  356. Params: option.Params,
  357. EngineName: image,
  358. DatasetNames: datasetNames,
  359. DatasetInfos: datasetInfos,
  360. IsLatestVersion: modelarts.IsLatestVersion,
  361. VersionCount: modelarts.VersionCountOne,
  362. WorkServerNumber: 1,
  363. Spec: spec,
  364. }
  365. if option.ModelName != "" { //使用预训练模型训练
  366. req.ModelName = option.ModelName
  367. req.LabelName = option.LabelName
  368. req.CkptName = option.CkptName
  369. req.ModelVersion = option.ModelVersion
  370. req.PreTrainModelUrl = option.PreTrainModelUrl
  371. }
  372. jobId, err := grampus.GenerateTrainJob(ctx, req)
  373. if err != nil {
  374. log.Error("GenerateTrainJob failed:%v", err.Error(), ctx.Data["MsgID"])
  375. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(err.Error()))
  376. return
  377. }
  378. ctx.JSON(http.StatusOK, models.BaseMessageApi{Code: 0, Message: jobId})
  379. }
  380. func checkParameters(ctx *context.Context, option api.CreateTrainJobOption, lock *redis_lock.DistributeLock, repo *models.Repository) (*models.Specification, map[string]models.DatasetInfo, string, error) {
  381. isOk, err := lock.Lock(models.CloudbrainKeyDuration)
  382. if !isOk {
  383. log.Error("lock processed failed:%v", err, ctx.Data["MsgID"])
  384. return nil, nil, "", fmt.Errorf(ctx.Tr("repo.cloudbrain_samejob_err"))
  385. }
  386. if !jobNamePattern.MatchString(option.DisplayJobName) {
  387. return nil, nil, "", fmt.Errorf(ctx.Tr("repo.cloudbrain_jobname_err"))
  388. }
  389. bootFileExist, err := ctx.Repo.FileExists(option.BootFile, option.BranchName)
  390. if err != nil || !bootFileExist {
  391. log.Error("Get bootfile error:", err, ctx.Data["MsgID"])
  392. return nil, nil, "", fmt.Errorf(ctx.Tr("repo.cloudbrain_bootfile_err"))
  393. }
  394. computeResource := models.GPUResource
  395. if isNpuTask(option) {
  396. computeResource = models.NPUResource
  397. }
  398. //check count limit
  399. taskType := option.Type
  400. if isC2NetTask(option) {
  401. taskType = 2
  402. }
  403. count, err := GetNotFinalStatusTaskCount(ctx.User.ID, taskType, string(models.JobTypeTrain), computeResource)
  404. if err != nil {
  405. log.Error("GetCountByUserID failed:%v", err, ctx.Data["MsgID"])
  406. return nil, nil, "", fmt.Errorf("system error")
  407. } else {
  408. if count >= 1 {
  409. log.Error("the user already has running or waiting task", ctx.Data["MsgID"])
  410. return nil, nil, "", fmt.Errorf("you have already a running or waiting task, can not create more.")
  411. }
  412. }
  413. //check param
  414. if err := paramCheckCreateTrainJob(option.BootFile, option.BranchName); err != nil {
  415. log.Error("paramCheckCreateTrainJob failed:(%v)", err, ctx.Data["MsgID"])
  416. return nil, nil, "", err
  417. }
  418. //check whether the task name in the project is duplicated
  419. tasks, err := models.GetCloudbrainsByDisplayJobName(repo.ID, string(models.JobTypeTrain), option.DisplayJobName)
  420. if err == nil {
  421. if len(tasks) != 0 {
  422. log.Error("the job name did already exist", ctx.Data["MsgID"])
  423. return nil, nil, "", fmt.Errorf("The job name did already exist.")
  424. }
  425. } else {
  426. if !models.IsErrJobNotExist(err) {
  427. log.Error("system error, %v", err, ctx.Data["MsgID"])
  428. return nil, nil, "", fmt.Errorf("system error")
  429. }
  430. }
  431. //check specification
  432. computeType := models.GPU
  433. if isNpuTask(option) {
  434. computeType = models.NPU
  435. }
  436. cluster := models.OpenICluster
  437. if isC2NetTask(option) {
  438. cluster = models.C2NetCluster
  439. }
  440. aiCenterCode := ""
  441. if option.Type == TaskTypeCloudbrainOne {
  442. aiCenterCode = models.AICenterOfCloudBrainOne
  443. } else if option.Type == TaskTypeModelArts {
  444. aiCenterCode = models.AICenterOfCloudBrainTwo
  445. }
  446. spec, err := resource.GetAndCheckSpec(ctx.User.ID, option.SpecId, models.FindSpecsOptions{
  447. JobType: models.JobTypeTrain,
  448. ComputeResource: computeType,
  449. Cluster: cluster,
  450. AiCenterCode: aiCenterCode,
  451. })
  452. if err != nil || spec == nil {
  453. return nil, nil, "", fmt.Errorf("Resource specification is not available.")
  454. }
  455. if !account.IsPointBalanceEnough(ctx.User.ID, spec.UnitPrice) {
  456. log.Error("point balance is not enough,userId=%d specId=%d", ctx.User.ID, spec.ID)
  457. return nil, nil, "", fmt.Errorf(ctx.Tr("points.insufficient_points_balance"))
  458. }
  459. //check dataset
  460. var datasetInfos map[string]models.DatasetInfo
  461. var datasetNames string
  462. if option.Type != TaskTypeModelArts {
  463. if isC2NetTask(option) {
  464. datasetInfos, datasetNames, err = models.GetDatasetInfo(option.Attachment, computeType)
  465. } else {
  466. datasetInfos, datasetNames, err = models.GetDatasetInfo(option.Attachment)
  467. }
  468. if err != nil {
  469. log.Error("GetDatasetInfo failed: %v", err, ctx.Data["MsgID"])
  470. return nil, nil, "", fmt.Errorf(ctx.Tr("cloudbrain.error.dataset_select"))
  471. }
  472. }
  473. return spec, datasetInfos, datasetNames, err
  474. }
  475. func isNpuTask(option api.CreateTrainJobOption) bool {
  476. return option.Type == TaskTypeModelArts || option.Type == TaskTypeGrampusNPU
  477. }
  478. func isC2NetTask(option api.CreateTrainJobOption) bool {
  479. return option.Type == TaskTypeGrampusGPU || option.Type == TaskTypeGrampusNPU
  480. }
  481. func GrampusTrainJobNpuCreate(ctx *context.Context, option api.CreateTrainJobOption) {
  482. displayJobName := option.DisplayJobName
  483. jobName := util.ConvertDisplayJobNameToJobName(displayJobName)
  484. uuid := option.Attachment
  485. description := option.Description
  486. bootFile := strings.TrimSpace(option.BootFile)
  487. params := option.Params
  488. repo := ctx.Repo.Repository
  489. codeLocalPath := setting.JobPath + jobName + modelarts.CodePath
  490. codeObsPath := grampus.JobPath + jobName + modelarts.CodePath
  491. branchName := option.BranchName
  492. isLatestVersion := modelarts.IsLatestVersion
  493. versionCount := modelarts.VersionCountOne
  494. engineName := option.Image
  495. lock := redis_lock.NewDistributeLock(redis_key.CloudbrainBindingJobNameKey(fmt.Sprint(repo.ID), string(models.JobTypeTrain), displayJobName))
  496. defer lock.UnLock()
  497. spec, datasetInfos, datasetNames, err := checkParameters(ctx, option, lock, repo)
  498. if err != nil {
  499. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(err.Error()))
  500. return
  501. }
  502. //prepare code and out path
  503. _, err = ioutil.ReadDir(codeLocalPath)
  504. if err == nil {
  505. os.RemoveAll(codeLocalPath)
  506. }
  507. if err := downloadZipCode(ctx, codeLocalPath, branchName); err != nil {
  508. log.Error("downloadZipCode failed, server timed out: %s (%v)", repo.FullName(), err)
  509. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(ctx.Tr("cloudbrain.load_code_failed")))
  510. return
  511. }
  512. //todo: upload code (send to file_server todo this work?)
  513. if err := obsMkdir(setting.CodePathPrefix + jobName + modelarts.OutputPath); err != nil {
  514. log.Error("Failed to obsMkdir_output: %s (%v)", repo.FullName(), err)
  515. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(ctx.Tr("cloudbrain.load_code_failed")))
  516. return
  517. }
  518. if err := uploadCodeToObs(codeLocalPath, jobName, ""); err != nil {
  519. log.Error("Failed to uploadCodeToObs: %s (%v)", repo.FullName(), err)
  520. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(ctx.Tr("cloudbrain.load_code_failed")))
  521. return
  522. }
  523. var datasetRemotePath, allFileName string
  524. for _, datasetInfo := range datasetInfos {
  525. if datasetRemotePath == "" {
  526. datasetRemotePath = datasetInfo.DataLocalPath + "'" + datasetInfo.FullName + "'"
  527. allFileName = datasetInfo.FullName
  528. } else {
  529. datasetRemotePath = datasetRemotePath + ";" + datasetInfo.DataLocalPath + "'" + datasetInfo.FullName + "'"
  530. allFileName = allFileName + ";" + datasetInfo.FullName
  531. }
  532. }
  533. //prepare command
  534. preTrainModelPath := getPreTrainModelPath(option.PreTrainModelUrl, option.CkptName)
  535. command, err := generateCommand(repo.Name, grampus.ProcessorTypeNPU, codeObsPath+cloudbrain.DefaultBranchName+".zip", datasetRemotePath, bootFile, params, setting.CodePathPrefix+jobName+modelarts.OutputPath, allFileName, preTrainModelPath, option.CkptName, grampus.GetNpuModelRemoteObsUrl(jobName))
  536. if err != nil {
  537. log.Error("Failed to generateCommand: %s (%v)", displayJobName, err, ctx.Data["MsgID"])
  538. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi("Create task failed, internal error"))
  539. return
  540. }
  541. commitID, _ := ctx.Repo.GitRepo.GetBranchCommitID(branchName)
  542. req := &grampus.GenerateTrainJobReq{
  543. JobName: jobName,
  544. DisplayJobName: displayJobName,
  545. ComputeResource: models.NPUResource,
  546. ProcessType: grampus.ProcessorTypeNPU,
  547. Command: command,
  548. ImageId: option.ImageID,
  549. Description: description,
  550. CodeObsPath: codeObsPath,
  551. BootFileUrl: codeObsPath + bootFile,
  552. BootFile: bootFile,
  553. WorkServerNumber: option.WorkServerNumber,
  554. Uuid: uuid,
  555. CommitID: commitID,
  556. IsLatestVersion: isLatestVersion,
  557. BranchName: branchName,
  558. Params: option.Params,
  559. EngineName: engineName,
  560. VersionCount: versionCount,
  561. TotalVersionCount: modelarts.TotalVersionCount,
  562. DatasetNames: datasetNames,
  563. DatasetInfos: datasetInfos,
  564. Spec: spec,
  565. CodeName: strings.ToLower(repo.Name),
  566. }
  567. if option.ModelName != "" { //使用预训练模型训练
  568. req.ModelName = option.ModelName
  569. req.LabelName = option.LabelName
  570. req.CkptName = option.CkptName
  571. req.ModelVersion = option.ModelVersion
  572. req.PreTrainModelUrl = option.PreTrainModelUrl
  573. req.PreTrainModelPath = preTrainModelPath
  574. }
  575. jobId, err := grampus.GenerateTrainJob(ctx, req)
  576. if err != nil {
  577. log.Error("GenerateTrainJob failed:%v", err.Error())
  578. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(err.Error()))
  579. return
  580. }
  581. ctx.JSON(http.StatusOK, models.BaseMessageApi{Code: 0, Message: jobId})
  582. }
  583. func obsMkdir(dir string) error {
  584. input := &obs.PutObjectInput{}
  585. input.Bucket = setting.Bucket
  586. input.Key = dir
  587. _, err := storage.ObsCli.PutObject(input)
  588. if err != nil {
  589. log.Error("PutObject(%s) failed: %s", input.Key, err.Error())
  590. return err
  591. }
  592. return nil
  593. }
  594. func uploadCodeToObs(codePath, jobName, parentDir string) error {
  595. files, err := readDir(codePath)
  596. if err != nil {
  597. log.Error("readDir(%s) failed: %s", codePath, err.Error())
  598. return err
  599. }
  600. for _, file := range files {
  601. if file.IsDir() {
  602. input := &obs.PutObjectInput{}
  603. input.Bucket = setting.Bucket
  604. input.Key = parentDir + file.Name() + "/"
  605. _, err = storage.ObsCli.PutObject(input)
  606. if err != nil {
  607. log.Error("PutObject(%s) failed: %s", input.Key, err.Error())
  608. return err
  609. }
  610. if err = uploadCodeToObs(codePath+file.Name()+"/", jobName, parentDir+file.Name()+"/"); err != nil {
  611. log.Error("uploadCodeToObs(%s) failed: %s", file.Name(), err.Error())
  612. return err
  613. }
  614. } else {
  615. input := &obs.PutFileInput{}
  616. input.Bucket = setting.Bucket
  617. input.Key = setting.CodePathPrefix + jobName + "/code/" + parentDir + file.Name()
  618. input.SourceFile = codePath + file.Name()
  619. _, err = storage.ObsCli.PutFile(input)
  620. if err != nil {
  621. log.Error("PutFile(%s) failed: %s", input.SourceFile, err.Error())
  622. return err
  623. }
  624. }
  625. }
  626. return nil
  627. }
  628. func paramCheckCreateTrainJob(bootFile string, branchName string) error {
  629. if !strings.HasSuffix(strings.TrimSpace(bootFile), ".py") {
  630. log.Error("the boot file(%s) must be a python file", bootFile)
  631. return errors.New("启动文件必须是python文件")
  632. }
  633. if branchName == "" {
  634. log.Error("the branch must not be null!", branchName)
  635. return errors.New("代码分支不能为空!")
  636. }
  637. return nil
  638. }
  639. func downloadZipCode(ctx *context.Context, codePath, branchName string) error {
  640. archiveType := git.ZIP
  641. archivePath := codePath
  642. if !com.IsDir(archivePath) {
  643. if err := os.MkdirAll(archivePath, os.ModePerm); err != nil {
  644. log.Error("MkdirAll failed:" + err.Error())
  645. return err
  646. }
  647. }
  648. // Get corresponding commit.
  649. var (
  650. commit *git.Commit
  651. err error
  652. )
  653. gitRepo := ctx.Repo.GitRepo
  654. if err != nil {
  655. log.Error("OpenRepository failed:" + err.Error())
  656. return err
  657. }
  658. if gitRepo.IsBranchExist(branchName) {
  659. commit, err = gitRepo.GetBranchCommit(branchName)
  660. if err != nil {
  661. log.Error("GetBranchCommit failed:" + err.Error())
  662. return err
  663. }
  664. } else {
  665. log.Error("the branch is not exist: " + branchName)
  666. return fmt.Errorf("The branch does not exist.")
  667. }
  668. archivePath = path.Join(archivePath, grampus.CodeArchiveName)
  669. if !com.IsFile(archivePath) {
  670. if err := commit.CreateArchive(archivePath, git.CreateArchiveOpts{
  671. Format: archiveType,
  672. Prefix: setting.Repository.PrefixArchiveFiles,
  673. }); err != nil {
  674. log.Error("CreateArchive failed:" + err.Error())
  675. return err
  676. }
  677. }
  678. return nil
  679. }
  680. func uploadCodeToMinio(codePath, jobName, parentDir string) error {
  681. files, err := readDir(codePath)
  682. if err != nil {
  683. log.Error("readDir(%s) failed: %s", codePath, err.Error())
  684. return err
  685. }
  686. for _, file := range files {
  687. if file.IsDir() {
  688. if err = uploadCodeToMinio(codePath+file.Name()+"/", jobName, parentDir+file.Name()+"/"); err != nil {
  689. log.Error("uploadCodeToMinio(%s) failed: %s", file.Name(), err.Error())
  690. return err
  691. }
  692. } else {
  693. destObject := setting.CBCodePathPrefix + jobName + parentDir + file.Name()
  694. sourceFile := codePath + file.Name()
  695. err = storage.Attachments.UploadObject(destObject, sourceFile)
  696. if err != nil {
  697. log.Error("UploadObject(%s) failed: %s", file.Name(), err.Error())
  698. return err
  699. }
  700. }
  701. }
  702. return nil
  703. }
  704. func uploadOneFileToMinio(codePath, filePath, jobName, parentDir string) error {
  705. destObject := setting.CBCodePathPrefix + jobName + parentDir + path.Base(filePath)
  706. sourceFile := codePath + "/" + filePath
  707. err := storage.Attachments.UploadObject(destObject, sourceFile)
  708. if err != nil {
  709. log.Error("UploadObject(%s) failed: %s", filePath, err.Error())
  710. return err
  711. }
  712. return nil
  713. }
  714. func readDir(dirname string) ([]os.FileInfo, error) {
  715. f, err := os.Open(dirname)
  716. if err != nil {
  717. return nil, err
  718. }
  719. list, err := f.Readdir(0)
  720. f.Close()
  721. if err != nil {
  722. //todo: can not upload empty folder
  723. if err == io.EOF {
  724. return nil, nil
  725. }
  726. return nil, err
  727. }
  728. //sort.Slice(list, func(i, j int) bool { return list[i].Name() < list[j].Name() })
  729. return list, nil
  730. }
  731. func mkModelPath(modelPath string) error {
  732. return mkPathAndReadMeFile(modelPath, "You can put the files into this directory and download the files by the web page.")
  733. }
  734. func mkPathAndReadMeFile(path string, text string) error {
  735. err := os.MkdirAll(path, os.ModePerm)
  736. if err != nil {
  737. log.Error("MkdirAll(%s) failed:%v", path, err)
  738. return err
  739. }
  740. fileName := path + "README"
  741. f, err := os.OpenFile(fileName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.ModePerm)
  742. if err != nil {
  743. log.Error("OpenFile failed", err.Error())
  744. return err
  745. }
  746. defer f.Close()
  747. _, err = f.WriteString(text)
  748. if err != nil {
  749. log.Error("WriteString failed", err.Error())
  750. return err
  751. }
  752. return nil
  753. }
  754. func getPreTrainModelPath(pretrainModelDir string, fileName string) string {
  755. index := strings.Index(pretrainModelDir, "/")
  756. if index > 0 {
  757. filterBucket := pretrainModelDir[index+1:]
  758. return filterBucket + fileName
  759. } else {
  760. return ""
  761. }
  762. }
  763. func generateCommand(repoName, processorType, codeRemotePath, dataRemotePath, bootFile, paramSrc, outputRemotePath, datasetName, pretrainModelPath, pretrainModelFileName, modelRemoteObsUrl string) (string, error) {
  764. var command string
  765. //prepare
  766. workDir := grampus.NpuWorkDir
  767. if processorType == grampus.ProcessorTypeNPU {
  768. command += "pwd;cd " + workDir + grampus.CommandPrepareScriptNpu
  769. } else if processorType == grampus.ProcessorTypeGPU {
  770. workDir = grampus.GpuWorkDir
  771. command += "pwd;cd " + workDir + fmt.Sprintf(grampus.CommandPrepareScriptGpu, setting.Grampus.SyncScriptProject, setting.Grampus.SyncScriptProject)
  772. }
  773. //download code & dataset
  774. if processorType == grampus.ProcessorTypeNPU {
  775. //no need to download code & dataset by internet
  776. } else if processorType == grampus.ProcessorTypeGPU {
  777. commandDownload := "./downloader_for_minio " + setting.Grampus.Env + " " + codeRemotePath + " " + grampus.CodeArchiveName + " '" + dataRemotePath + "' '" + datasetName + "'"
  778. commandDownload = processPretrainModelParameter(pretrainModelPath, pretrainModelFileName, commandDownload)
  779. command += commandDownload
  780. }
  781. //unzip code & dataset
  782. if processorType == grampus.ProcessorTypeNPU {
  783. //no need to process
  784. } else if processorType == grampus.ProcessorTypeGPU {
  785. unZipDatasetCommand := GenerateDatasetUnzipCommand(datasetName)
  786. commandUnzip := "cd " + workDir + "code;unzip -q master.zip;rm -f master.zip;echo \"start to unzip dataset\";cd " + workDir + "dataset;" + unZipDatasetCommand
  787. command += commandUnzip
  788. }
  789. command += "echo \"unzip finished;start to exec code;\";"
  790. // set export
  791. var commandExport string
  792. if processorType == grampus.ProcessorTypeNPU {
  793. commandExport = "export bucket=" + setting.Bucket + " && export remote_path=" + outputRemotePath + ";"
  794. } else if processorType == grampus.ProcessorTypeGPU {
  795. commandExport = "export env=" + setting.Grampus.Env + " && export remote_path=" + outputRemotePath + ";"
  796. }
  797. command += commandExport
  798. //exec code
  799. var parameters models.Parameters
  800. var paramCode string
  801. if len(paramSrc) != 0 {
  802. err := json.Unmarshal([]byte(paramSrc), &parameters)
  803. if err != nil {
  804. log.Error("Failed to Unmarshal params: %s (%v)", paramSrc, err)
  805. return command, err
  806. }
  807. for _, parameter := range parameters.Parameter {
  808. paramCode += " --" + parameter.Label + "=" + parameter.Value
  809. }
  810. }
  811. var commandCode string
  812. if processorType == grampus.ProcessorTypeNPU {
  813. paramCode += " --model_url=" + modelRemoteObsUrl
  814. commandCode = "/bin/bash /home/work/run_train_for_openi.sh /home/work/openi.py " + grampus.NpuLocalLogUrl + paramCode + ";"
  815. } else if processorType == grampus.ProcessorTypeGPU {
  816. if pretrainModelFileName != "" {
  817. paramCode += " --ckpt_url" + "=" + workDir + "pretrainmodel/" + pretrainModelFileName
  818. }
  819. commandCode = "cd " + workDir + "code/" + strings.ToLower(repoName) + ";python " + bootFile + paramCode + ";"
  820. }
  821. command += commandCode
  822. //get exec result
  823. commandGetRes := "result=$?;"
  824. command += commandGetRes
  825. //upload models
  826. if processorType == grampus.ProcessorTypeNPU {
  827. // no need to upload
  828. } else if processorType == grampus.ProcessorTypeGPU {
  829. commandUpload := "cd " + workDir + setting.Grampus.SyncScriptProject + "/;./uploader_for_gpu " + setting.Grampus.Env + " " + outputRemotePath + " " + workDir + "output/;"
  830. command += commandUpload
  831. }
  832. //check exec result
  833. commandCheckRes := "bash -c \"[[ $result -eq 0 ]] && exit 0 || exit -1\""
  834. command += commandCheckRes
  835. return command, nil
  836. }
  837. func processPretrainModelParameter(pretrainModelPath string, pretrainModelFileName string, commandDownload string) string {
  838. commandDownloadTemp := commandDownload
  839. if pretrainModelPath != "" {
  840. commandDownloadTemp += " '" + pretrainModelPath + "' '" + pretrainModelFileName + "'"
  841. }
  842. commandDownloadTemp += ";"
  843. return commandDownloadTemp
  844. }
  845. func GenerateDatasetUnzipCommand(datasetName string) string {
  846. var unZipDatasetCommand string
  847. datasetNameArray := strings.Split(datasetName, ";")
  848. if len(datasetNameArray) == 1 { //单数据集
  849. unZipDatasetCommand = "unzip -q '" + datasetName + "';"
  850. if strings.HasSuffix(datasetNameArray[0], ".tar.gz") {
  851. unZipDatasetCommand = "tar --strip-components=1 -zxvf '" + datasetName + "';"
  852. }
  853. unZipDatasetCommand += "rm -f '" + datasetName + "';"
  854. } else { //多数据集
  855. for _, datasetNameTemp := range datasetNameArray {
  856. if strings.HasSuffix(datasetNameTemp, ".tar.gz") {
  857. unZipDatasetCommand = unZipDatasetCommand + "tar -zxvf '" + datasetNameTemp + "';"
  858. } else {
  859. unZipDatasetCommand = unZipDatasetCommand + "unzip -q '" + datasetNameTemp + "' -d './" + strings.TrimSuffix(datasetNameTemp, ".zip") + "';"
  860. }
  861. unZipDatasetCommand += "rm -f '" + datasetNameTemp + "';"
  862. }
  863. }
  864. return unZipDatasetCommand
  865. }
  866. func getPoolId() string {
  867. var resourcePools modelarts.ResourcePool
  868. json.Unmarshal([]byte(setting.ResourcePools), &resourcePools)
  869. return resourcePools.Info[0].ID
  870. }
  871. func PrepareSpec4Show(task *models.Cloudbrain) {
  872. s, err := resource.GetCloudbrainSpec(task.ID)
  873. if err != nil {
  874. log.Info("error:" + err.Error())
  875. return
  876. }
  877. task.Spec = s
  878. }
  879. func IsTaskNotStop(task *models.Cloudbrain) bool {
  880. statuses := CloudbrainOneNotFinalStatuses
  881. if task.Type == models.TypeCloudBrainTwo || task.Type == models.TypeCDCenter {
  882. statuses = CloudbrainTwoNotFinalStatuses
  883. } else {
  884. statuses = GrampusNotFinalStatuses
  885. }
  886. for _, status := range statuses {
  887. if task.Status == status {
  888. return true
  889. }
  890. }
  891. return false
  892. }
  893. func SyncTaskStatus(task *models.Cloudbrain) error {
  894. if task.Type == models.TypeCloudBrainOne {
  895. result, err := cloudbrain.GetJob(task.JobID)
  896. if err != nil {
  897. log.Info("error:" + err.Error())
  898. return fmt.Errorf("repo.cloudbrain_query_fail")
  899. }
  900. if result != nil {
  901. jobRes, _ := models.ConvertToJobResultPayload(result.Payload)
  902. taskRoles := jobRes.TaskRoles
  903. taskRes, _ := models.ConvertToTaskPod(taskRoles[cloudbrain.SubTaskName].(map[string]interface{}))
  904. oldStatus := task.Status
  905. task.Status = taskRes.TaskStatuses[0].State
  906. task.ContainerID = taskRes.TaskStatuses[0].ContainerID
  907. models.ParseAndSetDurationFromCloudBrainOne(jobRes, task)
  908. if task.DeletedAt.IsZero() { //normal record
  909. if oldStatus != task.Status {
  910. notification.NotifyChangeCloudbrainStatus(task, oldStatus)
  911. }
  912. err = models.UpdateJob(task)
  913. if err != nil {
  914. return fmt.Errorf("repo.cloudbrain_query_fail")
  915. }
  916. }
  917. } else {
  918. log.Info("error:" + err.Error())
  919. return fmt.Errorf("repo.cloudbrain_query_fail")
  920. }
  921. } else if task.Type == models.TypeCloudBrainTwo || task.Type == models.TypeCDCenter {
  922. err := modelarts.HandleTrainJobInfo(task)
  923. if err != nil {
  924. return fmt.Errorf("repo.cloudbrain_query_fail")
  925. }
  926. } else if task.Type == models.TypeC2Net {
  927. result, err := grampus.GetJob(task.JobID)
  928. if err != nil {
  929. log.Error("GetJob failed:" + err.Error())
  930. return fmt.Errorf("repo.cloudbrain_query_fail")
  931. }
  932. if result != nil {
  933. if len(result.JobInfo.Tasks[0].CenterID) == 1 && len(result.JobInfo.Tasks[0].CenterName) == 1 {
  934. task.AiCenter = result.JobInfo.Tasks[0].CenterID[0] + "+" + result.JobInfo.Tasks[0].CenterName[0]
  935. }
  936. oldStatus := task.Status
  937. task.Status = grampus.TransTrainJobStatus(result.JobInfo.Status)
  938. if task.Status != oldStatus || task.Status == models.GrampusStatusRunning {
  939. task.Duration = result.JobInfo.RunSec
  940. if task.Duration < 0 {
  941. task.Duration = 0
  942. }
  943. task.TrainJobDuration = models.ConvertDurationToStr(task.Duration)
  944. if task.StartTime == 0 && result.JobInfo.StartedAt > 0 {
  945. task.StartTime = timeutil.TimeStamp(result.JobInfo.StartedAt)
  946. }
  947. if task.EndTime == 0 && models.IsTrainJobTerminal(task.Status) && task.StartTime > 0 {
  948. task.EndTime = task.StartTime.Add(task.Duration)
  949. }
  950. task.CorrectCreateUnix()
  951. if oldStatus != task.Status {
  952. notification.NotifyChangeCloudbrainStatus(task, oldStatus)
  953. if models.IsTrainJobTerminal(task.Status) && task.ComputeResource == models.NPUResource {
  954. if len(result.JobInfo.Tasks[0].CenterID) == 1 {
  955. urchin.GetBackNpuModel(task.ID, grampus.GetRemoteEndPoint(result.JobInfo.Tasks[0].CenterID[0]), grampus.BucketRemote, grampus.GetNpuModelObjectKey(task.JobName), grampus.GetCenterProxy(setting.Grampus.LocalCenterID))
  956. }
  957. }
  958. }
  959. err = models.UpdateJob(task)
  960. if err != nil {
  961. log.Error("UpdateJob failed:" + err.Error())
  962. return fmt.Errorf("repo.cloudbrain_query_fail")
  963. }
  964. }
  965. }
  966. }
  967. return nil
  968. }
  969. func getTrainJobCommand(option api.CreateTrainJobOption) (string, error) {
  970. var command string
  971. bootFile := strings.TrimSpace(option.BootFile)
  972. params := option.Params
  973. if !strings.HasSuffix(bootFile, ".py") {
  974. log.Error("bootFile(%s) format error", bootFile)
  975. return command, errors.New("bootFile format error")
  976. }
  977. var parameters models.Parameters
  978. var param string
  979. if len(params) != 0 {
  980. err := json.Unmarshal([]byte(params), &parameters)
  981. if err != nil {
  982. log.Error("Failed to Unmarshal params: %s (%v)", params, err)
  983. return command, err
  984. }
  985. for _, parameter := range parameters.Parameter {
  986. param += " --" + parameter.Label + "=" + parameter.Value
  987. }
  988. }
  989. if option.CkptName != "" {
  990. param += " --ckpt_url" + "=" + "/pretrainmodel/" + option.CkptName
  991. }
  992. command += "python /code/" + bootFile + param + " > " + cloudbrain.ModelMountPath + "/" + option.DisplayJobName + "-" + cloudbrain.LogFile
  993. return command, nil
  994. }
  995. func checkMultiNode(userId int64, serverNum int) string {
  996. if serverNum == 1 {
  997. return ""
  998. }
  999. modelarts.InitMultiNode()
  1000. var isServerNumValid = false
  1001. if modelarts.MultiNodeConfig != nil {
  1002. for _, info := range modelarts.MultiNodeConfig.Info {
  1003. if isInOrg, _ := models.IsOrganizationMemberByOrgName(info.Org, userId); isInOrg {
  1004. if isInNodes(info.Node, serverNum) {
  1005. isServerNumValid = true
  1006. break
  1007. }
  1008. }
  1009. }
  1010. }
  1011. if isServerNumValid {
  1012. return ""
  1013. } else {
  1014. return "repo.modelarts.no_node_right"
  1015. }
  1016. }
  1017. func isInNodes(nodes []int, num int) bool {
  1018. for _, node := range nodes {
  1019. if node == num {
  1020. return true
  1021. }
  1022. }
  1023. return false
  1024. }
  1025. func getUserCommand(engineId int, req *modelarts.GenerateTrainJobReq) (string, string) {
  1026. userImageUrl := ""
  1027. userCommand := ""
  1028. if engineId < 0 {
  1029. tmpCodeObsPath := strings.Trim(req.CodeObsPath, "/")
  1030. tmpCodeObsPaths := strings.Split(tmpCodeObsPath, "/")
  1031. lastCodeDir := "code"
  1032. if len(tmpCodeObsPaths) > 0 {
  1033. lastCodeDir = tmpCodeObsPaths[len(tmpCodeObsPaths)-1]
  1034. }
  1035. userCommand = "/bin/bash /home/work/run_train.sh 's3://" + req.CodeObsPath + "' '" + lastCodeDir + "/" + req.BootFile + "' '/tmp/log/train.log' --'data_url'='s3://" + req.DataUrl + "' --'train_url'='s3://" + req.TrainUrl + "'"
  1036. var versionInfos modelarts.VersionInfo
  1037. if err := json.Unmarshal([]byte(setting.EngineVersions), &versionInfos); err != nil {
  1038. log.Info("json parse err." + err.Error())
  1039. } else {
  1040. for _, engine := range versionInfos.Version {
  1041. if engine.ID == engineId {
  1042. userImageUrl = engine.Url
  1043. break
  1044. }
  1045. }
  1046. }
  1047. for _, param := range req.Parameters {
  1048. userCommand += " --'" + param.Label + "'='" + param.Value + "'"
  1049. }
  1050. return userCommand, userImageUrl
  1051. }
  1052. return userCommand, userImageUrl
  1053. }