model, nsp_loss, mlm_loss, vocabulary = get_model(ctx)
lower = "uncased" in args.dataset_name
tokenizer = BERTTokenizer(vocabulary, lower=lower)
store = mx.kv.create(args.kvstore)
if args.ckpt_dir:
After Change
if not args.eval_only:
if args.data:
logging.info("Using training data at {}".format(args.data))
data_train = get_dataset(args.data, args.batch_size, len(ctx), True, store)
train(data_train, model, nsp_loss, mlm_loss, len(vocabulary), ctx, store)