Skip to content

feat(dpmodel): add backend-independent trainer abstraction#5603

Queued
njzjz wants to merge 6 commits into
deepmodeling:masterfrom
njzjz:feat/dpmodel-abstract-trainer-5229
Queued

feat(dpmodel): add backend-independent trainer abstraction#5603
njzjz wants to merge 6 commits into
deepmodeling:masterfrom
njzjz:feat/dpmodel-abstract-trainer-5229

Conversation

@njzjz

@njzjz njzjz commented Jun 28, 2026

Copy link
Copy Markdown
Member

Summary

  • Add backend-independent training abstractions under deepmd.dpmodel.train for task/rank normalization, display scheduling, learning-curve output, checkpoint cadence, lifecycle hooks, and shared train entrypoint orchestration.
  • Factor common training-data helpers so single-task training is handled as a one-task collection and multi-task data construction/summary/probability handling is shared where possible.
  • Add a backend-independent finetune rule builder in deepmd.utils.finetune, and reduce the PT, PT-exportable, Paddle, and JAX backend finetune modules to backend-specific checkpoint loading plus shared rule generation.
  • Migrate JAX train entrypoint/trainer onto the shared pipeline and add JAX finetune plus multi-task support on top of the new abstractions.
  • Migrate pt_expt train entrypoint/trainer behavior further onto the shared pipeline, including single-task-as-multi-task normalization, data summaries, checkpoint retention, stat-file parent creation, relative latest checkpoint symlinks, and checkpoint parent creation.
  • Address PR review comments around task-key validation, learning-curve metric ordering, lifecycle cleanup, print_summary fallback behavior, broken __len__ handling, JAX finetune branch/alias validation, numeric-looking JAX task keys, HDF5 stat paths, and pt_expt checkpoint symlinks.
  • Move the new dpmodel trainer/entrypoint tests from source/tests/test_dpmodel_*.py into source/tests/common/dpmodel/.

Refs #5229, #5230, #5231

Tests

  • ruff format .
  • ruff check .
  • git diff --check
  • PYTHONPATH=/home/jzzeng/codes/deepmd-kit pytest source/tests/common/dpmodel/test_train_abstract_trainer.py source/tests/common/dpmodel/test_train_entrypoint.py source/tests/common/dpmodel/test_train_data.py source/tests/common/dpmodel/test_training_utils.py source/tests/common/test_finetune_utils.py source/tests/jax/test_training.py source/tests/pt_expt/test_entrypoint.py source/tests/pt_expt/test_multitask.py::TestMultiTaskSeA::test_multitask_finetune source/tests/pt_expt/test_multitask.py::TestMultiTaskSeA::test_multitask_finetune_from_single_task source/tests/pt_expt/test_multitask.py::TestMultiTaskSeA::test_multitask_finetune_no_change_model_params -q (53 passed, 2 subtests passed)
  • PYTHONPATH=/home/jzzeng/codes/deepmd-kit timeout 180 srun --gres=gpu:1 dp --jax train input.json --skip-neighbor-stat --finetune pretrain.jax --use-pretrain-script on a temporary 1-step water finetune smoke; completed on NVIDIA GeForce RTX 5090 and saved ft-model-1.jax.
  • PYTHONPATH=/home/jzzeng/codes/deepmd-kit timeout 180 srun --gres=gpu:1 dp --pt-expt train input.json --skip-neighbor-stat on a temporary 2-step water smoke; completed on NVIDIA GeForce RTX 5090, saved ckpts/pt-model-2.pt, created stats/stat.hdf5, and verified ckpts/pt-model.pt -> pt-model-2.pt with old step checkpoint pruned by max_ckpt_keep=1.

Notes

  • Paddle-specific runtime tests were not run locally because paddle is not installed in this environment.
  • Plain PyTorch backend test collection is blocked in this environment by external deepmd_gnn/CUDA initialization, not by the shared finetune rule builder changes.

Summary by CodeRabbit

  • New Features

    • Introduced a unified, backend-independent training framework with consistent single-task and multi-task handling, learning-curve output, and structured training/validation steps.
    • Added a common training entrypoint abstraction that standardizes config normalization, neighbor-stat updates, and lifecycle teardown.
    • Implemented full-validation with best-checkpoint tracking, top-K selection, and val.log reporting (including backend-specific checkpoint suffixes).
  • Bug Fixes

    • Improved checkpoint save/restore and retention (including “latest” link updates and older checkpoint cleanup).
    • Improved task-weighting logic to better handle datasets with/without sizing information.
    • Fixed multi-task neighbor-stat updates and JAX full-validation error propagation across ranks.
  • Tests

    • Expanded unit and smoke tests for training orchestration, finetuning, validation, and checkpoint reconciliation.
  • Documentation

    • Updated validation-configuration help text to reflect broader backend support.

@coderabbitai

coderabbitai Bot commented Jun 28, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 70d84d20-3fe9-4f28-b26b-bda2fe28ce96

📥 Commits

Reviewing files that changed from the base of the PR and between 7632fef and 418b4c8.

📒 Files selected for processing (8)
  • deepmd/dpmodel/train/trainer.py
  • deepmd/dpmodel/train/validation.py
  • deepmd/jax/train/trainer.py
  • deepmd/jax/train/validation.py
  • deepmd/pt/train/validation.py
  • source/tests/common/dpmodel/test_train_abstract_trainer.py
  • source/tests/jax/test_training.py
  • source/tests/pt/test_validation.py
🚧 Files skipped from review as they are similar to previous changes (7)
  • source/tests/common/dpmodel/test_train_abstract_trainer.py
  • deepmd/jax/train/validation.py
  • deepmd/pt/train/validation.py
  • source/tests/jax/test_training.py
  • deepmd/dpmodel/train/validation.py
  • deepmd/dpmodel/train/trainer.py
  • deepmd/jax/train/trainer.py

