You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

modelarts.go 12 kB

4 years ago
3 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
3 years ago
3 years ago
4 years ago
3 years ago
3 years ago
3 years ago
3 years ago
4 years ago
3 years ago
3 years ago
4 years ago
4 years ago
4 years ago
4 years ago
3 years ago
3 years ago
3 years ago
3 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
4 years ago
4 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
4 years ago
4 years ago
3 years ago
4 years ago
3 years ago
4 years ago
3 years ago
4 years ago
3 years ago
4 years ago
3 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. package modelarts
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "path"
  6. "strconv"
  7. "code.gitea.io/gitea/models"
  8. "code.gitea.io/gitea/modules/context"
  9. "code.gitea.io/gitea/modules/log"
  10. "code.gitea.io/gitea/modules/setting"
  11. "code.gitea.io/gitea/modules/storage"
  12. )
  13. const (
  14. //notebook
  15. storageTypeOBS = "obs"
  16. autoStopDuration = 4 * 60 * 60
  17. DataSetMountPath = "/home/ma-user/work"
  18. NotebookEnv = "Python3"
  19. NotebookType = "Ascend"
  20. FlavorInfo = "Ascend: 1*Ascend 910 CPU: 24 核 96GiB (modelarts.kat1.xlarge)"
  21. //train-job
  22. // ResourcePools = "{\"resource_pool\":[{\"id\":\"pool1328035d\", \"value\":\"专属资源池\"}]}"
  23. // Engines = "{\"engine\":[{\"id\":1, \"value\":\"Ascend-Powered-Engine\"}]}"
  24. // EngineVersions = "{\"version\":[{\"id\":118,\"value\":\"MindSpore-1.0.0-c75-python3.7-euleros2.8-aarch64\"}," +
  25. // "{\"id\":119,\"value\":\"MindSpore-1.1.1-c76-python3.7-euleros2.8-aarch64\"}," +
  26. // "{\"id\":120,\"value\":\"MindSpore-1.1.1-c76-tr5-python3.7-euleros2.8-aarch64\"}," +
  27. // "{\"id\":117,\"value\":\"TF-1.15-c75-python3.7-euleros2.8-aarch64\"}" +
  28. // "]}"
  29. // TrainJobFlavorInfo = "{\"flavor\":[{\"code\":\"modelarts.bm.910.arm.public.2\",\"value\":\"Ascend : 2 * Ascend 910 CPU:48 核 512GiB\"}," +
  30. // "{\"code\":\"modelarts.bm.910.arm.public.8\",\"value\":\"Ascend : 8 * Ascend 910 CPU:192 核 2048GiB\"}," +
  31. // "{\"code\":\"modelarts.bm.910.arm.public.4\",\"value\":\"Ascend : 4 * Ascend 910 CPU:96 核 1024GiB\"}," +
  32. // "{\"code\":\"modelarts.bm.910.arm.public.1\",\"value\":\"Ascend : 1 * Ascend 910 CPU:24 核 256GiB\"}" +
  33. // "]}"
  34. CodePath = "/code/"
  35. OutputPath = "/output/"
  36. LogPath = "/log/"
  37. JobPath = "/job/"
  38. OrderDesc = "desc" //向下查询
  39. OrderAsc = "asc" //向上查询
  40. Lines = 500
  41. TrainUrl = "train_url"
  42. DataUrl = "data_url"
  43. PerPage = 10
  44. IsLatestVersion = "1"
  45. NotLatestVersion = "0"
  46. DebugType = -1
  47. VersionCount = 1
  48. SortByCreateTime = "create_time"
  49. ConfigTypeCustom = "custom"
  50. TotalVersionCount = 1
  51. )
  52. var (
  53. poolInfos *models.PoolInfos
  54. FlavorInfos *models.FlavorInfos
  55. )
  56. type GenerateTrainJobReq struct {
  57. JobName string
  58. Uuid string
  59. Description string
  60. CodeObsPath string
  61. BootFile string
  62. BootFileUrl string
  63. DataUrl string
  64. TrainUrl string
  65. FlavorCode string
  66. LogUrl string
  67. PoolID string
  68. WorkServerNumber int
  69. EngineID int64
  70. Parameters []models.Parameter
  71. CommitID string
  72. IsLatestVersion string
  73. Params string
  74. BranchName string
  75. PreVersionId int64
  76. PreVersionName string
  77. FlavorName string
  78. VersionCount int
  79. EngineName string
  80. TotalVersionCount int
  81. }
  82. type GenerateTrainJobVersionReq struct {
  83. JobName string
  84. Uuid string
  85. Description string
  86. CodeObsPath string
  87. BootFile string
  88. BootFileUrl string
  89. DataUrl string
  90. TrainUrl string
  91. FlavorCode string
  92. LogUrl string
  93. PoolID string
  94. WorkServerNumber int
  95. EngineID int64
  96. Parameters []models.Parameter
  97. Params string
  98. PreVersionId int64
  99. CommitID string
  100. BranchName string
  101. FlavorName string
  102. EngineName string
  103. PreVersionName string
  104. TotalVersionCount int
  105. }
  106. type VersionInfo struct {
  107. Version []struct {
  108. ID int `json:"id"`
  109. Value string `json:"value"`
  110. } `json:"version"`
  111. }
  112. type Flavor struct {
  113. Info []struct {
  114. Code string `json:"code"`
  115. Value string `json:"value"`
  116. } `json:"flavor"`
  117. }
  118. type Engine struct {
  119. Info []struct {
  120. ID int `json:"id"`
  121. Value string `json:"value"`
  122. } `json:"engine"`
  123. }
  124. type ResourcePool struct {
  125. Info []struct {
  126. ID string `json:"id"`
  127. Value string `json:"value"`
  128. } `json:"resource_pool"`
  129. }
  130. // type Parameter struct {
  131. // Label string `json:"label"`
  132. // Value string `json:"value"`
  133. // }
  134. // type Parameters struct {
  135. // Parameter []Parameter `json:"parameter"`
  136. // }
  137. type Parameters struct {
  138. Parameter []struct {
  139. Label string `json:"label"`
  140. Value string `json:"value"`
  141. } `json:"parameter"`
  142. }
  143. func GenerateTask(ctx *context.Context, jobName, uuid, description, flavor string) error {
  144. var dataActualPath string
  145. if uuid != "" {
  146. dataActualPath = setting.Bucket + "/" + setting.BasePath + path.Join(uuid[0:1], uuid[1:2]) + "/" + uuid + "/"
  147. } else {
  148. userPath := setting.UserBasePath + ctx.User.Name + "/"
  149. isExist, err := storage.ObsHasObject(userPath)
  150. if err != nil {
  151. log.Error("ObsHasObject failed:%v", err.Error(), ctx.Data["MsgID"])
  152. return err
  153. }
  154. if !isExist {
  155. if err = storage.ObsCreateObject(userPath); err != nil {
  156. log.Error("ObsCreateObject failed:%v", err.Error(), ctx.Data["MsgID"])
  157. return err
  158. }
  159. }
  160. dataActualPath = setting.Bucket + "/" + userPath
  161. }
  162. if poolInfos == nil {
  163. json.Unmarshal([]byte(setting.PoolInfos), &poolInfos)
  164. }
  165. jobResult, err := CreateJob(models.CreateNotebookParams{
  166. JobName: jobName,
  167. Description: description,
  168. ProfileID: setting.ProfileID,
  169. Flavor: flavor,
  170. Pool: models.Pool{
  171. ID: poolInfos.PoolInfo[0].PoolId,
  172. Name: poolInfos.PoolInfo[0].PoolName,
  173. Type: poolInfos.PoolInfo[0].PoolType,
  174. },
  175. Spec: models.Spec{
  176. Storage: models.Storage{
  177. Type: storageTypeOBS,
  178. Location: models.Location{
  179. Path: dataActualPath,
  180. },
  181. },
  182. AutoStop: models.AutoStop{
  183. Enable: true,
  184. Duration: autoStopDuration,
  185. },
  186. },
  187. })
  188. if err != nil {
  189. log.Error("CreateJob failed: %v", err.Error())
  190. return err
  191. }
  192. err = models.CreateCloudbrain(&models.Cloudbrain{
  193. Status: string(models.JobWaiting),
  194. UserID: ctx.User.ID,
  195. RepoID: ctx.Repo.Repository.ID,
  196. JobID: jobResult.ID,
  197. JobName: jobName,
  198. JobType: string(models.JobTypeDebug),
  199. Type: models.TypeCloudBrainTwo,
  200. Uuid: uuid,
  201. ComputeResource: models.NPUResource,
  202. })
  203. if err != nil {
  204. return err
  205. }
  206. return nil
  207. }
  208. func GenerateTrainJob(ctx *context.Context, req *GenerateTrainJobReq) (err error) {
  209. jobResult, err := createTrainJob(models.CreateTrainJobParams{
  210. JobName: req.JobName,
  211. Description: req.Description,
  212. Config: models.Config{
  213. WorkServerNum: req.WorkServerNumber,
  214. AppUrl: req.CodeObsPath,
  215. BootFileUrl: req.BootFileUrl,
  216. DataUrl: req.DataUrl,
  217. EngineID: req.EngineID,
  218. TrainUrl: req.TrainUrl,
  219. LogUrl: req.LogUrl,
  220. PoolID: req.PoolID,
  221. CreateVersion: true,
  222. Flavor: models.Flavor{
  223. Code: req.FlavorCode,
  224. },
  225. Parameter: req.Parameters,
  226. },
  227. })
  228. if err != nil {
  229. log.Error("CreateJob failed: %v", err.Error())
  230. return err
  231. }
  232. attach, err := models.GetAttachmentByUUID(req.Uuid)
  233. if err != nil {
  234. log.Error("GetAttachmentByUUID(%s) failed:%v", strconv.FormatInt(jobResult.JobID, 10), err.Error())
  235. return err
  236. }
  237. err = models.CreateCloudbrain(&models.Cloudbrain{
  238. Status: TransTrainJobStatus(jobResult.Status),
  239. UserID: ctx.User.ID,
  240. RepoID: ctx.Repo.Repository.ID,
  241. JobID: strconv.FormatInt(jobResult.JobID, 10),
  242. JobName: req.JobName,
  243. JobType: string(models.JobTypeTrain),
  244. Type: models.TypeCloudBrainTwo,
  245. VersionID: jobResult.VersionID,
  246. VersionName: jobResult.VersionName,
  247. Uuid: req.Uuid,
  248. DatasetName: attach.Name,
  249. CommitID: req.CommitID,
  250. IsLatestVersion: req.IsLatestVersion,
  251. ComputeResource: models.NPUResource,
  252. EngineID: req.EngineID,
  253. TrainUrl: req.TrainUrl,
  254. BranchName: req.BranchName,
  255. Parameters: req.Params,
  256. BootFile: req.BootFile,
  257. DataUrl: req.DataUrl,
  258. LogUrl: req.LogUrl,
  259. FlavorCode: req.FlavorCode,
  260. Description: req.Description,
  261. WorkServerNumber: req.WorkServerNumber,
  262. FlavorName: req.FlavorName,
  263. EngineName: req.EngineName,
  264. VersionCount: req.VersionCount,
  265. TotalVersionCount: req.TotalVersionCount,
  266. })
  267. if err != nil {
  268. log.Error("CreateCloudbrain(%s) failed:%v", req.JobName, err.Error())
  269. return err
  270. }
  271. return nil
  272. }
  273. func GenerateTrainJobVersion(ctx *context.Context, req *GenerateTrainJobReq, jobId string) (err error) {
  274. jobResult, err := createTrainJobVersion(models.CreateTrainJobVersionParams{
  275. Description: req.Description,
  276. Config: models.TrainJobVersionConfig{
  277. WorkServerNum: req.WorkServerNumber,
  278. AppUrl: req.CodeObsPath,
  279. BootFileUrl: req.BootFileUrl,
  280. DataUrl: req.DataUrl,
  281. EngineID: req.EngineID,
  282. TrainUrl: req.TrainUrl,
  283. LogUrl: req.LogUrl,
  284. PoolID: req.PoolID,
  285. Flavor: models.Flavor{
  286. Code: req.FlavorCode,
  287. },
  288. Parameter: req.Parameters,
  289. PreVersionId: req.PreVersionId,
  290. },
  291. }, jobId)
  292. if err != nil {
  293. log.Error("CreateJob failed: %v", err.Error())
  294. return err
  295. }
  296. attach, err := models.GetAttachmentByUUID(req.Uuid)
  297. if err != nil {
  298. log.Error("GetAttachmentByUUID(%s) failed:%v", strconv.FormatInt(jobResult.JobID, 10), err.Error())
  299. return err
  300. }
  301. repo := ctx.Repo.Repository
  302. VersionTaskList, VersionListCount, err := models.CloudbrainsVersionList(&models.CloudbrainsOptions{
  303. RepoID: repo.ID,
  304. Type: models.TypeCloudBrainTwo,
  305. JobType: string(models.JobTypeTrain),
  306. JobID: strconv.FormatInt(jobResult.JobID, 10),
  307. })
  308. if err != nil {
  309. ctx.ServerError("Cloudbrain", err)
  310. return err
  311. }
  312. //将当前版本的isLatestVersion设置为"1"和任务数量更新,任务数量包括当前版本数VersionCount和历史创建的总版本数TotalVersionCount
  313. err = models.CreateCloudbrain(&models.Cloudbrain{
  314. Status: TransTrainJobStatus(jobResult.Status),
  315. UserID: ctx.User.ID,
  316. RepoID: ctx.Repo.Repository.ID,
  317. JobID: strconv.FormatInt(jobResult.JobID, 10),
  318. JobName: req.JobName,
  319. JobType: string(models.JobTypeTrain),
  320. Type: models.TypeCloudBrainTwo,
  321. VersionID: jobResult.VersionID,
  322. VersionName: jobResult.VersionName,
  323. Uuid: req.Uuid,
  324. DatasetName: attach.Name,
  325. CommitID: req.CommitID,
  326. IsLatestVersion: req.IsLatestVersion,
  327. PreVersionName: req.PreVersionName,
  328. ComputeResource: models.NPUResource,
  329. EngineID: req.EngineID,
  330. TrainUrl: req.TrainUrl,
  331. BranchName: req.BranchName,
  332. Parameters: req.Params,
  333. BootFile: req.BootFile,
  334. DataUrl: req.DataUrl,
  335. LogUrl: req.LogUrl,
  336. PreVersionId: req.PreVersionId,
  337. FlavorCode: req.FlavorCode,
  338. Description: req.Description,
  339. WorkServerNumber: req.WorkServerNumber,
  340. FlavorName: req.FlavorName,
  341. EngineName: req.EngineName,
  342. TotalVersionCount: VersionTaskList[0].TotalVersionCount + 1,
  343. VersionCount: VersionListCount + 1,
  344. })
  345. if err != nil {
  346. log.Error("CreateCloudbrain(%s) failed:%v", req.JobName, err.Error())
  347. return err
  348. }
  349. //将训练任务的上一版本的isLatestVersion设置为"0"
  350. err = models.SetVersionCountAndLatestVersion(strconv.FormatInt(jobResult.JobID, 10), VersionTaskList[0].VersionName, VersionCount, NotLatestVersion, TotalVersionCount)
  351. if err != nil {
  352. ctx.ServerError("Update IsLatestVersion failed", err)
  353. return err
  354. }
  355. return err
  356. }
  357. func TransTrainJobStatus(status int) string {
  358. switch status {
  359. case 0:
  360. return "UNKNOWN"
  361. case 1:
  362. return "INIT"
  363. case 2:
  364. return "IMAGE_CREATING"
  365. case 3:
  366. return "IMAGE_FAILED"
  367. case 4:
  368. return "SUBMIT_TRYING"
  369. case 5:
  370. return "SUBMIT_FAILED"
  371. case 6:
  372. return "DELETE_FAILED"
  373. case 7:
  374. return "WAITING"
  375. case 8:
  376. return "RUNNING"
  377. case 9:
  378. return "KILLING"
  379. case 10:
  380. return "COMPLETED"
  381. case 11:
  382. return "FAILED"
  383. case 12:
  384. return "KILLED"
  385. case 13:
  386. return "CANCELED"
  387. case 14:
  388. return "LOST"
  389. case 15:
  390. return "SCALING"
  391. case 16:
  392. return "SUBMIT_MODEL_FAILED"
  393. case 17:
  394. return "DEPLOY_SERVICE_FAILED"
  395. case 18:
  396. return "CHECK_INIT"
  397. case 19:
  398. return "CHECK_RUNNING"
  399. case 20:
  400. return "CHECK_RUNNING_COMPLETED"
  401. case 21:
  402. return "CHECK_FAILED"
  403. default:
  404. return strconv.Itoa(status)
  405. }
  406. }
  407. func GetVersionOutputPathByTotalVersionCount(TotalVersionCount int) (VersionOutputPath string) {
  408. talVersionCountToString := fmt.Sprintf("%04d", TotalVersionCount)
  409. VersionOutputPath = "V" + talVersionCountToString
  410. return VersionOutputPath
  411. }