You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

API_knowledge_distill.md 2.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. 知识蒸馏
  2. =========
  3. "软标签蒸馏"算子: pred_distill
  4. ---------
  5. `knowledge_distill_util.pred_distill(args, student_logits, teacher_logits):`
  6. [源代码](../model_compress/distil/src/knowledge_distill_util.py#L381)
  7. `pred_distill`为teacher和student模型添加软标签损失,使得student模型可以学习教师模型的输出,达到student模型模仿teacher模型在预测层的表现的目的。
  8. 采用[soft_cross_entropy](../model_compress/distil/src/knowledge_distill_util.py#L336)来计算损失。
  9. **参数:**
  10. - **args**: 一些超参,如teacher_temperature和student_temperature,对student和teacher模型进行soft操作的温度值。
  11. - **student_logits**: student模型预测出的logits。
  12. - **teacher_logits**: teacher模型预测出的logits。
  13. **返回:** 由teacher模型和student模型组合得到的软标签损失。
  14. ---
  15. "层与层蒸馏"算子: layer_distill
  16. ---------
  17. `knowledge_distill_util.layer_distill(args, student_reps, teacher_reps):`
  18. [源代码](../model_compress/distil/src/knowledge_distill_util.py#L346)
  19. `layer_distill`为teacher和student模型添加层与层损失,使得student模型可以学习教师模型的隐藏层特征,达到用teacher模型的暗知识(Dark Knowledge)指导student模型学习的目的,将teacher模型中的知识更好的蒸馏到student模型中。通过[MSE](../model_compress/distil/src/knowledge_distill_util.py#L343)来计算student模型和teacher模型中间层的距离。
  20. **参数:**
  21. - **args**: 一些超参,暂未用到,仅留出接口。
  22. - **student_reps**: student模型的所有中间层表示。
  23. - **teacher_reps**: teacher模型的所有中间层表示。
  24. **返回:** 由teacher模型和student模型组合得到的层与层蒸馏损失。
  25. >注:该算子仅适用于BERT类的student和teacher模型。
  26. ---
  27. "注意力蒸馏"算子: att_distill
  28. ---------
  29. `knowledge_distill_util.att_distill(args, student_atts, teacher_atts):`
  30. [源代码](../model_compress/distil/src/knowledge_distill_util.py#L363)
  31. `att_distill`为teacher和student模型添加注意力损失,使得student模型可以学习教师模型的attention score矩阵,学习到其中包含语义知识,例如语法和相互关系等。通过[MSE](../model_compress/distil/src/knowledge_distill_util.py#L343)来计算损失。
  32. **参数:**
  33. - **args**: 一些超参,暂未用到,仅留出接口。
  34. - **student_reps**: student模型的所有的attention score矩阵。
  35. - **teacher_reps**: teacher模型的所有的attention score矩阵。
  36. **返回:** 由teacher模型和student模型组合得到的注意力蒸馏损失。
  37. >注:该算子仅适用于BERT类的student和teacher模型。

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能