📝 Walkthrough

Walkthrough

Adds shared training abstractions for entrypoints, trainers, task normalization, and validation. Refactors JAX and pt_expt training flows to use the shared pipeline. Consolidates finetune rule handling into shared utilities. Extends JAX checkpoint serialization for multitask models.

Changes

Shared Training Abstractions and Backend Integrations

Layer / File(s) Summary
Core training contracts
deepmd/dpmodel/train/__init__.py, deepmd/dpmodel/train/data.py, deepmd/dpmodel/train/entrypoint.py, deepmd/dpmodel/train/trainer.py, deepmd/dpmodel/utils/training_utils.py, source/tests/common/dpmodel/test_train_abstract_trainer.py, source/tests/common/dpmodel/test_train_data.py, source/tests/common/dpmodel/test_train_entrypoint.py, source/tests/common/dpmodel/test_training_utils.py
Adds shared task config, entrypoint, trainer, and learning-curve abstractions, plus task-size fallback logic and coverage for the common training pipeline.
Validation and checkpoint management
deepmd/dpmodel/train/validation.py, deepmd/pt/train/validation.py, deepmd/utils/argcheck.py, source/tests/pt/test_validation.py, source/tests/pt_expt/test_training.py
Adds shared full-validation machinery, updates PyTorch validation checkpoint naming and cleanup, and expands validating-argument docs.
Shared finetune rule builder
deepmd/utils/finetune.py, deepmd/pt/utils/finetune.py, deepmd/pd/utils/finetune.py, deepmd/pt_expt/utils/finetune.py, deepmd/jax/utils/finetune.py, source/tests/common/test_finetune_utils.py
Adds FinetuneRuleBuilder and shared finetune rule helpers, then reduces backend finetune modules to wrappers with tests for branch and alias behavior.
JAX training refactor
deepmd/jax/entrypoints/train.py, deepmd/jax/train/trainer.py, deepmd/jax/train/validation.py, deepmd/jax/utils/serialization.py, deepmd/jax/utils/finetune.py, source/tests/jax/test_training.py
Refactors JAX entrypoint and trainer flow onto the shared abstractions, adds JAX full validation, and extends serialization for composite multitask checkpoints.
pt_expt training refactor
deepmd/pt_expt/entrypoints/main.py, deepmd/pt_expt/train/training.py, deepmd/pt_expt/utils/finetune.py, source/tests/pt_expt/test_entrypoint.py
Refactors pt_expt entrypoint and trainer onto the shared abstractions, including distributed setup/teardown, task mapping, and checkpoint handling.

Estimated code review effort: 5 (Critical) | ~120 minutes

Possibly related PRs

Suggested reviewers: wanghan-iapcm, iProzd

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 52.09% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly summarizes the main change: adding a backend-independent trainer abstraction for dpmodel.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (1)
source/tests/test_dpmodel_abstract_trainer.py (1)

99-109: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick win

Avoid hard-coding the default task key in this test.

TrainingTaskCollection.single() owns the default key via DEFAULT_TASK_KEY, so indexing with "Default" makes this test fail on an internal rename without any behavior change. Pull the lone task from the collection API instead.

Proposed fix
-    task = tasks["Default"]
+    task = tasks.select()
     task.add_data_requirements()
 
     assert not tasks.is_multitask
     assert tasks.select() is task
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@source/tests/test_dpmodel_abstract_trainer.py` around lines 99 - 109, The
test is hard-coding the default task name instead of using the collection API.
In test_dpmodel_abstract_trainer.py, update the TrainingTaskCollection.single()
usage so the lone task is retrieved without indexing by "Default", and reference
the collection’s default-task behavior through its API/select() rather than the
literal key. Keep the assertions on task.add_data_requirements(),
tasks.is_multitask, and tasks.select() intact, but bind task from the collection
in a way that won’t break if DEFAULT_TASK_KEY is renamed.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/dpmodel/train/trainer.py`:
- Around line 380-383: Keep the single-task table schema anchored to
train_results in the Trainer formatting path. The header logic and the row
emission in the same train/eval summary should use the exact same key sequence
from train_results, not valid_results, so the validation columns stay aligned
and missing or reordered validation metrics do not break train_results lookups.
Update the loop in the Trainer row-building code to derive the metric order from
train_results and only read matching validation values by key when present.
- Around line 148-161: The task normalization in the trainer constructor
silently overwrites duplicate entries when `tasks` is a sequence, so add an
explicit duplicate-key check before building `self._tasks` in
`Trainer.__init__`. Update the branch that handles non-Mapping `tasks` to
validate that each `task.key` appears only once, and raise a clear `ValueError`
if duplicates are found. Keep the existing key-matching validation and
`_normalize_probabilities` flow intact.

In `@deepmd/jax/train/trainer.py`:
- Around line 188-194: The training setup in
AbstractTrainer.on_train_begin()/train() is missing registration of loss label
requirements before creating the task collection. Update the training flow
around TrainingTaskCollection.single and self.run(tasks) so
DeepmdDataSystem.add_data_requirements() is called with
self.loss.label_requirement (or equivalent task data_requirements) before
batching starts, ensuring get_batch() includes the labels needed by the loss
path.

In `@source/tests/test_dpmodel_abstract_trainer.py`:
- Around line 53-64: The evaluate_training helper is consuming a batch for
inactive multitask entries by falling back to task.training_data.get_batch()
when step_result is None, which advances the cursor during display collection.
Update evaluate_training in test_dpmodel_abstract_trainer.py so it only reads
step_result.payload for the matching task and otherwise returns a non-consuming
placeholder/skip path for inactive tasks; keep the batch-order contract aligned
with the multitask loop in trainer.py and the TrainStepResult/task.key check.

