model_train#
Warning
The code reference is a work in progress and may contain inconsistencies.
- socialnet.model_train(trajectory_files: Sequence[Path | str] | Path | str, output_parent: Path | str | None = None, session_name: str | None = None, model: Literal['fc', 'attention', 'interaction'] = 'attention', future_steps: int = 32, history_steps: int = 0, input_types: list[list[Literal['n', 'natt', 'gvars', 'h', 'hvars', 'hbott', 'pts']]] | None = None, variables: list[list[Literal['fV', 'fvx', 'fvy', 'fv', 'fA', 'fax', 'fay', 'fa', 'nbL', 'nbS', 'nbx', 'nby', 'nbE', 'nbV', 'nbvx', 'nbvy', 'nbv', 'nbA', 'nbax', 'nbay', 'nban', 'nbat', 'nba']]] | None = None, history_variables: list[list[Literal['fst', 'nbst', 'fdt', 'nbdt']]] | None = None, num_neighbours: int = 4, blind: list[Literal['nbv', 'nba', 'nbat', 'nban', 'fv', 'fa', 'fat', 'fan']] | None = None, sigma: float = 1.0, remove_outer: float = 0.8, shuffle: Literal['trajectories', 'social_context', ''] = '', loader_fractions: list[float | None] | None = None, target: Literal['turn', 'location', 'acceleration', 'acceleration_gaussian'] = 'turn', **kwargs)[source]#
Train SocialNet model from a set of trajectory files.
- Parameters:
trajectory_files (Sequence[Path | str] | Path | str) – Trajectory files to train the model on. If a single file is passed, it is converted to a list.
output_parent (Path | str | None, optional) – Parent folder where the output will be saved. If None, defaults to the current working directory.
session_name (str | None, optional) – Name of the session. If None, defaults to the current date and time.
model (Literal["fc", "attention", "interaction"], optional) – Model type: “fc” for fully connected, “attention” for attention-based, “interaction” for interaction-based models.
future_steps (int, optional) – Number of future frames to predict. This is the number of frames in the future for which the model will make predictions, by default 32.
history_steps (int, optional) – Number of past frames to use. This is the number of frames in the past that the model will consider for making predictions, by default 0.
input_types (list[list[INPUT_TYPE]], optional) – List of lists of input types for the model. Each sublist corresponds to a different input type, e.g., “natt”, “n”, “gvars”, “h”, “hvars”, “hbott”, “pts”.
variables (list[list[VARIABLES]], optional) –
List of lists of variable names for each input type “gvars”. The variables can include:
fV: focal velocity vector.fvx,fvy: focal velocity x or y components.fv: focal speed.fA: focal acceleration vector.fax,fay: focal acceleration x or y components.fa: norm of focal acceleration vector.nbLornbS: neighbour position vector.nbx,nby: neighbour position x or y components.nbE: neighbour orientation vector.nbV: neighbour velocity vector.nbvx,nbvy: neighbour velocity x or y components.nbv: neighbour speed.nbA: neighbour acceleration vector.nbax,nbay: neighbour acceleration x or y components.nban,nbat: neighbour acceleration (normal or tangential).nba: norm of neighbour acceleration vector.
history_variables (list[list[Literal["fst", "nbst", "fdt", "nbdt"]]], optional) – List of lists of variable names for each input type “hvars”. E.g., “st” (straightness), “dt” (distance travelled).
num_neighbours (int, optional) – Number of neighbours to consider in the model.
blind (list[Literal["nbv", "nba", "nbat", "nban", "fv", "fa", "fat", "fan"]], optional) – List of variables to blind the model to (e.g., “nbv”, “nba”, etc.).
sigma (float, optional) – Data smoothing factor.
remove_outer (float, optional) – Fraction of the radius to remove from the outer region.
shuffle (Literal["trajectories", "social_context", ""], optional) – Shuffle mode: “trajectories”, “social_context”, or “” (no shuffle).
loader_fractions (list[float | None], optional) – Fractions for train, validation, and test splits (e.g., [0.5, 0.01, 0.01]).
target (Literal["turn", "location", "acceleration", "acceleration_gaussian"], optional) – What to predict: “turn” (future turning side), “location” (future location), “acceleration”, or “acceleration_gaussian”.
**kwargs – Additional keyword arguments for advanced configuration.
- Returns:
Dictionary with training results and statistics.
- Return type:
dict