You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grampus.go 7.4 kB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
2 years ago
3 years ago
2 years ago
2 years ago
2 years ago
3 years ago
3 years ago
3 years ago
3 years ago
2 years ago
2 years ago
3 years ago
2 years ago
3 years ago
3 years ago
2 years ago
3 years ago
2 years ago
2 years ago
3 years ago
3 years ago
2 years ago
3 years ago
2 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
2 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
2 years ago
2 years ago
2 years ago
2 years ago
3 years ago
3 years ago
3 years ago
3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. package grampus
  2. import (
  3. "encoding/json"
  4. "strings"
  5. "code.gitea.io/gitea/modules/setting"
  6. "code.gitea.io/gitea/models"
  7. "code.gitea.io/gitea/modules/context"
  8. "code.gitea.io/gitea/modules/log"
  9. "code.gitea.io/gitea/modules/notification"
  10. "code.gitea.io/gitea/modules/timeutil"
  11. )
  12. const (
  13. JobPath = "job/"
  14. ProcessorTypeNPU = "npu.huawei.com/NPU"
  15. ProcessorTypeGPU = "nvidia.com/gpu"
  16. GpuWorkDir = "/tmp/"
  17. NpuWorkDir = "/cache/"
  18. CodeArchiveName = "master.zip"
  19. )
  20. var (
  21. poolInfos *models.PoolInfos
  22. FlavorInfos *setting.StFlavorInfos
  23. ImageInfos *setting.StImageInfosModelArts
  24. SpecialPools *models.SpecialPools
  25. CommandPrepareScript = ";mkdir -p output;mkdir -p code;mkdir -p dataset;mkdir -p pretrainmodel;echo \"start loading script\";wget -q https://git.openi.org.cn/OpenIOSSG/%s/archive/master.zip;" +
  26. "echo \"finish loading script\";unzip -q master.zip;cd %s;chmod 777 downloader_for_obs uploader_for_npu downloader_for_minio uploader_for_gpu;"
  27. )
  28. type GenerateTrainJobReq struct {
  29. JobName string
  30. Command string
  31. ImageUrl string //与image_id二选一,都有的情况下优先image_url
  32. ImageId string
  33. DisplayJobName string
  34. Uuid string
  35. Description string
  36. CodeObsPath string
  37. BootFile string
  38. BootFileUrl string
  39. DataUrl string
  40. TrainUrl string
  41. WorkServerNumber int
  42. EngineID int64
  43. CommitID string
  44. IsLatestVersion string
  45. BranchName string
  46. PreVersionId int64
  47. PreVersionName string
  48. VersionCount int
  49. EngineName string
  50. TotalVersionCount int
  51. ComputeResource string
  52. ProcessType string
  53. DatasetNames string
  54. DatasetInfos map[string]models.DatasetInfo
  55. Params string
  56. ModelName string
  57. LabelName string
  58. CkptName string
  59. ModelVersion string
  60. PreTrainModelPath string
  61. PreTrainModelUrl string
  62. Spec *models.Specification
  63. }
  64. func getEndPoint() string {
  65. index := strings.Index(setting.Endpoint, "//")
  66. endpoint := setting.Endpoint[index+2:]
  67. return endpoint
  68. }
  69. func getDatasetGrampus(datasetInfos map[string]models.DatasetInfo) []models.GrampusDataset {
  70. var datasetGrampus []models.GrampusDataset
  71. endPoint := getEndPoint()
  72. for _, datasetInfo := range datasetInfos {
  73. datasetGrampus = append(datasetGrampus, models.GrampusDataset{
  74. Name: datasetInfo.FullName,
  75. Bucket: setting.Bucket,
  76. EndPoint: endPoint,
  77. ObjectKey: datasetInfo.DataLocalPath + datasetInfo.FullName,
  78. })
  79. }
  80. return datasetGrampus
  81. }
  82. func GenerateTrainJob(ctx *context.Context, req *GenerateTrainJobReq) (err error) {
  83. createTime := timeutil.TimeStampNow()
  84. centerID, centerName := getCentersParamter(ctx, req)
  85. var datasetGrampus, modelGrampus []models.GrampusDataset
  86. if ProcessorTypeNPU == req.ProcessType {
  87. datasetGrampus = getDatasetGrampus(req.DatasetInfos)
  88. if len(req.ModelName) != 0 {
  89. modelGrampus = []models.GrampusDataset{
  90. {
  91. Name: req.ModelName,
  92. Bucket: setting.Bucket,
  93. EndPoint: getEndPoint(),
  94. ObjectKey: req.PreTrainModelPath,
  95. },
  96. }
  97. }
  98. }
  99. jobResult, err := createJob(models.CreateGrampusJobRequest{
  100. Name: req.JobName,
  101. Tasks: []models.GrampusTasks{
  102. {
  103. Name: req.JobName,
  104. Command: req.Command,
  105. ResourceSpecId: req.Spec.SourceSpecId,
  106. ImageId: req.ImageId,
  107. ImageUrl: req.ImageUrl,
  108. CenterID: centerID,
  109. CenterName: centerName,
  110. ReplicaNum: 1,
  111. Datasets: datasetGrampus,
  112. Models: modelGrampus,
  113. },
  114. },
  115. })
  116. if err != nil {
  117. log.Error("createJob failed: %v", err.Error())
  118. return err
  119. }
  120. jobID := jobResult.JobInfo.JobID
  121. err = models.CreateCloudbrain(&models.Cloudbrain{
  122. Status: TransTrainJobStatus(jobResult.JobInfo.Status),
  123. UserID: ctx.User.ID,
  124. RepoID: ctx.Repo.Repository.ID,
  125. JobID: jobID,
  126. JobName: req.JobName,
  127. DisplayJobName: req.DisplayJobName,
  128. JobType: string(models.JobTypeTrain),
  129. Type: models.TypeC2Net,
  130. Uuid: req.Uuid,
  131. DatasetName: req.DatasetNames,
  132. CommitID: req.CommitID,
  133. IsLatestVersion: req.IsLatestVersion,
  134. ComputeResource: req.ComputeResource,
  135. ImageID: req.ImageId,
  136. TrainUrl: req.TrainUrl,
  137. BranchName: req.BranchName,
  138. Parameters: req.Params,
  139. BootFile: req.BootFile,
  140. DataUrl: req.DataUrl,
  141. Description: req.Description,
  142. WorkServerNumber: req.WorkServerNumber,
  143. EngineName: req.EngineName,
  144. VersionCount: req.VersionCount,
  145. TotalVersionCount: req.TotalVersionCount,
  146. CreatedUnix: createTime,
  147. UpdatedUnix: createTime,
  148. Spec: req.Spec,
  149. ModelName: req.ModelName,
  150. ModelVersion: req.ModelVersion,
  151. LabelName: req.LabelName,
  152. PreTrainModelUrl: req.PreTrainModelUrl,
  153. CkptName: req.CkptName,
  154. })
  155. if err != nil {
  156. log.Error("CreateCloudbrain(%s) failed:%v", req.DisplayJobName, err.Error())
  157. return err
  158. }
  159. var actionType models.ActionType
  160. if req.ComputeResource == models.NPUResource {
  161. actionType = models.ActionCreateGrampusNPUTrainTask
  162. } else if req.ComputeResource == models.GPUResource {
  163. actionType = models.ActionCreateGrampusGPUTrainTask
  164. }
  165. notification.NotifyOtherTask(ctx.User, ctx.Repo.Repository, jobID, req.DisplayJobName, actionType)
  166. return nil
  167. }
  168. func getCentersParamter(ctx *context.Context, req *GenerateTrainJobReq) ([]string, []string) {
  169. var centerID []string
  170. var centerName []string
  171. includeCenters := make(map[string]string)
  172. excludeCenters := make(map[string]string)
  173. if SpecialPools != nil {
  174. for _, pool := range SpecialPools.Pools {
  175. if !pool.IsExclusive && strings.Contains(req.ComputeResource, pool.Type) {
  176. org, _ := models.GetOrgByName(pool.Org)
  177. if org != nil {
  178. isOrgMember, _ := models.IsOrganizationMember(org.ID, ctx.User.ID)
  179. if isOrgMember {
  180. for _, info := range pool.Pool {
  181. includeCenters[info.Queue] = info.Value
  182. }
  183. } else {
  184. for _, info := range pool.Pool {
  185. excludeCenters[info.Queue] = info.Value
  186. }
  187. }
  188. }
  189. }
  190. }
  191. }
  192. if len(includeCenters) > 0 {
  193. //如果有专属资源池,根据专属资源池指定智算中心
  194. for k, v := range includeCenters {
  195. centerID = append(centerID, k)
  196. centerName = append(centerName, v)
  197. }
  198. } else if len(excludeCenters) > 0 {
  199. //否则,有要排除的中心,先获取所有中心,删除其中的排除中心,得到指定的智算中心
  200. allCenters := make(map[string]string)
  201. specs, err := GetResourceSpecs(req.ProcessType)
  202. if err == nil {
  203. for _, info := range specs.Infos {
  204. for _, center := range info.Centers {
  205. allCenters[center.ID] = center.Name
  206. }
  207. }
  208. }
  209. for k, _ := range excludeCenters {
  210. delete(allCenters, k)
  211. }
  212. for k, v := range allCenters {
  213. centerID = append(centerID, k)
  214. centerName = append(centerName, v)
  215. }
  216. }
  217. return centerID, centerName
  218. }
  219. func TransTrainJobStatus(status string) string {
  220. if status == models.GrampusStatusPending {
  221. status = models.GrampusStatusWaiting
  222. }
  223. return strings.ToUpper(status)
  224. }
  225. func InitSpecialPool() {
  226. if SpecialPools == nil && setting.Grampus.SpecialPools != "" {
  227. json.Unmarshal([]byte(setting.Grampus.SpecialPools), &SpecialPools)
  228. }
  229. }