---

Nitpick comments:
In `@source/tests/test_dpmodel_abstract_trainer.py`:
- Around line 99-109: The test is hard-coding the default task name instead of
using the collection API. In test_dpmodel_abstract_trainer.py, update the
TrainingTaskCollection.single() usage so the lone task is retrieved without
indexing by "Default", and reference the collection’s default-task behavior
through its API/select() rather than the literal key. Keep the assertions on
task.add_data_requirements(), tasks.is_multitask, and tasks.select() intact, but
bind task from the collection in a way that won’t break if DEFAULT_TASK_KEY is
renamed.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: ae088f30-bef5-4da4-ae02-40f16ee87975

📥 Commits

Reviewing files that changed from the base of the PR and between a9bcbc5 and d100673.

📒 Files selected for processing (4)
  • deepmd/dpmodel/train/__init__.py
  • deepmd/dpmodel/train/trainer.py
  • deepmd/jax/train/trainer.py
  • source/tests/test_dpmodel_abstract_trainer.py

Comment thread deepmd/dpmodel/train/trainer.py
Comment thread deepmd/dpmodel/train/trainer.py
Comment thread deepmd/jax/train/trainer.py Outdated
Comment thread source/tests/test_dpmodel_abstract_trainer.py Outdated
@njzjz njzjz force-pushed the feat/dpmodel-abstract-trainer-5229 branch from d100673 to 6e168b7 Compare June 28, 2026 18:38

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
deepmd/dpmodel/train/trainer.py (2)

422-428: 🩺 Stability & Availability | 🟠 Major | ⚡ Quick win

Single-task row schema should anchor to train_results, not valid_results.

format_header iterates train_results (Line 382), but the row loop here iterates valid_results and indexes train_results[key]. If a backend returns validation metrics in a different order or omits a metric, the row desynchronizes from the header and train_results[key] can raise KeyError.

Suggested fix
             if valid_results is not None:
                 assert not self._is_multitask(valid_results)
-                for key in valid_results:
+                for key in train_results:
                     row += (
-                        f"   {float(valid_results[key]):11.2e}"
+                        f"   {float(valid_results.get(key, float('nan'))):11.2e}"
                         f" {float(train_results[key]):11.2e}"
                     )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/dpmodel/train/trainer.py` around lines 422 - 428, The single-task row
assembly in `Trainer.format_*` is anchored to the wrong metric source: the
header follows `train_results`, but this row loop currently iterates
`valid_results` while reading `train_results[key]`, which can desynchronize
columns or fail when validation metrics differ. Update the row-building logic to
use the same key order as `train_results` (matching the existing `format_header`
contract) and only use `valid_results` for the optional validation value lookup,
keeping the schema consistent with the training metrics.

148-161: 🎯 Functional Correctness | 🟠 Major | ⚡ Quick win

Duplicate task keys are still silently dropped when tasks is a sequence.

task_dict = {task.key: task for task in tasks} overwrites earlier entries, so a misconfigured multitask run can lose a task with no error. Validate uniqueness before building _tasks.

Suggested fix
         if isinstance(tasks, Mapping):
             task_dict = dict(tasks)
         else:
-            task_dict = {task.key: task for task in tasks}
+            task_list = list(tasks)
+            task_dict = {task.key: task for task in task_list}
+            if len(task_dict) != len(task_list):
+                raise ValueError("Training task keys must be unique.")
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/dpmodel/train/trainer.py` around lines 148 - 161, The task
normalization in Trainer.__init__ still silently overwrites duplicate keys when
tasks is a sequence because task_dict is built directly from task.key; add an
explicit duplicate-key check before assigning to self._tasks. In the branch that
handles non-Mapping tasks, validate that each task.key is unique (raise a
ValueError on duplicates) before constructing the dict, while keeping the
existing key-vs-task.key consistency check and downstream
_normalize_probabilities flow unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/pt_expt/train/training.py`:
- Line 831: Wire the parsed retention setting into checkpoint cleanup so
`max_ckpt_keep` actually limits saved checkpoints: `TrainerConfig` already
carries the value and `AbstractTrainer.run()`/`save_checkpoint()` currently only
append new `self.save_ckpt-<step>.pt` files. Update the checkpoint-saving flow
to track existing checkpoints and prune the oldest ones after each save, using
`self.max_ckpt_keep` as the cap and keeping the newest checkpoints only.

---

Duplicate comments:
In `@deepmd/dpmodel/train/trainer.py`:
- Around line 422-428: The single-task row assembly in `Trainer.format_*` is
anchored to the wrong metric source: the header follows `train_results`, but
this row loop currently iterates `valid_results` while reading
`train_results[key]`, which can desynchronize columns or fail when validation
metrics differ. Update the row-building logic to use the same key order as
`train_results` (matching the existing `format_header` contract) and only use
`valid_results` for the optional validation value lookup, keeping the schema
consistent with the training metrics.
- Around line 148-161: The task normalization in Trainer.__init__ still silently
overwrites duplicate keys when tasks is a sequence because task_dict is built
directly from task.key; add an explicit duplicate-key check before assigning to
self._tasks. In the branch that handles non-Mapping tasks, validate that each
task.key is unique (raise a ValueError on duplicates) before constructing the
dict, while keeping the existing key-vs-task.key consistency check and
downstream _normalize_probabilities flow unchanged.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 2ee2703c-74d9-48e0-8c5e-ce031e73c19a

📥 Commits

Reviewing files that changed from the base of the PR and between d100673 and 6e168b7.

📒 Files selected for processing (5)
  • deepmd/dpmodel/train/__init__.py
  • deepmd/dpmodel/train/trainer.py
  • deepmd/jax/train/trainer.py
  • deepmd/pt_expt/train/training.py
  • source/tests/test_dpmodel_abstract_trainer.py
✅ Files skipped from review due to trivial changes (1)
  • deepmd/dpmodel/train/init.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • source/tests/test_dpmodel_abstract_trainer.py
  • deepmd/jax/train/trainer.py

Comment thread deepmd/pt_expt/train/training.py Outdated
@codecov

codecov Bot commented Jun 28, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 75.68047% with 411 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.77%. Comparing base (73de44b) to head (418b4c8).
⚠️ Report is 11 commits behind head on master.

Files with missing lines Patch % Lines
deepmd/jax/train/trainer.py 40.67% 175 Missing ⚠️
deepmd/dpmodel/train/validation.py 78.64% 63 Missing ⚠️
deepmd/jax/train/validation.py 36.23% 44 Missing ⚠️
deepmd/jax/entrypoints/train.py 50.79% 31 Missing ⚠️
deepmd/dpmodel/train/trainer.py 92.05% 24 Missing ⚠️
deepmd/pt_expt/entrypoints/main.py 84.95% 17 Missing ⚠️
deepmd/pt_expt/train/training.py 93.06% 14 Missing ⚠️
deepmd/pt/train/validation.py 69.23% 12 Missing ⚠️
deepmd/utils/finetune.py 90.38% 10 Missing ⚠️
deepmd/dpmodel/train/entrypoint.py 89.70% 7 Missing ⚠️
... and 3 more
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5603      +/-   ##
==========================================
- Coverage   81.98%   81.77%   -0.22%     
==========================================
  Files         959      973      +14     
  Lines      105430   107386    +1956     
  Branches     4071     4144      +73     
==========================================
+ Hits        86442    87811    +1369     
- Misses      17518    18057     +539     
- Partials     1470     1518      +48     

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@njzjz njzjz force-pushed the feat/dpmodel-abstract-trainer-5229 branch from 6e168b7 to 78272a7 Compare June 29, 2026 06:35

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/dpmodel/train/entrypoint.py`:
- Around line 79-83: The cleanup guard in the entrypoint flow is starting too
late, so `teardown_run()` will not run if `setup_run()` fails after partially
initializing backend state. Move the `try/finally` in the main training
entrypoint so it wraps `setup_run()` itself, ensuring `teardown_run()` is always
invoked even when setup raises. Use the existing `setup_run()` and
`teardown_run()` flow in this module as the anchor when making the change.
- Line 48: The AbstractTrainEntrypoint contract is too implicit: it currently
has no abstract methods and the hook methods are docstring-only, which triggers
lint issues. Update AbstractTrainEntrypoint so run_training() is marked
abstract, and give the optional hook methods explicit no-op bodies with return
None unless they must be overridden by every backend. Use the class and method
names AbstractTrainEntrypoint and run_training() to locate the contract and hook
definitions.

In `@deepmd/pt_expt/entrypoints/main.py`:
- Around line 295-297: Replace the existing assert-based validation in the
multi-task config handling with an explicit runtime check that raises ValueError
when the config is "RANDOM"; this ensures the validation in the main entrypoint
is enforced even under optimized Python runs. Update the check in the entrypoint
logic that handles the multi-task configuration so the rejection of "RANDOM" is
always consistent.
- Around line 257-259: `setup_run` and `teardown_run` need ownership tracking
for the distributed process group: make `setup_run` skip `init_process_group()`
when a default group already exists, and record whether this entrypoint created
the group. Then update `teardown_run` to call `destroy_process_group()` only
when it is cleaning up a group created by `setup_run`, so it does not interfere
with a caller-managed distributed context.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: ec38d095-ecf4-4024-ac9f-7523f8d40e60

📥 Commits

Reviewing files that changed from the base of the PR and between 6e168b7 and 78272a7.

📒 Files selected for processing (11)
  • deepmd/dpmodel/train/__init__.py
  • deepmd/dpmodel/train/entrypoint.py
  • deepmd/dpmodel/train/trainer.py
  • deepmd/jax/entrypoints/train.py
  • deepmd/jax/train/trainer.py
  • deepmd/pt_expt/entrypoints/main.py
  • deepmd/pt_expt/train/training.py
  • source/tests/jax/test_training.py
  • source/tests/pt_expt/test_entrypoint.py
  • source/tests/test_dpmodel_abstract_trainer.py
  • source/tests/test_dpmodel_train_entrypoint.py
🚧 Files skipped from review as they are similar to previous changes (4)
  • deepmd/dpmodel/train/init.py
  • source/tests/test_dpmodel_abstract_trainer.py
  • deepmd/jax/train/trainer.py
  • deepmd/pt_expt/train/training.py

Comment thread deepmd/dpmodel/train/entrypoint.py
Comment thread deepmd/dpmodel/train/entrypoint.py Outdated
Comment thread deepmd/pt_expt/entrypoints/main.py
Comment thread deepmd/pt_expt/entrypoints/main.py Outdated
@njzjz njzjz force-pushed the feat/dpmodel-abstract-trainer-5229 branch from 78272a7 to 5bebfdf Compare June 29, 2026 08:27

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
deepmd/jax/utils/serialization.py (1)

189-207: 🗄️ Data Integrity & Integration | 🟠 Major | ⚡ Quick win

Preserve string task keys when converting restored state keys.

convert_str_to_int_key(state) also rewrites digit-only task names under state["models"]. A valid multi-task key like "0" becomes 0, then state_by_model[model_key] fails because model_def_script["model_dict"] still uses "0".

Proposed fix
-        convert_str_to_int_key(state)
-
         model_def_script = data.model_def_script
         if "model_dict" in model_def_script:
             state_by_model = state.get("models", state)
+            if "models" in state:
+                for model_state in state_by_model.values():
+                    convert_str_to_int_key(model_state)
+            else:
+                convert_str_to_int_key(state_by_model)
             model_dict = {"model_dict": {}}
             for model_key, model_params in model_def_script["model_dict"].items():
                 abstract_model = get_model(model_params)
                 graphdef, abstract_state = nnx.split(abstract_model)
                 abstract_state.replace_by_pure_dict(state_by_model[model_key])
                 model = nnx.merge(graphdef, abstract_state)
                 model_dict["model_dict"][model_key] = model.serialize()
         else:
+            convert_str_to_int_key(state)
             abstract_model = get_model(model_def_script)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/jax/utils/serialization.py` around lines 189 - 207, The state
restoration logic in convert_str_to_int_key is too aggressive and converts
digit-only task names inside state["models"], which breaks later lookup in the
model merge path. Update the key conversion so it only normalizes keys that
represent numeric indices where needed, while preserving string task keys used
by model_def_script["model_dict"]; then ensure the state_by_model lookup in the
model merging section continues to use the original task-name strings.
deepmd/pt_expt/train/training.py (1)

1509-1515: 🗄️ Data Integrity & Integration | 🟠 Major | ⚡ Quick win

Make the latest symlink target relative to its own directory.

When save_ckpt includes a directory, e.g. out/model.ckpt, latest.symlink_to("out/model.ckpt-1.pt") creates out/model.ckpt.pt -> out/model.ckpt-1.pt, which resolves as out/out/model.ckpt-1.pt. Restarting from the prefix then follows a broken symlink.

Proposed fix
-        ckpt_path = f"{self.save_ckpt}-{step}.pt"
+        ckpt_path = Path(f"{self.save_ckpt}-{step}.pt")
         torch.save(state, ckpt_path)
         # symlink latest
         latest = Path(f"{self.save_ckpt}.pt")
         if latest.is_symlink() or latest.exists():
             latest.unlink()
-        latest.symlink_to(ckpt_path)
+        latest.symlink_to(ckpt_path.name)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt_expt/train/training.py` around lines 1509 - 1515, The `latest`
symlink in `training.py` is created with a target that can become incorrectly
resolved when `save_ckpt` includes a directory. Update the checkpoint-saving
logic in the block that builds `ckpt_path`, `latest =
Path(f"{self.save_ckpt}.pt")`, and calls `latest.symlink_to(...)` so the symlink
target is computed relative to `latest`’s parent directory instead of using the
full path string. Make sure the existing `torch.save` flow and cleanup of any
prior `latest` link remain unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/dpmodel/train/data.py`:
- Around line 120-124: The fallback in _print_summary is too broad because it
catches every TypeError from data.print_summary and retries without prob, which
can hide real failures. Update _print_summary to detect the supported call
signature first (or only retry on an explicit argument-count mismatch) so only
old print_summary implementations use the no-prob path while genuine TypeError
exceptions still propagate.

In `@deepmd/dpmodel/train/trainer.py`:
- Around line 575-579: The optional ABC hook methods on Training hooks are
docstring-only and trigger Ruff B027; update the on_train_begin and on_train_end
methods in TrainingTask/Trainer-related code to use explicit no-op bodies by
adding return None so they remain optional without failing lint.
- Around line 472-479: The cleanup guard starts too late in the training setup
path, so resources created by on_train_begin() may not be released if
_open_learning_curve() fails. Move the try/finally in Trainer.train() (or the
surrounding training entry point) so it begins before on_train_begin(tasks), and
keep on_train_end(tasks) in the finally block to ensure backend state is always
cleaned up even when setup throws.

In `@deepmd/dpmodel/utils/training_utils.py`:
- Around line 124-130: The `_training_data_size` helper is swallowing all
`TypeError` from `len(training_data)` and turning broken `__len__`
implementations into a fallback size of 1. Update `_training_data_size` so it
only returns 1 for objects that truly do not support sizing, and allow real
`__len__` failures to propagate instead of masking them; keep the existing
`get_nsystems` path and the `resolve_model_prob()` callers in mind when
adjusting the behavior.

In `@deepmd/jax/utils/finetune.py`:
- Around line 50-53: The single-task pretrained checkpoint path in finetune.py
silently ignores non-empty model_branch_from values other than RANDOM, which can
hide typos and fall back to Default unexpectedly. Update the branch-selection
logic in the finetuning flow around single_config_chosen/model_branch_from to
explicitly validate the branch name when from_multitask is false, and raise an
error for any unknown non-empty value instead of proceeding. Keep RANDOM as the
only special case that enables new_fitting, and ensure the check happens before
the code falls back to Default.
- Around line 152-157: The pre-check in finetune setup is rejecting valid
finetune_head aliases because it compares against raw pretrained_keys before
alias resolution. Update the validation around pretrained_key in the finetune
flow so it uses the same alias-aware mapping as _get_finetune_rule_single() via
get_model_dict(), and only raise the ValueError after checking the resolved
model dict entries rather than the unexpanded pretrained_keys list.

In `@deepmd/pt_expt/entrypoints/main.py`:
- Around line 167-177: The _ensure_stat_file_path helper currently creates HDF5
stat files with h5py.File before ensuring the parent directory exists, so paths
like stats/model_stat.h5 can fail during trainer setup. Update
_ensure_stat_file_path to create the parent directories first for file targets,
then open the HDF5 file, and keep the existing directory creation path unchanged
for non-HDF5 stat targets. Use the stat_file_path, Path, and h5py.File logic in
this function to apply the fix.

---

Outside diff comments:
In `@deepmd/jax/utils/serialization.py`:
- Around line 189-207: The state restoration logic in convert_str_to_int_key is
too aggressive and converts digit-only task names inside state["models"], which
breaks later lookup in the model merge path. Update the key conversion so it
only normalizes keys that represent numeric indices where needed, while
preserving string task keys used by model_def_script["model_dict"]; then ensure
the state_by_model lookup in the model merging section continues to use the
original task-name strings.

In `@deepmd/pt_expt/train/training.py`:
- Around line 1509-1515: The `latest` symlink in `training.py` is created with a
target that can become incorrectly resolved when `save_ckpt` includes a
directory. Update the checkpoint-saving logic in the block that builds
`ckpt_path`, `latest = Path(f"{self.save_ckpt}.pt")`, and calls
`latest.symlink_to(...)` so the symlink target is computed relative to
`latest`’s parent directory instead of using the full path string. Make sure the
existing `torch.save` flow and cleanup of any prior `latest` link remain
unchanged.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 78180874-4d87-4c66-bd4e-fd0b4ee961bb

📥 Commits

Reviewing files that changed from the base of the PR and between 78272a7 and 5bebfdf.

📒 Files selected for processing (15)
  • deepmd/dpmodel/train/__init__.py
  • deepmd/dpmodel/train/data.py
  • deepmd/dpmodel/train/entrypoint.py
  • deepmd/dpmodel/train/trainer.py
  • deepmd/dpmodel/utils/training_utils.py
  • deepmd/jax/entrypoints/train.py
  • deepmd/jax/train/trainer.py
  • deepmd/jax/utils/finetune.py
  • deepmd/jax/utils/serialization.py
  • deepmd/pt_expt/entrypoints/main.py
  • deepmd/pt_expt/train/training.py
  • source/tests/jax/test_training.py
  • source/tests/pt_expt/test_entrypoint.py
  • source/tests/test_dpmodel_abstract_trainer.py
  • source/tests/test_dpmodel_train_entrypoint.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • deepmd/dpmodel/train/init.py

Comment thread deepmd/dpmodel/train/data.py Outdated
Comment thread deepmd/dpmodel/train/trainer.py Outdated
Comment thread deepmd/dpmodel/train/trainer.py
Comment thread deepmd/dpmodel/utils/training_utils.py
Comment thread deepmd/jax/utils/finetune.py Outdated
Comment thread deepmd/jax/utils/finetune.py Outdated
Comment thread deepmd/pt_expt/entrypoints/main.py
@njzjz njzjz force-pushed the feat/dpmodel-abstract-trainer-5229 branch from 5bebfdf to 9485193 Compare June 29, 2026 09:27
@njzjz njzjz requested review from iProzd and wanghan-iapcm June 29, 2026 09:30

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/jax/utils/finetune.py`:
- Around line 21-25: The suffix check in _load_model_params is too restrictive
and rejects checkpoint directories or pointer paths that serialize_from_file can
already deserialize. Remove the hard-coded .jax validation in _load_model_params
and let serialize_from_file(finetune_model) handle the input format consistently
with the freeze() and init_model paths, while still extracting
"model_def_script" from the loaded checkpoint data.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: cd8a6c26-b867-4429-892e-f4aa7e5fbc5e

📥 Commits

Reviewing files that changed from the base of the PR and between 5bebfdf and 9485193.

📒 Files selected for processing (22)
  • deepmd/dpmodel/train/__init__.py
  • deepmd/dpmodel/train/data.py
  • deepmd/dpmodel/train/entrypoint.py
  • deepmd/dpmodel/train/trainer.py
  • deepmd/dpmodel/utils/training_utils.py
  • deepmd/jax/entrypoints/train.py
  • deepmd/jax/train/trainer.py
  • deepmd/jax/utils/finetune.py
  • deepmd/jax/utils/serialization.py
  • deepmd/pd/utils/finetune.py
  • deepmd/pt/utils/finetune.py
  • deepmd/pt_expt/entrypoints/main.py
  • deepmd/pt_expt/train/training.py
  • deepmd/pt_expt/utils/finetune.py
  • deepmd/utils/finetune.py
  • source/tests/common/dpmodel/test_train_abstract_trainer.py
  • source/tests/common/dpmodel/test_train_data.py
  • source/tests/common/dpmodel/test_train_entrypoint.py
  • source/tests/common/dpmodel/test_training_utils.py
  • source/tests/common/test_finetune_utils.py
  • source/tests/jax/test_training.py
  • source/tests/pt_expt/test_entrypoint.py
🚧 Files skipped from review as they are similar to previous changes (8)
  • deepmd/dpmodel/utils/training_utils.py
  • deepmd/dpmodel/train/init.py
  • deepmd/jax/utils/serialization.py
  • deepmd/dpmodel/train/data.py
  • deepmd/pt_expt/train/training.py
  • deepmd/jax/entrypoints/train.py
  • deepmd/jax/train/trainer.py
  • deepmd/pt_expt/entrypoints/main.py

Comment thread deepmd/jax/utils/finetune.py
@njzjz njzjz linked an issue Jun 29, 2026 that may be closed by this pull request
Comment thread deepmd/jax/train/trainer.py
Comment thread deepmd/dpmodel/utils/training_utils.py
Comment thread deepmd/dpmodel/train/trainer.py
@njzjz njzjz force-pushed the feat/dpmodel-abstract-trainer-5229 branch from 5767485 to 7280c14 Compare June 30, 2026 05:21
@njzjz njzjz requested a review from wanghan-iapcm June 30, 2026 11:27
Comment thread deepmd/jax/train/trainer.py Outdated
Comment thread deepmd/dpmodel/train/trainer.py
Copilot AI review requested due to automatic review settings July 1, 2026 05:52
@njzjz

njzjz commented Jul 1, 2026

Copy link
Copy Markdown
Member Author

Pushed 7632feff4 with the requested updates.

Changes:

  • Fixed JAX init_model so it preserves the input model script when use_pretrain_script=False; only restart adopts checkpoint metadata/current step.
  • Added trainer-level full validation support for pt_expt and JAX, with shared logging/top-k/best-checkpoint bookkeeping in deepmd.dpmodel.train.validation.
  • Kept full validation out of model capability APIs.
  • Fixed pt_expt full-validation force evaluation by enabling gradients on validation coordinate inputs.

Validation:

  • ruff format .
  • ruff check .
  • git diff --check
  • source/tests/common/dpmodel/test_train_abstract_trainer.py
  • source/tests/pt/test_validation.py
  • source/tests/pt_expt/test_training.py::TestTraining::test_full_validation_loop
  • JAX full-validation tests plus TestJAXTraining::test_train_entrypoint_runs_one_step_from_scratch
  • Live GPU CLI: srun --gres=gpu:1 dp --pt-expt train input.json --skip-neighbor-stat
  • Live GPU CLI: srun --gres=gpu:1 dp --jax train input.json --skip-neighbor-stat

Both live GPU CLI runs produced val.log and best checkpoint artifacts.

@njzjz njzjz requested a review from iProzd July 1, 2026 05:54
@njzjz

njzjz commented Jul 1, 2026

Copy link
Copy Markdown
Member Author

@coderabbitai review

@coderabbitai

coderabbitai Bot commented Jul 1, 2026

Copy link
Copy Markdown
Contributor
✅ Action performed

Review finished.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/dpmodel/train/trainer.py`:
- Around line 230-236: The mapping branch in trainer.py under the probability
handling logic currently validates missing keys but lets extra task probability
keys slip through. Update the probabilities check in the Mapping path to reject
unknown keys as well as missing ones, using self._keys as the only allowed set
before building the numpy array. Keep the fix localized to the probability
normalization/validation block so stale or mistyped entries raise a ValueError
instead of being silently ignored.

In `@deepmd/dpmodel/train/validation.py`:
- Around line 282-326: The validation flow in run() only raises errors on rank
0, so peer JAX processes can keep running after a rank-0 failure. Add a
backend-specific error propagation hook in the validator path around
_raise_if_error and/or the rank-0 try blocks, then implement the JAX version
used by the trainer to broadcast the failure or terminate all processes
consistently. Make sure the hook is invoked for failures in _evaluate,
save_checkpoint/_reconcile_best_checkpoints, and _log_result so non-zero ranks
do not continue silently.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 9857057b-3e07-4295-8a4a-6f044d0b4b1e

📥 Commits

Reviewing files that changed from the base of the PR and between 5bebfdf and 7632fef.

📒 Files selected for processing (28)
  • deepmd/dpmodel/train/__init__.py
  • deepmd/dpmodel/train/data.py
  • deepmd/dpmodel/train/entrypoint.py
  • deepmd/dpmodel/train/trainer.py
  • deepmd/dpmodel/train/validation.py
  • deepmd/dpmodel/utils/training_utils.py
  • deepmd/jax/entrypoints/train.py
  • deepmd/jax/train/trainer.py
  • deepmd/jax/train/validation.py
  • deepmd/jax/utils/finetune.py
  • deepmd/jax/utils/serialization.py
  • deepmd/pd/utils/finetune.py
  • deepmd/pt/train/validation.py
  • deepmd/pt/utils/finetune.py
  • deepmd/pt_expt/entrypoints/main.py
  • deepmd/pt_expt/train/training.py
  • deepmd/pt_expt/utils/finetune.py
  • deepmd/utils/argcheck.py
  • deepmd/utils/finetune.py
  • source/tests/common/dpmodel/test_train_abstract_trainer.py
  • source/tests/common/dpmodel/test_train_data.py
  • source/tests/common/dpmodel/test_train_entrypoint.py
  • source/tests/common/dpmodel/test_training_utils.py
  • source/tests/common/test_finetune_utils.py
  • source/tests/jax/test_training.py
  • source/tests/pt/test_validation.py
  • source/tests/pt_expt/test_entrypoint.py
  • source/tests/pt_expt/test_training.py
💤 Files with no reviewable changes (9)
  • source/tests/common/dpmodel/test_train_data.py
  • source/tests/common/dpmodel/test_training_utils.py
  • source/tests/pt/test_validation.py
  • source/tests/pt_expt/test_training.py
  • source/tests/common/dpmodel/test_train_entrypoint.py
  • source/tests/common/dpmodel/test_train_abstract_trainer.py
  • source/tests/pt_expt/test_entrypoint.py
  • source/tests/common/test_finetune_utils.py
  • source/tests/jax/test_training.py
✅ Files skipped from review due to trivial changes (1)
  • deepmd/utils/argcheck.py
🚧 Files skipped from review as they are similar to previous changes (12)
  • deepmd/dpmodel/utils/training_utils.py
  • deepmd/pt/utils/finetune.py
  • deepmd/dpmodel/train/init.py
  • deepmd/pd/utils/finetune.py
  • deepmd/pt_expt/utils/finetune.py
  • deepmd/dpmodel/train/data.py
  • deepmd/jax/utils/serialization.py
  • deepmd/jax/entrypoints/train.py
  • deepmd/utils/finetune.py
  • deepmd/jax/train/trainer.py
  • deepmd/pt_expt/train/training.py
  • deepmd/pt_expt/entrypoints/main.py

Comment thread deepmd/dpmodel/train/trainer.py
Comment thread deepmd/dpmodel/train/validation.py

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces a backend-independent training abstraction layer under deepmd.dpmodel.train, and migrates JAX and pt_expt training/finetune flows to reuse shared orchestration (tasks/ranks normalization, learning-curve output, checkpoint cadence, lifecycle hooks, and full-validation best-checkpoint management). It also consolidates fine-tuning rule generation into a single backend-agnostic builder in deepmd.utils.finetune, leaving backend modules to focus on checkpoint loading and state/tensor copying.

Changes:

  • Add backend-independent trainer/entrypoint/data/full-validation primitives in deepmd.dpmodel.train, with associated test coverage moved/added under source/tests/common/dpmodel/.
  • Centralize fine-tuning rule building in deepmd.utils.finetune and refactor PT/PT-exportable/Paddle/JAX finetune helpers to delegate to the shared builder.
  • Extend full-validation best-checkpoint management to support both file checkpoints (.pt) and directory checkpoints (.jax), and migrate JAX + pt_expt training entrypoints onto the shared pipeline.

Reviewed changes

Copilot reviewed 28 out of 28 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
source/tests/pt/test_validation.py Adds coverage for best-checkpoint reconciliation when checkpoints are directories (e.g., .jax).
source/tests/pt_expt/test_training.py Adds an integration-style test asserting pt_expt full validation writes/prunes best checkpoints and logs.
source/tests/pt_expt/test_entrypoint.py New tests for pt_expt entrypoint option normalization, process-group ownership, checkpoint links/retention, and stat-file creation.
source/tests/jax/test_training.py Adds tests for JAX optimizer LR scaling, finetune loading, state/key normalization, and entrypoint gating/multitask behavior.
source/tests/common/test_finetune_utils.py Adds extensive tests for the shared finetune rule builder (single/multitask, aliases, random fitting, immutability).
source/tests/common/dpmodel/test_training_utils.py New tests for model probability resolution and size fallbacks.
source/tests/common/dpmodel/test_train_entrypoint.py New tests validating the shared entrypoint pipeline sequencing and teardown behavior on failures.
source/tests/common/dpmodel/test_train_data.py New tests for shared data-summary printing compatibility/failure propagation.
source/tests/common/dpmodel/test_train_abstract_trainer.py New tests for the backend-independent trainer loop, multitask sampling, lcurve formatting, checkpoint cadence, and full-validation ordering.
deepmd/utils/finetune.py Introduces FinetuneRuleBuilder and shared finetune-rule construction APIs.
deepmd/utils/argcheck.py Updates full-validation docs to indicate support across PT, pt_expt, and JAX.
deepmd/pt/utils/finetune.py Refactors PyTorch finetune rules to delegate to shared finetune rule builder.
deepmd/pt/train/validation.py Extends best-checkpoint handling to support configurable suffixes and directory checkpoints; improves validation-data iteration.
deepmd/pt_expt/utils/finetune.py Refactors pt_expt finetune rules to delegate to shared finetune rule builder with backend-specific errors.
deepmd/pt_expt/train/training.py Migrates pt_expt training loop to AbstractTrainer, adds task normalization, full validation hook, checkpoint retention, and relative latest-link handling.
deepmd/pt_expt/entrypoints/main.py Migrates pt_expt train entrypoint to the shared entrypoint pipeline and shared data helpers; adds stat-file creation helper.
deepmd/pd/utils/finetune.py Refactors Paddle finetune rules to delegate to shared finetune rule builder.
deepmd/jax/utils/serialization.py Updates JAX serialization to support multitask state layout and avoid mis-normalizing numeric-looking task keys.
deepmd/jax/utils/finetune.py New JAX finetune helper delegating to shared finetune rule builder.
deepmd/jax/train/validation.py New JAX full-validation implementation built on backend-independent FullValidatorBase.
deepmd/jax/train/trainer.py Migrates JAX trainer to AbstractTrainer, adds multitask support, shared finetune integration, full validation hook, and checkpoint writing changes.
deepmd/jax/entrypoints/train.py Migrates JAX train entrypoint to the shared entrypoint pipeline and shared data helpers; adds multitask-aware neighbor-stat update.
deepmd/dpmodel/utils/training_utils.py Improves model-probability resolution by handling get_nsystems() and non-sized data sources.
deepmd/dpmodel/train/validation.py New backend-independent full-validation base with best-checkpoint management and val.log formatting.
deepmd/dpmodel/train/trainer.py New backend-independent trainer abstraction handling task selection, display scheduling, lcurve writing, and checkpoint cadence.
deepmd/dpmodel/train/entrypoint.py New backend-independent entrypoint orchestration pipeline used by backend-specific entrypoints.
deepmd/dpmodel/train/data.py New shared data/task config normalization utilities and summary-printing compatibility helper.
deepmd/dpmodel/train/init.py Exposes the new backend-independent training abstractions as a package API.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread deepmd/pt/train/validation.py
Comment thread deepmd/jax/train/trainer.py Outdated
@njzjz

njzjz commented Jul 1, 2026

Copy link
Copy Markdown
Member Author

@coderabbitai review

@coderabbitai

coderabbitai Bot commented Jul 1, 2026

Copy link
Copy Markdown
Contributor
✅ Action performed

Review finished.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@njzjz njzjz added this pull request to the merge queue Jul 1, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Abstract PyTorch Exportable Training Code into dpmodel

4 participants